Skip to main content

radicle_protocol/
deserializer.rs

1use std::io;
2use std::marker::PhantomData;
3
4use crate::bounded;
5use crate::bounded::BoundedVec;
6use crate::service::message::Message;
7use crate::wire;
8
9/// Message stream deserializer.
10///
11/// Used, for example, to turn a byte stream into network messages.
12#[derive(Debug)]
13pub struct Deserializer<const B: usize, D = Message> {
14    unparsed: BoundedVec<u8, B>,
15    item: PhantomData<D>,
16}
17
18impl<const B: usize, D: wire::Decode> Default for Deserializer<B, D> {
19    fn default() -> Self {
20        Self::new(wire::Size::MAX as usize + 1)
21    }
22}
23
24impl<const B: usize, D> TryFrom<Vec<u8>> for Deserializer<B, D> {
25    type Error = bounded::Error;
26
27    fn try_from(unparsed: Vec<u8>) -> Result<Self, Self::Error> {
28        BoundedVec::try_from(unparsed).map(|unparsed| Self {
29            unparsed,
30            item: PhantomData,
31        })
32    }
33}
34
35impl<const B: usize, D: wire::Decode> Deserializer<B, D> {
36    /// Create a new stream decoder.
37    pub fn new(capacity: usize) -> Self {
38        Self {
39            unparsed: BoundedVec::with_capacity(capacity)
40                .expect("Deserializer::new: capacity exceeds maximum"),
41            item: PhantomData,
42        }
43    }
44
45    /// Input bytes into the decoder.
46    pub fn input(&mut self, bytes: &[u8]) -> Result<(), bounded::Error> {
47        self.unparsed.extend_from_slice(bytes)
48    }
49
50    /// Decode and return the next message. Returns [`None`] if nothing was decoded.
51    pub fn deserialize_next(&mut self) -> Result<Option<D>, wire::Invalid> {
52        let mut reader = io::Cursor::new(self.unparsed.as_slice());
53
54        match D::decode(&mut reader) {
55            Ok(msg) => {
56                let pos = reader.position() as usize;
57                self.unparsed.drain(..pos);
58
59                Ok(Some(msg))
60            }
61            Err(wire::Error::UnexpectedEnd { .. }) => Ok(None),
62            Err(wire::Error::Invalid(err)) => Err(err),
63        }
64    }
65
66    /// Drain the unparsed buffer.
67    pub fn unparsed(&mut self) -> impl ExactSizeIterator<Item = u8> + '_ {
68        self.unparsed.drain(..)
69    }
70
71    /// Return whether there are unparsed bytes.
72    pub fn is_empty(&self) -> bool {
73        self.unparsed.is_empty()
74    }
75
76    /// Return the size of the unparsed data.
77    pub fn len(&self) -> usize {
78        self.unparsed.len()
79    }
80}
81
82unsafe impl<const B: usize, D: wire::Decode> bytes::BufMut for Deserializer<B, D> {
83    fn remaining_mut(&self) -> usize {
84        self.unparsed.remaining_mut()
85    }
86
87    unsafe fn advance_mut(&mut self, cnt: usize) {
88        unsafe {
89            self.unparsed.advance_mut(cnt);
90        }
91    }
92
93    fn chunk_mut(&mut self) -> &mut bytes::buf::UninitSlice {
94        self.unparsed.chunk_mut()
95    }
96}
97
98impl<const B: usize, D: wire::Decode> io::Write for Deserializer<B, D> {
99    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
100        self.input(buf).map_err(|_| io::ErrorKind::OutOfMemory)?;
101
102        Ok(buf.len())
103    }
104
105    fn flush(&mut self) -> io::Result<()> {
106        Ok(())
107    }
108}
109
110impl<const B: usize, D: wire::Decode> Iterator for Deserializer<B, D> {
111    type Item = Result<D, wire::Invalid>;
112
113    fn next(&mut self) -> Option<Self::Item> {
114        self.deserialize_next().transpose()
115    }
116}
117
118#[cfg(test)]
119mod test {
120    use super::*;
121    use qcheck_macros::quickcheck;
122
123    use radicle::assert_matches;
124
125    const MSG_HELLO: &[u8] = &[5, b'h', b'e', b'l', b'l', b'o'];
126    const MSG_BYE: &[u8] = &[3, b'b', b'y', b'e'];
127
128    #[test]
129    fn test_decode_next() {
130        let mut decoder = Deserializer::<1024, String>::new(8);
131
132        decoder.input(&[3, b'b']).unwrap();
133        assert_matches!(decoder.deserialize_next(), Ok(None));
134        assert_eq!(decoder.unparsed.len(), 2);
135
136        decoder.input(b"y").unwrap();
137        assert_matches!(decoder.deserialize_next(), Ok(None));
138        assert_eq!(decoder.unparsed.len(), 3);
139
140        decoder.input(b"e").unwrap();
141        assert_matches!(decoder.deserialize_next(), Ok(Some(s)) if s.as_str() == "bye");
142        assert_eq!(decoder.unparsed.len(), 0);
143        assert!(decoder.is_empty());
144    }
145
146    #[test]
147    fn test_unparsed() {
148        let mut decoder = Deserializer::<1024, String>::new(8);
149
150        decoder.input(&[3, b'b', b'y']).unwrap();
151        assert_eq!(decoder.unparsed().collect::<Vec<_>>(), vec![3, b'b', b'y']);
152        assert!(decoder.is_empty());
153    }
154
155    #[quickcheck]
156    fn prop_decode_next(chunk_size: usize) {
157        let mut bytes = vec![];
158        let mut msgs = vec![];
159        let mut decoder = Deserializer::<1024, String>::new(8);
160
161        let chunk_size = 1 + chunk_size % MSG_HELLO.len() + MSG_BYE.len();
162
163        bytes.extend_from_slice(MSG_HELLO);
164        bytes.extend_from_slice(MSG_BYE);
165
166        for chunk in bytes.as_slice().chunks(chunk_size) {
167            decoder.input(chunk).unwrap();
168
169            while let Some(msg) = decoder.deserialize_next().unwrap() {
170                msgs.push(msg);
171            }
172        }
173
174        assert_eq!(decoder.unparsed.len(), 0);
175        assert_eq!(msgs.len(), 2);
176        assert_eq!(msgs[0], String::from("hello"));
177        assert_eq!(msgs[1], String::from("bye"));
178    }
179}