sosistab/fec/
decoder.rs

1use std::sync::Arc;
2
3use crate::buffer::Buff;
4
5use super::{post_decode, wrapped::WrappedReedSolomon};
6
7/// A single-use FEC decoder.
8#[derive(Debug)]
9pub struct FrameDecoder {
10    data_shards: usize,
11    parity_shards: usize,
12    space: Vec<Vec<u8>>,
13    present: Vec<bool>,
14    present_count: usize,
15    rs_decoder: Arc<WrappedReedSolomon>,
16    done: bool,
17}
18
19impl FrameDecoder {
20    #[tracing::instrument(level = "trace")]
21    pub fn new(data_shards: usize, parity_shards: usize) -> Self {
22        FrameDecoder {
23            data_shards,
24            parity_shards,
25            present_count: 0,
26            space: vec![],
27            present: vec![false; data_shards + parity_shards],
28            rs_decoder: WrappedReedSolomon::new_cached(data_shards, parity_shards),
29            done: false,
30        }
31    }
32
33    #[tracing::instrument(level = "trace", skip(self, pkt))]
34    pub fn decode(&mut self, pkt: &[u8], pkt_idx: usize) -> Option<Vec<Buff>> {
35        // if rand::random::<f64>() < 0.1 {
36        //     tracing::debug!("decoding with {}/{}", self.data_shards, self.parity_shards);
37        // }
38        // if we don't have parity shards, don't touch anything
39        if self.parity_shards == 0 {
40            self.done = true;
41            return Some(vec![post_decode(Buff::copy_from_slice(pkt))?]);
42        }
43        if self.space.is_empty() {
44            tracing::trace!("decode with pad len {}", pkt.len());
45            self.space = vec![vec![0u8; pkt.len()]; self.data_shards + self.parity_shards]
46        }
47        if self.space.len() <= pkt_idx {
48            return None;
49        }
50        if self.done
51            || pkt_idx > self.space.len()
52            || pkt_idx > self.present.len()
53            || self.space[pkt_idx].len() != pkt.len()
54        {
55            return None;
56        }
57        // decompress without allocation
58        self.space[pkt_idx].copy_from_slice(pkt);
59        if !self.present[pkt_idx] {
60            self.present_count += 1
61        }
62        self.present[pkt_idx] = true;
63        // if I'm a data shard, just return it
64        if pkt_idx < self.data_shards {
65            return Some(vec![post_decode(Buff::copy_from_slice(
66                &self.space[pkt_idx],
67            ))?]);
68        }
69        if self.present_count < self.data_shards {
70            tracing::trace!("don't even attempt yet");
71            return None;
72        }
73        let mut ref_vec: Vec<(&mut [u8], bool)> = self
74            .space
75            .iter_mut()
76            .zip(self.present.iter())
77            .map(|(v, pres)| (v.as_mut(), *pres))
78            .collect();
79        // otherwise, attempt to reconstruct
80        tracing::trace!(
81            "attempting to reconstruct (data={}, parity={})",
82            self.data_shards,
83            self.parity_shards
84        );
85        self.rs_decoder.get_inner().reconstruct(&mut ref_vec).ok()?;
86        self.done = true;
87        let res = self
88            .space
89            .drain(0..)
90            .zip(self.present.iter().cloned())
91            .take(self.data_shards)
92            .filter_map(|(elem, present)| {
93                if !present {
94                    post_decode(Buff::copy_from_slice(&elem))
95                } else {
96                    None
97                }
98            })
99            .collect();
100        Some(res)
101    }
102}