Skip to main content

over_there/core/transport/wire/input/
decoder.rs

1use crate::core::transport::{constants, wire::packet::Packet};
2use crate::utils::TtlValue;
3use derive_more::{Display, Error};
4use std::collections::HashMap;
5use std::time::Duration;
6
7#[derive(Debug, Display, Error)]
8pub enum DecoderError {
9    #[display(fmt = "id:{}, index:{}", id, index)]
10    PacketExists {
11        id: u32,
12        index: u32,
13    },
14    #[display(fmt = "id:{}, index:{}", id, index)]
15    PacketBeyondLastIndex {
16        id: u32,
17        index: u32,
18    },
19    #[display(fmt = "id:{}, index:{}", id, index)]
20    FinalPacketAlreadyExists {
21        id: u32,
22        index: u32,
23    },
24    IncompletePacketCollection,
25}
26
27#[derive(Debug, Clone)]
28struct PacketGroup {
29    /// Collection of packets, where the key is the index of the packet
30    packets: HashMap<u32, Packet>,
31
32    /// The final index of the packet group, which we only know once we've
33    /// received the final packet (can still be out of order)
34    final_index: Option<u32>,
35}
36
37impl Default for PacketGroup {
38    fn default() -> Self {
39        Self {
40            packets: HashMap::new(),
41            final_index: None,
42        }
43    }
44}
45
46#[derive(Debug, Clone)]
47pub(crate) struct Decoder {
48    /// Map of unique id to associated group of packets being decoded
49    packet_groups: HashMap<TtlValue<u32>, PacketGroup>,
50
51    /// Maximum time-to-live for each group of packets before being removed;
52    /// this time can be updated upon adding a new packet to a group
53    ttl: Duration,
54}
55
56impl Decoder {
57    pub fn new(ttl: Duration) -> Self {
58        Self {
59            packet_groups: HashMap::new(),
60            ttl,
61        }
62    }
63
64    /// Returns the total packet groups contained within the decoder
65    #[cfg(test)]
66    pub fn len(&self) -> usize {
67        self.packet_groups.len()
68    }
69
70    /// Adds a new packet to the decoder, consuming it for reconstruction
71    pub fn add_packet(&mut self, packet: Packet) -> Result<(), DecoderError> {
72        let id = packet.id();
73        let index = packet.index();
74        let is_final = packet.is_final();
75
76        // Check if we already have a group for this packet, otherwise create
77        // a new group
78        let group = self
79            .packet_groups
80            .entry(TtlValue::new(id, self.ttl))
81            .or_default();
82
83        // Check if we already have this packet
84        if group.packets.contains_key(&index) {
85            return Err(DecoderError::PacketExists { id, index });
86        }
87
88        // Check if we are adding a final packet when we already have one
89        if let Some(last_index) = group.final_index {
90            if is_final {
91                return Err(DecoderError::FinalPacketAlreadyExists {
92                    id,
93                    index: last_index,
94                });
95            }
96        }
97
98        // Check if we are trying to add a packet beyond the final one
99        if group.final_index.map(|i| index > i).unwrap_or(false) {
100            return Err(DecoderError::PacketBeyondLastIndex { id, index });
101        }
102
103        // Add the packet to our group and, if it's final, mark it
104        group.packets.insert(index, packet);
105        if is_final {
106            group.final_index = Some(index);
107        }
108
109        Ok(())
110    }
111
112    /// Removes the specified packet group, returning whether or not the
113    /// group existed to be removed
114    pub fn remove_group(&mut self, group_id: u32) -> bool {
115        self.packet_groups.remove(&group_id.into()).is_some()
116    }
117
118    /// Removes all expired packet groups from the decoder
119    pub fn remove_expired(&mut self) {
120        self.packet_groups.retain(|k, _| !k.has_expired())
121    }
122
123    /// Determines whether or not all packets have been added to the decoder
124    pub fn verify(&self, group_id: u32) -> bool {
125        self.packet_groups
126            .get(&group_id.into())
127            .and_then(|g| {
128                let total_packets = g.packets.len() as u32;
129                g.final_index.map(|i| i + 1 == total_packets)
130            })
131            .unwrap_or_default()
132    }
133
134    /// Reconstructs the data represented by the packets
135    /// NOTE: This currently produces a copy of all data instead of passing
136    ///       back out ownership
137    pub fn decode(&self, group_id: u32) -> Result<Vec<u8>, DecoderError> {
138        // Verify that we have all packets
139        if !self.verify(group_id) {
140            return Err(DecoderError::IncompletePacketCollection);
141        }
142
143        // Grab the appropriate group, which we can now assume exists
144        let group = self.packet_groups.get(&group_id.into()).unwrap();
145
146        // Gather references to packets in proper order
147        let mut packets = group.packets.values().collect::<Vec<&Packet>>();
148        packets.sort_unstable_by_key(|p| p.index());
149
150        // Collect packet data into one unified binary representation
151        // TODO: Improve by NOT cloning data
152        Ok(packets.iter().flat_map(|p| p.data().clone()).collect())
153    }
154}
155
156impl Default for Decoder {
157    fn default() -> Self {
158        Self::new(constants::DEFAULT_TTL)
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use crate::core::transport::wire::packet::{
166        Metadata, PacketEncryption, PacketType,
167    };
168
169    /// Make a packet with data; if last, mark as so with no nonce
170    fn make_packet(
171        id: u32,
172        index: u32,
173        is_last: bool,
174        data: Vec<u8>,
175    ) -> Packet {
176        let r#type = if is_last {
177            PacketType::Final {
178                encryption: PacketEncryption::None,
179            }
180        } else {
181            PacketType::NotFinal
182        };
183        let metadata = Metadata { id, index, r#type };
184        Packet::new(metadata, Default::default(), data)
185    }
186
187    /// Make an empty packet; if last, mark as so with no nonce
188    fn make_empty_packet(id: u32, index: u32, is_last: bool) -> Packet {
189        make_packet(id, index, is_last, vec![])
190    }
191
192    #[test]
193    fn add_packet_fails_if_packet_already_exists() {
194        let mut a = Decoder::default();
195        let id = 123;
196        let index = 999;
197
198        // Add first packet successfully
199        let result = a.add_packet(make_empty_packet(id, index, false));
200        assert_eq!(
201            result.is_ok(),
202            true,
203            "Expected success for adding first packet, but got {}",
204            result.unwrap_err(),
205        );
206
207        // Fail if adding packet with same index
208        match a
209            .add_packet(make_empty_packet(id, index, false))
210            .unwrap_err()
211        {
212            DecoderError::PacketExists {
213                id: eid,
214                index: eindex,
215            } => {
216                assert_eq!(id, eid, "Unexpected index returned in error");
217                assert_eq!(index, eindex, "Unexpected index returned in error");
218            }
219            e => panic!("Unexpected error {} received", e),
220        }
221    }
222
223    #[test]
224    fn add_packet_fails_if_adding_packet_beyond_last() {
225        let mut a = Decoder::default();
226        let id = 123;
227
228        // Add first packet successfully
229        let result = a.add_packet(make_empty_packet(id, 0, true));
230        assert_eq!(
231            result.is_ok(),
232            true,
233            "Expected success for adding first packet, but got {}",
234            result.unwrap_err(),
235        );
236
237        // Fail if adding packet after final packet
238        match a.add_packet(make_empty_packet(id, 1, false)).unwrap_err() {
239            DecoderError::PacketBeyondLastIndex {
240                id: eid,
241                index: eindex,
242            } => {
243                assert_eq!(id, eid, "Beyond packet id was different");
244                assert_eq!(eindex, 1, "Beyond packet index was wrong");
245            }
246            e => panic!("Unexpected error {} received", e),
247        }
248    }
249
250    #[test]
251    fn add_packet_fails_if_last_packet_already_added() {
252        let mut a = Decoder::default();
253
254        // Make the second packet (index) be the last packet
255        let result = a.add_packet(make_empty_packet(0, 1, true));
256        assert_eq!(
257            result.is_ok(),
258            true,
259            "Expected success for adding first packet, but got {}",
260            result.unwrap_err(),
261        );
262
263        // Fail if making the first packet (index) be the last packet
264        // when we already have a last packet
265        match a.add_packet(make_empty_packet(0, 0, true)).unwrap_err() {
266            DecoderError::FinalPacketAlreadyExists { id, index } => {
267                assert_eq!(id, 0, "Last packet id different than expected");
268                assert_eq!(
269                    index, 1,
270                    "Last packet index different than expected"
271                );
272            }
273            e => panic!("Unexpected error {} received", e),
274        }
275    }
276
277    #[test]
278    fn remove_group_should_remove_the_underlying_packet_group() {
279        let mut a = Decoder::default();
280
281        // Add a couple of packets
282        a.add_packet(make_empty_packet(0, 0, true)).unwrap();
283        a.add_packet(make_empty_packet(1, 0, true)).unwrap();
284        a.add_packet(make_empty_packet(2, 0, true)).unwrap();
285        assert_eq!(a.packet_groups.len(), 3);
286
287        // Remove a group that doesn't exist
288        assert!(!a.remove_group(3));
289        assert_eq!(a.packet_groups.len(), 3);
290
291        // Remove a group that does exist
292        assert!(a.remove_group(1));
293        assert_eq!(a.packet_groups.len(), 2);
294    }
295
296    #[test]
297    fn remove_expired_should_only_retain_packet_groups_not_expired() {
298        let mut a = Decoder::new(Duration::from_millis(10));
299
300        // Add a couple of packets
301        a.add_packet(make_empty_packet(0, 0, true)).unwrap();
302        a.add_packet(make_empty_packet(1, 0, true)).unwrap();
303        assert_eq!(a.packet_groups.len(), 2);
304
305        // Add another thread a little later
306        std::thread::sleep(Duration::from_millis(11));
307        a.add_packet(make_empty_packet(2, 0, true)).unwrap();
308        assert_eq!(a.packet_groups.len(), 3);
309
310        // Remove the expired packet groups
311        a.remove_expired();
312        assert_eq!(a.packet_groups.len(), 1, "Unexpired packet did not remain");
313    }
314
315    #[test]
316    fn verify_yields_false_if_empty() {
317        let a = Decoder::default();
318        assert_eq!(a.verify(0), false);
319    }
320
321    #[test]
322    fn verify_yields_false_if_missing_last_packet() {
323        let mut a = Decoder::default();
324
325        // Add first packet (index 0), still needing final packet
326        let _ = a.add_packet(make_empty_packet(0, 0, false));
327
328        assert_eq!(a.verify(0), false);
329    }
330
331    #[test]
332    fn verify_yields_false_if_missing_first_packet() {
333        let mut a = Decoder::default();
334
335        // Add packet at end (index 1), still needing first packet
336        assert_eq!(
337            a.add_packet(make_empty_packet(0, 1, true)).is_ok(),
338            true,
339            "Unexpectedly failed to add a new packet",
340        );
341
342        assert_eq!(a.verify(0), false);
343    }
344
345    #[test]
346    fn verify_yields_false_if_missing_inbetween_packet() {
347        let mut a = Decoder::default();
348
349        // Add packet at beginning (index 0)
350        assert_eq!(
351            a.add_packet(make_empty_packet(0, 0, false)).is_ok(),
352            true,
353            "Unexpectedly failed to add a new packet",
354        );
355
356        // Add packet at end (index 2)
357        assert_eq!(
358            a.add_packet(make_empty_packet(0, 2, true)).is_ok(),
359            true,
360            "Unexpectedly failed to add a new packet",
361        );
362
363        assert_eq!(a.verify(0), false);
364    }
365
366    #[test]
367    fn verify_yields_true_if_have_all_packets() {
368        let mut a = Decoder::default();
369
370        assert_eq!(
371            a.add_packet(make_empty_packet(0, 0, true)).is_ok(),
372            true,
373            "Unexpectedly failed to add a new packet",
374        );
375
376        assert_eq!(a.verify(0), true);
377    }
378
379    #[test]
380    fn decode_fails_if_not_verified() {
381        let a = Decoder::default();
382
383        let result = a.decode(0);
384
385        match result.unwrap_err() {
386            DecoderError::IncompletePacketCollection => (),
387            e => panic!("Unexpected error {} received", e),
388        }
389    }
390
391    #[test]
392    fn decode_yields_data_from_single_packet_if_complete() {
393        let mut a = Decoder::default();
394        let data: Vec<u8> = vec![1, 2, 3];
395
396        // Try a single packet and collecting data
397        let _ = a.add_packet(make_packet(0, 0, true, data.clone()));
398
399        let collected_data = a.decode(0).unwrap();
400        assert_eq!(data, collected_data);
401    }
402
403    #[test]
404    fn decode_yields_combined_data_from_multiple_packets_if_complete() {
405        let mut a = Decoder::default();
406        let data: Vec<u8> = vec![1, 2, 3, 4, 5];
407
408        // Try a multiple packets and collecting data
409        let _ = a.add_packet(make_packet(0, 2, true, data[3..].to_vec()));
410        let _ = a.add_packet(make_packet(0, 0, false, data[0..1].to_vec()));
411        let _ = a.add_packet(make_packet(0, 1, false, data[1..3].to_vec()));
412
413        let collected_data = a.decode(0).unwrap();
414        assert_eq!(data, collected_data);
415    }
416}