amadeus_utils/
reed_solomon.rs

1/// Translated from https://github.com/amadeus-robot/reedsolomon_ex
2use reed_solomon_simd::{ReedSolomonDecoder, ReedSolomonEncoder};
3
4pub const SHARD_SIZE: usize = 1024;
5
6pub struct ReedSolomonResource {
7    pub encoder: ReedSolomonEncoder,
8    pub decoder: ReedSolomonDecoder,
9}
10
11#[derive(Debug, thiserror::Error)]
12pub enum Error {
13    #[error(transparent)]
14    ReedSolomonSimd(#[from] reed_solomon_simd::Error),
15}
16
17impl ReedSolomonResource {
18    pub fn new(data_shards: usize, recovery_shards: usize) -> Result<ReedSolomonResource, Error> {
19        let encoder = ReedSolomonEncoder::new(data_shards, recovery_shards, SHARD_SIZE)?;
20        let decoder = ReedSolomonDecoder::new(data_shards, recovery_shards, SHARD_SIZE)?;
21        Ok(ReedSolomonResource { encoder, decoder })
22    }
23
24    pub fn encode_shards(&mut self, data: &[u8]) -> Result<Vec<(usize, Vec<u8>)>, Error> {
25        let chunk_size = SHARD_SIZE;
26
27        let chunk_count = (data.len() + 1023) / SHARD_SIZE;
28        let mut encoded_shards = Vec::with_capacity(chunk_count * 2);
29        let mut itr = 0;
30
31        // step through `data` in increments of `chunk_size`
32        for chunk_start in (0..data.len()).step_by(chunk_size) {
33            let chunk_end = (chunk_start + chunk_size).min(data.len());
34            let chunk = &data[chunk_start..chunk_end];
35
36            // create a 1024-byte buffer initialized to 0
37            let mut buffer = [0u8; SHARD_SIZE];
38            buffer[..chunk.len()].copy_from_slice(chunk);
39
40            self.encoder.add_original_shard(buffer)?;
41
42            let bin = buffer.to_vec();
43            encoded_shards.push((itr, bin));
44            itr += 1;
45        }
46
47        let result = self.encoder.encode()?;
48        let recovery: Vec<_> = result.recovery_iter().collect();
49        for recovered_shard in recovery {
50            let bin = recovered_shard.to_vec();
51            encoded_shards.push((itr, bin));
52            itr += 1;
53        }
54
55        Ok(encoded_shards)
56    }
57
58    pub fn decode_shards(
59        &mut self,
60        shards: Vec<(usize, Vec<u8>)>,
61        total_shards: usize,
62        original_size: usize,
63    ) -> Result<Vec<u8>, Error> {
64        let mut combined = vec![0u8; original_size];
65
66        let half = total_shards / 2;
67        for (index, bin) in shards {
68            if index < half {
69                let shard_data = bin.as_slice();
70
71                let offset = index * SHARD_SIZE;
72                // protect against going past original_size
73                let end = (offset + shard_data.len()).min(original_size);
74                combined[offset..end].copy_from_slice(&shard_data[..(end - offset)]);
75
76                self.decoder.add_original_shard(index, shard_data)?;
77            } else {
78                self.decoder.add_recovery_shard(index - half, bin.as_slice())?;
79            }
80        }
81        let result = self.decoder.decode()?;
82
83        for idx in 0..half {
84            if let Some(shard_data) = result.restored_original(idx) {
85                let offset = idx * SHARD_SIZE;
86                let end = (offset + shard_data.len()).min(original_size);
87                combined[offset..end].copy_from_slice(&shard_data[..(end - offset)]);
88            }
89        }
90
91        Ok(combined)
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98
99    #[test]
100    fn rs_roundtrip_multi_shard() {
101        // prepare data that spans 3 shards (last shard partial)
102        let len = SHARD_SIZE * 3 - 10;
103        let mut data = vec![0u8; len];
104        for (i, b) in data.iter_mut().enumerate() {
105            *b = (i % 251) as u8;
106        }
107
108        // set data_shards = recovery_shards = chunk_count = 3 for this test
109        let mut rs = ReedSolomonResource::new(3, 3).expect("new");
110        let shards = rs.encode_shards(&data).expect("encode");
111        assert_eq!(shards.len(), 6);
112
113        // drop one original shard and one recovery shard randomly
114        let mut kept: Vec<(usize, Vec<u8>)> =
115            shards.into_iter().enumerate().filter(|(i, _)| *i != 1 && *i != 4).map(|(_, v)| v).collect();
116        // total_shards = originals(3) + recovery(3)
117        let restored = rs.decode_shards(kept.drain(..).collect(), 6, len).expect("decode");
118        assert_eq!(restored, data);
119    }
120}