minicbor_io/
async_reader.rs

1use crate::Error;
2use futures_io::AsyncRead;
3use futures_util::AsyncReadExt;
4use minicbor::Decode;
5use std::io;
6
7/// Wraps an [`AsyncRead`] and reads length-delimited CBOR values.
8///
9/// *Requires cargo feature* `"async-io"`.
10#[derive(Debug)]
11pub struct AsyncReader<R> {
12    reader: R,
13    buffer: Vec<u8>,
14    max_len: usize,
15    state: State
16}
17
18/// Read state.
19#[derive(Debug)]
20enum State {
21    /// Reading length prefix.
22    ReadLen([u8; 4], u8),
23    /// Reading CBOR item bytes.
24    ReadVal(usize)
25}
26
27impl State {
28    /// Setup a new state.
29    fn new() -> Self {
30        State::ReadLen([0; 4], 0)
31    }
32}
33
34impl<R> AsyncReader<R> {
35    /// Create a new reader with a max. buffer size of 512KiB.
36    pub fn new(reader: R) -> Self {
37        Self::with_buffer(reader, Vec::new())
38    }
39
40    /// Create a new reader with a max. buffer size of 512KiB.
41    pub fn with_buffer(reader: R, buffer: Vec<u8>) -> Self {
42        Self { reader, buffer, max_len: 512 * 1024, state: State::new() }
43    }
44
45    /// Set the max. buffer size in bytes.
46    ///
47    /// If length values greater than this are decoded, an
48    /// [`Error::InvalidLen`] will be returned.
49    pub fn set_max_len(&mut self, val: u32) {
50        self.max_len = val as usize
51    }
52
53    /// Get a reference to the inner reader.
54    pub fn reader(&self) -> &R {
55        &self.reader
56    }
57
58    /// Get a mutable reference to the inner reader.
59    pub fn reader_mut(&mut self) -> &mut R {
60        &mut self.reader
61    }
62
63    /// Deconstruct this reader into the inner reader and the buffer.
64    pub fn into_parts(self) -> (R, Vec<u8>) {
65        (self.reader, self.buffer)
66    }
67}
68
69impl<R: AsyncRead + Unpin> AsyncReader<R> {
70    /// Read the next CBOR value and decode it.
71    ///
72    /// The value is assumed to be preceded by a `u32` (4 bytes in network
73    /// byte order) denoting the length of the CBOR item in bytes.
74    ///
75    /// Reading 0 bytes when decoding the length prefix results in `Ok(None)`,
76    /// otherwise either `Some` value or an error is returned.
77    ///
78    /// # Cancellation
79    ///
80    /// The future returned by `AsyncReader::read` can be dropped while still
81    /// pending. Subsequent calls to `AsyncReader::read` will resume reading
82    /// where the previous future left off.
83    pub async fn read<'a, T: Decode<'a, ()>>(&'a mut self) -> Result<Option<T>, Error> {
84        self.read_with(&mut ()).await
85    }
86
87    /// Like [`AsyncReader::read`] but accepting a user provided decoding context.
88    pub async fn read_with<'a, C, T: Decode<'a, C>>(&'a mut self, ctx: &mut C) -> Result<Option<T>, Error> {
89        loop {
90            match self.state {
91                State::ReadLen(buf, 4) => {
92                    let len = u32::from_be_bytes(buf) as usize;
93                    if len > self.max_len {
94                        return Err(Error::InvalidLen)
95                    }
96                    self.buffer.clear();
97                    self.buffer.resize(len, 0u8);
98                    self.state = State::ReadVal(0)
99                }
100                State::ReadLen(ref mut buf, ref mut o) => {
101                    let n = self.reader.read(&mut buf[usize::from(*o) ..]).await?;
102                    if n == 0 {
103                        return if *o == 0 {
104                            Ok(None)
105                        } else {
106                            Err(Error::Io(io::ErrorKind::UnexpectedEof.into()))
107                        }
108                    }
109                    *o += n as u8
110                }
111                State::ReadVal(o) if o >= self.buffer.len() => {
112                    self.state = State::new();
113                    return minicbor::decode_with(&self.buffer, ctx).map_err(Error::Decode).map(Some)
114                }
115                State::ReadVal(ref mut o) => {
116                    let n = self.reader.read(&mut self.buffer[*o ..]).await?;
117                    if n == 0 {
118                        return Err(Error::Io(io::ErrorKind::UnexpectedEof.into()))
119                    }
120                    *o += n
121                }
122            }
123        }
124    }
125}
126