ark/
encode.rs

1//!
2//! Definitions of protocol encodings.
3//!
4
5
6use std::{fmt, io};
7
8use bitcoin::hashes::{sha256, Hash};
9// We use bitcoin::io::{Read, Write} here but we shouldn't have to.
10// I created this issue in the hope that rust-bitcoin fixes this nuisance:
11//  https://github.com/rust-bitcoin/rust-bitcoin/issues/4530
12use bitcoin::secp256k1::{self, schnorr, PublicKey};
13
14
15/// Error occuring during protocol decoding.
16#[derive(Debug, thiserror::Error)]
17pub enum ProtocolDecodingError {
18	#[error("I/O error: {0}")]
19	Io(#[from] io::Error),
20	#[error("invalid protocol encoding: {message}")]
21	Invalid {
22		message: String,
23		source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
24	},
25}
26
27impl ProtocolDecodingError {
28	/// Create a new [ProtocolDecodingError::Invalid] with the given message.
29	pub fn invalid(message: impl fmt::Display) -> Self {
30		Self::Invalid {
31			message: message.to_string(),
32			source: None,
33		}
34	}
35
36	/// Create a new [ProtocolDecodingError::Invalid] with the given message and source error.
37	pub fn invalid_err<E>(source: E, message: impl fmt::Display) -> Self
38	where
39		E: std::error::Error + Send + Sync + 'static,
40	{
41		Self::Invalid {
42			message: message.to_string(),
43			source: Some(Box::new(source)),
44		}
45	}
46}
47
48impl From<bitcoin::consensus::encode::Error> for ProtocolDecodingError {
49	fn from(e: bitcoin::consensus::encode::Error) -> Self {
50		match e {
51			bitcoin::consensus::encode::Error::Io(e) => Self::Io(e.into()),
52			e => Self::invalid_err(e, "bitcoin protocol decoding error"),
53		}
54	}
55}
56
57impl From<bitcoin::io::Error> for ProtocolDecodingError {
58	fn from(e: bitcoin::io::Error) -> Self {
59	    Self::Io(e.into())
60	}
61}
62
63/// Trait for encoding objects according to the bark protocol encoding.
64pub trait ProtocolEncoding: Sized {
65	/// Encode the object into the writer.
66	//TODO(stevenroose) return nb of bytes written like bitcoin::consensus::Encodable does?
67	fn encode<W: io::Write + ?Sized>(&self, writer: &mut W) -> Result<(), io::Error>;
68
69	/// Decode the object from the writer.
70	fn decode<R: io::Read + ?Sized>(reader: &mut R) -> Result<Self, ProtocolDecodingError>;
71
72	/// Serialize the object into a byte vector.
73	fn serialize(&self) -> Vec<u8> {
74		let mut buf = Vec::new();
75		self.encode(&mut buf).expect("buffers don't produce I/O errors");
76		buf
77	}
78
79	/// Deserialize object from the given byte slice.
80	fn deserialize(mut byte_slice: &[u8]) -> Result<Self, ProtocolDecodingError> {
81		Self::decode(&mut byte_slice)
82	}
83
84	/// Serialize the object to a lowercase hex string.
85	fn serialize_hex(&self) -> String {
86		use hex_conservative::Case::Lower;
87		let mut buf = String::new();
88		let mut writer = hex_conservative::display::HexWriter::new(&mut buf, Lower);
89		self.encode(&mut writer).expect("no I/O errors for buffers");
90		buf
91	}
92
93	/// Deserialize object from hex slice.
94	fn deserialize_hex(hex_str: &str) -> Result<Self, ProtocolDecodingError> {
95		let mut iter = hex_conservative::HexToBytesIter::new(hex_str).map_err(|e| {
96			ProtocolDecodingError::Io(io::Error::new(io::ErrorKind::InvalidData, e))
97		})?;
98		Self::decode(&mut iter)
99	}
100}
101
102/// Utility trait to write some primitive values into our encoding format.
103pub trait WriteExt: io::Write {
104	/// Write an 8-bit unsigned integer in little-endian.
105	fn emit_u8(&mut self, v: u8) -> Result<(), io::Error> {
106		self.write_all(&v.to_le_bytes())
107	}
108
109	/// Write a 16-bit unsigned integer in little-endian.
110	fn emit_u16(&mut self, v: u16) -> Result<(), io::Error> {
111		self.write_all(&v.to_le_bytes())
112	}
113
114	/// Write a 32-bit unsigned integer in little-endian.
115	fn emit_u32(&mut self, v: u32) -> Result<(), io::Error> {
116		self.write_all(&v.to_le_bytes())
117	}
118
119	/// Write a 64-bit unsigned integer in little-endian.
120	fn emit_u64(&mut self, v: u64) -> Result<(), io::Error> {
121		self.write_all(&v.to_le_bytes())
122	}
123
124	/// Write the entire slice to the writer.
125	fn emit_slice(&mut self, slice: &[u8]) -> Result<(), io::Error> {
126		self.write_all(slice)
127	}
128
129	/// Write a value in compact size aka "VarInt" encoding.
130	fn emit_compact_size(&mut self, value: impl Into<u64>) -> Result<usize, io::Error> {
131		let value = value.into();
132		match value {
133			0..=0xFC => {
134				self.emit_u8(value as u8)?;
135				Ok(1)
136			},
137			0xFD..=0xFFFF => {
138				self.emit_u8(0xFD)?;
139				self.emit_u16(value as u16)?;
140				Ok(3)
141			},
142			0x10000..=0xFFFFFFFF => {
143				self.emit_u8(0xFE)?;
144				self.emit_u32(value as u32)?;
145				Ok(5)
146			},
147			_ => {
148				self.emit_u8(0xFF)?;
149				self.emit_u64(value)?;
150				Ok(9)
151			},
152		}
153	}
154}
155
156impl<W: io::Write + ?Sized> WriteExt for W {}
157
158/// Utility trait to read some primitive values into our encoding format.
159pub trait ReadExt: io::Read {
160	/// Read an 8-bit unsigned integer in little-endian.
161	fn read_u8(&mut self) -> Result<u8, io::Error> {
162		let mut buf = [0; 1];
163		self.read_exact(&mut buf[..])?;
164		Ok(u8::from_le_bytes(buf))
165	}
166
167	/// Read a 16-bit unsigned integer in little-endian.
168	fn read_u16(&mut self) -> Result<u16, io::Error> {
169		let mut buf = [0; 2];
170		self.read_exact(&mut buf[..])?;
171		Ok(u16::from_le_bytes(buf))
172	}
173
174	/// Read a 32-bit unsigned integer in little-endian.
175	fn read_u32(&mut self) -> Result<u32, io::Error> {
176		let mut buf = [0; 4];
177		self.read_exact(&mut buf[..])?;
178		Ok(u32::from_le_bytes(buf))
179	}
180
181	/// Read a 64-bit unsigned integer in little-endian.
182	fn read_u64(&mut self) -> Result<u64, io::Error> {
183		let mut buf = [0; 8];
184		self.read_exact(&mut buf[..])?;
185		Ok(u64::from_le_bytes(buf))
186	}
187
188	/// Read from the writer to fill the entire slice.
189	fn read_slice(&mut self, slice: &mut [u8]) -> Result<(), io::Error> {
190		self.read_exact(slice)
191	}
192
193	/// Read a value in compact size aka "VarInt" encoding.
194	fn read_compact_size(&mut self) -> Result<u64, io::Error> {
195		match self.read_u8()? {
196			0xFF => {
197				let x = self.read_u64()?;
198				if x < 0x1_0000_0000 { // I.e., would have fit in a `u32`.
199					Err(io::Error::new(io::ErrorKind::InvalidData, "non-minimal varint"))
200				} else {
201					Ok(x)
202				}
203			},
204			0xFE => {
205				let x = self.read_u32()?;
206				if x < 0x1_0000 { // I.e., would have fit in a `u16`.
207					Err(io::Error::new(io::ErrorKind::InvalidData, "non-minimal varint"))
208				} else {
209					Ok(x as u64)
210				}
211			},
212			0xFD => {
213				let x = self.read_u16()?;
214				if x < 0xFD { // Could have been encoded as a `u8`.
215					Err(io::Error::new(io::ErrorKind::InvalidData, "non-minimal varint"))
216				} else {
217					Ok(x as u64)
218				}
219			},
220			n => Ok(n as u64),
221		}
222	}
223}
224
225impl<R: io::Read + ?Sized> ReadExt for R {}
226
227
228impl ProtocolEncoding for PublicKey {
229	fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
230		w.emit_slice(&self.serialize())
231	}
232
233	fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
234		let mut buf = [0; secp256k1::constants::PUBLIC_KEY_SIZE];
235		r.read_slice(&mut buf[..])?;
236		PublicKey::from_slice(&buf).map_err(|e| {
237			ProtocolDecodingError::invalid_err(e, "invalid public key")
238		})
239	}
240}
241
242impl ProtocolEncoding for Option<PublicKey> {
243	fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
244		if let Some(pk) = self {
245			w.emit_slice(&pk.serialize())
246		} else {
247			w.emit_u8(0)
248		}
249	}
250
251	fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
252		let first = r.read_u8()?;
253		if first == 0 {
254			Ok(None)
255		} else {
256			let mut pk = [first; secp256k1::constants::PUBLIC_KEY_SIZE];
257			r.read_slice(&mut pk[1..])?;
258			Ok(Some(PublicKey::from_slice(&pk).map_err(|e| {
259				ProtocolDecodingError::invalid_err(e, "invalid public key")
260			})?))
261		}
262	}
263}
264
265impl ProtocolEncoding for schnorr::Signature {
266	fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
267		w.emit_slice(&self.serialize())
268	}
269
270	fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
271		let mut buf = [0; secp256k1::constants::SCHNORR_SIGNATURE_SIZE];
272		r.read_slice(&mut buf[..])?;
273		schnorr::Signature::from_slice(&buf).map_err(|e| {
274			ProtocolDecodingError::invalid_err(e, "invalid schnorr signature")
275		})
276	}
277}
278
279impl ProtocolEncoding for Option<schnorr::Signature> {
280	fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
281		if let Some(sig) = self {
282			w.emit_slice(&sig.serialize())
283		} else {
284			w.emit_slice(&[0; secp256k1::constants::SCHNORR_SIGNATURE_SIZE])
285		}
286	}
287
288	fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
289		let mut buf = [0; secp256k1::constants::SCHNORR_SIGNATURE_SIZE];
290		r.read_slice(&mut buf[..])?;
291		if buf == [0; secp256k1::constants::SCHNORR_SIGNATURE_SIZE] {
292			Ok(None)
293		} else {
294			Ok(Some(schnorr::Signature::from_slice(&buf).map_err(|e| {
295				ProtocolDecodingError::invalid_err(e, "invalid schnorr signature")
296			})?))
297		}
298	}
299}
300
301impl ProtocolEncoding for sha256::Hash {
302	fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
303		w.emit_slice(&self[..])
304	}
305
306	fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
307		let mut buf = [0; sha256::Hash::LEN];
308		r.read_exact(&mut buf[..])?;
309		Ok(sha256::Hash::from_byte_array(buf))
310	}
311}
312
313/// A macro to implement our [ProtocolEncoding] for a rust-bitcoin type that
314/// implements their `consensus::Encodable/Decodable` traits.
315macro_rules! impl_bitcoin_encode {
316	($name:ty) => {
317		impl ProtocolEncoding for $name {
318			fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
319				let mut wrapped = bitcoin::io::FromStd::new(w);
320				bitcoin::consensus::Encodable::consensus_encode(self, &mut wrapped)?;
321				Ok(())
322			}
323
324			fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
325				let mut wrapped = bitcoin::io::FromStd::new(r);
326				let ret = bitcoin::consensus::Decodable::consensus_decode(&mut wrapped)?;
327				Ok(ret)
328			}
329		}
330	};
331}
332
333impl_bitcoin_encode!(bitcoin::OutPoint);
334impl_bitcoin_encode!(bitcoin::TxOut);
335
336
337pub mod serde {
338	//! Module that helps to encode [ProtocolEncoding] objects with serde.
339	//!
340	//! By default, the objects will be encoded as bytes for regular serializers,
341	//! and as hex for human-readable serializers.
342	//!
343	//! Can be used as follows:
344	//! ```no_run
345	//! # use ark::Vtxo;
346	//! # use serde::{Serialize, Deserialize};
347	//! #[derive(Serialize, Deserialize)]
348	//! struct SomeStruct {
349	//! 	#[serde(with = "ark::encode::serde")]
350	//! 	single: Vtxo,
351	//! 	#[serde(with = "ark::encode::serde::vec")]
352	//! 	multiple: Vec<Vtxo>,
353	//! }
354	//! ```
355
356	use std::fmt;
357	use std::borrow::Cow;
358	use std::marker::PhantomData;
359
360	use serde::{de, ser, Deserialize, Deserializer, Serialize, Serializer};
361
362	use super::ProtocolEncoding;
363
364	struct SerWrapper<'a, T>(&'a T);
365
366	impl<'a, T: ProtocolEncoding> Serialize for SerWrapper<'a, T> {
367		fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
368			if s.is_human_readable() {
369				s.serialize_str(&self.0.serialize_hex())
370			} else {
371				s.serialize_bytes(&self.0.serialize())
372			}
373		}
374	}
375
376	struct DeWrapper<T>(T);
377
378	impl<'de, T: ProtocolEncoding> Deserialize<'de> for DeWrapper<T> {
379		fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
380			if d.is_human_readable() {
381					let s = <Cow<'de, str>>::deserialize(d)?;
382					Ok(DeWrapper(ProtocolEncoding::deserialize_hex(s.as_ref())
383							.map_err(serde::de::Error::custom)?))
384			} else {
385					let b = <Cow<'de, [u8]>>::deserialize(d)?;
386					Ok(DeWrapper(ProtocolEncoding::deserialize(b.as_ref())
387							.map_err(serde::de::Error::custom)?))
388			}
389		}
390	}
391
392	pub fn serialize<T: ProtocolEncoding, S: Serializer>(v: &T, s: S) -> Result<S::Ok, S::Error> {
393		SerWrapper(v).serialize(s)
394	}
395
396	pub fn deserialize<'d, T: ProtocolEncoding, D: Deserializer<'d>>(d: D) -> Result<T, D::Error> {
397		Ok(DeWrapper::<T>::deserialize(d)?.0)
398	}
399
400	pub mod vec {
401		use super::*;
402
403		pub fn serialize<T: ProtocolEncoding, S: Serializer>(v: &[T], s: S) -> Result<S::Ok, S::Error> {
404			let mut seq = s.serialize_seq(Some(v.len()))?;
405			for item in v {
406				ser::SerializeSeq::serialize_element(&mut seq, &SerWrapper(item))?;
407			}
408			ser::SerializeSeq::end(seq)
409		}
410
411		pub fn deserialize<'d, T: ProtocolEncoding, D: Deserializer<'d>>(d: D) -> Result<Vec<T>, D::Error> {
412			struct Visitor<T>(PhantomData<T>);
413
414			impl<'de, T: ProtocolEncoding> de::Visitor<'de> for Visitor<T> {
415				type Value = Vec<T>;
416
417				fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
418					f.write_str("a vector of objects implementing ProtocolEncoding")
419				}
420
421				fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
422					let mut ret = Vec::with_capacity(seq.size_hint().unwrap_or_default());
423					while let Some(v) = seq.next_element::<DeWrapper<T>>()? {
424						ret.push(v.0);
425					}
426					Ok(ret)
427				}
428			}
429			d.deserialize_seq(Visitor(PhantomData))
430		}
431	}
432}
433
434
435#[cfg(any(test, feature = "test-util"))]
436pub mod test {
437	use bitcoin::hex::DisplayHex;
438	use ::serde::{Deserialize, Serialize};
439	use serde_json;
440
441	use super::*;
442
443	/// Test that the object's encoding round-trips.
444	pub fn encoding_roundtrip<T>(object: &T)
445	where
446		T: ProtocolEncoding + fmt::Debug + PartialEq,
447	{
448		let encoded = object.serialize();
449		let decoded = T::deserialize(&encoded).unwrap();
450
451		assert_eq!(*object, decoded);
452
453		let re_encoded = decoded.serialize();
454		assert_eq!(encoded.as_hex().to_string(), re_encoded.as_hex().to_string());
455	}
456
457	pub fn json_roundtrip<T>(object: &T)
458	where
459		T: fmt::Debug + PartialEq + Serialize + for<'de> Deserialize<'de>,
460	{
461		let encoded = serde_json::to_string(object).unwrap();
462		let decoded: T = serde_json::from_str(&encoded).unwrap();
463
464		assert_eq!(*object, decoded);
465	}
466}