1use std::io;
2use std::io::Write;
3
4use circular::Buffer;
5use nom_derive::Parse;
6
7use crate::packet::Packet;
8use crate::packet::PacketHeader;
9use crate::Error;
10
11#[derive(Clone, Debug)]
12pub struct Parser {
13 buffer: Buffer,
14 #[allow(dead_code)]
15 current_header: Option<PacketHeader>
16}
17
18impl Default for Parser {
19 fn default() -> Self {
20 #[allow(clippy::identity_op)] Self::with_capacity(1 * 1024 * 1024)
22 }
23}
24
25impl Parser {
26 pub fn with_capacity(capacity: usize) -> Self {
27 Self{
28 buffer: Buffer::with_capacity(capacity),
29 current_header: None
30 }
31 }
32
33 #[inline]
34 pub fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
35 self.buffer.write(buf)
36 }
37
38 fn check_packet(&mut self) -> bool {
39 match &self.current_header {
40 None => {
41 if(self.buffer.available_data() >= 5) {
42 let header_bytes = &self.buffer.data()[..5];
43 self.current_header = Some(PacketHeader::parse(header_bytes).unwrap().1);
44 self.check_packet()
45 } else {
46 false
47 }
48 },
49 Some(header) => {
50 self.buffer.available_data() >= header.length as usize + 4
51 }
52 }
53 }
54
55 pub fn get_packet(&mut self) -> Result<Option<Packet>, Error> {
56 match (self.check_packet(), &self.current_header) {
57 (false, _) => Ok(None),
58 (true, None) => Err(Error::NoHeader),
59 (true, Some(header)) => {
60 let len = header.length as usize + 4; let data = &self.buffer.data()[..len];
62 let packet = Packet::parse(data)?.1;
63 self.buffer.consume(len);
64 self.current_header = None;
65 Ok(Some(packet))
66 }
67 }
68 }
69}
70
71#[cfg(test)]
72pub(crate) fn encode(packet: &Packet) -> Vec<u8> {
73 use bincode::Options;
74 let en = bincode::DefaultOptions::new().with_big_endian().with_fixint_encoding();
75 en.serialize(packet).unwrap()
76}
77
78#[cfg(test)]
79mod test {
80 use rand::Rng;
81 use test_strategy::proptest;
82 use crate::packet::*;
83 use super::*;
84
85 fn random_slices<T>(slice: &[T]) -> Vec<&[T]> {
86 let mut rng = rand::thread_rng();
87 let magic = rng.gen_range(2..std::cmp::min(16, slice.len() / 2));
88 let count = rng.gen_range(1..(slice.len() / magic));
89 let approx_size = slice.len() / count;
90 let mut lengths = (0..count).map(|_| rng.gen_range((approx_size / 2)..((approx_size * 3) / 2))).collect::<Vec<_>>();
91 while(lengths.iter().sum::<usize>() != slice.len()) {
92 if(lengths.iter().sum::<usize>() < slice.len()) {
93 for len in lengths.iter_mut() {
94 if(rng.gen::<f32>() < 0.01) {
95 (*len) += 1;
96 }
97 }
98 } else {
99 for len in lengths.iter_mut() {
100 if(*len > 1 && rng.gen::<f32>() < 0.01) {
101 (*len) -= 1;
102 }
103 }
104 }
105 }
106 let mut subslices = Vec::with_capacity(lengths.len());
107 let mut current_pos = 0;
108 for len in lengths {
109 subslices.push(&slice[current_pos..(len + current_pos)]);
110 current_pos += len;
111 }
112 subslices
113 } #[ignore] #[test]
132 fn validate_packet_result_size() {
133 assert_eq!(std::mem::size_of::<Result<Packet, Error>>(), std::mem::size_of::<Packet>() + std::mem::size_of::<*const ()>());
134 assert_eq!(std::mem::size_of::<Result<Option<Packet>, Error>>(), std::mem::size_of::<Result<Packet, Error>>());
135 }
136
137 #[test]
138 fn single_init() {
139 let mut stream = Parser::default();
140 assert_eq!(stream.get_packet(), Ok(None));
141 let packet = Payload::init(1, vec![]).into_packet();
142 stream.write(&encode(&packet)).unwrap();
143 assert_eq!(stream.get_packet(), Ok(Some(packet)));
144 assert_eq!(stream.get_packet(), Ok(None));
145 }
146
147 #[test]
148 fn multipart_init() {
149 let mut stream = Parser::default();
150 let packet = Payload::init(2, vec![]).into_packet();
151 let bytes = encode(&packet);
152 stream.write(&bytes[0..3]).unwrap();
153 assert_eq!(stream.get_packet(), Ok(None));
154 stream.write(&bytes[3..bytes.len()]).unwrap();
155 assert_eq!(stream.get_packet(), Ok(Some(packet)));
156 }
157
158 #[test]
159 fn handshake() {
160 let mut stream = Parser::default();
161 let init = Payload::init(32768, (0..100).collect()).into_packet();
162 stream.write(&encode(&init)).unwrap();
163 assert_eq!(stream.get_packet(), Ok(Some(init)));
164 let version = Payload::version(3, (100..150).collect()).into_packet();
165 stream.write(&encode(&version)).unwrap();
166 assert_eq!(stream.get_packet(), Ok(Some(version)));
167 assert_eq!(stream.get_packet(), Ok(None));
168 }
169
170 #[test]
171 fn handshake_queued() {
172 let mut stream = Parser::default();
173 let init = Payload::init(32768, (0..100).collect()).into_packet();
174 stream.write(&encode(&init)).unwrap();
175 let version = Payload::version(3, (100..150).collect()).into_packet();
176 stream.write(&encode(&version)).unwrap();
177 assert_eq!(stream.get_packet(), Ok(Some(init)));
178 assert_eq!(stream.get_packet(), Ok(Some(version)));
179 assert_eq!(stream.get_packet(), Ok(None));
180 }
181
182 #[proptest]
183 fn arbitrary_sequence(input: Vec<Packet>) {
184 let mut stream = Parser::default();
185 assert_eq!(stream.get_packet(), Ok(None));
186 for packet in input {
187 stream.write(&encode(&packet)).unwrap();
188 assert_eq!(stream.get_packet(), Ok(Some(packet)));
189 assert_eq!(stream.get_packet(), Ok(None));
190 }
191 assert_eq!(stream.get_packet(), Ok(None));
192 }
193
194 #[proptest]
195 fn arbitrary_sequence_queued(input: Vec<Packet>) {
196 let mut stream = Parser::default();
197 assert_eq!(stream.get_packet(), Ok(None));
198 for packet in &input {
199 stream.write(&encode(packet)).unwrap();
200 }
201 for packet in input {
202 assert_eq!(stream.get_packet(), Ok(Some(packet)));
203 }
204 assert_eq!(stream.get_packet(), Ok(None));
205 }
206
207 #[proptest]
208 fn arbitrary_sequence_multipart(input: Vec<Packet>) {
209 let mut stream = Parser::default();
210 assert_eq!(stream.get_packet(), Ok(None));
211 for packet in input {
212 let bytes = encode(&packet);
213 let slices = random_slices(&bytes);
214 for slice in slices {
215 stream.write(slice).unwrap();
216 }
217 assert_eq!(stream.get_packet(), Ok(Some(packet)));
218 assert_eq!(stream.get_packet(), Ok(None));
219 }
220 assert_eq!(stream.get_packet(), Ok(None));
221 }
222
223 #[proptest]
224 fn arbitrary_sequence_multipart_queued(input: Vec<Packet>) {
225 let mut stream = Parser::default();
226 assert_eq!(stream.get_packet(), Ok(None));
227 for packet in &input {
228 let bytes = encode(packet);
229 let slices = random_slices(&bytes);
230 for slice in slices {
231 stream.write(slice).unwrap();
232 }
233 }
234 for packet in input {
235 assert_eq!(stream.get_packet(), Ok(Some(packet)));
236 }
237 assert_eq!(stream.get_packet(), Ok(None));
238 }
239}
240