1use crate::core::transport::{constants, wire::packet::Packet};
2use crate::utils::TtlValue;
3use derive_more::{Display, Error};
4use std::collections::HashMap;
5use std::time::Duration;
6
7#[derive(Debug, Display, Error)]
8pub enum DecoderError {
9 #[display(fmt = "id:{}, index:{}", id, index)]
10 PacketExists {
11 id: u32,
12 index: u32,
13 },
14 #[display(fmt = "id:{}, index:{}", id, index)]
15 PacketBeyondLastIndex {
16 id: u32,
17 index: u32,
18 },
19 #[display(fmt = "id:{}, index:{}", id, index)]
20 FinalPacketAlreadyExists {
21 id: u32,
22 index: u32,
23 },
24 IncompletePacketCollection,
25}
26
27#[derive(Debug, Clone)]
28struct PacketGroup {
29 packets: HashMap<u32, Packet>,
31
32 final_index: Option<u32>,
35}
36
37impl Default for PacketGroup {
38 fn default() -> Self {
39 Self {
40 packets: HashMap::new(),
41 final_index: None,
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
47pub(crate) struct Decoder {
48 packet_groups: HashMap<TtlValue<u32>, PacketGroup>,
50
51 ttl: Duration,
54}
55
56impl Decoder {
57 pub fn new(ttl: Duration) -> Self {
58 Self {
59 packet_groups: HashMap::new(),
60 ttl,
61 }
62 }
63
64 #[cfg(test)]
66 pub fn len(&self) -> usize {
67 self.packet_groups.len()
68 }
69
70 pub fn add_packet(&mut self, packet: Packet) -> Result<(), DecoderError> {
72 let id = packet.id();
73 let index = packet.index();
74 let is_final = packet.is_final();
75
76 let group = self
79 .packet_groups
80 .entry(TtlValue::new(id, self.ttl))
81 .or_default();
82
83 if group.packets.contains_key(&index) {
85 return Err(DecoderError::PacketExists { id, index });
86 }
87
88 if let Some(last_index) = group.final_index {
90 if is_final {
91 return Err(DecoderError::FinalPacketAlreadyExists {
92 id,
93 index: last_index,
94 });
95 }
96 }
97
98 if group.final_index.map(|i| index > i).unwrap_or(false) {
100 return Err(DecoderError::PacketBeyondLastIndex { id, index });
101 }
102
103 group.packets.insert(index, packet);
105 if is_final {
106 group.final_index = Some(index);
107 }
108
109 Ok(())
110 }
111
112 pub fn remove_group(&mut self, group_id: u32) -> bool {
115 self.packet_groups.remove(&group_id.into()).is_some()
116 }
117
118 pub fn remove_expired(&mut self) {
120 self.packet_groups.retain(|k, _| !k.has_expired())
121 }
122
123 pub fn verify(&self, group_id: u32) -> bool {
125 self.packet_groups
126 .get(&group_id.into())
127 .and_then(|g| {
128 let total_packets = g.packets.len() as u32;
129 g.final_index.map(|i| i + 1 == total_packets)
130 })
131 .unwrap_or_default()
132 }
133
134 pub fn decode(&self, group_id: u32) -> Result<Vec<u8>, DecoderError> {
138 if !self.verify(group_id) {
140 return Err(DecoderError::IncompletePacketCollection);
141 }
142
143 let group = self.packet_groups.get(&group_id.into()).unwrap();
145
146 let mut packets = group.packets.values().collect::<Vec<&Packet>>();
148 packets.sort_unstable_by_key(|p| p.index());
149
150 Ok(packets.iter().flat_map(|p| p.data().clone()).collect())
153 }
154}
155
156impl Default for Decoder {
157 fn default() -> Self {
158 Self::new(constants::DEFAULT_TTL)
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use crate::core::transport::wire::packet::{
166 Metadata, PacketEncryption, PacketType,
167 };
168
169 fn make_packet(
171 id: u32,
172 index: u32,
173 is_last: bool,
174 data: Vec<u8>,
175 ) -> Packet {
176 let r#type = if is_last {
177 PacketType::Final {
178 encryption: PacketEncryption::None,
179 }
180 } else {
181 PacketType::NotFinal
182 };
183 let metadata = Metadata { id, index, r#type };
184 Packet::new(metadata, Default::default(), data)
185 }
186
187 fn make_empty_packet(id: u32, index: u32, is_last: bool) -> Packet {
189 make_packet(id, index, is_last, vec![])
190 }
191
192 #[test]
193 fn add_packet_fails_if_packet_already_exists() {
194 let mut a = Decoder::default();
195 let id = 123;
196 let index = 999;
197
198 let result = a.add_packet(make_empty_packet(id, index, false));
200 assert_eq!(
201 result.is_ok(),
202 true,
203 "Expected success for adding first packet, but got {}",
204 result.unwrap_err(),
205 );
206
207 match a
209 .add_packet(make_empty_packet(id, index, false))
210 .unwrap_err()
211 {
212 DecoderError::PacketExists {
213 id: eid,
214 index: eindex,
215 } => {
216 assert_eq!(id, eid, "Unexpected index returned in error");
217 assert_eq!(index, eindex, "Unexpected index returned in error");
218 }
219 e => panic!("Unexpected error {} received", e),
220 }
221 }
222
223 #[test]
224 fn add_packet_fails_if_adding_packet_beyond_last() {
225 let mut a = Decoder::default();
226 let id = 123;
227
228 let result = a.add_packet(make_empty_packet(id, 0, true));
230 assert_eq!(
231 result.is_ok(),
232 true,
233 "Expected success for adding first packet, but got {}",
234 result.unwrap_err(),
235 );
236
237 match a.add_packet(make_empty_packet(id, 1, false)).unwrap_err() {
239 DecoderError::PacketBeyondLastIndex {
240 id: eid,
241 index: eindex,
242 } => {
243 assert_eq!(id, eid, "Beyond packet id was different");
244 assert_eq!(eindex, 1, "Beyond packet index was wrong");
245 }
246 e => panic!("Unexpected error {} received", e),
247 }
248 }
249
250 #[test]
251 fn add_packet_fails_if_last_packet_already_added() {
252 let mut a = Decoder::default();
253
254 let result = a.add_packet(make_empty_packet(0, 1, true));
256 assert_eq!(
257 result.is_ok(),
258 true,
259 "Expected success for adding first packet, but got {}",
260 result.unwrap_err(),
261 );
262
263 match a.add_packet(make_empty_packet(0, 0, true)).unwrap_err() {
266 DecoderError::FinalPacketAlreadyExists { id, index } => {
267 assert_eq!(id, 0, "Last packet id different than expected");
268 assert_eq!(
269 index, 1,
270 "Last packet index different than expected"
271 );
272 }
273 e => panic!("Unexpected error {} received", e),
274 }
275 }
276
277 #[test]
278 fn remove_group_should_remove_the_underlying_packet_group() {
279 let mut a = Decoder::default();
280
281 a.add_packet(make_empty_packet(0, 0, true)).unwrap();
283 a.add_packet(make_empty_packet(1, 0, true)).unwrap();
284 a.add_packet(make_empty_packet(2, 0, true)).unwrap();
285 assert_eq!(a.packet_groups.len(), 3);
286
287 assert!(!a.remove_group(3));
289 assert_eq!(a.packet_groups.len(), 3);
290
291 assert!(a.remove_group(1));
293 assert_eq!(a.packet_groups.len(), 2);
294 }
295
296 #[test]
297 fn remove_expired_should_only_retain_packet_groups_not_expired() {
298 let mut a = Decoder::new(Duration::from_millis(10));
299
300 a.add_packet(make_empty_packet(0, 0, true)).unwrap();
302 a.add_packet(make_empty_packet(1, 0, true)).unwrap();
303 assert_eq!(a.packet_groups.len(), 2);
304
305 std::thread::sleep(Duration::from_millis(11));
307 a.add_packet(make_empty_packet(2, 0, true)).unwrap();
308 assert_eq!(a.packet_groups.len(), 3);
309
310 a.remove_expired();
312 assert_eq!(a.packet_groups.len(), 1, "Unexpired packet did not remain");
313 }
314
315 #[test]
316 fn verify_yields_false_if_empty() {
317 let a = Decoder::default();
318 assert_eq!(a.verify(0), false);
319 }
320
321 #[test]
322 fn verify_yields_false_if_missing_last_packet() {
323 let mut a = Decoder::default();
324
325 let _ = a.add_packet(make_empty_packet(0, 0, false));
327
328 assert_eq!(a.verify(0), false);
329 }
330
331 #[test]
332 fn verify_yields_false_if_missing_first_packet() {
333 let mut a = Decoder::default();
334
335 assert_eq!(
337 a.add_packet(make_empty_packet(0, 1, true)).is_ok(),
338 true,
339 "Unexpectedly failed to add a new packet",
340 );
341
342 assert_eq!(a.verify(0), false);
343 }
344
345 #[test]
346 fn verify_yields_false_if_missing_inbetween_packet() {
347 let mut a = Decoder::default();
348
349 assert_eq!(
351 a.add_packet(make_empty_packet(0, 0, false)).is_ok(),
352 true,
353 "Unexpectedly failed to add a new packet",
354 );
355
356 assert_eq!(
358 a.add_packet(make_empty_packet(0, 2, true)).is_ok(),
359 true,
360 "Unexpectedly failed to add a new packet",
361 );
362
363 assert_eq!(a.verify(0), false);
364 }
365
366 #[test]
367 fn verify_yields_true_if_have_all_packets() {
368 let mut a = Decoder::default();
369
370 assert_eq!(
371 a.add_packet(make_empty_packet(0, 0, true)).is_ok(),
372 true,
373 "Unexpectedly failed to add a new packet",
374 );
375
376 assert_eq!(a.verify(0), true);
377 }
378
379 #[test]
380 fn decode_fails_if_not_verified() {
381 let a = Decoder::default();
382
383 let result = a.decode(0);
384
385 match result.unwrap_err() {
386 DecoderError::IncompletePacketCollection => (),
387 e => panic!("Unexpected error {} received", e),
388 }
389 }
390
391 #[test]
392 fn decode_yields_data_from_single_packet_if_complete() {
393 let mut a = Decoder::default();
394 let data: Vec<u8> = vec![1, 2, 3];
395
396 let _ = a.add_packet(make_packet(0, 0, true, data.clone()));
398
399 let collected_data = a.decode(0).unwrap();
400 assert_eq!(data, collected_data);
401 }
402
403 #[test]
404 fn decode_yields_combined_data_from_multiple_packets_if_complete() {
405 let mut a = Decoder::default();
406 let data: Vec<u8> = vec![1, 2, 3, 4, 5];
407
408 let _ = a.add_packet(make_packet(0, 2, true, data[3..].to_vec()));
410 let _ = a.add_packet(make_packet(0, 0, false, data[0..1].to_vec()));
411 let _ = a.add_packet(make_packet(0, 1, false, data[1..3].to_vec()));
412
413 let collected_data = a.decode(0).unwrap();
414 assert_eq!(data, collected_data);
415 }
416}