1use std::borrow::Cow;
7use std::{fmt, io};
8
9use bitcoin::hashes::{sha256, Hash};
10use bitcoin::secp256k1::{self, schnorr, PublicKey};
14
15
16#[derive(Debug, thiserror::Error)]
18pub enum ProtocolDecodingError {
19 #[error("I/O error: {0}")]
20 Io(#[from] io::Error),
21 #[error("invalid protocol encoding: {message}")]
22 Invalid {
23 message: String,
24 source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
25 },
26}
27
28impl ProtocolDecodingError {
29 pub fn invalid(message: impl fmt::Display) -> Self {
31 Self::Invalid {
32 message: message.to_string(),
33 source: None,
34 }
35 }
36
37 pub fn invalid_err<E>(source: E, message: impl fmt::Display) -> Self
39 where
40 E: std::error::Error + Send + Sync + 'static,
41 {
42 Self::Invalid {
43 message: message.to_string(),
44 source: Some(Box::new(source)),
45 }
46 }
47}
48
49impl From<bitcoin::consensus::encode::Error> for ProtocolDecodingError {
50 fn from(e: bitcoin::consensus::encode::Error) -> Self {
51 match e {
52 bitcoin::consensus::encode::Error::Io(e) => Self::Io(e.into()),
53 e => Self::invalid_err(e, "bitcoin protocol decoding error"),
54 }
55 }
56}
57
58impl From<bitcoin::io::Error> for ProtocolDecodingError {
59 fn from(e: bitcoin::io::Error) -> Self {
60 Self::Io(e.into())
61 }
62}
63
64pub trait ProtocolEncoding: Sized {
66 fn encode<W: io::Write + ?Sized>(&self, writer: &mut W) -> Result<(), io::Error>;
69
70 fn decode<R: io::Read + ?Sized>(reader: &mut R) -> Result<Self, ProtocolDecodingError>;
72
73 fn serialize(&self) -> Vec<u8> {
75 let mut buf = Vec::new();
76 self.encode(&mut buf).expect("buffers don't produce I/O errors");
77 buf
78 }
79
80 fn deserialize(mut byte_slice: &[u8]) -> Result<Self, ProtocolDecodingError> {
82 Self::decode(&mut byte_slice)
83 }
84
85 fn serialize_hex(&self) -> String {
87 use hex_conservative::Case::Lower;
88 let mut buf = String::new();
89 let mut writer = hex_conservative::display::HexWriter::new(&mut buf, Lower);
90 self.encode(&mut writer).expect("no I/O errors for buffers");
91 buf
92 }
93
94 fn deserialize_hex(hex_str: &str) -> Result<Self, ProtocolDecodingError> {
96 let mut iter = hex_conservative::HexToBytesIter::new(hex_str).map_err(|e| {
97 ProtocolDecodingError::Io(io::Error::new(io::ErrorKind::InvalidData, e))
98 })?;
99 Self::decode(&mut iter)
100 }
101}
102
103pub trait WriteExt: io::Write {
105 fn emit_u8(&mut self, v: u8) -> Result<(), io::Error> {
107 self.write_all(&v.to_le_bytes())
108 }
109
110 fn emit_u16(&mut self, v: u16) -> Result<(), io::Error> {
112 self.write_all(&v.to_le_bytes())
113 }
114
115 fn emit_u32(&mut self, v: u32) -> Result<(), io::Error> {
117 self.write_all(&v.to_le_bytes())
118 }
119
120 fn emit_u64(&mut self, v: u64) -> Result<(), io::Error> {
122 self.write_all(&v.to_le_bytes())
123 }
124
125 fn emit_slice(&mut self, slice: &[u8]) -> Result<(), io::Error> {
127 self.write_all(slice)
128 }
129
130 fn emit_compact_size(&mut self, value: impl Into<u64>) -> Result<usize, io::Error> {
132 let value = value.into();
133 match value {
134 0..=0xFC => {
135 self.emit_u8(value as u8)?;
136 Ok(1)
137 },
138 0xFD..=0xFFFF => {
139 self.emit_u8(0xFD)?;
140 self.emit_u16(value as u16)?;
141 Ok(3)
142 },
143 0x10000..=0xFFFFFFFF => {
144 self.emit_u8(0xFE)?;
145 self.emit_u32(value as u32)?;
146 Ok(5)
147 },
148 _ => {
149 self.emit_u8(0xFF)?;
150 self.emit_u64(value)?;
151 Ok(9)
152 },
153 }
154 }
155}
156
157impl<W: io::Write + ?Sized> WriteExt for W {}
158
159pub trait ReadExt: io::Read {
161 fn read_u8(&mut self) -> Result<u8, io::Error> {
163 let mut buf = [0; 1];
164 self.read_exact(&mut buf[..])?;
165 Ok(u8::from_le_bytes(buf))
166 }
167
168 fn read_u16(&mut self) -> Result<u16, io::Error> {
170 let mut buf = [0; 2];
171 self.read_exact(&mut buf[..])?;
172 Ok(u16::from_le_bytes(buf))
173 }
174
175 fn read_u32(&mut self) -> Result<u32, io::Error> {
177 let mut buf = [0; 4];
178 self.read_exact(&mut buf[..])?;
179 Ok(u32::from_le_bytes(buf))
180 }
181
182 fn read_u64(&mut self) -> Result<u64, io::Error> {
184 let mut buf = [0; 8];
185 self.read_exact(&mut buf[..])?;
186 Ok(u64::from_le_bytes(buf))
187 }
188
189 fn read_slice(&mut self, slice: &mut [u8]) -> Result<(), io::Error> {
191 self.read_exact(slice)
192 }
193
194 fn read_byte_array<const N: usize>(&mut self) -> Result<[u8; N], io::Error> {
196 let mut ret = [0u8; N];
197 self.read_exact(&mut ret)?;
198 Ok(ret)
199 }
200
201 fn read_compact_size(&mut self) -> Result<u64, io::Error> {
203 match self.read_u8()? {
204 0xFF => {
205 let x = self.read_u64()?;
206 if x < 0x1_0000_0000 { Err(io::Error::new(io::ErrorKind::InvalidData, "non-minimal varint"))
208 } else {
209 Ok(x)
210 }
211 },
212 0xFE => {
213 let x = self.read_u32()?;
214 if x < 0x1_0000 { Err(io::Error::new(io::ErrorKind::InvalidData, "non-minimal varint"))
216 } else {
217 Ok(x as u64)
218 }
219 },
220 0xFD => {
221 let x = self.read_u16()?;
222 if x < 0xFD { Err(io::Error::new(io::ErrorKind::InvalidData, "non-minimal varint"))
224 } else {
225 Ok(x as u64)
226 }
227 },
228 n => Ok(n as u64),
229 }
230 }
231}
232
233impl<R: io::Read + ?Sized> ReadExt for R {}
234
235
236impl ProtocolEncoding for PublicKey {
237 fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
238 w.emit_slice(&self.serialize())
239 }
240
241 fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
242 let mut buf = [0; secp256k1::constants::PUBLIC_KEY_SIZE];
243 r.read_slice(&mut buf[..])?;
244 PublicKey::from_slice(&buf).map_err(|e| {
245 ProtocolDecodingError::invalid_err(e, "invalid public key")
246 })
247 }
248}
249
250impl ProtocolEncoding for Option<PublicKey> {
251 fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
252 if let Some(pk) = self {
253 w.emit_slice(&pk.serialize())
254 } else {
255 w.emit_u8(0)
256 }
257 }
258
259 fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
260 let first = r.read_u8()?;
261 if first == 0 {
262 Ok(None)
263 } else {
264 let mut pk = [first; secp256k1::constants::PUBLIC_KEY_SIZE];
265 r.read_slice(&mut pk[1..])?;
266 Ok(Some(PublicKey::from_slice(&pk).map_err(|e| {
267 ProtocolDecodingError::invalid_err(e, "invalid public key")
268 })?))
269 }
270 }
271}
272
273impl ProtocolEncoding for schnorr::Signature {
274 fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
275 w.emit_slice(&self.serialize())
276 }
277
278 fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
279 let mut buf = [0; secp256k1::constants::SCHNORR_SIGNATURE_SIZE];
280 r.read_slice(&mut buf[..])?;
281 schnorr::Signature::from_slice(&buf).map_err(|e| {
282 ProtocolDecodingError::invalid_err(e, "invalid schnorr signature")
283 })
284 }
285}
286
287impl ProtocolEncoding for Option<schnorr::Signature> {
288 fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
289 if let Some(sig) = self {
290 w.emit_slice(&sig.serialize())
291 } else {
292 w.emit_slice(&[0; secp256k1::constants::SCHNORR_SIGNATURE_SIZE])
293 }
294 }
295
296 fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
297 let mut buf = [0; secp256k1::constants::SCHNORR_SIGNATURE_SIZE];
298 r.read_slice(&mut buf[..])?;
299 if buf == [0; secp256k1::constants::SCHNORR_SIGNATURE_SIZE] {
300 Ok(None)
301 } else {
302 Ok(Some(schnorr::Signature::from_slice(&buf).map_err(|e| {
303 ProtocolDecodingError::invalid_err(e, "invalid schnorr signature")
304 })?))
305 }
306 }
307}
308
309impl ProtocolEncoding for sha256::Hash {
310 fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
311 w.emit_slice(&self[..])
312 }
313
314 fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
315 let mut buf = [0; sha256::Hash::LEN];
316 r.read_exact(&mut buf[..])?;
317 Ok(sha256::Hash::from_byte_array(buf))
318 }
319}
320
321macro_rules! impl_bitcoin_encode {
324 ($name:ty) => {
325 impl ProtocolEncoding for $name {
326 fn encode<W: io::Write + ?Sized>(&self, w: &mut W) -> Result<(), io::Error> {
327 let mut wrapped = bitcoin::io::FromStd::new(w);
328 bitcoin::consensus::Encodable::consensus_encode(self, &mut wrapped)?;
329 Ok(())
330 }
331
332 fn decode<R: io::Read + ?Sized>(r: &mut R) -> Result<Self, ProtocolDecodingError> {
333 let mut wrapped = bitcoin::io::FromStd::new(r);
334 let ret = bitcoin::consensus::Decodable::consensus_decode(&mut wrapped)?;
335 Ok(ret)
336 }
337 }
338 };
339}
340
341impl_bitcoin_encode!(bitcoin::BlockHash);
342impl_bitcoin_encode!(bitcoin::OutPoint);
343impl_bitcoin_encode!(bitcoin::TxOut);
344
345
346impl<'a, T: ProtocolEncoding + Clone> ProtocolEncoding for Cow<'a, T> {
347 fn encode<W: io::Write + ?Sized>(&self, writer: &mut W) -> Result<(), io::Error> {
348 ProtocolEncoding::encode(self.as_ref(), writer)
349 }
350
351 fn decode<R: io::Read + ?Sized>(reader: &mut R) -> Result<Self, ProtocolDecodingError> {
352 Ok(Cow::Owned(ProtocolEncoding::decode(reader)?))
353 }
354}
355
356
357pub mod serde {
358 use std::fmt;
377 use std::borrow::Cow;
378 use std::marker::PhantomData;
379
380 use serde::{de, ser, Deserialize, Deserializer, Serialize, Serializer};
381
382 use super::ProtocolEncoding;
383
384 struct SerWrapper<'a, T>(&'a T);
385
386 impl<'a, T: ProtocolEncoding> Serialize for SerWrapper<'a, T> {
387 fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
388 if s.is_human_readable() {
389 s.serialize_str(&self.0.serialize_hex())
390 } else {
391 s.serialize_bytes(&self.0.serialize())
392 }
393 }
394 }
395
396 struct DeWrapper<T>(T);
397
398 impl<'de, T: ProtocolEncoding> Deserialize<'de> for DeWrapper<T> {
399 fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
400 if d.is_human_readable() {
401 let s = <Cow<'de, str>>::deserialize(d)?;
402 Ok(DeWrapper(ProtocolEncoding::deserialize_hex(s.as_ref())
403 .map_err(serde::de::Error::custom)?))
404 } else {
405 let b = <Cow<'de, [u8]>>::deserialize(d)?;
406 Ok(DeWrapper(ProtocolEncoding::deserialize(b.as_ref())
407 .map_err(serde::de::Error::custom)?))
408 }
409 }
410 }
411
412 pub fn serialize<T: ProtocolEncoding, S: Serializer>(v: &T, s: S) -> Result<S::Ok, S::Error> {
413 SerWrapper(v).serialize(s)
414 }
415
416 pub fn deserialize<'d, T: ProtocolEncoding, D: Deserializer<'d>>(d: D) -> Result<T, D::Error> {
417 Ok(DeWrapper::<T>::deserialize(d)?.0)
418 }
419
420 pub mod vec {
421 use super::*;
422
423 pub fn serialize<T: ProtocolEncoding, S: Serializer>(v: &[T], s: S) -> Result<S::Ok, S::Error> {
424 let mut seq = s.serialize_seq(Some(v.len()))?;
425 for item in v {
426 ser::SerializeSeq::serialize_element(&mut seq, &SerWrapper(item))?;
427 }
428 ser::SerializeSeq::end(seq)
429 }
430
431 pub fn deserialize<'d, T: ProtocolEncoding, D: Deserializer<'d>>(d: D) -> Result<Vec<T>, D::Error> {
432 struct Visitor<T>(PhantomData<T>);
433
434 impl<'de, T: ProtocolEncoding> de::Visitor<'de> for Visitor<T> {
435 type Value = Vec<T>;
436
437 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
438 f.write_str("a vector of objects implementing ProtocolEncoding")
439 }
440
441 fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
442 let mut ret = Vec::with_capacity(seq.size_hint().unwrap_or_default());
443 while let Some(v) = seq.next_element::<DeWrapper<T>>()? {
444 ret.push(v.0);
445 }
446 Ok(ret)
447 }
448 }
449 d.deserialize_seq(Visitor(PhantomData))
450 }
451 }
452
453 pub mod cow {
454 use super::*;
455
456 use std::borrow::Cow;
457
458 pub fn serialize<'a, T, S>(v: &Cow<'a, T>, s: S) -> Result<S::Ok, S::Error>
459 where
460 T: ProtocolEncoding + Clone,
461 S: Serializer,
462 {
463 SerWrapper(v.as_ref()).serialize(s)
464 }
465
466 pub fn deserialize<'d, T, D>(d: D) -> Result<Cow<'static, T>, D::Error>
467 where
468 T: ProtocolEncoding + Clone,
469 D: Deserializer<'d>,
470 {
471 Ok(Cow::Owned(DeWrapper::<T>::deserialize(d)?.0))
472 }
473
474 pub mod vec {
475 use super::*;
476
477 use std::borrow::Cow;
478
479 pub fn serialize<'a, T, S>(v: &Cow<'a, [T]>, s: S) -> Result<S::Ok, S::Error>
480 where
481 T: ProtocolEncoding + Clone,
482 S: Serializer,
483 {
484 let mut seq = s.serialize_seq(Some(v.len()))?;
485 for item in v.as_ref().iter() {
486 ser::SerializeSeq::serialize_element(&mut seq, &SerWrapper(item))?;
487 }
488 ser::SerializeSeq::end(seq)
489 }
490
491 pub fn deserialize<'d, T, D>(d: D) -> Result<Cow<'static, [T]>, D::Error>
492 where
493 T: ProtocolEncoding + Clone,
494 D: Deserializer<'d>,
495 {
496 struct Visitor<T>(PhantomData<T>);
497
498 impl<'de, T: ProtocolEncoding + Clone + 'static> de::Visitor<'de> for Visitor<T> {
499 type Value = Cow<'static, [T]>;
500
501 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
502 f.write_str("a vector of objects implementing ProtocolEncoding")
503 }
504
505 fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
506 let mut ret = Vec::with_capacity(seq.size_hint().unwrap_or_default());
507 while let Some(v) = seq.next_element::<DeWrapper<T>>()? {
508 ret.push(v.0);
509 }
510 Ok(ret.into())
511 }
512 }
513 d.deserialize_seq(Visitor(PhantomData))
514 }
515 }
516 }
517}
518
519
520#[cfg(any(test, feature = "test-util"))]
521pub mod test {
522 use bitcoin::hex::DisplayHex;
523 use ::serde::{Deserialize, Serialize};
524 use serde_json;
525
526 use super::*;
527
528 pub fn encoding_roundtrip<T>(object: &T)
530 where
531 T: ProtocolEncoding + fmt::Debug + PartialEq,
532 {
533 let encoded = object.serialize();
534 let decoded = T::deserialize(&encoded).unwrap();
535
536 assert_eq!(*object, decoded);
537
538 let re_encoded = decoded.serialize();
539 assert_eq!(encoded.as_hex().to_string(), re_encoded.as_hex().to_string());
540 }
541
542 pub fn json_roundtrip<T>(object: &T)
543 where
544 T: fmt::Debug + PartialEq + Serialize + for<'de> Deserialize<'de>,
545 {
546 let encoded = serde_json::to_string(object).unwrap();
547 let decoded: T = serde_json::from_str(&encoded).unwrap();
548
549 assert_eq!(*object, decoded);
550 }
551}