Skip to main content

nox_core/protocol/
fec.rs

1//! Reed-Solomon FEC for SURB responses. D-of-(D+P) reconstruction over GF(2^8).
2//! Applied response-path only (exit -> client). All shards must be uniform size
3//! (last data shard zero-padded); output truncated to `original_data_len`.
4
5use reed_solomon_erasure::galois_8::ReedSolomon;
6use serde::{Deserialize, Serialize};
7use thiserror::Error;
8
9/// FEC parameters carried on every fragment (12 bytes). Present on all fragments
10/// because any could be dropped and the reassembler needs these from whichever arrives first.
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
12pub struct FecInfo {
13    pub data_shard_count: u32,
14    /// Used to truncate zero-padding from the last data shard after reconstruction.
15    pub original_data_len: u64,
16}
17
18#[derive(Error, Debug, Clone, PartialEq, Eq)]
19pub enum FecError {
20    #[error("No data shards provided")]
21    EmptyDataShards,
22
23    #[error("Zero parity shards requested")]
24    ZeroParityShards,
25
26    #[error("Total shards {total} exceeds GF(2^8) limit of 255 (data={data}, parity={parity})")]
27    TooManyShards {
28        data: usize,
29        parity: usize,
30        total: usize,
31    },
32
33    #[error(
34        "Non-uniform shard sizes: first shard is {expected} bytes, shard {index} is {got} bytes"
35    )]
36    NonUniformShards {
37        expected: usize,
38        index: usize,
39        got: usize,
40    },
41
42    #[error("Empty shard data (all shards must be non-empty)")]
43    EmptyShardData,
44
45    #[error("Reed-Solomon encoder creation failed: {0}")]
46    EncoderCreationFailed(String),
47
48    #[error("Reed-Solomon encoding failed: {0}")]
49    EncodingFailed(String),
50
51    #[error("Reed-Solomon reconstruction failed: {0}")]
52    ReconstructionFailed(String),
53
54    #[error(
55        "Insufficient shards for reconstruction: have {available}, need {required} (data_shard_count)"
56    )]
57    InsufficientShards { available: usize, required: usize },
58
59    #[error("Shard array length {got} does not match expected {expected} (data + parity)")]
60    ShardCountMismatch { expected: usize, got: usize },
61}
62
63/// Generate parity shards from uniform-length data shards using Reed-Solomon.
64/// Caller MUST zero-pad the last data shard to match the others before calling.
65pub fn encode_parity_shards(
66    data_shards: &[Vec<u8>],
67    parity_count: usize,
68) -> Result<Vec<Vec<u8>>, FecError> {
69    if data_shards.is_empty() {
70        return Err(FecError::EmptyDataShards);
71    }
72    if parity_count == 0 {
73        return Err(FecError::ZeroParityShards);
74    }
75
76    let total = data_shards.len() + parity_count;
77    if total > 255 {
78        return Err(FecError::TooManyShards {
79            data: data_shards.len(),
80            parity: parity_count,
81            total,
82        });
83    }
84
85    let shard_size = data_shards[0].len();
86    if shard_size == 0 {
87        return Err(FecError::EmptyShardData);
88    }
89
90    for (i, shard) in data_shards.iter().enumerate().skip(1) {
91        if shard.len() != shard_size {
92            return Err(FecError::NonUniformShards {
93                expected: shard_size,
94                index: i,
95                got: shard.len(),
96            });
97        }
98    }
99
100    let rs = ReedSolomon::new(data_shards.len(), parity_count)
101        .map_err(|e| FecError::EncoderCreationFailed(e.to_string()))?;
102
103    let mut parity: Vec<Vec<u8>> = (0..parity_count).map(|_| vec![0u8; shard_size]).collect();
104
105    let data_refs: Vec<&[u8]> = data_shards.iter().map(Vec::as_slice).collect();
106    let mut parity_refs: Vec<&mut [u8]> = parity.iter_mut().map(Vec::as_mut_slice).collect();
107
108    rs.encode_sep(&data_refs, &mut parity_refs)
109        .map_err(|e| FecError::EncodingFailed(e.to_string()))?;
110
111    Ok(parity)
112}
113
114/// Pad data shards to uniform size for RS alignment. Last chunk is zero-padded.
115pub fn pad_to_uniform(data_chunks: &[Vec<u8>]) -> Result<(Vec<Vec<u8>>, usize), FecError> {
116    if data_chunks.is_empty() {
117        return Err(FecError::EmptyDataShards);
118    }
119
120    let shard_size = data_chunks[0].len();
121    let padded: Vec<Vec<u8>> = data_chunks
122        .iter()
123        .map(|chunk| {
124            if chunk.len() == shard_size {
125                chunk.clone()
126            } else {
127                let mut padded = chunk.clone();
128                padded.resize(shard_size, 0);
129                padded
130            }
131        })
132        .collect();
133
134    Ok((padded, shard_size))
135}
136
137/// Reconstruct original data from a (possibly incomplete) set of D+P shard slots.
138/// Fast path if all data shards present; RS reconstruction otherwise.
139pub fn decode_shards(
140    shards: &mut [Option<Vec<u8>>],
141    data_shard_count: usize,
142    original_data_len: u64,
143) -> Result<Vec<u8>, FecError> {
144    if data_shard_count == 0 {
145        return Err(FecError::EmptyDataShards);
146    }
147
148    let total_shards = shards.len();
149    if total_shards < data_shard_count {
150        return Err(FecError::ShardCountMismatch {
151            expected: data_shard_count,
152            got: total_shards,
153        });
154    }
155
156    let parity_count = total_shards - data_shard_count;
157
158    let available = shards.iter().filter(|s| s.is_some()).count();
159    if available < data_shard_count {
160        return Err(FecError::InsufficientShards {
161            available,
162            required: data_shard_count,
163        });
164    }
165
166    let all_data_present = shards[..data_shard_count].iter().all(Option::is_some);
167    if all_data_present {
168        let mut result = Vec::with_capacity(original_data_len as usize);
169        for shard in &shards[..data_shard_count] {
170            if let Some(data) = shard.as_ref() {
171                result.extend_from_slice(data);
172            }
173        }
174        result.truncate(original_data_len as usize);
175        return Ok(result);
176    }
177
178    if parity_count == 0 {
179        return Err(FecError::InsufficientShards {
180            available,
181            required: data_shard_count,
182        });
183    }
184
185    let rs = ReedSolomon::new(data_shard_count, parity_count)
186        .map_err(|e| FecError::EncoderCreationFailed(e.to_string()))?;
187
188    rs.reconstruct(shards)
189        .map_err(|e| FecError::ReconstructionFailed(e.to_string()))?;
190
191    let mut result = Vec::with_capacity(original_data_len as usize);
192    for shard in &shards[..data_shard_count] {
193        match shard.as_ref() {
194            Some(data) => result.extend_from_slice(data),
195            None => {
196                return Err(FecError::ReconstructionFailed(
197                    "RS reconstruction did not fill all data shards".to_string(),
198                ));
199            }
200        }
201    }
202    result.truncate(original_data_len as usize);
203
204    Ok(result)
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    fn make_data_shards(data: &[u8], shard_size: usize) -> Vec<Vec<u8>> {
212        let chunks: Vec<Vec<u8>> = data.chunks(shard_size).map(|c| c.to_vec()).collect();
213        let (padded, _) = pad_to_uniform(&chunks).unwrap();
214        padded
215    }
216
217    #[test]
218    fn test_encode_decode_roundtrip() {
219        let original = b"Hello, Reed-Solomon FEC for mixnet responses!".to_vec();
220        let shard_size = 16;
221        let data_shards = make_data_shards(&original, shard_size);
222        let d = data_shards.len(); // 3 data shards
223
224        let parity = encode_parity_shards(&data_shards, 2).unwrap();
225        assert_eq!(parity.len(), 2);
226        assert!(parity.iter().all(|p| p.len() == shard_size));
227
228        let mut shards: Vec<Option<Vec<u8>>> = data_shards
229            .iter()
230            .chain(parity.iter())
231            .map(|s| Some(s.clone()))
232            .collect();
233
234        let recovered = decode_shards(&mut shards, d, original.len() as u64).unwrap();
235        assert_eq!(recovered, original);
236    }
237
238    #[test]
239    fn test_single_data_shard_recovery() {
240        let original = b"Short message".to_vec();
241        let shard_size = original.len();
242        let data_shards = vec![original.clone()]; // D=1
243
244        let parity = encode_parity_shards(&data_shards, 1).unwrap(); // P=1
245        assert_eq!(parity.len(), 1);
246        assert_eq!(parity[0].len(), shard_size);
247
248        let mut shards: Vec<Option<Vec<u8>>> = vec![None, Some(parity[0].clone())];
249
250        let recovered = decode_shards(&mut shards, 1, original.len() as u64).unwrap();
251        assert_eq!(recovered, original);
252    }
253
254    #[test]
255    fn test_fast_path_no_rs_needed() {
256        let original: Vec<u8> = (0..100).collect();
257        let shard_size = 25;
258        let data_shards = make_data_shards(&original, shard_size);
259        let d = data_shards.len(); // 4
260
261        let parity = encode_parity_shards(&data_shards, 2).unwrap();
262
263        let mut shards: Vec<Option<Vec<u8>>> = data_shards
264            .iter()
265            .map(|s| Some(s.clone()))
266            .chain(std::iter::repeat_with(|| None).take(parity.len()))
267            .collect();
268
269        let recovered = decode_shards(&mut shards, d, original.len() as u64).unwrap();
270        assert_eq!(recovered, original);
271    }
272
273    #[test]
274    fn test_drop_data_shard_rs_recovery() {
275        let original: Vec<u8> = (0..300).map(|i| (i % 256) as u8).collect();
276        let shard_size = 100;
277        let data_shards = make_data_shards(&original, shard_size);
278        let d = data_shards.len(); // 3
279
280        let parity = encode_parity_shards(&data_shards, 2).unwrap();
281
282        let mut shards: Vec<Option<Vec<u8>>> = vec![
283            Some(data_shards[0].clone()),
284            None, // dropped!
285            Some(data_shards[2].clone()),
286            Some(parity[0].clone()),
287            Some(parity[1].clone()),
288        ];
289
290        let recovered = decode_shards(&mut shards, d, original.len() as u64).unwrap();
291        assert_eq!(recovered, original);
292    }
293
294    #[test]
295    fn test_padding_edge_case() {
296        let original: Vec<u8> = (0..50).collect();
297        let shard_size = 16;
298        let data_shards = make_data_shards(&original, shard_size);
299        let d = data_shards.len(); // ceil(50/16) = 4
300
301        assert_eq!(d, 4);
302        assert!(data_shards.iter().all(|s| s.len() == shard_size));
303        assert_eq!(data_shards[3][2..], vec![0u8; 14]);
304
305        let parity = encode_parity_shards(&data_shards, 1).unwrap();
306
307        let mut shards: Vec<Option<Vec<u8>>> = data_shards
308            .iter()
309            .chain(parity.iter())
310            .map(|s| Some(s.clone()))
311            .collect();
312
313        let recovered = decode_shards(&mut shards, d, original.len() as u64).unwrap();
314        assert_eq!(recovered, original);
315    }
316
317    #[test]
318    fn test_max_shard_boundary() {
319        let shard_size = 8;
320        let data_shards: Vec<Vec<u8>> = (0..200).map(|i| vec![i as u8; shard_size]).collect();
321
322        let parity = encode_parity_shards(&data_shards, 55).unwrap(); // 200 + 55 = 255
323        assert_eq!(parity.len(), 55);
324
325        let result = encode_parity_shards(&data_shards, 56);
326        assert!(matches!(
327            result,
328            Err(FecError::TooManyShards { total: 256, .. })
329        ));
330    }
331
332    #[test]
333    fn test_insufficient_shards_error() {
334        let original: Vec<u8> = (0..300).map(|i| (i % 256) as u8).collect();
335        let shard_size = 100;
336        let data_shards = make_data_shards(&original, shard_size);
337        let d = data_shards.len(); // 3
338
339        let parity = encode_parity_shards(&data_shards, 2).unwrap();
340
341        let mut shards: Vec<Option<Vec<u8>>> = vec![
342            None,
343            None,
344            Some(data_shards[2].clone()),
345            None,
346            Some(parity[1].clone()),
347        ];
348
349        let result = decode_shards(&mut shards, d, original.len() as u64);
350        assert!(matches!(
351            result,
352            Err(FecError::InsufficientShards {
353                available: 2,
354                required: 3,
355            })
356        ));
357    }
358
359    #[test]
360    fn test_fec_info_serialization_roundtrip() {
361        let info = FecInfo {
362            data_shard_count: 10,
363            original_data_len: 307_000,
364        };
365
366        let bytes = bincode::serialize(&info).unwrap();
367        let recovered: FecInfo = bincode::deserialize(&bytes).unwrap();
368        assert_eq!(info, recovered);
369
370        assert_eq!(bytes.len(), 12); // u32 (4) + u64 (8)
371    }
372
373    #[test]
374    fn test_option_fec_info_none_overhead() {
375        let none_info: Option<FecInfo> = None;
376        let bytes = bincode::serialize(&none_info).unwrap();
377        assert!(bytes.len() <= 4);
378
379        let some_info: Option<FecInfo> = Some(FecInfo {
380            data_shard_count: 10,
381            original_data_len: 307_000,
382        });
383        let some_bytes = bincode::serialize(&some_info).unwrap();
384        assert!(some_bytes.len() <= 16);
385    }
386
387    #[test]
388    fn test_empty_data_shards_error() {
389        let result = encode_parity_shards(&[], 2);
390        assert!(matches!(result, Err(FecError::EmptyDataShards)));
391    }
392
393    #[test]
394    fn test_zero_parity_error() {
395        let shards = vec![vec![1u8, 2, 3]];
396        let result = encode_parity_shards(&shards, 0);
397        assert!(matches!(result, Err(FecError::ZeroParityShards)));
398    }
399
400    #[test]
401    fn test_non_uniform_shards_error() {
402        let shards = vec![vec![1u8, 2, 3], vec![4u8, 5]];
403        let result = encode_parity_shards(&shards, 1);
404        assert!(matches!(
405            result,
406            Err(FecError::NonUniformShards {
407                expected: 3,
408                index: 1,
409                got: 2,
410            })
411        ));
412    }
413
414    #[test]
415    fn test_pad_to_uniform() {
416        let chunks = vec![vec![1, 2, 3, 4, 5], vec![6, 7, 8, 9, 10], vec![11, 12]];
417
418        let (padded, shard_size) = pad_to_uniform(&chunks).unwrap();
419        assert_eq!(shard_size, 5);
420        assert_eq!(padded.len(), 3);
421        assert!(padded.iter().all(|s| s.len() == 5));
422        assert_eq!(padded[2], vec![11, 12, 0, 0, 0]);
423    }
424
425    #[test]
426    fn test_pad_to_uniform_empty_error() {
427        let result = pad_to_uniform(&[]);
428        assert!(matches!(result, Err(FecError::EmptyDataShards)));
429    }
430
431    #[test]
432    fn test_large_payload_fec() {
433        let original: Vec<u8> = (0..100_000).map(|i| (i % 256) as u8).collect();
434        let shard_size = 30_700;
435        let data_shards = make_data_shards(&original, shard_size);
436        let d = data_shards.len(); // ceil(100000/30700) = 4
437
438        let p = ((d as f64) * 0.3).ceil() as usize; // 30% FEC = 2
439        let parity = encode_parity_shards(&data_shards, p).unwrap();
440
441        let mut shards: Vec<Option<Vec<u8>>> = data_shards
442            .iter()
443            .chain(parity.iter())
444            .map(|s| Some(s.clone()))
445            .collect();
446
447        shards[1] = None;
448        shards[d] = None;
449
450        let recovered = decode_shards(&mut shards, d, original.len() as u64).unwrap();
451        assert_eq!(recovered, original);
452    }
453
454    #[test]
455    fn test_drop_all_parity_fast_path() {
456        let original: Vec<u8> = (0..200).collect();
457        let shard_size = 50;
458        let data_shards = make_data_shards(&original, shard_size);
459        let d = data_shards.len();
460
461        let parity = encode_parity_shards(&data_shards, 3).unwrap();
462
463        let mut shards: Vec<Option<Vec<u8>>> = data_shards
464            .iter()
465            .map(|s| Some(s.clone()))
466            .chain(std::iter::repeat_with(|| None).take(parity.len()))
467            .collect();
468
469        let recovered = decode_shards(&mut shards, d, original.len() as u64).unwrap();
470        assert_eq!(recovered, original);
471    }
472
473    #[test]
474    fn test_mixed_data_and_parity_drops() {
475        let original: Vec<u8> = (0..500).map(|i| (i % 256) as u8).collect();
476        let shard_size = 100;
477        let data_shards = make_data_shards(&original, shard_size);
478        let d = data_shards.len(); // 5
479
480        let parity = encode_parity_shards(&data_shards, 3).unwrap(); // P=3
481
482        let mut shards: Vec<Option<Vec<u8>>> = data_shards
483            .iter()
484            .chain(parity.iter())
485            .map(|s| Some(s.clone()))
486            .collect();
487
488        shards[0] = None;
489        shards[3] = None;
490        shards[5] = None;
491        let recovered = decode_shards(&mut shards, d, original.len() as u64).unwrap();
492        assert_eq!(recovered, original);
493    }
494}