1use bytes::{Buf, BufMut};
116use commonware_codec::{EncodeSize, FixedSize, Read, ReadExt, ReadRangeExt, Write};
117use commonware_cryptography::Hasher;
118use commonware_storage::bmt::{self, Builder};
119use reed_solomon_simd::{Error as RsError, ReedSolomonDecoder, ReedSolomonEncoder};
120use std::collections::HashSet;
121use thiserror::Error;
122
123#[derive(Error, Debug)]
125pub enum Error {
126 #[error("reed-solomon error: {0}")]
127 ReedSolomon(#[from] RsError),
128 #[error("inconsistent")]
129 Inconsistent,
130 #[error("invalid proof")]
131 InvalidProof,
132 #[error("not enough chunks")]
133 NotEnoughChunks,
134 #[error("duplicate chunk index: {0}")]
135 DuplicateIndex(u16),
136 #[error("invalid data length: {0}")]
137 InvalidDataLength(usize),
138 #[error("invalid index: {0}")]
139 InvalidIndex(u16),
140}
141
142#[derive(Clone)]
144pub struct Chunk<H: Hasher> {
145 pub shard: Vec<u8>,
147
148 pub index: u16,
150
151 pub proof: bmt::Proof<H>,
153}
154
155impl<H: Hasher> Chunk<H> {
156 pub fn new(shard: Vec<u8>, index: u16, proof: bmt::Proof<H>) -> Self {
158 Self {
159 shard,
160 index,
161 proof,
162 }
163 }
164
165 pub fn verify(&self, root: &H::Digest) -> bool {
167 let mut hasher = H::new();
169 hasher.update(&self.shard);
170 let shard_digest = hasher.finalize();
171
172 self.proof
174 .verify(&mut hasher, &shard_digest, self.index as u32, root)
175 .is_ok()
176 }
177}
178
179impl<H: Hasher> Write for Chunk<H> {
180 fn write(&self, writer: &mut impl BufMut) {
181 self.shard.write(writer);
182 self.index.write(writer);
183 self.proof.write(writer);
184 }
185}
186
187impl<H: Hasher> Read for Chunk<H> {
188 type Cfg = usize;
190
191 fn read_cfg(reader: &mut impl Buf, cfg: &Self::Cfg) -> Result<Self, commonware_codec::Error> {
192 let shard = Vec::<u8>::read_range(reader, ..=*cfg)?;
193 let index = u16::read(reader)?;
194 let proof = bmt::Proof::<H>::read(reader)?;
195 Ok(Self {
196 shard,
197 index,
198 proof,
199 })
200 }
201}
202
203impl<H: Hasher> EncodeSize for Chunk<H> {
204 fn encode_size(&self) -> usize {
205 self.shard.encode_size() + self.index.encode_size() + self.proof.encode_size()
206 }
207}
208
209fn prepare_data(data: Vec<u8>, k: usize, m: usize) -> Vec<Vec<u8>> {
211 let data_len = data.len();
213 let prefixed_len = u32::SIZE + data_len;
214 let mut shard_len = prefixed_len.div_ceil(k);
215
216 if shard_len % 2 != 0 {
218 shard_len += 1;
219 }
220
221 let length_bytes = (data_len as u32).to_be_bytes();
223 let mut src = length_bytes.into_iter().chain(data);
224 let mut shards = Vec::with_capacity(k + m); for _ in 0..k {
226 let mut shard = Vec::with_capacity(shard_len);
227 for _ in 0..shard_len {
228 shard.push(src.next().unwrap_or(0));
229 }
230 shards.push(shard);
231 }
232 shards
233}
234
235fn extract_data(shards: Vec<Vec<u8>>, k: usize) -> Vec<u8> {
237 let mut data = shards.into_iter().take(k).flatten();
239
240 let data_len = (&mut data)
242 .take(u32::SIZE)
243 .collect::<Vec<_>>()
244 .try_into()
245 .expect("insufficient data");
246 let data_len = u32::from_be_bytes(data_len) as usize;
247
248 data.take(data_len).collect()
250}
251
252pub fn encode<H: Hasher>(
265 total: u16,
266 min: u16,
267 data: Vec<u8>,
268) -> Result<(H::Digest, Vec<Chunk<H>>), Error> {
269 assert!(total > min);
271 assert!(min > 0);
272 let n = total as usize;
273 let k = min as usize;
274 let m = n - k;
275 if data.len() > u32::MAX as usize {
276 return Err(Error::InvalidDataLength(data.len()));
277 }
278
279 let mut shards = prepare_data(data, k, m);
281 let shard_len = shards[0].len();
282
283 let mut encoder = ReedSolomonEncoder::new(k, m, shard_len).map_err(Error::ReedSolomon)?;
285 for shard in &shards {
286 encoder
287 .add_original_shard(shard)
288 .map_err(Error::ReedSolomon)?;
289 }
290
291 let encoding = encoder.encode().map_err(Error::ReedSolomon)?;
293 let recovery_shards: Vec<Vec<u8>> = encoding
294 .recovery_iter()
295 .map(|shard| shard.to_vec())
296 .collect();
297 shards.extend(recovery_shards);
298
299 let mut builder = Builder::<H>::new(n);
301 let mut hasher = H::new();
302 for shard in &shards {
303 builder.add(&{
304 hasher.update(shard);
305 hasher.finalize()
306 });
307 }
308 let tree = builder.build();
309 let root = tree.root();
310
311 let mut chunks = Vec::with_capacity(n);
313 for (i, shard) in shards.into_iter().enumerate() {
314 let proof = tree.proof(i as u32).map_err(|_| Error::InvalidProof)?;
315 chunks.push(Chunk::new(shard, i as u16, proof));
316 }
317
318 Ok((root, chunks))
319}
320
321pub fn decode<H: Hasher>(
334 total: u16,
335 min: u16,
336 root: &H::Digest,
337 chunks: Vec<Chunk<H>>,
338) -> Result<Vec<u8>, Error> {
339 assert!(total > min);
341 assert!(min > 0);
342 let n = total as usize;
343 let k = min as usize;
344 let m = n - k;
345 if chunks.len() < k {
346 return Err(Error::NotEnoughChunks);
347 }
348
349 let shard_len = chunks[0].shard.len();
351 let mut seen = HashSet::new();
352 let mut provided_originals: Vec<(usize, Vec<u8>)> = Vec::new();
353 let mut provided_recoveries: Vec<(usize, Vec<u8>)> = Vec::new();
354 for chunk in chunks {
355 let index = chunk.index;
357 if index >= total {
358 return Err(Error::InvalidIndex(index));
359 }
360 if seen.contains(&index) {
361 return Err(Error::DuplicateIndex(index));
362 }
363 seen.insert(index);
364
365 if !chunk.verify(root) {
367 return Err(Error::InvalidProof);
368 }
369
370 if index < min {
372 provided_originals.push((index as usize, chunk.shard));
373 } else {
374 provided_recoveries.push((index as usize - k, chunk.shard));
375 }
376 }
377
378 let mut decoder = ReedSolomonDecoder::new(k, m, shard_len).map_err(Error::ReedSolomon)?;
380 for (idx, ref shard) in &provided_originals {
381 decoder
382 .add_original_shard(*idx, shard)
383 .map_err(Error::ReedSolomon)?;
384 }
385 for (idx, ref shard) in &provided_recoveries {
386 decoder
387 .add_recovery_shard(*idx, shard)
388 .map_err(Error::ReedSolomon)?;
389 }
390 let decoding = decoder.decode().map_err(Error::ReedSolomon)?;
391
392 let mut shards = Vec::with_capacity(n);
394 shards.resize(k, Vec::new());
395 for (idx, shard) in provided_originals {
396 shards[idx] = shard;
397 }
398 for (idx, shard) in decoding.restored_original_iter() {
399 shards[idx] = shard.to_vec();
400 }
401
402 let mut encoder = ReedSolomonEncoder::new(k, m, shard_len).map_err(Error::ReedSolomon)?;
404 for shard in shards.iter().take(k) {
405 encoder
406 .add_original_shard(shard)
407 .map_err(Error::ReedSolomon)?;
408 }
409 let encoding = encoder.encode().map_err(Error::ReedSolomon)?;
410 let recovery_shards: Vec<Vec<u8>> = encoding
411 .recovery_iter()
412 .map(|shard| shard.to_vec())
413 .collect();
414 shards.extend(recovery_shards);
415
416 let mut builder = Builder::<H>::new(n);
418 let mut hasher = H::new();
419 for shard in &shards {
420 builder.add(&{
421 hasher.update(shard);
422 hasher.finalize()
423 });
424 }
425 let computed_tree = builder.build();
426
427 if computed_tree.root() != *root {
429 return Err(Error::Inconsistent);
430 }
431
432 Ok(extract_data(shards, k))
434}
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439 use commonware_cryptography::Sha256;
440
441 #[test]
442 fn test_basic() {
443 let data = b"Hello, Reed-Solomon!";
444 let total = 7u16;
445 let min = 4u16;
446
447 let (root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
449 assert_eq!(chunks.len(), total as usize);
450
451 let minimal = chunks.into_iter().take(min as usize).collect();
453 let decoded = decode::<Sha256>(total, min, &root, minimal).unwrap();
454 assert_eq!(decoded, data);
455 }
456
457 #[test]
458 fn test_moderate() {
459 let data = b"Testing with more pieces than minimum";
460 let total = 10u16;
461 let min = 4u16;
462
463 let (root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
465
466 let minimal = chunks.into_iter().take(min as usize).collect();
468 let decoded = decode::<Sha256>(total, min, &root, minimal).unwrap();
469 assert_eq!(decoded, data);
470 }
471
472 #[test]
473 fn test_recovery() {
474 let data = b"Testing recovery pieces";
475 let total = 8u16;
476 let min = 3u16;
477
478 let (root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
480
481 let pieces: Vec<_> = vec![
483 chunks[0].clone(), chunks[4].clone(), chunks[6].clone(), ];
487
488 let decoded = decode::<Sha256>(total, min, &root, pieces).unwrap();
490 assert_eq!(decoded, data);
491 }
492
493 #[test]
494 fn test_not_enough_pieces() {
495 let data = b"Test insufficient pieces";
496 let total = 6u16;
497 let min = 4u16;
498
499 let (root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
501
502 let pieces: Vec<_> = chunks.into_iter().take(2).collect();
504
505 let result = decode::<Sha256>(total, min, &root, pieces);
507 assert!(matches!(result, Err(Error::NotEnoughChunks)));
508 }
509
510 #[test]
511 fn test_duplicate_index() {
512 let data = b"Test duplicate detection";
513 let total = 5u16;
514 let min = 3u16;
515
516 let (root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
518
519 let pieces = vec![chunks[0].clone(), chunks[0].clone(), chunks[1].clone()];
521
522 let result = decode::<Sha256>(total, min, &root, pieces);
524 assert!(matches!(result, Err(Error::DuplicateIndex(0))));
525 }
526
527 #[test]
528 #[should_panic(expected = "assertion failed: total > min")]
529 fn test_invalid_total() {
530 let data = b"Test parameter validation";
531
532 encode::<Sha256>(3, 3, data.to_vec()).unwrap();
534 }
535
536 #[test]
537 #[should_panic(expected = "assertion failed: min > 0")]
538 fn test_invalid_min() {
539 let data = b"Test parameter validation";
540
541 encode::<Sha256>(5, 0, data.to_vec()).unwrap();
543 }
544
545 #[test]
546 fn test_empty_data() {
547 let data = b"";
548 let total = 100u16;
549 let min = 30u16;
550
551 let (root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
553
554 let minimal = chunks.into_iter().take(min as usize).collect();
556 let decoded = decode::<Sha256>(total, min, &root, minimal).unwrap();
557 assert_eq!(decoded, data);
558 }
559
560 #[test]
561 fn test_large_data() {
562 let data = vec![42u8; 1000]; let total = 7u16;
564 let min = 4u16;
565
566 let (root, chunks) = encode::<Sha256>(total, min, data.clone()).unwrap();
568
569 let minimal = chunks.into_iter().take(min as usize).collect();
571 let decoded = decode::<Sha256>(total, min, &root, minimal).unwrap();
572 assert_eq!(decoded, data);
573 }
574
575 #[test]
576 fn test_malicious_root_detection() {
577 let data = b"Original data that should be protected";
578 let total = 7u16;
579 let min = 4u16;
580
581 let (_correct_root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
583
584 let mut hasher = Sha256::new();
586 hasher.update(b"malicious_data_that_wasnt_actually_encoded");
587 let malicious_root = hasher.finalize();
588
589 let minimal = chunks.into_iter().take(min as usize).collect();
591
592 let result = decode::<Sha256>(total, min, &malicious_root, minimal);
594 assert!(matches!(result, Err(Error::InvalidProof)));
595 }
596
597 #[test]
598 fn test_manipulated_chunk_detection() {
599 let data = b"Data integrity must be maintained";
600 let total = 6u16;
601 let min = 3u16;
602
603 let (root, mut chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
605
606 if !chunks[1].shard.is_empty() {
608 chunks[1].shard[0] ^= 0xFF; }
610
611 let result = decode::<Sha256>(total, min, &root, chunks);
613 assert!(matches!(result, Err(Error::InvalidProof)));
614 }
615
616 #[test]
617 fn test_inconsistent_shards() {
618 let data = b"Test data for malicious encoding";
619 let total = 5u16;
620 let min = 3u16;
621 let m = total - min;
622
623 let shards = prepare_data(data.to_vec(), min as usize, total as usize - min as usize);
625 let shard_size = shards[0].len();
626
627 let mut encoder = ReedSolomonEncoder::new(min as usize, m as usize, shard_size).unwrap();
629 for shard in &shards {
630 encoder.add_original_shard(shard).unwrap();
631 }
632 let recovery_result = encoder.encode().unwrap();
633 let mut recovery_shards: Vec<Vec<u8>> = recovery_result
634 .recovery_iter()
635 .map(|s| s.to_vec())
636 .collect();
637
638 if !recovery_shards[0].is_empty() {
640 recovery_shards[0][0] ^= 0xFF;
641 }
642
643 let mut malicious_shards = shards.clone();
645 malicious_shards.extend(recovery_shards);
646
647 let mut builder = Builder::<Sha256>::new(total as usize);
649 for shard in &malicious_shards {
650 let mut hasher = Sha256::new();
651 hasher.update(shard);
652 builder.add(&hasher.finalize());
653 }
654 let malicious_tree = builder.build();
655 let malicious_root = malicious_tree.root();
656
657 let selected_indices = vec![0, 1, 3]; let mut pieces = Vec::new();
660 for &i in &selected_indices {
661 let merkle_proof = malicious_tree.proof(i as u32).unwrap();
662 let shard = malicious_shards[i].clone();
663 let chunk = Chunk::new(shard, i as u16, merkle_proof);
664 pieces.push(chunk);
665 }
666
667 let result = decode::<Sha256>(total, min, &malicious_root, pieces);
669 assert!(matches!(result, Err(Error::Inconsistent)));
670 }
671
672 #[test]
673 fn test_odd_shard_len() {
674 let data = b"a";
675 let total = 3u16;
676 let min = 2u16;
677
678 let (root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
680
681 let pieces: Vec<_> = vec![
683 chunks[0].clone(), chunks[2].clone(), ];
686
687 let decoded = decode::<Sha256>(total, min, &root, pieces).unwrap();
689 assert_eq!(decoded, data);
690 }
691
692 #[test]
693 fn test_invalid_index() {
694 let data = b"Testing recovery pieces";
695 let total = 8u16;
696 let min = 3u16;
697
698 let (root, mut chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
700
701 chunks[1].index = 8;
703 let pieces: Vec<_> = vec![
704 chunks[0].clone(), chunks[1].clone(), chunks[6].clone(), ];
708
709 let result = decode::<Sha256>(total, min, &root, pieces);
711 assert!(matches!(result, Err(Error::InvalidIndex(8))));
712 }
713
714 #[test]
715 fn test_max_chunks() {
716 let data = vec![42u8; 1000]; let total = u16::MAX;
718 let min = u16::MAX / 2;
719
720 let (root, chunks) = encode::<Sha256>(total, min, data.clone()).unwrap();
722
723 let minimal = chunks.into_iter().take(min as usize).collect();
725 let decoded = decode::<Sha256>(total, min, &root, minimal).unwrap();
726 assert_eq!(decoded, data);
727 }
728
729 #[test]
730 fn test_too_many_chunks() {
731 let data = vec![42u8; 1000]; let total = u16::MAX;
733 let min = u16::MAX / 2 - 1;
734
735 let result = encode::<Sha256>(total, min, data.clone());
737 assert!(matches!(
738 result,
739 Err(Error::ReedSolomon(
740 reed_solomon_simd::Error::UnsupportedShardCount {
741 original_count: _,
742 recovery_count: _,
743 }
744 ))
745 ));
746 }
747}