1use std::{fmt, io};
7
8use bitcoin::hashes::{sha256, Hash};
9use bitcoin::secp256k1::{self, schnorr, PublicKey};
13
14
15#[derive(Debug, thiserror::Error)]
17pub enum ProtocolDecodingError {
18 #[error("I/O error: {0}")]
19 Io(#[from] io::Error),
20 #[error("invalid protocol encoding: {message}")]
21 Invalid {
22 message: String,
23 source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
24 },
25}
26
27impl ProtocolDecodingError {
28 pub fn invalid(message: impl fmt::Display) -> Self {
30 Self::Invalid {
31 message: message.to_string(),
32 source: None,
33 }
34 }
35
36 pub fn invalid_err<E>(source: E, message: impl fmt::Display) -> Self
38 where
39 E: std::error::Error + Send + Sync + 'static,
40 {
41 Self::Invalid {
42 message: message.to_string(),
43 source: Some(Box::new(source)),
44 }
45 }
46}
47
48impl From<bitcoin::consensus::encode::Error> for ProtocolDecodingError {
49 fn from(e: bitcoin::consensus::encode::Error) -> Self {
50 match e {
51 bitcoin::consensus::encode::Error::Io(e) => Self::Io(e.into()),
52 e => Self::invalid_err(e, "bitcoin protocol decoding error"),
53 }
54 }
55}
56
57impl From<bitcoin::io::Error> for ProtocolDecodingError {
58 fn from(e: bitcoin::io::Error) -> Self {
59 Self::Io(e.into())
60 }
61}
62
63pub trait ProtocolEncoding: Sized {
65 fn encode<W: io::Write + ?Sized>(&self, writer: &mut W) -> Result<(), io::Error>;
68
69 fn decode<R: io::Read + ?Sized>(reader: &mut R) -> Result<Self, ProtocolDecodingError>;
71
72 fn serialize(&self) -> Vec<u8> {
74 let mut buf = Vec::new();
75 self.encode(&mut buf).expect("buffers don't produce I/O errors");
76 buf
77 }
78
79 fn deserialize(mut byte_slice: &[u8]) -> Result<Self, ProtocolDecodingError> {
81 Self::decode(&mut byte_slice)
82 }
83
84 fn serialize_hex(&self) -> String {
86 use hex_conservative::Case::Lower;
87 let mut buf = String::new();
88 let mut writer = hex_conservative::display::HexWriter::new(&mut buf, Lower);
89 self.encode(&mut writer).expect("no I/O errors for buffers");
90 buf
91 }
92
93 fn deserialize_hex(hex_str: &str) -> Result<Self, ProtocolDecodingError> {
95 let mut iter = hex_conservative::HexToBytesIter::new(hex_str).map_err(|e| {
96 ProtocolDecodingError::Io(io::Error::new(io::ErrorKind::InvalidData, e))
97 })?;
98 Self::decode(&mut iter)
99 }
100}
101
102pub trait WriteExt: io::Write {
104 fn emit_u8(&mut self, v: u8) -> Result<(), io::Error> {
106 self.write_all(&v.to_le_bytes())
107 }
108
109 fn emit_u16(&mut self, v: u16) -> Result<(), io::Error> {
111 self.write_all(&v.to_le_bytes())
112 }
113
114 fn emit_u32(&mut self, v: u32) -> Result<(), io::Error> {
116 self.write_all(&v.to_le_bytes())
117 }
118
119 fn emit_u64(&mut self, v: u64) -> Result<(), io::Error> {
121 self.write_all(&v.to_le_bytes())
122 }
123
124 fn emit_slice(&mut self, slice: &[u8]) -> Result<(), io::Error> {
126 self.write_all(slice)
127 }
128
129 fn emit_compact_size(&mut self, value: impl Into<u64>) -> Result<usize, io::Error> {
131 let value = value.into();
132 match value {
133 0..=0xFC => {
134 self.emit_u8(value as u8)?;
135 Ok(1)
136 },
137 0xFD..=0xFFFF => {
138 self.emit_u8(0xFD)?;
139 self.emit_u16(value as u16)?;
140 Ok(3)
141 },
142 0x10000..=0xFFFFFFFF => {
143 self.emit_u8(0xFE)?;
144 self.emit_u32(value as u32)?;
145 Ok(5)
146 },
147 _ => {
148 self.emit_u8(0xFF)?;
149 self.emit_u64(value)?;
150 Ok(9)
151 },
152 }
153 }
154}
155
156impl<W: io::Write + ?Sized> WriteExt for W {}
157
158pub trait ReadExt: io::Read {
160 fn read_u8(&mut self) -> Result<u8, io::Error> {
162 let mut buf = [0; 1];
163 self.read_exact(&mut buf[..])?;
164 Ok(u8::from_le_bytes(buf))
165 }
166
167 fn read_u16(&mut self) -> Result<u16, io::Error> {
169 let mut buf = [0; 2];
170 self.read_exact(&mut buf[..])?;
171 Ok(u16::from_le_bytes(buf))
172 }
173
174 fn read_u32(&mut self) -> Result<u32, io::Error> {
176 let mut buf = [0; 4];
177 self.read_exact(&mut buf[..])?;
178 Ok(u32::from_le_bytes(buf))
179 }
180
181 fn read_u64(&mut self) -> Result<u64, io::Error> {
183 let mut buf = [0; 8];
184 self.read_exact(&mut buf[..])?;
185 Ok(u64::from_le_bytes(buf))
186 }
187
188 fn read_slice(&mut self, slice: &mut [u8]) -> Result<(), io::Error> {
190 self.read_exact(slice)
191 }
192
193 fn read_compact_size(&mut self) -> Result<u64, io::Error> {
195 match self.read_u8()? {
196 0xFF => {
197 let x = self.read_u64()?;
198 if x < 0x1_0000_0000 { Err(io::Error::new(io::ErrorKind::InvalidData, "non-minimal varint"))
200 } else {
201 Ok(x)
202 }
203 },
204 0xFE => {
205 let x = self.read_u32()?;
206 if x < 0x1_0000 { Err(io::Error::new(io::ErrorKind::InvalidData, "non-minimal varint"))
208 } else {
209 Ok(x as u64)
210 }
211 },
212 0xFD => {
213 let x = self.read_u16()?;
214 if x < 0xFD { Err(io::Error::new(io::ErrorKind::InvalidData, "non-minimal varint"))
216 } else {
217 Ok(x as u64)
218 }
219 },
220 n => Ok(n as u64),
221 }
222 }
223}
224
225impl<R: io::Read + ?Sized> ReadExt for R {}
226
227
228impl ProtocolEncoding for PublicKey {
229 fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
230 w.emit_slice(&self.serialize())
231 }
232
233 fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
234 let mut buf = [0; secp256k1::constants::PUBLIC_KEY_SIZE];
235 r.read_slice(&mut buf[..])?;
236 PublicKey::from_slice(&buf).map_err(|e| {
237 ProtocolDecodingError::invalid_err(e, "invalid public key")
238 })
239 }
240}
241
242impl ProtocolEncoding for Option<PublicKey> {
243 fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
244 if let Some(pk) = self {
245 w.emit_slice(&pk.serialize())
246 } else {
247 w.emit_u8(0)
248 }
249 }
250
251 fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
252 let first = r.read_u8()?;
253 if first == 0 {
254 Ok(None)
255 } else {
256 let mut pk = [first; secp256k1::constants::PUBLIC_KEY_SIZE];
257 r.read_slice(&mut pk[1..])?;
258 Ok(Some(PublicKey::from_slice(&pk).map_err(|e| {
259 ProtocolDecodingError::invalid_err(e, "invalid public key")
260 })?))
261 }
262 }
263}
264
265impl ProtocolEncoding for schnorr::Signature {
266 fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
267 w.emit_slice(&self.serialize())
268 }
269
270 fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
271 let mut buf = [0; secp256k1::constants::SCHNORR_SIGNATURE_SIZE];
272 r.read_slice(&mut buf[..])?;
273 schnorr::Signature::from_slice(&buf).map_err(|e| {
274 ProtocolDecodingError::invalid_err(e, "invalid schnorr signature")
275 })
276 }
277}
278
279impl ProtocolEncoding for Option<schnorr::Signature> {
280 fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
281 if let Some(sig) = self {
282 w.emit_slice(&sig.serialize())
283 } else {
284 w.emit_slice(&[0; secp256k1::constants::SCHNORR_SIGNATURE_SIZE])
285 }
286 }
287
288 fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
289 let mut buf = [0; secp256k1::constants::SCHNORR_SIGNATURE_SIZE];
290 r.read_slice(&mut buf[..])?;
291 if buf == [0; secp256k1::constants::SCHNORR_SIGNATURE_SIZE] {
292 Ok(None)
293 } else {
294 Ok(Some(schnorr::Signature::from_slice(&buf).map_err(|e| {
295 ProtocolDecodingError::invalid_err(e, "invalid schnorr signature")
296 })?))
297 }
298 }
299}
300
301impl ProtocolEncoding for sha256::Hash {
302 fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
303 w.emit_slice(&self[..])
304 }
305
306 fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
307 let mut buf = [0; sha256::Hash::LEN];
308 r.read_exact(&mut buf[..])?;
309 Ok(sha256::Hash::from_byte_array(buf))
310 }
311}
312
313macro_rules! impl_bitcoin_encode {
316 ($name:ty) => {
317 impl ProtocolEncoding for $name {
318 fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
319 let mut wrapped = bitcoin::io::FromStd::new(w);
320 bitcoin::consensus::Encodable::consensus_encode(self, &mut wrapped)?;
321 Ok(())
322 }
323
324 fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
325 let mut wrapped = bitcoin::io::FromStd::new(r);
326 let ret = bitcoin::consensus::Decodable::consensus_decode(&mut wrapped)?;
327 Ok(ret)
328 }
329 }
330 };
331}
332
333impl_bitcoin_encode!(bitcoin::OutPoint);
334impl_bitcoin_encode!(bitcoin::TxOut);
335
336
337pub mod serde {
338 use std::fmt;
357 use std::borrow::Cow;
358 use std::marker::PhantomData;
359
360 use serde::{de, ser, Deserialize, Deserializer, Serialize, Serializer};
361
362 use super::ProtocolEncoding;
363
364 struct SerWrapper<'a, T>(&'a T);
365
366 impl<'a, T: ProtocolEncoding> Serialize for SerWrapper<'a, T> {
367 fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
368 if s.is_human_readable() {
369 s.serialize_str(&self.0.serialize_hex())
370 } else {
371 s.serialize_bytes(&self.0.serialize())
372 }
373 }
374 }
375
376 struct DeWrapper<T>(T);
377
378 impl<'de, T: ProtocolEncoding> Deserialize<'de> for DeWrapper<T> {
379 fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
380 if d.is_human_readable() {
381 let s = <Cow<'de, str>>::deserialize(d)?;
382 Ok(DeWrapper(ProtocolEncoding::deserialize_hex(s.as_ref())
383 .map_err(serde::de::Error::custom)?))
384 } else {
385 let b = <Cow<'de, [u8]>>::deserialize(d)?;
386 Ok(DeWrapper(ProtocolEncoding::deserialize(b.as_ref())
387 .map_err(serde::de::Error::custom)?))
388 }
389 }
390 }
391
392 pub fn serialize<T: ProtocolEncoding, S: Serializer>(v: &T, s: S) -> Result<S::Ok, S::Error> {
393 SerWrapper(v).serialize(s)
394 }
395
396 pub fn deserialize<'d, T: ProtocolEncoding, D: Deserializer<'d>>(d: D) -> Result<T, D::Error> {
397 Ok(DeWrapper::<T>::deserialize(d)?.0)
398 }
399
400 pub mod vec {
401 use super::*;
402
403 pub fn serialize<T: ProtocolEncoding, S: Serializer>(v: &[T], s: S) -> Result<S::Ok, S::Error> {
404 let mut seq = s.serialize_seq(Some(v.len()))?;
405 for item in v {
406 ser::SerializeSeq::serialize_element(&mut seq, &SerWrapper(item))?;
407 }
408 ser::SerializeSeq::end(seq)
409 }
410
411 pub fn deserialize<'d, T: ProtocolEncoding, D: Deserializer<'d>>(d: D) -> Result<Vec<T>, D::Error> {
412 struct Visitor<T>(PhantomData<T>);
413
414 impl<'de, T: ProtocolEncoding> de::Visitor<'de> for Visitor<T> {
415 type Value = Vec<T>;
416
417 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
418 f.write_str("a vector of objects implementing ProtocolEncoding")
419 }
420
421 fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
422 let mut ret = Vec::with_capacity(seq.size_hint().unwrap_or_default());
423 while let Some(v) = seq.next_element::<DeWrapper<T>>()? {
424 ret.push(v.0);
425 }
426 Ok(ret)
427 }
428 }
429 d.deserialize_seq(Visitor(PhantomData))
430 }
431 }
432}
433
434
435#[cfg(any(test, feature = "test-util"))]
436pub mod test {
437 use bitcoin::hex::DisplayHex;
438 use ::serde::{Deserialize, Serialize};
439 use serde_json;
440
441 use super::*;
442
443 pub fn encoding_roundtrip<T>(object: &T)
445 where
446 T: ProtocolEncoding + fmt::Debug + PartialEq,
447 {
448 let encoded = object.serialize();
449 let decoded = T::deserialize(&encoded).unwrap();
450
451 assert_eq!(*object, decoded);
452
453 let re_encoded = decoded.serialize();
454 assert_eq!(encoded.as_hex().to_string(), re_encoded.as_hex().to_string());
455 }
456
457 pub fn json_roundtrip<T>(object: &T)
458 where
459 T: fmt::Debug + PartialEq + Serialize + for<'de> Deserialize<'de>,
460 {
461 let encoded = serde_json::to_string(object).unwrap();
462 let decoded: T = serde_json::from_str(&encoded).unwrap();
463
464 assert_eq!(*object, decoded);
465 }
466}