cloud_mmr/
ser.rs

1// Copyright 2021 The Grin Developers
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Serialization and deserialization layer specialized for binary encoding.
16//! Ensures consistency and safety. Basically a minimal subset or
17//! rustc_serialize customized for our need.
18//!
19//! To use it simply implement `Writeable` or `Readable` and then use the
20//! `serialize` or `deserialize` functions on them as appropriate.
21
22use crate::hash::{DefaultHashable, Hash, Hashed};
23use byteorder::{BigEndian, ByteOrder, ReadBytesExt};
24use bytes::Buf;
25use std::fmt::{self, Debug};
26use std::io::{self, Read, Write};
27use std::{error, marker, string};
28
29pub const PROTOCOL_VERSION: ProtocolVersion = ProtocolVersion(1_000);
30
31/// Possible errors deriving from serializing or deserializing.
32#[derive(Clone, Eq, PartialEq, Debug, Serialize, Deserialize)]
33pub enum Error {
34    /// Wraps an io error produced when reading or writing
35    IOErr(
36        String,
37        #[serde(
38            serialize_with = "serialize_error_kind",
39            deserialize_with = "deserialize_error_kind"
40        )]
41        io::ErrorKind,
42    ),
43    /// Expected a given value that wasn't found
44    UnexpectedData {
45        /// What we wanted
46        expected: Vec<u8>,
47        /// What we got
48        received: Vec<u8>,
49    },
50    /// Data wasn't in a consumable format
51    CorruptedData,
52    /// Incorrect number of elements (when deserializing a vec via read_multi say).
53    CountError,
54    /// When asked to read too much data
55    TooLargeReadErr,
56    /// Error from from_hex deserialization
57    HexError(String),
58    /// Inputs/outputs/kernels must be sorted lexicographically.
59    SortError,
60    /// Inputs/outputs/kernels must be unique.
61    DuplicateError,
62    /// Block header version (hard-fork schedule).
63    InvalidBlockVersion,
64    /// Unsupported protocol version
65    UnsupportedProtocolVersion,
66    /// bincode error
67    BincodeErr(String),
68}
69
70impl From<io::Error> for Error {
71    fn from(e: io::Error) -> Error {
72        Error::IOErr(format!("{}", e), e.kind())
73    }
74}
75
76impl From<io::ErrorKind> for Error {
77    fn from(e: io::ErrorKind) -> Error {
78        Error::IOErr(format!("{}", io::Error::from(e)), e)
79    }
80}
81
82impl fmt::Display for Error {
83    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84        match *self {
85            Error::IOErr(ref e, ref _k) => write!(f, "{}", e),
86            Error::UnexpectedData {
87                expected: ref e,
88                received: ref r,
89            } => write!(f, "expected {:?}, got {:?}", e, r),
90            Error::CorruptedData => f.write_str("corrupted data"),
91            Error::CountError => f.write_str("count error"),
92            Error::SortError => f.write_str("sort order"),
93            Error::DuplicateError => f.write_str("duplicate"),
94            Error::TooLargeReadErr => f.write_str("too large read"),
95            Error::HexError(ref e) => write!(f, "hex error {:?}", e),
96            Error::InvalidBlockVersion => f.write_str("invalid block version"),
97            Error::UnsupportedProtocolVersion => f.write_str("unsupported protocol version"),
98            Error::BincodeErr(ref e) => write!(f, "bincode error {:?}", e),
99        }
100    }
101}
102
103impl error::Error for Error {
104    fn cause(&self) -> Option<&dyn error::Error> {
105        match *self {
106            Error::IOErr(ref _e, ref _k) => Some(self),
107            _ => None,
108        }
109    }
110
111    fn description(&self) -> &str {
112        match *self {
113            Error::IOErr(ref e, _) => e,
114            Error::UnexpectedData { .. } => "unexpected data",
115            Error::CorruptedData => "corrupted data",
116            Error::CountError => "count error",
117            Error::SortError => "sort order",
118            Error::DuplicateError => "duplicate error",
119            Error::TooLargeReadErr => "too large read",
120            Error::HexError(_) => "hex error",
121            Error::InvalidBlockVersion => "invalid block version",
122            Error::UnsupportedProtocolVersion => "unsupported protocol version",
123            Error::BincodeErr(_) => "bincode error",
124        }
125    }
126}
127
128/// Signal to a serializable object how much of its data should be serialized
129#[derive(Copy, Clone, PartialEq, Eq)]
130pub enum SerializationMode {
131    /// Serialize everything sufficiently to fully reconstruct the object
132    Full,
133    /// Serialize the data that defines the object
134    Hash,
135}
136
137impl SerializationMode {
138    /// Hash mode?
139    pub fn is_hash_mode(&self) -> bool {
140        match self {
141            SerializationMode::Hash => true,
142            _ => false,
143        }
144    }
145}
146
147/// Implementations defined how different numbers and binary structures are
148/// written to an underlying stream or container (depending on implementation).
149pub trait Writer {
150    /// The mode this serializer is writing in
151    fn serialization_mode(&self) -> SerializationMode;
152
153    /// Protocol version for version specific serialization rules.
154    fn protocol_version(&self) -> ProtocolVersion;
155
156    /// Writes a u8 as bytes
157    fn write_u8(&mut self, n: u8) -> Result<(), Error> {
158        self.write_fixed_bytes(&[n])
159    }
160
161    /// Writes a u16 as bytes
162    fn write_u16(&mut self, n: u16) -> Result<(), Error> {
163        let mut bytes = [0; 2];
164        BigEndian::write_u16(&mut bytes, n);
165        self.write_fixed_bytes(&bytes)
166    }
167
168    /// Writes a u32 as bytes
169    fn write_u32(&mut self, n: u32) -> Result<(), Error> {
170        let mut bytes = [0; 4];
171        BigEndian::write_u32(&mut bytes, n);
172        self.write_fixed_bytes(&bytes)
173    }
174
175    /// Writes a u32 as bytes
176    fn write_i32(&mut self, n: i32) -> Result<(), Error> {
177        let mut bytes = [0; 4];
178        BigEndian::write_i32(&mut bytes, n);
179        self.write_fixed_bytes(&bytes)
180    }
181
182    /// Writes a u64 as bytes
183    fn write_u64(&mut self, n: u64) -> Result<(), Error> {
184        let mut bytes = [0; 8];
185        BigEndian::write_u64(&mut bytes, n);
186        self.write_fixed_bytes(&bytes)
187    }
188
189    /// Writes a i64 as bytes
190    fn write_i64(&mut self, n: i64) -> Result<(), Error> {
191        let mut bytes = [0; 8];
192        BigEndian::write_i64(&mut bytes, n);
193        self.write_fixed_bytes(&bytes)
194    }
195
196    /// Writes a variable number of bytes. The length is encoded as a 64-bit
197    /// prefix.
198    fn write_bytes<T: AsRef<[u8]>>(&mut self, bytes: T) -> Result<(), Error> {
199        self.write_u64(bytes.as_ref().len() as u64)?;
200        self.write_fixed_bytes(bytes)
201    }
202
203    /// Writes a fixed number of bytes. The reader is expected to know the actual length on read.
204    fn write_fixed_bytes<T: AsRef<[u8]>>(&mut self, bytes: T) -> Result<(), Error>;
205
206    /// Writes a fixed length of "empty" bytes.
207    fn write_empty_bytes(&mut self, length: usize) -> Result<(), Error> {
208        self.write_fixed_bytes(vec![0u8; length])
209    }
210}
211
212/// Signal to a deserializable object how much of its data should be deserialized
213#[derive(Copy, Clone, PartialEq, Eq)]
214pub enum DeserializationMode {
215    /// Deserialize everything sufficiently to fully reconstruct the object
216    Full,
217    /// For Block Headers, skip reading proof
218    SkipPow,
219}
220
221impl DeserializationMode {
222    /// Default deserialization mode
223    pub fn default() -> Self {
224        DeserializationMode::Full
225    }
226}
227
228/// Implementations defined how different numbers and binary structures are
229/// read from an underlying stream or container (depending on implementation).
230pub trait Reader {
231    /// The mode this reader is reading from
232    fn deserialization_mode(&self) -> DeserializationMode;
233    /// Read a u8 from the underlying Read
234    fn read_u8(&mut self) -> Result<u8, Error>;
235    /// Read a u16 from the underlying Read
236    fn read_u16(&mut self) -> Result<u16, Error>;
237    /// Read a u32 from the underlying Read
238    fn read_u32(&mut self) -> Result<u32, Error>;
239    /// Read a u64 from the underlying Read
240    fn read_u64(&mut self) -> Result<u64, Error>;
241    /// Read a i32 from the underlying Read
242    fn read_i32(&mut self) -> Result<i32, Error>;
243    /// Read a i64 from the underlying Read
244    fn read_i64(&mut self) -> Result<i64, Error>;
245    /// Read a u64 len prefix followed by that number of exact bytes.
246    fn read_bytes_len_prefix(&mut self) -> Result<Vec<u8>, Error>;
247    /// Read a fixed number of bytes from the underlying reader.
248    fn read_fixed_bytes(&mut self, length: usize) -> Result<Vec<u8>, Error>;
249    /// Consumes a byte from the reader, producing an error if it doesn't have
250    /// the expected value
251    fn expect_u8(&mut self, val: u8) -> Result<u8, Error>;
252    /// Access to underlying protocol version to support
253    /// version specific deserialization logic.
254    fn protocol_version(&self) -> ProtocolVersion;
255
256    /// Read a fixed number of "empty" bytes from the underlying reader.
257    /// It is an error if any non-empty bytes encountered.
258    fn read_empty_bytes(&mut self, length: usize) -> Result<(), Error> {
259        for _ in 0..length {
260            if self.read_u8()? != 0u8 {
261                return Err(Error::CorruptedData);
262            }
263        }
264        Ok(())
265    }
266}
267
268/// Trait that every type that can be serialized as binary must implement.
269/// Writes directly to a Writer, a utility type thinly wrapping an
270/// underlying Write implementation.
271pub trait Writeable {
272    /// Write the data held by this Writeable to the provided writer
273    fn write<W: Writer>(&self, writer: &mut W) -> Result<(), Error>;
274}
275
276/// Reader that exposes an Iterator interface.
277pub struct IteratingReader<'a, T, R: Reader> {
278    count: u64,
279    curr: u64,
280    reader: &'a mut R,
281    _marker: marker::PhantomData<T>,
282}
283
284impl<'a, T, R: Reader> IteratingReader<'a, T, R> {
285    /// Constructor to create a new iterating reader for the provided underlying reader.
286    /// Takes a count so we know how many to iterate over.
287    pub fn new(reader: &'a mut R, count: u64) -> Self {
288        let curr = 0;
289        IteratingReader {
290            count,
291            curr,
292            reader,
293            _marker: marker::PhantomData,
294        }
295    }
296}
297
298impl<'a, T, R> Iterator for IteratingReader<'a, T, R>
299where
300    T: Readable,
301    R: Reader,
302{
303    type Item = T;
304
305    fn next(&mut self) -> Option<T> {
306        if self.curr >= self.count {
307            return None;
308        }
309        self.curr += 1;
310        T::read(self.reader).ok()
311    }
312}
313
314/// Reads multiple serialized items into a Vec.
315pub fn read_multi<T, R>(reader: &mut R, count: u64) -> Result<Vec<T>, Error>
316where
317    T: Readable,
318    R: Reader,
319{
320    // Very rudimentary check to ensure we do not overflow anything
321    // attempting to read huge amounts of data.
322    // Probably better than checking if count * size overflows a u64 though.
323    if count > 1_000_000 {
324        return Err(Error::TooLargeReadErr);
325    }
326
327    let res: Vec<T> = IteratingReader::new(reader, count).collect();
328    if res.len() as u64 != count {
329        return Err(Error::CountError);
330    }
331    Ok(res)
332}
333
334/// Protocol version for serialization/deserialization.
335/// Note: This is used in various places including but limited to
336/// the p2p layer and our local db storage layer.
337/// We may speak multiple versions to various peers and a potentially *different*
338/// version for our local db.
339#[derive(Clone, Copy, Debug, Deserialize, Eq, Ord, PartialOrd, PartialEq, Serialize)]
340pub struct ProtocolVersion(pub u32);
341
342impl ProtocolVersion {
343    /// The max protocol version supported.
344    pub const MAX: u32 = std::u32::MAX;
345
346    /// Protocol version as u32 to allow for convenient exhaustive matching on values.
347    pub fn value(self) -> u32 {
348        self.0
349    }
350
351    /// Our default "local" protocol version.
352    /// This protocol version is provided to peers as part of the Hand/Shake
353    /// negotiation in the p2p layer. Connected peers will negotiate a suitable
354    /// protocol version for serialization/deserialization of p2p messages.
355    pub fn local() -> ProtocolVersion {
356        PROTOCOL_VERSION
357    }
358
359    /// We need to specify a protocol version for our local database.
360    /// Regardless of specific version used when sending/receiving data between peers
361    /// we need to take care with serialization/deserialization of data locally in the db.
362    pub fn local_db() -> ProtocolVersion {
363        ProtocolVersion(1)
364    }
365}
366
367impl fmt::Display for ProtocolVersion {
368    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
369        write!(f, "{}", self.0)
370    }
371}
372
373impl From<ProtocolVersion> for u32 {
374    fn from(v: ProtocolVersion) -> u32 {
375        v.0
376    }
377}
378
379impl Writeable for ProtocolVersion {
380    fn write<W: Writer>(&self, writer: &mut W) -> Result<(), Error> {
381        writer.write_u32(self.0)
382    }
383}
384
385impl Readable for ProtocolVersion {
386    fn read<R: Reader>(reader: &mut R) -> Result<ProtocolVersion, Error> {
387        let version = reader.read_u32()?;
388        Ok(ProtocolVersion(version))
389    }
390}
391
392/// Trait that every type that can be deserialized from binary must implement.
393/// Reads directly to a Reader, a utility type thinly wrapping an
394/// underlying Read implementation.
395pub trait Readable
396where
397    Self: Sized,
398{
399    /// Reads the data necessary to this Readable from the provided reader
400    fn read<R: Reader>(reader: &mut R) -> Result<Self, Error>;
401}
402
403/// Deserializes a Readable from any std::io::Read implementation.
404pub fn deserialize<T: Readable, R: Read>(
405    source: &mut R,
406    version: ProtocolVersion,
407    mode: DeserializationMode,
408) -> Result<T, Error> {
409    let mut reader = BinReader::new(source, version, mode);
410    T::read(&mut reader)
411}
412
413/// Deserialize a Readable based on our default "local" protocol version.
414pub fn deserialize_default<T: Readable, R: Read>(source: &mut R) -> Result<T, Error> {
415    deserialize(
416        source,
417        ProtocolVersion::local(),
418        DeserializationMode::default(),
419    )
420}
421
422/// Serializes a Writeable into any std::io::Write implementation.
423pub fn serialize<W: Writeable>(
424    sink: &mut dyn Write,
425    version: ProtocolVersion,
426    thing: &W,
427) -> Result<(), Error> {
428    let mut writer = BinWriter::new(sink, version);
429    thing.write(&mut writer)
430}
431
432/// Serialize a Writeable according to our default "local" protocol version.
433pub fn serialize_default<W: Writeable>(sink: &mut dyn Write, thing: &W) -> Result<(), Error> {
434    serialize(sink, ProtocolVersion::local(), thing)
435}
436
437/// Utility function to serialize a writeable directly in memory using a
438/// Vec<u8>.
439pub fn ser_vec<W: Writeable>(thing: &W, version: ProtocolVersion) -> Result<Vec<u8>, Error> {
440    let mut vec = vec![];
441    serialize(&mut vec, version, thing)?;
442    Ok(vec)
443}
444
445/// Utility to read from a binary source
446pub struct BinReader<'a, R: Read> {
447    source: &'a mut R,
448    version: ProtocolVersion,
449    deser_mode: DeserializationMode,
450}
451
452impl<'a, R: Read> BinReader<'a, R> {
453    /// Constructor for a new BinReader for the provided source and protocol version.
454    pub fn new(source: &'a mut R, version: ProtocolVersion, mode: DeserializationMode) -> Self {
455        BinReader {
456            source,
457            version,
458            deser_mode: mode,
459        }
460    }
461}
462
463fn map_io_err(err: io::Error) -> Error {
464    Error::IOErr(format!("{}", err), err.kind())
465}
466
467/// Utility wrapper for an underlying byte Reader. Defines higher level methods
468/// to read numbers, byte vectors, hashes, etc.
469impl<'a, R: Read> Reader for BinReader<'a, R> {
470    fn deserialization_mode(&self) -> DeserializationMode {
471        self.deser_mode
472    }
473    fn read_u8(&mut self) -> Result<u8, Error> {
474        self.source.read_u8().map_err(map_io_err)
475    }
476    fn read_u16(&mut self) -> Result<u16, Error> {
477        self.source.read_u16::<BigEndian>().map_err(map_io_err)
478    }
479    fn read_u32(&mut self) -> Result<u32, Error> {
480        self.source.read_u32::<BigEndian>().map_err(map_io_err)
481    }
482    fn read_i32(&mut self) -> Result<i32, Error> {
483        self.source.read_i32::<BigEndian>().map_err(map_io_err)
484    }
485    fn read_u64(&mut self) -> Result<u64, Error> {
486        self.source.read_u64::<BigEndian>().map_err(map_io_err)
487    }
488    fn read_i64(&mut self) -> Result<i64, Error> {
489        self.source.read_i64::<BigEndian>().map_err(map_io_err)
490    }
491    /// Read a variable size vector from the underlying Read. Expects a usize
492    fn read_bytes_len_prefix(&mut self) -> Result<Vec<u8>, Error> {
493        let len = self.read_u64()?;
494        self.read_fixed_bytes(len as usize)
495    }
496
497    /// Read a fixed number of bytes.
498    fn read_fixed_bytes(&mut self, len: usize) -> Result<Vec<u8>, Error> {
499        // not reading more than 100k bytes in a single read
500        if len > 100_000 {
501            return Err(Error::TooLargeReadErr);
502        }
503        let mut buf = vec![0; len];
504        self.source
505            .read_exact(&mut buf)
506            .map(move |_| buf)
507            .map_err(map_io_err)
508    }
509
510    fn expect_u8(&mut self, val: u8) -> Result<u8, Error> {
511        let b = self.read_u8()?;
512        if b == val {
513            Ok(b)
514        } else {
515            Err(Error::UnexpectedData {
516                expected: vec![val],
517                received: vec![b],
518            })
519        }
520    }
521
522    fn protocol_version(&self) -> ProtocolVersion {
523        self.version
524    }
525}
526
527/// A reader that reads straight off a stream.
528/// Tracks total bytes read so we can verify we read the right number afterwards.
529pub struct StreamingReader<'a> {
530    total_bytes_read: u64,
531    version: ProtocolVersion,
532    stream: &'a mut dyn Read,
533    deser_mode: DeserializationMode,
534}
535
536impl<'a> StreamingReader<'a> {
537    /// Create a new streaming reader with the provided underlying stream.
538    /// Also takes a duration to be used for each individual read_exact call.
539    pub fn new(stream: &'a mut dyn Read, version: ProtocolVersion) -> StreamingReader<'a> {
540        StreamingReader {
541            total_bytes_read: 0,
542            version,
543            stream,
544            deser_mode: DeserializationMode::Full,
545        }
546    }
547
548    /// Returns the total bytes read via this streaming reader.
549    pub fn total_bytes_read(&self) -> u64 {
550        self.total_bytes_read
551    }
552}
553
554/// Note: We use read_fixed_bytes() here to ensure our "async" I/O behaves as expected.
555impl<'a> Reader for StreamingReader<'a> {
556    fn deserialization_mode(&self) -> DeserializationMode {
557        self.deser_mode
558    }
559    fn read_u8(&mut self) -> Result<u8, Error> {
560        let buf = self.read_fixed_bytes(1)?;
561        Ok(buf[0])
562    }
563    fn read_u16(&mut self) -> Result<u16, Error> {
564        let buf = self.read_fixed_bytes(2)?;
565        Ok(BigEndian::read_u16(&buf[..]))
566    }
567    fn read_u32(&mut self) -> Result<u32, Error> {
568        let buf = self.read_fixed_bytes(4)?;
569        Ok(BigEndian::read_u32(&buf[..]))
570    }
571    fn read_i32(&mut self) -> Result<i32, Error> {
572        let buf = self.read_fixed_bytes(4)?;
573        Ok(BigEndian::read_i32(&buf[..]))
574    }
575    fn read_u64(&mut self) -> Result<u64, Error> {
576        let buf = self.read_fixed_bytes(8)?;
577        Ok(BigEndian::read_u64(&buf[..]))
578    }
579    fn read_i64(&mut self) -> Result<i64, Error> {
580        let buf = self.read_fixed_bytes(8)?;
581        Ok(BigEndian::read_i64(&buf[..]))
582    }
583
584    /// Read a variable size vector from the underlying stream. Expects a usize
585    fn read_bytes_len_prefix(&mut self) -> Result<Vec<u8>, Error> {
586        let len = self.read_u64()?;
587        self.total_bytes_read += 8;
588        self.read_fixed_bytes(len as usize)
589    }
590
591    /// Read a fixed number of bytes.
592    fn read_fixed_bytes(&mut self, len: usize) -> Result<Vec<u8>, Error> {
593        let mut buf = vec![0u8; len];
594        self.stream.read_exact(&mut buf)?;
595        self.total_bytes_read += len as u64;
596        Ok(buf)
597    }
598
599    fn expect_u8(&mut self, val: u8) -> Result<u8, Error> {
600        let b = self.read_u8()?;
601        if b == val {
602            Ok(b)
603        } else {
604            Err(Error::UnexpectedData {
605                expected: vec![val],
606                received: vec![b],
607            })
608        }
609    }
610
611    fn protocol_version(&self) -> ProtocolVersion {
612        self.version
613    }
614}
615
616/// Protocol version-aware wrapper around a `Buf` impl
617pub struct BufReader<'a, B: Buf> {
618    inner: &'a mut B,
619    version: ProtocolVersion,
620    bytes_read: usize,
621    deser_mode: DeserializationMode,
622}
623
624impl<'a, B: Buf> BufReader<'a, B> {
625    /// Construct a new BufReader
626    pub fn new(buf: &'a mut B, version: ProtocolVersion) -> Self {
627        Self {
628            inner: buf,
629            version,
630            bytes_read: 0,
631            deser_mode: DeserializationMode::Full,
632        }
633    }
634
635    /// Check whether the buffer has enough bytes remaining to perform a read
636    fn has_remaining(&mut self, len: usize) -> Result<(), Error> {
637        if self.inner.remaining() >= len {
638            self.bytes_read += len;
639            Ok(())
640        } else {
641            Err(io::ErrorKind::UnexpectedEof.into())
642        }
643    }
644
645    /// The total bytes read
646    pub fn bytes_read(&self) -> u64 {
647        self.bytes_read as u64
648    }
649
650    /// Convenience function to read from the buffer and deserialize
651    pub fn body<T: Readable>(&mut self) -> Result<T, Error> {
652        T::read(self)
653    }
654}
655
656impl<'a, B: Buf> Reader for BufReader<'a, B> {
657    fn deserialization_mode(&self) -> DeserializationMode {
658        self.deser_mode
659    }
660
661    fn read_u8(&mut self) -> Result<u8, Error> {
662        self.has_remaining(1)?;
663        Ok(self.inner.get_u8())
664    }
665
666    fn read_u16(&mut self) -> Result<u16, Error> {
667        self.has_remaining(2)?;
668        Ok(self.inner.get_u16())
669    }
670
671    fn read_u32(&mut self) -> Result<u32, Error> {
672        self.has_remaining(4)?;
673        Ok(self.inner.get_u32())
674    }
675
676    fn read_u64(&mut self) -> Result<u64, Error> {
677        self.has_remaining(8)?;
678        Ok(self.inner.get_u64())
679    }
680
681    fn read_i32(&mut self) -> Result<i32, Error> {
682        self.has_remaining(4)?;
683        Ok(self.inner.get_i32())
684    }
685
686    fn read_i64(&mut self) -> Result<i64, Error> {
687        self.has_remaining(8)?;
688        Ok(self.inner.get_i64())
689    }
690
691    fn read_bytes_len_prefix(&mut self) -> Result<Vec<u8>, Error> {
692        let len = self.read_u64()?;
693        self.read_fixed_bytes(len as usize)
694    }
695
696    fn read_fixed_bytes(&mut self, len: usize) -> Result<Vec<u8>, Error> {
697        // not reading more than 100k bytes in a single read
698        if len > 100_000 {
699            return Err(Error::TooLargeReadErr);
700        }
701        self.has_remaining(len)?;
702
703        let mut buf = vec![0; len];
704        self.inner.copy_to_slice(&mut buf[..]);
705        Ok(buf)
706    }
707
708    fn expect_u8(&mut self, val: u8) -> Result<u8, Error> {
709        let b = self.read_u8()?;
710        if b == val {
711            Ok(b)
712        } else {
713            Err(Error::UnexpectedData {
714                expected: vec![val],
715                received: vec![b],
716            })
717        }
718    }
719
720    fn protocol_version(&self) -> ProtocolVersion {
721        self.version
722    }
723}
724
725/// Collections of items must be sorted lexicographically and all unique.
726pub trait VerifySortedAndUnique<T> {
727    /// Verify a collection of items is sorted and all unique.
728    fn verify_sorted_and_unique(&self) -> Result<(), Error>;
729}
730
731impl<T: Ord> VerifySortedAndUnique<T> for Vec<T> {
732    fn verify_sorted_and_unique(&self) -> Result<(), Error> {
733        for pair in self.windows(2) {
734            if pair[0] > pair[1] {
735                return Err(Error::SortError);
736            } else if pair[0] == pair[1] {
737                return Err(Error::DuplicateError);
738            }
739        }
740        Ok(())
741    }
742}
743
744/// Utility wrapper for an underlying byte Writer. Defines higher level methods
745/// to write numbers, byte vectors, hashes, etc.
746pub struct BinWriter<'a> {
747    sink: &'a mut dyn Write,
748    version: ProtocolVersion,
749}
750
751impl<'a> BinWriter<'a> {
752    /// Wraps a standard Write in a new BinWriter
753    pub fn new(sink: &'a mut dyn Write, version: ProtocolVersion) -> BinWriter<'a> {
754        BinWriter { sink, version }
755    }
756
757    /// Constructor for BinWriter with default "local" protocol version.
758    pub fn default(sink: &'a mut dyn Write) -> BinWriter<'a> {
759        BinWriter::new(sink, ProtocolVersion::local())
760    }
761}
762
763impl<'a> Writer for BinWriter<'a> {
764    fn serialization_mode(&self) -> SerializationMode {
765        SerializationMode::Full
766    }
767
768    fn write_fixed_bytes<T: AsRef<[u8]>>(&mut self, bytes: T) -> Result<(), Error> {
769        self.sink.write_all(bytes.as_ref())?;
770        Ok(())
771    }
772
773    fn protocol_version(&self) -> ProtocolVersion {
774        self.version
775    }
776}
777
778macro_rules! impl_int {
779    ($int:ty, $w_fn:ident, $r_fn:ident) => {
780        impl Writeable for $int {
781            fn write<W: Writer>(&self, writer: &mut W) -> Result<(), Error> {
782                writer.$w_fn(*self)
783            }
784        }
785
786        impl Readable for $int {
787            fn read<R: Reader>(reader: &mut R) -> Result<$int, Error> {
788                reader.$r_fn()
789            }
790        }
791    };
792}
793
794impl_int!(u8, write_u8, read_u8);
795impl_int!(u16, write_u16, read_u16);
796impl_int!(u32, write_u32, read_u32);
797impl_int!(i32, write_i32, read_i32);
798impl_int!(u64, write_u64, read_u64);
799impl_int!(i64, write_i64, read_i64);
800
801impl<T> Readable for Vec<T>
802where
803    T: Readable,
804{
805    fn read<R: Reader>(reader: &mut R) -> Result<Vec<T>, Error> {
806        let mut buf = Vec::new();
807        loop {
808            let elem = T::read(reader);
809            match elem {
810                Ok(e) => buf.push(e),
811                Err(Error::IOErr(ref _d, ref kind)) if *kind == io::ErrorKind::UnexpectedEof => {
812                    break;
813                }
814                Err(e) => return Err(e),
815            }
816        }
817        Ok(buf)
818    }
819}
820
821impl<T> Writeable for Vec<T>
822where
823    T: Writeable,
824{
825    fn write<W: Writer>(&self, writer: &mut W) -> Result<(), Error> {
826        for elmt in self {
827            elmt.write(writer)?;
828        }
829        Ok(())
830    }
831}
832
833impl<'a, A: Writeable> Writeable for &'a A {
834    fn write<W: Writer>(&self, writer: &mut W) -> Result<(), Error> {
835        Writeable::write(*self, writer)
836    }
837}
838
839impl<A: Writeable, B: Writeable> Writeable for (A, B) {
840    fn write<W: Writer>(&self, writer: &mut W) -> Result<(), Error> {
841        Writeable::write(&self.0, writer)?;
842        Writeable::write(&self.1, writer)
843    }
844}
845
846impl<A: Readable, B: Readable> Readable for (A, B) {
847    fn read<R: Reader>(reader: &mut R) -> Result<(A, B), Error> {
848        Ok((Readable::read(reader)?, Readable::read(reader)?))
849    }
850}
851
852impl<A: Writeable, B: Writeable, C: Writeable> Writeable for (A, B, C) {
853    fn write<W: Writer>(&self, writer: &mut W) -> Result<(), Error> {
854        Writeable::write(&self.0, writer)?;
855        Writeable::write(&self.1, writer)?;
856        Writeable::write(&self.2, writer)
857    }
858}
859
860impl<A: Writeable, B: Writeable, C: Writeable, D: Writeable> Writeable for (A, B, C, D) {
861    fn write<W: Writer>(&self, writer: &mut W) -> Result<(), Error> {
862        Writeable::write(&self.0, writer)?;
863        Writeable::write(&self.1, writer)?;
864        Writeable::write(&self.2, writer)?;
865        Writeable::write(&self.3, writer)
866    }
867}
868
869impl<A: Readable, B: Readable, C: Readable> Readable for (A, B, C) {
870    fn read<R: Reader>(reader: &mut R) -> Result<(A, B, C), Error> {
871        Ok((
872            Readable::read(reader)?,
873            Readable::read(reader)?,
874            Readable::read(reader)?,
875        ))
876    }
877}
878
879impl<A: Readable, B: Readable, C: Readable, D: Readable> Readable for (A, B, C, D) {
880    fn read<R: Reader>(reader: &mut R) -> Result<(A, B, C, D), Error> {
881        Ok((
882            Readable::read(reader)?,
883            Readable::read(reader)?,
884            Readable::read(reader)?,
885            Readable::read(reader)?,
886        ))
887    }
888}
889
890/// Trait for types that can be added to a PMMR.
891pub trait PMMRable: Writeable + Clone + Debug + DefaultHashable {
892    /// The type of element actually stored in the MMR data file.
893    /// This allows us to store Hash elements in the header MMR for variable size BlockHeaders.
894    type E: Readable + Writeable + Debug;
895
896    /// Convert the pmmrable into the element to be stored in the MMR data file.
897    fn as_elmt(&self) -> Self::E;
898
899    /// Size of each element if "fixed" size. Elements are "variable" size if None.
900    fn elmt_size() -> Option<u16>;
901}
902
903/// Generic trait to ensure PMMR elements can be hashed with an index
904pub trait PMMRIndexHashable {
905    /// Hash with a given index
906    fn hash_with_index(&self, index: u64) -> Hash;
907}
908
909impl<T: DefaultHashable> PMMRIndexHashable for T {
910    fn hash_with_index(&self, index: u64) -> Hash {
911        (index, self).hash()
912    }
913}
914
915// serializer for io::Errorkind, originally auto-generated by serde-derive
916// slightly modified to handle the #[non_exhaustive] tag on io::ErrorKind
917fn serialize_error_kind<S>(kind: &io::ErrorKind, serializer: S) -> Result<S::Ok, S::Error>
918where
919    S: serde::Serializer,
920{
921    match *kind {
922        io::ErrorKind::NotFound => {
923            serde::Serializer::serialize_unit_variant(serializer, "ErrorKind", 0u32, "NotFound")
924        }
925        io::ErrorKind::PermissionDenied => serde::Serializer::serialize_unit_variant(
926            serializer,
927            "ErrorKind",
928            1u32,
929            "PermissionDenied",
930        ),
931        io::ErrorKind::ConnectionRefused => serde::Serializer::serialize_unit_variant(
932            serializer,
933            "ErrorKind",
934            2u32,
935            "ConnectionRefused",
936        ),
937        io::ErrorKind::ConnectionReset => serde::Serializer::serialize_unit_variant(
938            serializer,
939            "ErrorKind",
940            3u32,
941            "ConnectionReset",
942        ),
943        io::ErrorKind::ConnectionAborted => serde::Serializer::serialize_unit_variant(
944            serializer,
945            "ErrorKind",
946            4u32,
947            "ConnectionAborted",
948        ),
949        io::ErrorKind::NotConnected => {
950            serde::Serializer::serialize_unit_variant(serializer, "ErrorKind", 5u32, "NotConnected")
951        }
952        io::ErrorKind::AddrInUse => {
953            serde::Serializer::serialize_unit_variant(serializer, "ErrorKind", 6u32, "AddrInUse")
954        }
955        io::ErrorKind::AddrNotAvailable => serde::Serializer::serialize_unit_variant(
956            serializer,
957            "ErrorKind",
958            7u32,
959            "AddrNotAvailable",
960        ),
961        io::ErrorKind::BrokenPipe => {
962            serde::Serializer::serialize_unit_variant(serializer, "ErrorKind", 8u32, "BrokenPipe")
963        }
964        io::ErrorKind::AlreadyExists => serde::Serializer::serialize_unit_variant(
965            serializer,
966            "ErrorKind",
967            9u32,
968            "AlreadyExists",
969        ),
970        io::ErrorKind::WouldBlock => {
971            serde::Serializer::serialize_unit_variant(serializer, "ErrorKind", 10u32, "WouldBlock")
972        }
973        io::ErrorKind::InvalidInput => serde::Serializer::serialize_unit_variant(
974            serializer,
975            "ErrorKind",
976            11u32,
977            "InvalidInput",
978        ),
979        io::ErrorKind::InvalidData => {
980            serde::Serializer::serialize_unit_variant(serializer, "ErrorKind", 12u32, "InvalidData")
981        }
982        io::ErrorKind::TimedOut => {
983            serde::Serializer::serialize_unit_variant(serializer, "ErrorKind", 13u32, "TimedOut")
984        }
985        io::ErrorKind::WriteZero => {
986            serde::Serializer::serialize_unit_variant(serializer, "ErrorKind", 14u32, "WriteZero")
987        }
988        io::ErrorKind::Interrupted => {
989            serde::Serializer::serialize_unit_variant(serializer, "ErrorKind", 15u32, "Interrupted")
990        }
991        io::ErrorKind::Other => {
992            serde::Serializer::serialize_unit_variant(serializer, "ErrorKind", 16u32, "Other")
993        }
994        io::ErrorKind::UnexpectedEof => serde::Serializer::serialize_unit_variant(
995            serializer,
996            "ErrorKind",
997            17u32,
998            "UnexpectedEof",
999        ),
1000        // #[non_exhaustive] is used on the definition of ErrorKind for future compatability
1001        // That means match statements always need to match on _.
1002        // The downside here is that rustc won't be able to warn us if io::ErrorKind another
1003        // field is added to io::ErrorKind
1004        _ => serde::Serializer::serialize_unit_variant(serializer, "ErrorKind", 16u32, "Other"),
1005    }
1006}
1007
1008// deserializer for io::Errorkind, originally auto-generated by serde-derive
1009fn deserialize_error_kind<'de, D>(deserializer: D) -> Result<io::ErrorKind, D::Error>
1010where
1011    D: serde::Deserializer<'de>,
1012{
1013    #[allow(non_camel_case_types)]
1014    enum Field {
1015        field0,
1016        field1,
1017        field2,
1018        field3,
1019        field4,
1020        field5,
1021        field6,
1022        field7,
1023        field8,
1024        field9,
1025        field10,
1026        field11,
1027        field12,
1028        field13,
1029        field14,
1030        field15,
1031        field16,
1032        field17,
1033    }
1034    struct FieldVisitor;
1035    impl<'de> serde::de::Visitor<'de> for FieldVisitor {
1036        type Value = Field;
1037        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
1038            fmt::Formatter::write_str(formatter, "variant identifier")
1039        }
1040        fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
1041        where
1042            E: serde::de::Error,
1043        {
1044            match value {
1045                0u64 => Ok(Field::field0),
1046                1u64 => Ok(Field::field1),
1047                2u64 => Ok(Field::field2),
1048                3u64 => Ok(Field::field3),
1049                4u64 => Ok(Field::field4),
1050                5u64 => Ok(Field::field5),
1051                6u64 => Ok(Field::field6),
1052                7u64 => Ok(Field::field7),
1053                8u64 => Ok(Field::field8),
1054                9u64 => Ok(Field::field9),
1055                10u64 => Ok(Field::field10),
1056                11u64 => Ok(Field::field11),
1057                12u64 => Ok(Field::field12),
1058                13u64 => Ok(Field::field13),
1059                14u64 => Ok(Field::field14),
1060                15u64 => Ok(Field::field15),
1061                16u64 => Ok(Field::field16),
1062                17u64 => Ok(Field::field17),
1063                _ => Err(serde::de::Error::invalid_value(
1064                    serde::de::Unexpected::Unsigned(value),
1065                    &"variant index 0 <= i < 18",
1066                )),
1067            }
1068        }
1069        fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
1070        where
1071            E: serde::de::Error,
1072        {
1073            match value {
1074                "NotFound" => Ok(Field::field0),
1075                "PermissionDenied" => Ok(Field::field1),
1076                "ConnectionRefused" => Ok(Field::field2),
1077                "ConnectionReset" => Ok(Field::field3),
1078                "ConnectionAborted" => Ok(Field::field4),
1079                "NotConnected" => Ok(Field::field5),
1080                "AddrInUse" => Ok(Field::field6),
1081                "AddrNotAvailable" => Ok(Field::field7),
1082                "BrokenPipe" => Ok(Field::field8),
1083                "AlreadyExists" => Ok(Field::field9),
1084                "WouldBlock" => Ok(Field::field10),
1085                "InvalidInput" => Ok(Field::field11),
1086                "InvalidData" => Ok(Field::field12),
1087                "TimedOut" => Ok(Field::field13),
1088                "WriteZero" => Ok(Field::field14),
1089                "Interrupted" => Ok(Field::field15),
1090                "Other" => Ok(Field::field16),
1091                "UnexpectedEof" => Ok(Field::field17),
1092                _ => Err(serde::de::Error::unknown_variant(value, VARIANTS)),
1093            }
1094        }
1095        fn visit_bytes<E>(self, value: &[u8]) -> Result<Self::Value, E>
1096        where
1097            E: serde::de::Error,
1098        {
1099            match value {
1100                b"NotFound" => Ok(Field::field0),
1101                b"PermissionDenied" => Ok(Field::field1),
1102                b"ConnectionRefused" => Ok(Field::field2),
1103                b"ConnectionReset" => Ok(Field::field3),
1104                b"ConnectionAborted" => Ok(Field::field4),
1105                b"NotConnected" => Ok(Field::field5),
1106                b"AddrInUse" => Ok(Field::field6),
1107                b"AddrNotAvailable" => Ok(Field::field7),
1108                b"BrokenPipe" => Ok(Field::field8),
1109                b"AlreadyExists" => Ok(Field::field9),
1110                b"WouldBlock" => Ok(Field::field10),
1111                b"InvalidInput" => Ok(Field::field11),
1112                b"InvalidData" => Ok(Field::field12),
1113                b"TimedOut" => Ok(Field::field13),
1114                b"WriteZero" => Ok(Field::field14),
1115                b"Interrupted" => Ok(Field::field15),
1116                b"Other" => Ok(Field::field16),
1117                b"UnexpectedEof" => Ok(Field::field17),
1118                _ => {
1119                    let value = &string::String::from_utf8_lossy(value);
1120                    Err(serde::de::Error::unknown_variant(value, VARIANTS))
1121                }
1122            }
1123        }
1124    }
1125    impl<'de> serde::Deserialize<'de> for Field {
1126        #[inline]
1127        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
1128        where
1129            D: serde::Deserializer<'de>,
1130        {
1131            serde::Deserializer::deserialize_identifier(deserializer, FieldVisitor)
1132        }
1133    }
1134    struct Visitor<'de> {
1135        marker: marker::PhantomData<io::ErrorKind>,
1136        lifetime: marker::PhantomData<&'de ()>,
1137    }
1138    impl<'de> serde::de::Visitor<'de> for Visitor<'de> {
1139        type Value = io::ErrorKind;
1140        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
1141            fmt::Formatter::write_str(formatter, "enum io::ErrorKind")
1142        }
1143        fn visit_enum<A>(self, data: A) -> Result<Self::Value, A::Error>
1144        where
1145            A: serde::de::EnumAccess<'de>,
1146        {
1147            match match serde::de::EnumAccess::variant(data) {
1148                Ok(val) => val,
1149                Err(err) => {
1150                    return Err(err);
1151                }
1152            } {
1153                (Field::field0, variant) => {
1154                    match serde::de::VariantAccess::unit_variant(variant) {
1155                        Ok(val) => val,
1156                        Err(err) => {
1157                            return Err(err);
1158                        }
1159                    };
1160                    Ok(io::ErrorKind::NotFound)
1161                }
1162                (Field::field1, variant) => {
1163                    match serde::de::VariantAccess::unit_variant(variant) {
1164                        Ok(val) => val,
1165                        Err(err) => {
1166                            return Err(err);
1167                        }
1168                    };
1169                    Ok(io::ErrorKind::PermissionDenied)
1170                }
1171                (Field::field2, variant) => {
1172                    match serde::de::VariantAccess::unit_variant(variant) {
1173                        Ok(val) => val,
1174                        Err(err) => {
1175                            return Err(err);
1176                        }
1177                    };
1178                    Ok(io::ErrorKind::ConnectionRefused)
1179                }
1180                (Field::field3, variant) => {
1181                    match serde::de::VariantAccess::unit_variant(variant) {
1182                        Ok(val) => val,
1183                        Err(err) => {
1184                            return Err(err);
1185                        }
1186                    };
1187                    Ok(io::ErrorKind::ConnectionReset)
1188                }
1189                (Field::field4, variant) => {
1190                    match serde::de::VariantAccess::unit_variant(variant) {
1191                        Ok(val) => val,
1192                        Err(err) => {
1193                            return Err(err);
1194                        }
1195                    };
1196                    Ok(io::ErrorKind::ConnectionAborted)
1197                }
1198                (Field::field5, variant) => {
1199                    match serde::de::VariantAccess::unit_variant(variant) {
1200                        Ok(val) => val,
1201                        Err(err) => {
1202                            return Err(err);
1203                        }
1204                    };
1205                    Ok(io::ErrorKind::NotConnected)
1206                }
1207                (Field::field6, variant) => {
1208                    match serde::de::VariantAccess::unit_variant(variant) {
1209                        Ok(val) => val,
1210                        Err(err) => {
1211                            return Err(err);
1212                        }
1213                    };
1214                    Ok(io::ErrorKind::AddrInUse)
1215                }
1216                (Field::field7, variant) => {
1217                    match serde::de::VariantAccess::unit_variant(variant) {
1218                        Ok(val) => val,
1219                        Err(err) => {
1220                            return Err(err);
1221                        }
1222                    };
1223                    Ok(io::ErrorKind::AddrNotAvailable)
1224                }
1225                (Field::field8, variant) => {
1226                    match serde::de::VariantAccess::unit_variant(variant) {
1227                        Ok(val) => val,
1228                        Err(err) => {
1229                            return Err(err);
1230                        }
1231                    };
1232                    Ok(io::ErrorKind::BrokenPipe)
1233                }
1234                (Field::field9, variant) => {
1235                    match serde::de::VariantAccess::unit_variant(variant) {
1236                        Ok(val) => val,
1237                        Err(err) => {
1238                            return Err(err);
1239                        }
1240                    };
1241                    Ok(io::ErrorKind::AlreadyExists)
1242                }
1243                (Field::field10, variant) => {
1244                    match serde::de::VariantAccess::unit_variant(variant) {
1245                        Ok(val) => val,
1246                        Err(err) => {
1247                            return Err(err);
1248                        }
1249                    };
1250                    Ok(io::ErrorKind::WouldBlock)
1251                }
1252                (Field::field11, variant) => {
1253                    match serde::de::VariantAccess::unit_variant(variant) {
1254                        Ok(val) => val,
1255                        Err(err) => {
1256                            return Err(err);
1257                        }
1258                    };
1259                    Ok(io::ErrorKind::InvalidInput)
1260                }
1261                (Field::field12, variant) => {
1262                    match serde::de::VariantAccess::unit_variant(variant) {
1263                        Ok(val) => val,
1264                        Err(err) => {
1265                            return Err(err);
1266                        }
1267                    };
1268                    Ok(io::ErrorKind::InvalidData)
1269                }
1270                (Field::field13, variant) => {
1271                    match serde::de::VariantAccess::unit_variant(variant) {
1272                        Ok(val) => val,
1273                        Err(err) => {
1274                            return Err(err);
1275                        }
1276                    };
1277                    Ok(io::ErrorKind::TimedOut)
1278                }
1279                (Field::field14, variant) => {
1280                    match serde::de::VariantAccess::unit_variant(variant) {
1281                        Ok(val) => val,
1282                        Err(err) => {
1283                            return Err(err);
1284                        }
1285                    };
1286                    Ok(io::ErrorKind::WriteZero)
1287                }
1288                (Field::field15, variant) => {
1289                    match serde::de::VariantAccess::unit_variant(variant) {
1290                        Ok(val) => val,
1291                        Err(err) => {
1292                            return Err(err);
1293                        }
1294                    };
1295                    Ok(io::ErrorKind::Interrupted)
1296                }
1297                (Field::field16, variant) => {
1298                    match serde::de::VariantAccess::unit_variant(variant) {
1299                        Ok(val) => val,
1300                        Err(err) => {
1301                            return Err(err);
1302                        }
1303                    };
1304                    Ok(io::ErrorKind::Other)
1305                }
1306                (Field::field17, variant) => {
1307                    match serde::de::VariantAccess::unit_variant(variant) {
1308                        Ok(val) => val,
1309                        Err(err) => {
1310                            return Err(err);
1311                        }
1312                    };
1313                    Ok(io::ErrorKind::UnexpectedEof)
1314                }
1315            }
1316        }
1317    }
1318    const VARIANTS: &[&str] = &[
1319        "NotFound",
1320        "PermissionDenied",
1321        "ConnectionRefused",
1322        "ConnectionReset",
1323        "ConnectionAborted",
1324        "NotConnected",
1325        "AddrInUse",
1326        "AddrNotAvailable",
1327        "BrokenPipe",
1328        "AlreadyExists",
1329        "WouldBlock",
1330        "InvalidInput",
1331        "InvalidData",
1332        "TimedOut",
1333        "WriteZero",
1334        "Interrupted",
1335        "Other",
1336        "UnexpectedEof",
1337    ];
1338    serde::Deserializer::deserialize_enum(
1339        deserializer,
1340        "ErrorKind",
1341        VARIANTS,
1342        Visitor {
1343            marker: marker::PhantomData::<io::ErrorKind>,
1344            lifetime: marker::PhantomData,
1345        },
1346    )
1347}