amadeus_utils/
reed_solomon.rs1use reed_solomon_simd::{ReedSolomonDecoder, ReedSolomonEncoder};
3
4pub const SHARD_SIZE: usize = 1024;
5
6pub struct ReedSolomonResource {
7 pub encoder: ReedSolomonEncoder,
8 pub decoder: ReedSolomonDecoder,
9}
10
11#[derive(Debug, thiserror::Error)]
12pub enum Error {
13 #[error(transparent)]
14 ReedSolomonSimd(#[from] reed_solomon_simd::Error),
15}
16
17impl ReedSolomonResource {
18 pub fn new(data_shards: usize, recovery_shards: usize) -> Result<ReedSolomonResource, Error> {
19 let encoder = ReedSolomonEncoder::new(data_shards, recovery_shards, SHARD_SIZE)?;
20 let decoder = ReedSolomonDecoder::new(data_shards, recovery_shards, SHARD_SIZE)?;
21 Ok(ReedSolomonResource { encoder, decoder })
22 }
23
24 pub fn encode_shards(&mut self, data: &[u8]) -> Result<Vec<(usize, Vec<u8>)>, Error> {
25 let chunk_size = SHARD_SIZE;
26
27 let chunk_count = (data.len() + 1023) / SHARD_SIZE;
28 let mut encoded_shards = Vec::with_capacity(chunk_count * 2);
29 let mut itr = 0;
30
31 for chunk_start in (0..data.len()).step_by(chunk_size) {
33 let chunk_end = (chunk_start + chunk_size).min(data.len());
34 let chunk = &data[chunk_start..chunk_end];
35
36 let mut buffer = [0u8; SHARD_SIZE];
38 buffer[..chunk.len()].copy_from_slice(chunk);
39
40 self.encoder.add_original_shard(buffer)?;
41
42 let bin = buffer.to_vec();
43 encoded_shards.push((itr, bin));
44 itr += 1;
45 }
46
47 let result = self.encoder.encode()?;
48 let recovery: Vec<_> = result.recovery_iter().collect();
49 for recovered_shard in recovery {
50 let bin = recovered_shard.to_vec();
51 encoded_shards.push((itr, bin));
52 itr += 1;
53 }
54
55 Ok(encoded_shards)
56 }
57
58 pub fn decode_shards(
59 &mut self,
60 shards: Vec<(usize, Vec<u8>)>,
61 total_shards: usize,
62 original_size: usize,
63 ) -> Result<Vec<u8>, Error> {
64 let mut combined = vec![0u8; original_size];
65
66 let half = total_shards / 2;
67 for (index, bin) in shards {
68 if index < half {
69 let shard_data = bin.as_slice();
70
71 let offset = index * SHARD_SIZE;
72 let end = (offset + shard_data.len()).min(original_size);
74 combined[offset..end].copy_from_slice(&shard_data[..(end - offset)]);
75
76 self.decoder.add_original_shard(index, shard_data)?;
77 } else {
78 self.decoder.add_recovery_shard(index - half, bin.as_slice())?;
79 }
80 }
81 let result = self.decoder.decode()?;
82
83 for idx in 0..half {
84 if let Some(shard_data) = result.restored_original(idx) {
85 let offset = idx * SHARD_SIZE;
86 let end = (offset + shard_data.len()).min(original_size);
87 combined[offset..end].copy_from_slice(&shard_data[..(end - offset)]);
88 }
89 }
90
91 Ok(combined)
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98
99 #[test]
100 fn rs_roundtrip_multi_shard() {
101 let len = SHARD_SIZE * 3 - 10;
103 let mut data = vec![0u8; len];
104 for (i, b) in data.iter_mut().enumerate() {
105 *b = (i % 251) as u8;
106 }
107
108 let mut rs = ReedSolomonResource::new(3, 3).expect("new");
110 let shards = rs.encode_shards(&data).expect("encode");
111 assert_eq!(shards.len(), 6);
112
113 let mut kept: Vec<(usize, Vec<u8>)> =
115 shards.into_iter().enumerate().filter(|(i, _)| *i != 1 && *i != 4).map(|(_, v)| v).collect();
116 let restored = rs.decode_shards(kept.drain(..).collect(), 6, len).expect("decode");
118 assert_eq!(restored, data);
119 }
120}