1use core::fmt;
8
9use crate::{
10 collections::{Deque, Set, Vec},
11 fountain::part::MessageDescription,
12 fountain::{
13 chooser,
14 chooser::BaseFragmentChooser,
15 part::{IndexedPart, Part},
16 },
17};
18
19#[cfg(feature = "alloc")]
21pub type Decoder = BaseDecoder<Alloc>;
22
23pub type HeaplessDecoder<
25 const MAX_MESSAGE_LEN: usize,
26 const MAX_MIXED_PARTS: usize,
27 const MAX_FRAGMENT_LEN: usize,
28 const MAX_SEQUENCE_COUNT: usize,
29 const QUEUE_SIZE: usize,
30> = BaseDecoder<
31 Heapless<MAX_MESSAGE_LEN, MAX_MIXED_PARTS, MAX_FRAGMENT_LEN, MAX_SEQUENCE_COUNT, QUEUE_SIZE>,
32>;
33
34impl<
35 const MAX_MESSAGE_LEN: usize,
36 const MAX_MIXED_PARTS: usize,
37 const MAX_FRAGMENT_LEN: usize,
38 const MAX_SEQUENCE_COUNT: usize,
39 const QUEUE_SIZE: usize,
40 >
41 HeaplessDecoder<
42 MAX_MESSAGE_LEN,
43 MAX_MIXED_PARTS,
44 MAX_FRAGMENT_LEN,
45 MAX_SEQUENCE_COUNT,
46 QUEUE_SIZE,
47 >
48{
49 pub const fn new() -> Self {
51 Self {
52 message: heapless::Vec::new(),
53 mixed_parts: heapless::Vec::new(),
54 received: heapless::IndexSet::new(),
55 queue: heapless::Deque::new(),
56 fragment_chooser: chooser::HeaplessFragmentChooser::new(),
57 message_description: None,
58 }
59 }
60}
61
62#[derive(Default)]
68pub struct BaseDecoder<T: Types> {
69 message: T::Message,
70 mixed_parts: T::MixedParts,
71 received: T::Indexes,
72 queue: T::Queue,
73 fragment_chooser: BaseFragmentChooser<T::Chooser>,
74 message_description: Option<MessageDescription>,
75}
76
77impl<T: Types> BaseDecoder<T> {
78 pub fn receive(&mut self, part: &Part) -> Result<bool, Error> {
91 if self.is_complete() {
92 return Ok(false);
93 }
94
95 if !part.is_valid() {
96 return Err(Error::InvalidPart);
97 }
98
99 if self.is_empty() {
100 let message_len = part.data.len() * usize::try_from(part.sequence_count).unwrap();
101 if self.message.try_resize(message_len, 0).is_err() {
102 return Err(Error::NotEnoughSpace {
103 needed: message_len,
104 capacity: self.message.capacity(),
105 });
106 }
107 self.message_description = Some(part.to_message_description());
108 } else if !self.is_part_consistent(part) {
109 return Err(Error::InconsistentPart {
110 received: part.to_message_description(),
111 expected: self.message_description.clone().unwrap(),
112 });
113 }
114
115 let indexes = self.fragment_chooser.choose_fragments(
116 part.sequence,
117 part.sequence_count,
118 part.checksum,
119 );
120
121 let mut data = T::Fragment::default();
122 if data.try_extend_from_slice(part.data).is_err() {
123 return Err(Error::NotEnoughSpace {
124 needed: part.data.len(),
125 capacity: data.capacity(),
126 });
127 }
128
129 let part = IndexedPart::new(data, indexes);
130 self.queue.push_back(part);
131
132 while !self.is_complete() && !self.queue.is_empty() {
133 let part = self.queue.pop_front().unwrap();
134 if part.is_simple() {
135 self.process_simple(&part)?;
136 } else {
137 self.process_mixed(part);
138 }
139 }
140 Ok(!self.is_complete())
141 }
142
143 #[must_use]
150 pub fn is_part_consistent(&self, part: &Part) -> bool {
151 match self.message_description {
152 Some(ref message_description) => part == message_description,
153 None => false,
154 }
155 }
156
157 pub fn message(&self) -> Result<Option<&[u8]>, Error> {
169 if self.is_complete() {
170 if self.message[self.message_description.as_ref().unwrap().message_length..]
171 .iter()
172 .any(|&b| b != 0)
173 {
174 return Err(Error::InvalidPadding);
175 }
176
177 Ok(Some(
178 &self.message[..self.message_description.as_ref().unwrap().message_length],
179 ))
180 } else {
181 Ok(None)
182 }
183 }
184
185 #[must_use]
191 pub fn is_complete(&self) -> bool {
192 if self.is_empty() {
193 return false;
194 }
195
196 self.received.len()
197 == self
198 .message_description
199 .as_ref()
200 .unwrap()
201 .sequence_count
202 .try_into()
203 .unwrap()
204 }
205
206 pub fn estimated_percent_complete(&self) -> f64 {
208 if self.is_complete() {
209 return 1.0;
210 }
211
212 if self.is_empty() {
213 return 0.0;
214 }
215
216 let estimated_input_parts =
217 f64::from(self.message_description.as_ref().unwrap().sequence_count) * 1.75;
218 let received_parts = u32::try_from(self.received.len()).unwrap();
219 f64::min(0.99, f64::from(received_parts) / estimated_input_parts)
220 }
221
222 #[must_use]
227 pub fn is_empty(&self) -> bool {
228 self.message.is_empty()
229 && self.mixed_parts.is_empty()
230 && self.received.is_empty()
231 && self.queue.is_empty()
232 && self.message_description.is_none()
233 }
234
235 pub fn clear(&mut self) {
237 self.message.clear();
238 self.mixed_parts.clear();
239 self.received.clear();
240 self.queue.clear();
241 self.message_description = None;
242
243 debug_assert!(self.is_empty());
244 }
245
246 fn reduce_mixed(&mut self, part: &IndexedPart<T::Fragment, T::Indexes>) {
247 self.mixed_parts.retain_mut(|mixed_part| {
248 mixed_part.reduce(part);
249
250 if mixed_part.is_simple() {
251 self.queue.push_back(mixed_part.clone());
252 }
253
254 !mixed_part.is_simple()
255 });
256 }
257
258 fn process_simple(&mut self, part: &IndexedPart<T::Fragment, T::Indexes>) -> Result<(), Error> {
259 let index = *part.indexes.first().unwrap();
260 if self.received.contains(&index) {
261 return Ok(());
262 }
263
264 self.reduce_mixed(part);
265
266 let offset = index * self.message_description.as_ref().unwrap().fragment_length;
267 self.message[offset..offset + self.message_description.as_ref().unwrap().fragment_length]
268 .copy_from_slice(&part.data);
269 self.received
270 .insert(index)
271 .map_err(|_| Error::TooManyFragments)?;
272
273 Ok(())
274 }
275
276 fn process_mixed(&mut self, mut part: IndexedPart<T::Fragment, T::Indexes>) {
277 for mixed_part in (&self.mixed_parts as &[IndexedPart<T::Fragment, T::Indexes>]).iter() {
278 if part.indexes == mixed_part.indexes {
279 return;
280 }
281 }
282
283 for &index in self.received.iter() {
285 let offset = index * self.message_description.as_ref().unwrap().fragment_length;
286 part.reduce_by_simple(
287 &self.message
288 [offset..offset + self.message_description.as_ref().unwrap().fragment_length],
289 index,
290 );
291 if part.is_simple() {
292 break;
293 }
294 }
295
296 if !part.is_simple() {
298 for mixed_part in self.mixed_parts.iter() {
299 part.reduce(mixed_part);
300 if part.is_simple() {
301 break;
302 }
303 }
304 }
305
306 if part.is_simple() {
307 self.queue.push_back(part);
308 } else {
309 self.reduce_mixed(&part);
310 self.mixed_parts.try_push(part).ok();
311 }
312 }
313}
314
315pub trait Types: Default {
317 type Message: Vec<u8>;
319
320 type MixedParts: Vec<IndexedPart<Self::Fragment, Self::Indexes>>;
322
323 type Fragment: Clone + Vec<u8>;
325
326 type Indexes: PartialEq + Set<usize>;
328
329 type Queue: Deque<IndexedPart<Self::Fragment, Self::Indexes>>;
331
332 type Chooser: chooser::Types;
334}
335
336#[derive(Default)]
338#[cfg(feature = "alloc")]
339pub struct Alloc;
340
341#[cfg(feature = "alloc")]
342impl Types for Alloc {
343 type Message = alloc::vec::Vec<u8>;
344 type MixedParts =
345 alloc::vec::Vec<IndexedPart<alloc::vec::Vec<u8>, alloc::collections::BTreeSet<usize>>>;
346 type Fragment = alloc::vec::Vec<u8>;
347 type Indexes = alloc::collections::BTreeSet<usize>;
348 type Queue = alloc::collections::VecDeque<
349 IndexedPart<alloc::vec::Vec<u8>, alloc::collections::BTreeSet<usize>>,
350 >;
351 type Chooser = chooser::Alloc;
352}
353
354#[derive(Default)]
356pub struct Heapless<
357 const MAX_MESSAGE_LEN: usize,
358 const MAX_MIXED_PARTS: usize,
359 const MAX_FRAGMENT_LEN: usize,
360 const MAX_SEQUENCE_COUNT: usize,
361 const QUEUE_SIZE: usize,
362>;
363
364impl<
365 const MAX_MESSAGE_LEN: usize,
366 const MAX_MIXED_PARTS: usize,
367 const MAX_FRAGMENT_LEN: usize,
368 const MAX_SEQUENCE_COUNT: usize,
369 const QUEUE_SIZE: usize,
370 > Types
371 for Heapless<MAX_MESSAGE_LEN, MAX_MIXED_PARTS, MAX_FRAGMENT_LEN, MAX_SEQUENCE_COUNT, QUEUE_SIZE>
372{
373 type Message = heapless::Vec<u8, MAX_MESSAGE_LEN>;
374
375 type MixedParts = heapless::Vec<
376 IndexedPart<
377 heapless::Vec<u8, MAX_FRAGMENT_LEN>,
378 heapless::FnvIndexSet<usize, MAX_SEQUENCE_COUNT>,
379 >,
380 MAX_MIXED_PARTS,
381 >;
382
383 type Fragment = heapless::Vec<u8, MAX_FRAGMENT_LEN>;
384
385 type Indexes = heapless::FnvIndexSet<usize, MAX_SEQUENCE_COUNT>;
386
387 type Queue = heapless::Deque<
388 IndexedPart<
389 heapless::Vec<u8, MAX_FRAGMENT_LEN>,
390 heapless::FnvIndexSet<usize, MAX_SEQUENCE_COUNT>,
391 >,
392 QUEUE_SIZE,
393 >;
394
395 type Chooser = chooser::Heapless<MAX_SEQUENCE_COUNT>;
396}
397
398#[derive(Debug)]
400pub enum Error {
401 InvalidPadding,
403 InconsistentPart {
405 received: MessageDescription,
407 expected: MessageDescription,
409 },
410 InvalidPart,
412 NotEnoughSpace {
414 needed: usize,
416 capacity: usize,
418 },
419 TooManyFragments,
421}
422
423impl fmt::Display for Error {
424 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
425 match self {
426 Error::InvalidPadding => write!(f, "Invalid padding")?,
427 Error::InconsistentPart { received, expected } => {
428 write!(f, "Inconsistent part: ")?;
429
430 if received.sequence_count != expected.sequence_count {
431 write!(
432 f,
433 "sequence count mismatch (received {}, expected {}). ",
434 received.sequence_count, expected.sequence_count
435 )?;
436 }
437
438 if received.message_length != expected.message_length {
439 write!(
440 f,
441 "message length mismatch (received {}, expected {}). ",
442 received.message_length, expected.message_length
443 )?;
444 }
445
446 if received.checksum != expected.checksum {
447 write!(
448 f,
449 "checksum mismatch (received {:X}, expected {:X}). ",
450 received.checksum, expected.checksum
451 )?;
452 }
453
454 if received.fragment_length != expected.fragment_length {
455 write!(
456 f,
457 "checksum mismatch (received {:X}, expected {:X}). ",
458 received.fragment_length, expected.fragment_length
459 )?;
460 }
461 }
462 Error::InvalidPart => write!(f, "The scanned part is empty")?,
463 Error::NotEnoughSpace { needed, capacity } => {
464 write!(f, "Not enough space: needed {needed}, capacity {capacity}")?
465 }
466 Error::TooManyFragments => write!(f, "Too many fragments for the current message")?,
467 };
468 Ok(())
469 }
470}
471
472#[cfg(feature = "std")]
473impl std::error::Error for Error {}
474
475#[cfg(test)]
476#[cfg(feature = "alloc")]
477pub mod tests {
478 use super::*;
479 use crate::fountain::fragment_length;
480 use crate::{fountain::Encoder, xoshiro::test_utils::make_message};
481
482 const MESSAGE_SIZE: usize = 32767;
483 const MAX_FRAGMENT_LEN: usize = 1000;
484 const MAX_SEQUENCE_COUNT: usize = 64;
485 const MAX_MESSAGE_SIZE: usize =
486 fragment_length(MESSAGE_SIZE, MAX_FRAGMENT_LEN) * MAX_SEQUENCE_COUNT;
487 const SEED: &str = "Wolf";
488
489 fn message() -> alloc::vec::Vec<u8> {
490 make_message(SEED, MESSAGE_SIZE)
491 }
492
493 #[test]
494 fn test_decoder() {
495 fn test<T: Types>(decoder: &mut BaseDecoder<T>) {
496 let message = message();
497 let mut encoder = Encoder::new();
498 encoder.start(&message, MAX_FRAGMENT_LEN);
499 while !decoder.is_complete() {
500 assert_eq!(decoder.message().unwrap(), None);
501 let part = encoder.next_part();
502 let _next = decoder.receive(&part).unwrap();
503 }
504 assert_eq!(decoder.message().unwrap(), Some(message.as_slice()));
505 }
506
507 let mut heapless_decoder: HeaplessDecoder<
508 MAX_MESSAGE_SIZE,
509 MAX_SEQUENCE_COUNT,
510 MAX_FRAGMENT_LEN,
511 MAX_SEQUENCE_COUNT,
512 MAX_SEQUENCE_COUNT,
513 > = HeaplessDecoder::new();
514 let mut decoder = Decoder::default();
515
516 test(&mut heapless_decoder);
517 test(&mut decoder);
518 }
519
520 #[test]
521 fn test_decoder_skip_some_simple_fragments() {
522 let message = make_message(SEED, MESSAGE_SIZE);
523 let mut encoder = Encoder::new();
524 encoder.start(&message, MAX_FRAGMENT_LEN);
525 let mut decoder = Decoder::default();
526 let mut skip = false;
527 while !decoder.is_complete() {
528 let part = encoder.next_part();
529 if !skip {
530 let _next = decoder.receive(&part);
531 }
532 skip = !skip;
533 }
534 assert_eq!(decoder.message().unwrap(), Some(message.as_slice()));
535 }
536
537 #[test]
538 fn test_decoder_receive_return_value() {
539 let message = make_message(SEED, MESSAGE_SIZE);
540 let mut encoder = Encoder::new();
541 encoder.start(&message, MAX_FRAGMENT_LEN);
542 let mut decoder = Decoder::default();
543 let part = encoder.next_part();
544 assert!(decoder.receive(&part).unwrap());
545 let mut part = encoder.next_part();
547 part.checksum += 1;
548 while !decoder.is_complete() {
555 let part = encoder.next_part();
556 decoder.receive(&part).unwrap();
557 }
558 let part = encoder.next_part();
559 assert!(!decoder.receive(&part).unwrap());
560 }
561
562 #[test]
563 fn test_decoder_part_validation() {
564 fn test<T: Types>(decoder: &mut BaseDecoder<T>) {
565 let mut encoder = Encoder::new();
566 encoder.start("foo".as_bytes(), 2);
567
568 let mut part = encoder.next_part();
569 assert!(decoder.receive(&part).unwrap());
570 assert!(decoder.is_part_consistent(&part));
571 part.checksum += 1;
572 assert!(!decoder.is_part_consistent(&part));
573 part.checksum -= 1;
574 part.message_length += 1;
575 assert!(!decoder.is_part_consistent(&part));
576 part.message_length -= 1;
577 part.sequence_count += 1;
578 assert!(!decoder.is_part_consistent(&part));
579 part.sequence_count -= 1;
580 part.data = &[0];
581 assert!(!decoder.is_part_consistent(&part));
582 }
583
584 let mut heapless_decoder: HeaplessDecoder<8, 8, 8, 8, 8> = HeaplessDecoder::new();
585 let mut decoder = Decoder::default();
586
587 test(&mut heapless_decoder);
588 test(&mut decoder);
589 }
590
591 #[test]
592 fn test_empty_decoder_empty_part() {
593 fn test<T: Types>(decoder: &mut BaseDecoder<T>) {
594 let mut part = Part {
595 sequence: 12,
596 sequence_count: 8,
597 message_length: 100,
598 checksum: 0x1234_5678,
599 data: &[1, 5, 3, 3, 5],
600 };
601
602 part.sequence_count = 0;
604 assert!(matches!(decoder.receive(&part), Err(Error::InvalidPart)));
605 part.sequence_count = 8;
606
607 part.message_length = 0;
609 assert!(matches!(decoder.receive(&part), Err(Error::InvalidPart)));
610 part.message_length = 100;
611
612 part.data = &[];
614 assert!(matches!(decoder.receive(&part), Err(Error::InvalidPart)));
615 part.data = &[1, 5, 3, 3, 5];
616
617 assert!(!decoder.is_part_consistent(&part));
619 }
620
621 let mut heapless_decoder: HeaplessDecoder<100, 8, 5, 8, 8> = HeaplessDecoder::new();
622 let mut decoder = Decoder::default();
623
624 test(&mut heapless_decoder);
625 test(&mut decoder);
626 }
627}