foundation_ur/fountain/
decoder.rs

1// SPDX-FileCopyrightText: © 2023 Foundation Devices, Inc. <hello@foundationdevices.com>
2// SPDX-FileCopyrightText: © 2020 Dominik Spicher <dominikspicher@gmail.com>
3// SPDX-License-Identifier: MIT
4
5//! Decoder.
6
7use core::fmt;
8
9use crate::{
10    collections::{Deque, Set, Vec},
11    fountain::part::MessageDescription,
12    fountain::{
13        chooser,
14        chooser::BaseFragmentChooser,
15        part::{IndexedPart, Part},
16    },
17};
18
19/// A [`decoder`](BaseDecoder) that uses [`alloc`] collection types.
20#[cfg(feature = "alloc")]
21pub type Decoder = BaseDecoder<Alloc>;
22
23/// A [`decoder`](BaseDecoder) that uses fixed-capacity collection types.
24pub type HeaplessDecoder<
25    const MAX_MESSAGE_LEN: usize,
26    const MAX_MIXED_PARTS: usize,
27    const MAX_FRAGMENT_LEN: usize,
28    const MAX_SEQUENCE_COUNT: usize,
29    const QUEUE_SIZE: usize,
30> = BaseDecoder<
31    Heapless<MAX_MESSAGE_LEN, MAX_MIXED_PARTS, MAX_FRAGMENT_LEN, MAX_SEQUENCE_COUNT, QUEUE_SIZE>,
32>;
33
34impl<
35        const MAX_MESSAGE_LEN: usize,
36        const MAX_MIXED_PARTS: usize,
37        const MAX_FRAGMENT_LEN: usize,
38        const MAX_SEQUENCE_COUNT: usize,
39        const QUEUE_SIZE: usize,
40    >
41    HeaplessDecoder<
42        MAX_MESSAGE_LEN,
43        MAX_MIXED_PARTS,
44        MAX_FRAGMENT_LEN,
45        MAX_SEQUENCE_COUNT,
46        QUEUE_SIZE,
47    >
48{
49    /// Constructs a new [`HeaplessDecoder`].
50    pub const fn new() -> Self {
51        Self {
52            message: heapless::Vec::new(),
53            mixed_parts: heapless::Vec::new(),
54            received: heapless::IndexSet::new(),
55            queue: heapless::Deque::new(),
56            fragment_chooser: chooser::HeaplessFragmentChooser::new(),
57            message_description: None,
58        }
59    }
60}
61
62/// A decoder capable of receiving and recombining fountain-encoded transmissions.
63///
64/// # Examples
65///
66/// See the [`crate::fountain`] module documentation for an example.
67#[derive(Default)]
68pub struct BaseDecoder<T: Types> {
69    message: T::Message,
70    mixed_parts: T::MixedParts,
71    received: T::Indexes,
72    queue: T::Queue,
73    fragment_chooser: BaseFragmentChooser<T::Chooser>,
74    message_description: Option<MessageDescription>,
75}
76
77impl<T: Types> BaseDecoder<T> {
78    /// Receives a fountain-encoded part into the decoder.
79    ///
80    /// # Examples
81    ///
82    /// See the [`crate::fountain`] module documentation for an example.
83    ///
84    /// # Errors
85    ///
86    /// If the part would fail [`validate`] because it is inconsistent
87    /// with previously received parts, an error will be returned.
88    ///
89    /// [`validate`]: BaseDecoder::is_part_consistent
90    pub fn receive(&mut self, part: &Part) -> Result<bool, Error> {
91        if self.is_complete() {
92            return Ok(false);
93        }
94
95        if !part.is_valid() {
96            return Err(Error::InvalidPart);
97        }
98
99        if self.is_empty() {
100            let message_len = part.data.len() * usize::try_from(part.sequence_count).unwrap();
101            if self.message.try_resize(message_len, 0).is_err() {
102                return Err(Error::NotEnoughSpace {
103                    needed: message_len,
104                    capacity: self.message.capacity(),
105                });
106            }
107            self.message_description = Some(part.to_message_description());
108        } else if !self.is_part_consistent(part) {
109            return Err(Error::InconsistentPart {
110                received: part.to_message_description(),
111                expected: self.message_description.clone().unwrap(),
112            });
113        }
114
115        let indexes = self.fragment_chooser.choose_fragments(
116            part.sequence,
117            part.sequence_count,
118            part.checksum,
119        );
120
121        let mut data = T::Fragment::default();
122        if data.try_extend_from_slice(part.data).is_err() {
123            return Err(Error::NotEnoughSpace {
124                needed: part.data.len(),
125                capacity: data.capacity(),
126            });
127        }
128
129        let part = IndexedPart::new(data, indexes);
130        self.queue.push_back(part);
131
132        while !self.is_complete() && !self.queue.is_empty() {
133            let part = self.queue.pop_front().unwrap();
134            if part.is_simple() {
135                self.process_simple(&part)?;
136            } else {
137                self.process_mixed(part);
138            }
139        }
140        Ok(!self.is_complete())
141    }
142
143    /// Checks whether a [`Part`] is receivable by the decoder.
144    ///
145    /// This can fail if other parts were previously received whose
146    /// metadata (such as number of segments) is inconsistent with the
147    /// present [`Part`]. Note that a fresh decoder will always return
148    /// false here.
149    #[must_use]
150    pub fn is_part_consistent(&self, part: &Part) -> bool {
151        match self.message_description {
152            Some(ref message_description) => part == message_description,
153            None => false,
154        }
155    }
156
157    /// If [`complete`], returns the decoded message, `None` otherwise.
158    ///
159    /// # Errors
160    ///
161    /// If an inconsistent internal state is detected, an error will be returned.
162    ///
163    /// # Examples
164    ///
165    /// See the [`crate::fountain`] module documentation for an example.
166    ///
167    /// [`complete`]: BaseDecoder::is_complete
168    pub fn message(&self) -> Result<Option<&[u8]>, Error> {
169        if self.is_complete() {
170            if self.message[self.message_description.as_ref().unwrap().message_length..]
171                .iter()
172                .any(|&b| b != 0)
173            {
174                return Err(Error::InvalidPadding);
175            }
176
177            Ok(Some(
178                &self.message[..self.message_description.as_ref().unwrap().message_length],
179            ))
180        } else {
181            Ok(None)
182        }
183    }
184
185    /// Returns whether the decoder is complete and hence the message available.
186    ///
187    /// # Examples
188    ///
189    /// See the [`crate::fountain`] module documentation for an example.
190    #[must_use]
191    pub fn is_complete(&self) -> bool {
192        if self.is_empty() {
193            return false;
194        }
195
196        self.received.len()
197            == self
198                .message_description
199                .as_ref()
200                .unwrap()
201                .sequence_count
202                .try_into()
203                .unwrap()
204    }
205
206    /// Calculate estimated percentage of completion.
207    pub fn estimated_percent_complete(&self) -> f64 {
208        if self.is_complete() {
209            return 1.0;
210        }
211
212        if self.is_empty() {
213            return 0.0;
214        }
215
216        let estimated_input_parts =
217            f64::from(self.message_description.as_ref().unwrap().sequence_count) * 1.75;
218        let received_parts = u32::try_from(self.received.len()).unwrap();
219        f64::min(0.99, f64::from(received_parts) / estimated_input_parts)
220    }
221
222    /// Returns `true` if the decoder doesn't contain any data.
223    ///
224    /// Once a part is successfully [received](Self::receive) this method will
225    /// return `false`.
226    #[must_use]
227    pub fn is_empty(&self) -> bool {
228        self.message.is_empty()
229            && self.mixed_parts.is_empty()
230            && self.received.is_empty()
231            && self.queue.is_empty()
232            && self.message_description.is_none()
233    }
234
235    /// Clear the decoder so that it can be used again.
236    pub fn clear(&mut self) {
237        self.message.clear();
238        self.mixed_parts.clear();
239        self.received.clear();
240        self.queue.clear();
241        self.message_description = None;
242
243        debug_assert!(self.is_empty());
244    }
245
246    fn reduce_mixed(&mut self, part: &IndexedPart<T::Fragment, T::Indexes>) {
247        self.mixed_parts.retain_mut(|mixed_part| {
248            mixed_part.reduce(part);
249
250            if mixed_part.is_simple() {
251                self.queue.push_back(mixed_part.clone());
252            }
253
254            !mixed_part.is_simple()
255        });
256    }
257
258    fn process_simple(&mut self, part: &IndexedPart<T::Fragment, T::Indexes>) -> Result<(), Error> {
259        let index = *part.indexes.first().unwrap();
260        if self.received.contains(&index) {
261            return Ok(());
262        }
263
264        self.reduce_mixed(part);
265
266        let offset = index * self.message_description.as_ref().unwrap().fragment_length;
267        self.message[offset..offset + self.message_description.as_ref().unwrap().fragment_length]
268            .copy_from_slice(&part.data);
269        self.received
270            .insert(index)
271            .map_err(|_| Error::TooManyFragments)?;
272
273        Ok(())
274    }
275
276    fn process_mixed(&mut self, mut part: IndexedPart<T::Fragment, T::Indexes>) {
277        for mixed_part in (&self.mixed_parts as &[IndexedPart<T::Fragment, T::Indexes>]).iter() {
278            if part.indexes == mixed_part.indexes {
279                return;
280            }
281        }
282
283        // Reduce this part by all simple parts.
284        for &index in self.received.iter() {
285            let offset = index * self.message_description.as_ref().unwrap().fragment_length;
286            part.reduce_by_simple(
287                &self.message
288                    [offset..offset + self.message_description.as_ref().unwrap().fragment_length],
289                index,
290            );
291            if part.is_simple() {
292                break;
293            }
294        }
295
296        // Then reduce this part by all the mixed parts.
297        if !part.is_simple() {
298            for mixed_part in self.mixed_parts.iter() {
299                part.reduce(mixed_part);
300                if part.is_simple() {
301                    break;
302                }
303            }
304        }
305
306        if part.is_simple() {
307            self.queue.push_back(part);
308        } else {
309            self.reduce_mixed(&part);
310            self.mixed_parts.try_push(part).ok();
311        }
312    }
313}
314
315/// Types for [`BaseDecoder`].
316pub trait Types: Default {
317    /// Decoded message buffer.
318    type Message: Vec<u8>;
319
320    /// Mixed parts storage.
321    type MixedParts: Vec<IndexedPart<Self::Fragment, Self::Indexes>>;
322
323    /// Fragment buffer.
324    type Fragment: Clone + Vec<u8>;
325
326    /// Indexes storage.
327    type Indexes: PartialEq + Set<usize>;
328
329    /// Part queue.
330    type Queue: Deque<IndexedPart<Self::Fragment, Self::Indexes>>;
331
332    /// Fragment chooser types.
333    type Chooser: chooser::Types;
334}
335
336/// [`alloc`] types for [`BaseDecoder`].
337#[derive(Default)]
338#[cfg(feature = "alloc")]
339pub struct Alloc;
340
341#[cfg(feature = "alloc")]
342impl Types for Alloc {
343    type Message = alloc::vec::Vec<u8>;
344    type MixedParts =
345        alloc::vec::Vec<IndexedPart<alloc::vec::Vec<u8>, alloc::collections::BTreeSet<usize>>>;
346    type Fragment = alloc::vec::Vec<u8>;
347    type Indexes = alloc::collections::BTreeSet<usize>;
348    type Queue = alloc::collections::VecDeque<
349        IndexedPart<alloc::vec::Vec<u8>, alloc::collections::BTreeSet<usize>>,
350    >;
351    type Chooser = chooser::Alloc;
352}
353
354/// [`heapless`] types for [`BaseDecoder`].
355#[derive(Default)]
356pub struct Heapless<
357    const MAX_MESSAGE_LEN: usize,
358    const MAX_MIXED_PARTS: usize,
359    const MAX_FRAGMENT_LEN: usize,
360    const MAX_SEQUENCE_COUNT: usize,
361    const QUEUE_SIZE: usize,
362>;
363
364impl<
365        const MAX_MESSAGE_LEN: usize,
366        const MAX_MIXED_PARTS: usize,
367        const MAX_FRAGMENT_LEN: usize,
368        const MAX_SEQUENCE_COUNT: usize,
369        const QUEUE_SIZE: usize,
370    > Types
371    for Heapless<MAX_MESSAGE_LEN, MAX_MIXED_PARTS, MAX_FRAGMENT_LEN, MAX_SEQUENCE_COUNT, QUEUE_SIZE>
372{
373    type Message = heapless::Vec<u8, MAX_MESSAGE_LEN>;
374
375    type MixedParts = heapless::Vec<
376        IndexedPart<
377            heapless::Vec<u8, MAX_FRAGMENT_LEN>,
378            heapless::FnvIndexSet<usize, MAX_SEQUENCE_COUNT>,
379        >,
380        MAX_MIXED_PARTS,
381    >;
382
383    type Fragment = heapless::Vec<u8, MAX_FRAGMENT_LEN>;
384
385    type Indexes = heapless::FnvIndexSet<usize, MAX_SEQUENCE_COUNT>;
386
387    type Queue = heapless::Deque<
388        IndexedPart<
389            heapless::Vec<u8, MAX_FRAGMENT_LEN>,
390            heapless::FnvIndexSet<usize, MAX_SEQUENCE_COUNT>,
391        >,
392        QUEUE_SIZE,
393    >;
394
395    type Chooser = chooser::Heapless<MAX_SEQUENCE_COUNT>;
396}
397
398/// Errors that can happen during decoding.
399#[derive(Debug)]
400pub enum Error {
401    /// The padding is invalid.
402    InvalidPadding,
403    /// The received part is inconsistent with the previously received ones.
404    InconsistentPart {
405        /// The description of the message from the received part.
406        received: MessageDescription,
407        /// The expected description of the message originated from the previous parts scanned.
408        expected: MessageDescription,
409    },
410    /// The received part is empty.
411    InvalidPart,
412    /// Not enough space to receive the part.
413    NotEnoughSpace {
414        /// Needed space.
415        needed: usize,
416        /// Current capacity.
417        capacity: usize,
418    },
419    /// Too many fragments.
420    TooManyFragments,
421}
422
423impl fmt::Display for Error {
424    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
425        match self {
426            Error::InvalidPadding => write!(f, "Invalid padding")?,
427            Error::InconsistentPart { received, expected } => {
428                write!(f, "Inconsistent part: ")?;
429
430                if received.sequence_count != expected.sequence_count {
431                    write!(
432                        f,
433                        "sequence count mismatch (received {}, expected {}). ",
434                        received.sequence_count, expected.sequence_count
435                    )?;
436                }
437
438                if received.message_length != expected.message_length {
439                    write!(
440                        f,
441                        "message length mismatch (received {}, expected {}). ",
442                        received.message_length, expected.message_length
443                    )?;
444                }
445
446                if received.checksum != expected.checksum {
447                    write!(
448                        f,
449                        "checksum mismatch (received {:X}, expected {:X}). ",
450                        received.checksum, expected.checksum
451                    )?;
452                }
453
454                if received.fragment_length != expected.fragment_length {
455                    write!(
456                        f,
457                        "checksum mismatch (received {:X}, expected {:X}). ",
458                        received.fragment_length, expected.fragment_length
459                    )?;
460                }
461            }
462            Error::InvalidPart => write!(f, "The scanned part is empty")?,
463            Error::NotEnoughSpace { needed, capacity } => {
464                write!(f, "Not enough space: needed {needed}, capacity {capacity}")?
465            }
466            Error::TooManyFragments => write!(f, "Too many fragments for the current message")?,
467        };
468        Ok(())
469    }
470}
471
472#[cfg(feature = "std")]
473impl std::error::Error for Error {}
474
475#[cfg(test)]
476#[cfg(feature = "alloc")]
477pub mod tests {
478    use super::*;
479    use crate::fountain::fragment_length;
480    use crate::{fountain::Encoder, xoshiro::test_utils::make_message};
481
482    const MESSAGE_SIZE: usize = 32767;
483    const MAX_FRAGMENT_LEN: usize = 1000;
484    const MAX_SEQUENCE_COUNT: usize = 64;
485    const MAX_MESSAGE_SIZE: usize =
486        fragment_length(MESSAGE_SIZE, MAX_FRAGMENT_LEN) * MAX_SEQUENCE_COUNT;
487    const SEED: &str = "Wolf";
488
489    fn message() -> alloc::vec::Vec<u8> {
490        make_message(SEED, MESSAGE_SIZE)
491    }
492
493    #[test]
494    fn test_decoder() {
495        fn test<T: Types>(decoder: &mut BaseDecoder<T>) {
496            let message = message();
497            let mut encoder = Encoder::new();
498            encoder.start(&message, MAX_FRAGMENT_LEN);
499            while !decoder.is_complete() {
500                assert_eq!(decoder.message().unwrap(), None);
501                let part = encoder.next_part();
502                let _next = decoder.receive(&part).unwrap();
503            }
504            assert_eq!(decoder.message().unwrap(), Some(message.as_slice()));
505        }
506
507        let mut heapless_decoder: HeaplessDecoder<
508            MAX_MESSAGE_SIZE,
509            MAX_SEQUENCE_COUNT,
510            MAX_FRAGMENT_LEN,
511            MAX_SEQUENCE_COUNT,
512            MAX_SEQUENCE_COUNT,
513        > = HeaplessDecoder::new();
514        let mut decoder = Decoder::default();
515
516        test(&mut heapless_decoder);
517        test(&mut decoder);
518    }
519
520    #[test]
521    fn test_decoder_skip_some_simple_fragments() {
522        let message = make_message(SEED, MESSAGE_SIZE);
523        let mut encoder = Encoder::new();
524        encoder.start(&message, MAX_FRAGMENT_LEN);
525        let mut decoder = Decoder::default();
526        let mut skip = false;
527        while !decoder.is_complete() {
528            let part = encoder.next_part();
529            if !skip {
530                let _next = decoder.receive(&part);
531            }
532            skip = !skip;
533        }
534        assert_eq!(decoder.message().unwrap(), Some(message.as_slice()));
535    }
536
537    #[test]
538    fn test_decoder_receive_return_value() {
539        let message = make_message(SEED, MESSAGE_SIZE);
540        let mut encoder = Encoder::new();
541        encoder.start(&message, MAX_FRAGMENT_LEN);
542        let mut decoder = Decoder::default();
543        let part = encoder.next_part();
544        assert!(decoder.receive(&part).unwrap());
545        // non-valid
546        let mut part = encoder.next_part();
547        part.checksum += 1;
548        // TODO:
549        // assert!(matches!(
550        //     decoder.receive(&part),
551        //     Err(Error::InconsistentPart)
552        // ));
553        // decoder complete
554        while !decoder.is_complete() {
555            let part = encoder.next_part();
556            decoder.receive(&part).unwrap();
557        }
558        let part = encoder.next_part();
559        assert!(!decoder.receive(&part).unwrap());
560    }
561
562    #[test]
563    fn test_decoder_part_validation() {
564        fn test<T: Types>(decoder: &mut BaseDecoder<T>) {
565            let mut encoder = Encoder::new();
566            encoder.start("foo".as_bytes(), 2);
567
568            let mut part = encoder.next_part();
569            assert!(decoder.receive(&part).unwrap());
570            assert!(decoder.is_part_consistent(&part));
571            part.checksum += 1;
572            assert!(!decoder.is_part_consistent(&part));
573            part.checksum -= 1;
574            part.message_length += 1;
575            assert!(!decoder.is_part_consistent(&part));
576            part.message_length -= 1;
577            part.sequence_count += 1;
578            assert!(!decoder.is_part_consistent(&part));
579            part.sequence_count -= 1;
580            part.data = &[0];
581            assert!(!decoder.is_part_consistent(&part));
582        }
583
584        let mut heapless_decoder: HeaplessDecoder<8, 8, 8, 8, 8> = HeaplessDecoder::new();
585        let mut decoder = Decoder::default();
586
587        test(&mut heapless_decoder);
588        test(&mut decoder);
589    }
590
591    #[test]
592    fn test_empty_decoder_empty_part() {
593        fn test<T: Types>(decoder: &mut BaseDecoder<T>) {
594            let mut part = Part {
595                sequence: 12,
596                sequence_count: 8,
597                message_length: 100,
598                checksum: 0x1234_5678,
599                data: &[1, 5, 3, 3, 5],
600            };
601
602            // Check sequence_count.
603            part.sequence_count = 0;
604            assert!(matches!(decoder.receive(&part), Err(Error::InvalidPart)));
605            part.sequence_count = 8;
606
607            // Check message_length.
608            part.message_length = 0;
609            assert!(matches!(decoder.receive(&part), Err(Error::InvalidPart)));
610            part.message_length = 100;
611
612            // Check data.
613            part.data = &[];
614            assert!(matches!(decoder.receive(&part), Err(Error::InvalidPart)));
615            part.data = &[1, 5, 3, 3, 5];
616
617            // Should not validate as there aren't any previous parts received.
618            assert!(!decoder.is_part_consistent(&part));
619        }
620
621        let mut heapless_decoder: HeaplessDecoder<100, 8, 5, 8, 8> = HeaplessDecoder::new();
622        let mut decoder = Decoder::default();
623
624        test(&mut heapless_decoder);
625        test(&mut decoder);
626    }
627}