1use reed_solomon_erasure::galois_8::ReedSolomon;
6use serde::{Deserialize, Serialize};
7use thiserror::Error;
8
9#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
12pub struct FecInfo {
13 pub data_shard_count: u32,
14 pub original_data_len: u64,
16}
17
18#[derive(Error, Debug, Clone, PartialEq, Eq)]
19pub enum FecError {
20 #[error("No data shards provided")]
21 EmptyDataShards,
22
23 #[error("Zero parity shards requested")]
24 ZeroParityShards,
25
26 #[error("Total shards {total} exceeds GF(2^8) limit of 255 (data={data}, parity={parity})")]
27 TooManyShards {
28 data: usize,
29 parity: usize,
30 total: usize,
31 },
32
33 #[error(
34 "Non-uniform shard sizes: first shard is {expected} bytes, shard {index} is {got} bytes"
35 )]
36 NonUniformShards {
37 expected: usize,
38 index: usize,
39 got: usize,
40 },
41
42 #[error("Empty shard data (all shards must be non-empty)")]
43 EmptyShardData,
44
45 #[error("Reed-Solomon encoder creation failed: {0}")]
46 EncoderCreationFailed(String),
47
48 #[error("Reed-Solomon encoding failed: {0}")]
49 EncodingFailed(String),
50
51 #[error("Reed-Solomon reconstruction failed: {0}")]
52 ReconstructionFailed(String),
53
54 #[error(
55 "Insufficient shards for reconstruction: have {available}, need {required} (data_shard_count)"
56 )]
57 InsufficientShards { available: usize, required: usize },
58
59 #[error("Shard array length {got} does not match expected {expected} (data + parity)")]
60 ShardCountMismatch { expected: usize, got: usize },
61}
62
63pub fn encode_parity_shards(
66 data_shards: &[Vec<u8>],
67 parity_count: usize,
68) -> Result<Vec<Vec<u8>>, FecError> {
69 if data_shards.is_empty() {
70 return Err(FecError::EmptyDataShards);
71 }
72 if parity_count == 0 {
73 return Err(FecError::ZeroParityShards);
74 }
75
76 let total = data_shards.len() + parity_count;
77 if total > 255 {
78 return Err(FecError::TooManyShards {
79 data: data_shards.len(),
80 parity: parity_count,
81 total,
82 });
83 }
84
85 let shard_size = data_shards[0].len();
86 if shard_size == 0 {
87 return Err(FecError::EmptyShardData);
88 }
89
90 for (i, shard) in data_shards.iter().enumerate().skip(1) {
91 if shard.len() != shard_size {
92 return Err(FecError::NonUniformShards {
93 expected: shard_size,
94 index: i,
95 got: shard.len(),
96 });
97 }
98 }
99
100 let rs = ReedSolomon::new(data_shards.len(), parity_count)
101 .map_err(|e| FecError::EncoderCreationFailed(e.to_string()))?;
102
103 let mut parity: Vec<Vec<u8>> = (0..parity_count).map(|_| vec![0u8; shard_size]).collect();
104
105 let data_refs: Vec<&[u8]> = data_shards.iter().map(Vec::as_slice).collect();
106 let mut parity_refs: Vec<&mut [u8]> = parity.iter_mut().map(Vec::as_mut_slice).collect();
107
108 rs.encode_sep(&data_refs, &mut parity_refs)
109 .map_err(|e| FecError::EncodingFailed(e.to_string()))?;
110
111 Ok(parity)
112}
113
114pub fn pad_to_uniform(data_chunks: &[Vec<u8>]) -> Result<(Vec<Vec<u8>>, usize), FecError> {
116 if data_chunks.is_empty() {
117 return Err(FecError::EmptyDataShards);
118 }
119
120 let shard_size = data_chunks[0].len();
121 let padded: Vec<Vec<u8>> = data_chunks
122 .iter()
123 .map(|chunk| {
124 if chunk.len() == shard_size {
125 chunk.clone()
126 } else {
127 let mut padded = chunk.clone();
128 padded.resize(shard_size, 0);
129 padded
130 }
131 })
132 .collect();
133
134 Ok((padded, shard_size))
135}
136
137pub fn decode_shards(
140 shards: &mut [Option<Vec<u8>>],
141 data_shard_count: usize,
142 original_data_len: u64,
143) -> Result<Vec<u8>, FecError> {
144 if data_shard_count == 0 {
145 return Err(FecError::EmptyDataShards);
146 }
147
148 let total_shards = shards.len();
149 if total_shards < data_shard_count {
150 return Err(FecError::ShardCountMismatch {
151 expected: data_shard_count,
152 got: total_shards,
153 });
154 }
155
156 let parity_count = total_shards - data_shard_count;
157
158 let available = shards.iter().filter(|s| s.is_some()).count();
159 if available < data_shard_count {
160 return Err(FecError::InsufficientShards {
161 available,
162 required: data_shard_count,
163 });
164 }
165
166 let all_data_present = shards[..data_shard_count].iter().all(Option::is_some);
167 if all_data_present {
168 let mut result = Vec::with_capacity(original_data_len as usize);
169 for shard in &shards[..data_shard_count] {
170 if let Some(data) = shard.as_ref() {
171 result.extend_from_slice(data);
172 }
173 }
174 result.truncate(original_data_len as usize);
175 return Ok(result);
176 }
177
178 if parity_count == 0 {
179 return Err(FecError::InsufficientShards {
180 available,
181 required: data_shard_count,
182 });
183 }
184
185 let rs = ReedSolomon::new(data_shard_count, parity_count)
186 .map_err(|e| FecError::EncoderCreationFailed(e.to_string()))?;
187
188 rs.reconstruct(shards)
189 .map_err(|e| FecError::ReconstructionFailed(e.to_string()))?;
190
191 let mut result = Vec::with_capacity(original_data_len as usize);
192 for shard in &shards[..data_shard_count] {
193 match shard.as_ref() {
194 Some(data) => result.extend_from_slice(data),
195 None => {
196 return Err(FecError::ReconstructionFailed(
197 "RS reconstruction did not fill all data shards".to_string(),
198 ));
199 }
200 }
201 }
202 result.truncate(original_data_len as usize);
203
204 Ok(result)
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 fn make_data_shards(data: &[u8], shard_size: usize) -> Vec<Vec<u8>> {
212 let chunks: Vec<Vec<u8>> = data.chunks(shard_size).map(|c| c.to_vec()).collect();
213 let (padded, _) = pad_to_uniform(&chunks).unwrap();
214 padded
215 }
216
217 #[test]
218 fn test_encode_decode_roundtrip() {
219 let original = b"Hello, Reed-Solomon FEC for mixnet responses!".to_vec();
220 let shard_size = 16;
221 let data_shards = make_data_shards(&original, shard_size);
222 let d = data_shards.len(); let parity = encode_parity_shards(&data_shards, 2).unwrap();
225 assert_eq!(parity.len(), 2);
226 assert!(parity.iter().all(|p| p.len() == shard_size));
227
228 let mut shards: Vec<Option<Vec<u8>>> = data_shards
229 .iter()
230 .chain(parity.iter())
231 .map(|s| Some(s.clone()))
232 .collect();
233
234 let recovered = decode_shards(&mut shards, d, original.len() as u64).unwrap();
235 assert_eq!(recovered, original);
236 }
237
238 #[test]
239 fn test_single_data_shard_recovery() {
240 let original = b"Short message".to_vec();
241 let shard_size = original.len();
242 let data_shards = vec![original.clone()]; let parity = encode_parity_shards(&data_shards, 1).unwrap(); assert_eq!(parity.len(), 1);
246 assert_eq!(parity[0].len(), shard_size);
247
248 let mut shards: Vec<Option<Vec<u8>>> = vec![None, Some(parity[0].clone())];
249
250 let recovered = decode_shards(&mut shards, 1, original.len() as u64).unwrap();
251 assert_eq!(recovered, original);
252 }
253
254 #[test]
255 fn test_fast_path_no_rs_needed() {
256 let original: Vec<u8> = (0..100).collect();
257 let shard_size = 25;
258 let data_shards = make_data_shards(&original, shard_size);
259 let d = data_shards.len(); let parity = encode_parity_shards(&data_shards, 2).unwrap();
262
263 let mut shards: Vec<Option<Vec<u8>>> = data_shards
264 .iter()
265 .map(|s| Some(s.clone()))
266 .chain(std::iter::repeat_with(|| None).take(parity.len()))
267 .collect();
268
269 let recovered = decode_shards(&mut shards, d, original.len() as u64).unwrap();
270 assert_eq!(recovered, original);
271 }
272
273 #[test]
274 fn test_drop_data_shard_rs_recovery() {
275 let original: Vec<u8> = (0..300).map(|i| (i % 256) as u8).collect();
276 let shard_size = 100;
277 let data_shards = make_data_shards(&original, shard_size);
278 let d = data_shards.len(); let parity = encode_parity_shards(&data_shards, 2).unwrap();
281
282 let mut shards: Vec<Option<Vec<u8>>> = vec![
283 Some(data_shards[0].clone()),
284 None, Some(data_shards[2].clone()),
286 Some(parity[0].clone()),
287 Some(parity[1].clone()),
288 ];
289
290 let recovered = decode_shards(&mut shards, d, original.len() as u64).unwrap();
291 assert_eq!(recovered, original);
292 }
293
294 #[test]
295 fn test_padding_edge_case() {
296 let original: Vec<u8> = (0..50).collect();
297 let shard_size = 16;
298 let data_shards = make_data_shards(&original, shard_size);
299 let d = data_shards.len(); assert_eq!(d, 4);
302 assert!(data_shards.iter().all(|s| s.len() == shard_size));
303 assert_eq!(data_shards[3][2..], vec![0u8; 14]);
304
305 let parity = encode_parity_shards(&data_shards, 1).unwrap();
306
307 let mut shards: Vec<Option<Vec<u8>>> = data_shards
308 .iter()
309 .chain(parity.iter())
310 .map(|s| Some(s.clone()))
311 .collect();
312
313 let recovered = decode_shards(&mut shards, d, original.len() as u64).unwrap();
314 assert_eq!(recovered, original);
315 }
316
317 #[test]
318 fn test_max_shard_boundary() {
319 let shard_size = 8;
320 let data_shards: Vec<Vec<u8>> = (0..200).map(|i| vec![i as u8; shard_size]).collect();
321
322 let parity = encode_parity_shards(&data_shards, 55).unwrap(); assert_eq!(parity.len(), 55);
324
325 let result = encode_parity_shards(&data_shards, 56);
326 assert!(matches!(
327 result,
328 Err(FecError::TooManyShards { total: 256, .. })
329 ));
330 }
331
332 #[test]
333 fn test_insufficient_shards_error() {
334 let original: Vec<u8> = (0..300).map(|i| (i % 256) as u8).collect();
335 let shard_size = 100;
336 let data_shards = make_data_shards(&original, shard_size);
337 let d = data_shards.len(); let parity = encode_parity_shards(&data_shards, 2).unwrap();
340
341 let mut shards: Vec<Option<Vec<u8>>> = vec![
342 None,
343 None,
344 Some(data_shards[2].clone()),
345 None,
346 Some(parity[1].clone()),
347 ];
348
349 let result = decode_shards(&mut shards, d, original.len() as u64);
350 assert!(matches!(
351 result,
352 Err(FecError::InsufficientShards {
353 available: 2,
354 required: 3,
355 })
356 ));
357 }
358
359 #[test]
360 fn test_fec_info_serialization_roundtrip() {
361 let info = FecInfo {
362 data_shard_count: 10,
363 original_data_len: 307_000,
364 };
365
366 let bytes = bincode::serialize(&info).unwrap();
367 let recovered: FecInfo = bincode::deserialize(&bytes).unwrap();
368 assert_eq!(info, recovered);
369
370 assert_eq!(bytes.len(), 12); }
372
373 #[test]
374 fn test_option_fec_info_none_overhead() {
375 let none_info: Option<FecInfo> = None;
376 let bytes = bincode::serialize(&none_info).unwrap();
377 assert!(bytes.len() <= 4);
378
379 let some_info: Option<FecInfo> = Some(FecInfo {
380 data_shard_count: 10,
381 original_data_len: 307_000,
382 });
383 let some_bytes = bincode::serialize(&some_info).unwrap();
384 assert!(some_bytes.len() <= 16);
385 }
386
387 #[test]
388 fn test_empty_data_shards_error() {
389 let result = encode_parity_shards(&[], 2);
390 assert!(matches!(result, Err(FecError::EmptyDataShards)));
391 }
392
393 #[test]
394 fn test_zero_parity_error() {
395 let shards = vec![vec![1u8, 2, 3]];
396 let result = encode_parity_shards(&shards, 0);
397 assert!(matches!(result, Err(FecError::ZeroParityShards)));
398 }
399
400 #[test]
401 fn test_non_uniform_shards_error() {
402 let shards = vec![vec![1u8, 2, 3], vec![4u8, 5]];
403 let result = encode_parity_shards(&shards, 1);
404 assert!(matches!(
405 result,
406 Err(FecError::NonUniformShards {
407 expected: 3,
408 index: 1,
409 got: 2,
410 })
411 ));
412 }
413
414 #[test]
415 fn test_pad_to_uniform() {
416 let chunks = vec![vec![1, 2, 3, 4, 5], vec![6, 7, 8, 9, 10], vec![11, 12]];
417
418 let (padded, shard_size) = pad_to_uniform(&chunks).unwrap();
419 assert_eq!(shard_size, 5);
420 assert_eq!(padded.len(), 3);
421 assert!(padded.iter().all(|s| s.len() == 5));
422 assert_eq!(padded[2], vec![11, 12, 0, 0, 0]);
423 }
424
425 #[test]
426 fn test_pad_to_uniform_empty_error() {
427 let result = pad_to_uniform(&[]);
428 assert!(matches!(result, Err(FecError::EmptyDataShards)));
429 }
430
431 #[test]
432 fn test_large_payload_fec() {
433 let original: Vec<u8> = (0..100_000).map(|i| (i % 256) as u8).collect();
434 let shard_size = 30_700;
435 let data_shards = make_data_shards(&original, shard_size);
436 let d = data_shards.len(); let p = ((d as f64) * 0.3).ceil() as usize; let parity = encode_parity_shards(&data_shards, p).unwrap();
440
441 let mut shards: Vec<Option<Vec<u8>>> = data_shards
442 .iter()
443 .chain(parity.iter())
444 .map(|s| Some(s.clone()))
445 .collect();
446
447 shards[1] = None;
448 shards[d] = None;
449
450 let recovered = decode_shards(&mut shards, d, original.len() as u64).unwrap();
451 assert_eq!(recovered, original);
452 }
453
454 #[test]
455 fn test_drop_all_parity_fast_path() {
456 let original: Vec<u8> = (0..200).collect();
457 let shard_size = 50;
458 let data_shards = make_data_shards(&original, shard_size);
459 let d = data_shards.len();
460
461 let parity = encode_parity_shards(&data_shards, 3).unwrap();
462
463 let mut shards: Vec<Option<Vec<u8>>> = data_shards
464 .iter()
465 .map(|s| Some(s.clone()))
466 .chain(std::iter::repeat_with(|| None).take(parity.len()))
467 .collect();
468
469 let recovered = decode_shards(&mut shards, d, original.len() as u64).unwrap();
470 assert_eq!(recovered, original);
471 }
472
473 #[test]
474 fn test_mixed_data_and_parity_drops() {
475 let original: Vec<u8> = (0..500).map(|i| (i % 256) as u8).collect();
476 let shard_size = 100;
477 let data_shards = make_data_shards(&original, shard_size);
478 let d = data_shards.len(); let parity = encode_parity_shards(&data_shards, 3).unwrap(); let mut shards: Vec<Option<Vec<u8>>> = data_shards
483 .iter()
484 .chain(parity.iter())
485 .map(|s| Some(s.clone()))
486 .collect();
487
488 shards[0] = None;
489 shards[3] = None;
490 shards[5] = None;
491 let recovered = decode_shards(&mut shards, d, original.len() as u64).unwrap();
492 assert_eq!(recovered, original);
493 }
494}