radicle_protocol/
deserializer.rs1use std::io;
2use std::marker::PhantomData;
3
4use crate::bounded;
5use crate::bounded::BoundedVec;
6use crate::service::message::Message;
7use crate::wire;
8
9#[derive(Debug)]
13pub struct Deserializer<const B: usize, D = Message> {
14 unparsed: BoundedVec<u8, B>,
15 item: PhantomData<D>,
16}
17
18impl<const B: usize, D: wire::Decode> Default for Deserializer<B, D> {
19 fn default() -> Self {
20 Self::new(wire::Size::MAX as usize + 1)
21 }
22}
23
24impl<const B: usize, D> TryFrom<Vec<u8>> for Deserializer<B, D> {
25 type Error = bounded::Error;
26
27 fn try_from(unparsed: Vec<u8>) -> Result<Self, Self::Error> {
28 BoundedVec::try_from(unparsed).map(|unparsed| Self {
29 unparsed,
30 item: PhantomData,
31 })
32 }
33}
34
35impl<const B: usize, D: wire::Decode> Deserializer<B, D> {
36 pub fn new(capacity: usize) -> Self {
38 Self {
39 unparsed: BoundedVec::with_capacity(capacity)
40 .expect("Deserializer::new: capacity exceeds maximum"),
41 item: PhantomData,
42 }
43 }
44
45 pub fn input(&mut self, bytes: &[u8]) -> Result<(), bounded::Error> {
47 self.unparsed.extend_from_slice(bytes)
48 }
49
50 pub fn deserialize_next(&mut self) -> Result<Option<D>, wire::Invalid> {
52 let mut reader = io::Cursor::new(self.unparsed.as_slice());
53
54 match D::decode(&mut reader) {
55 Ok(msg) => {
56 let pos = reader.position() as usize;
57 self.unparsed.drain(..pos);
58
59 Ok(Some(msg))
60 }
61 Err(wire::Error::UnexpectedEnd { .. }) => Ok(None),
62 Err(wire::Error::Invalid(err)) => Err(err),
63 }
64 }
65
66 pub fn unparsed(&mut self) -> impl ExactSizeIterator<Item = u8> + '_ {
68 self.unparsed.drain(..)
69 }
70
71 pub fn is_empty(&self) -> bool {
73 self.unparsed.is_empty()
74 }
75
76 pub fn len(&self) -> usize {
78 self.unparsed.len()
79 }
80}
81
82unsafe impl<const B: usize, D: wire::Decode> bytes::BufMut for Deserializer<B, D> {
83 fn remaining_mut(&self) -> usize {
84 self.unparsed.remaining_mut()
85 }
86
87 unsafe fn advance_mut(&mut self, cnt: usize) {
88 self.unparsed.advance_mut(cnt);
89 }
90
91 fn chunk_mut(&mut self) -> &mut bytes::buf::UninitSlice {
92 self.unparsed.chunk_mut()
93 }
94}
95
96impl<const B: usize, D: wire::Decode> io::Write for Deserializer<B, D> {
97 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
98 self.input(buf).map_err(|_| io::ErrorKind::OutOfMemory)?;
99
100 Ok(buf.len())
101 }
102
103 fn flush(&mut self) -> io::Result<()> {
104 Ok(())
105 }
106}
107
108impl<const B: usize, D: wire::Decode> Iterator for Deserializer<B, D> {
109 type Item = Result<D, wire::Invalid>;
110
111 fn next(&mut self) -> Option<Self::Item> {
112 self.deserialize_next().transpose()
113 }
114}
115
116#[cfg(test)]
117mod test {
118 use super::*;
119 use qcheck_macros::quickcheck;
120
121 use radicle::assert_matches;
122
123 const MSG_HELLO: &[u8] = &[5, b'h', b'e', b'l', b'l', b'o'];
124 const MSG_BYE: &[u8] = &[3, b'b', b'y', b'e'];
125
126 #[test]
127 fn test_decode_next() {
128 let mut decoder = Deserializer::<1024, String>::new(8);
129
130 decoder.input(&[3, b'b']).unwrap();
131 assert_matches!(decoder.deserialize_next(), Ok(None));
132 assert_eq!(decoder.unparsed.len(), 2);
133
134 decoder.input(b"y").unwrap();
135 assert_matches!(decoder.deserialize_next(), Ok(None));
136 assert_eq!(decoder.unparsed.len(), 3);
137
138 decoder.input(b"e").unwrap();
139 assert_matches!(decoder.deserialize_next(), Ok(Some(s)) if s.as_str() == "bye");
140 assert_eq!(decoder.unparsed.len(), 0);
141 assert!(decoder.is_empty());
142 }
143
144 #[test]
145 fn test_unparsed() {
146 let mut decoder = Deserializer::<1024, String>::new(8);
147
148 decoder.input(&[3, b'b', b'y']).unwrap();
149 assert_eq!(decoder.unparsed().collect::<Vec<_>>(), vec![3, b'b', b'y']);
150 assert!(decoder.is_empty());
151 }
152
153 #[quickcheck]
154 fn prop_decode_next(chunk_size: usize) {
155 let mut bytes = vec![];
156 let mut msgs = vec![];
157 let mut decoder = Deserializer::<1024, String>::new(8);
158
159 let chunk_size = 1 + chunk_size % MSG_HELLO.len() + MSG_BYE.len();
160
161 bytes.extend_from_slice(MSG_HELLO);
162 bytes.extend_from_slice(MSG_BYE);
163
164 for chunk in bytes.as_slice().chunks(chunk_size) {
165 decoder.input(chunk).unwrap();
166
167 while let Some(msg) = decoder.deserialize_next().unwrap() {
168 msgs.push(msg);
169 }
170 }
171
172 assert_eq!(decoder.unparsed.len(), 0);
173 assert_eq!(msgs.len(), 2);
174 assert_eq!(msgs[0], String::from("hello"));
175 assert_eq!(msgs[1], String::from("bye"));
176 }
177}