1use bytes::{Buf, BufMut};
116use commonware_codec::{EncodeSize, FixedSize, RangeCfg, Read, ReadExt, ReadRangeExt, Write};
117use commonware_cryptography::Hasher;
118use commonware_storage::bmt::{self, Builder};
119use reed_solomon_simd::{Error as RsError, ReedSolomonDecoder, ReedSolomonEncoder};
120use std::collections::HashSet;
121use thiserror::Error;
122
123#[derive(Error, Debug)]
125pub enum Error {
126 #[error("reed-solomon error: {0}")]
127 ReedSolomon(#[from] RsError),
128 #[error("inconsistent")]
129 Inconsistent,
130 #[error("invalid proof")]
131 InvalidProof,
132 #[error("not enough chunks")]
133 NotEnoughChunks,
134 #[error("duplicate chunk index: {0}")]
135 DuplicateIndex(u16),
136 #[error("invalid data length: {0}")]
137 InvalidDataLength(usize),
138 #[error("invalid index: {0}")]
139 InvalidIndex(u16),
140}
141
142#[derive(Clone)]
144pub struct Chunk<H: Hasher> {
145 pub shard: Vec<u8>,
147
148 pub index: u16,
150
151 pub proof: bmt::Proof<H>,
153}
154
155impl<H: Hasher> Chunk<H> {
156 pub fn new(shard: Vec<u8>, index: u16, proof: bmt::Proof<H>) -> Self {
158 Self {
159 shard,
160 index,
161 proof,
162 }
163 }
164
165 pub fn verify(&self, index: u16, root: &H::Digest) -> bool {
167 if index != self.index {
169 return false;
170 }
171
172 let mut hasher = H::new();
174 hasher.update(&self.shard);
175 let shard_digest = hasher.finalize();
176
177 self.proof
179 .verify(&mut hasher, &shard_digest, self.index as u32, root)
180 .is_ok()
181 }
182}
183
184impl<H: Hasher> Write for Chunk<H> {
185 fn write(&self, writer: &mut impl BufMut) {
186 self.shard.write(writer);
187 self.index.write(writer);
188 self.proof.write(writer);
189 }
190}
191
192impl<H: Hasher> Read for Chunk<H> {
193 type Cfg = usize;
195
196 fn read_cfg(reader: &mut impl Buf, cfg: &Self::Cfg) -> Result<Self, commonware_codec::Error> {
197 let shard = Vec::<u8>::read_range(reader, ..=*cfg)?;
198 let index = u16::read(reader)?;
199 let proof = bmt::Proof::<H>::read(reader)?;
200 Ok(Self {
201 shard,
202 index,
203 proof,
204 })
205 }
206}
207
208impl<H: Hasher> EncodeSize for Chunk<H> {
209 fn encode_size(&self) -> usize {
210 self.shard.encode_size() + self.index.encode_size() + self.proof.encode_size()
211 }
212}
213
214#[derive(Clone)]
219pub struct Data<H: Hasher> {
220 pub shards: Vec<Vec<u8>>,
222
223 pub proof: bmt::RangeProof<H>,
225}
226
227impl<H: Hasher> Data<H> {
228 pub fn new(shards: Vec<Vec<u8>>, proof: bmt::RangeProof<H>) -> Self {
230 Self { shards, proof }
231 }
232
233 pub fn verify(&self, min: u16, root: &H::Digest) -> bool {
237 let k = min as usize;
239 if self.shards.len() != k {
240 return false;
241 }
242
243 let mut hasher = H::new();
245 let mut shards = Vec::with_capacity(k);
246 for shard in &self.shards {
247 hasher.update(shard);
248 shards.push(hasher.finalize());
249 }
250
251 self.proof.verify(&mut hasher, 0, &shards, root).is_ok()
253 }
254
255 pub fn extract(self) -> Vec<u8> {
257 let k = self.shards.len();
258 extract_data(self.shards, k)
259 }
260}
261
262impl<H: Hasher> Write for Data<H> {
263 fn write(&self, writer: &mut impl BufMut) {
264 self.shards.write(writer);
265 self.proof.write(writer);
266 }
267}
268
269impl<H: Hasher> Read for Data<H> {
270 type Cfg = usize;
272
273 fn read_cfg(reader: &mut impl Buf, cfg: &Self::Cfg) -> Result<Self, commonware_codec::Error> {
274 let shard_count: RangeCfg = (1..=u16::MAX as usize).into();
275 let shard_size: RangeCfg = (..=*cfg).into();
276 let shards = Vec::<Vec<u8>>::read_cfg(reader, &(shard_count, (shard_size, ())))?;
277 let proof = bmt::RangeProof::<H>::read(reader)?;
278
279 Ok(Self { shards, proof })
280 }
281}
282
283impl<H: Hasher> EncodeSize for Data<H> {
284 fn encode_size(&self) -> usize {
285 self.shards.encode_size() + self.proof.encode_size()
286 }
287}
288
289fn prepare_data(data: Vec<u8>, k: usize, m: usize) -> Vec<Vec<u8>> {
291 let data_len = data.len();
293 let prefixed_len = u32::SIZE + data_len;
294 let mut shard_len = prefixed_len.div_ceil(k);
295
296 if shard_len % 2 != 0 {
298 shard_len += 1;
299 }
300
301 let length_bytes = (data_len as u32).to_be_bytes();
303 let mut src = length_bytes.into_iter().chain(data);
304 let mut shards = Vec::with_capacity(k + m); for _ in 0..k {
306 let mut shard = Vec::with_capacity(shard_len);
307 for _ in 0..shard_len {
308 shard.push(src.next().unwrap_or(0));
309 }
310 shards.push(shard);
311 }
312 shards
313}
314
315fn extract_data(shards: Vec<Vec<u8>>, k: usize) -> Vec<u8> {
317 let mut data = shards.into_iter().take(k).flatten();
319
320 let data_len = (&mut data)
322 .take(u32::SIZE)
323 .collect::<Vec<_>>()
324 .try_into()
325 .expect("insufficient data");
326 let data_len = u32::from_be_bytes(data_len) as usize;
327
328 data.take(data_len).collect()
330}
331
332pub type Encoding<H> = (bmt::Tree<H>, Vec<Vec<u8>>);
334
335pub fn encode_inner<H: Hasher>(total: u16, min: u16, data: Vec<u8>) -> Result<Encoding<H>, Error> {
337 assert!(total > min);
339 assert!(min > 0);
340 let n = total as usize;
341 let k = min as usize;
342 let m = n - k;
343 if data.len() > u32::MAX as usize {
344 return Err(Error::InvalidDataLength(data.len()));
345 }
346
347 let mut shards = prepare_data(data, k, m);
349 let shard_len = shards[0].len();
350
351 let mut encoder = ReedSolomonEncoder::new(k, m, shard_len).map_err(Error::ReedSolomon)?;
353 for shard in &shards {
354 encoder
355 .add_original_shard(shard)
356 .map_err(Error::ReedSolomon)?;
357 }
358
359 let encoding = encoder.encode().map_err(Error::ReedSolomon)?;
361 let recovery_shards: Vec<Vec<u8>> = encoding
362 .recovery_iter()
363 .map(|shard| shard.to_vec())
364 .collect();
365 shards.extend(recovery_shards);
366
367 let mut builder = Builder::<H>::new(n);
369 let mut hasher = H::new();
370 for shard in &shards {
371 builder.add(&{
372 hasher.update(shard);
373 hasher.finalize()
374 });
375 }
376 let tree = builder.build();
377
378 Ok((tree, shards))
379}
380
381pub fn encode<H: Hasher>(
394 total: u16,
395 min: u16,
396 data: Vec<u8>,
397) -> Result<(H::Digest, Vec<Chunk<H>>), Error> {
398 let (tree, shards) = encode_inner::<H>(total, min, data)?;
400 let root = tree.root();
401 let n = total as usize;
402
403 let mut chunks = Vec::with_capacity(n);
405 for (i, shard) in shards.into_iter().enumerate() {
406 let proof = tree.proof(i as u32).map_err(|_| Error::InvalidProof)?;
407 chunks.push(Chunk::new(shard, i as u16, proof));
408 }
409
410 Ok((root, chunks))
411}
412
413pub fn translate<H: Hasher>(
426 total: u16,
427 min: u16,
428 data: Vec<u8>,
429) -> Result<(H::Digest, Data<H>), Error> {
430 let (tree, mut shards) = encode_inner::<H>(total, min, data)?;
432 let root = tree.root();
433 let k = min as usize;
434
435 shards.truncate(k);
437
438 let proof = tree
440 .range_proof(0, (min - 1) as u32)
441 .map_err(|_| Error::InvalidProof)?;
442 let proof = Data::new(shards, proof);
443
444 Ok((root, proof))
445}
446
447pub fn decode<H: Hasher>(
460 total: u16,
461 min: u16,
462 root: &H::Digest,
463 chunks: Vec<Chunk<H>>,
464) -> Result<Vec<u8>, Error> {
465 assert!(total > min);
467 assert!(min > 0);
468 let n = total as usize;
469 let k = min as usize;
470 let m = n - k;
471 if chunks.len() < k {
472 return Err(Error::NotEnoughChunks);
473 }
474
475 let shard_len = chunks[0].shard.len();
477 let mut seen = HashSet::new();
478 let mut provided_originals: Vec<(usize, Vec<u8>)> = Vec::new();
479 let mut provided_recoveries: Vec<(usize, Vec<u8>)> = Vec::new();
480 for chunk in chunks {
481 let index = chunk.index;
483 if index >= total {
484 return Err(Error::InvalidIndex(index));
485 }
486 if seen.contains(&index) {
487 return Err(Error::DuplicateIndex(index));
488 }
489 seen.insert(index);
490
491 if !chunk.verify(chunk.index, root) {
493 return Err(Error::InvalidProof);
494 }
495
496 if index < min {
498 provided_originals.push((index as usize, chunk.shard));
499 } else {
500 provided_recoveries.push((index as usize - k, chunk.shard));
501 }
502 }
503
504 let mut decoder = ReedSolomonDecoder::new(k, m, shard_len).map_err(Error::ReedSolomon)?;
506 for (idx, ref shard) in &provided_originals {
507 decoder
508 .add_original_shard(*idx, shard)
509 .map_err(Error::ReedSolomon)?;
510 }
511 for (idx, ref shard) in &provided_recoveries {
512 decoder
513 .add_recovery_shard(*idx, shard)
514 .map_err(Error::ReedSolomon)?;
515 }
516 let decoding = decoder.decode().map_err(Error::ReedSolomon)?;
517
518 let mut shards = Vec::with_capacity(n);
520 shards.resize(k, Vec::new());
521 for (idx, shard) in provided_originals {
522 shards[idx] = shard;
523 }
524 for (idx, shard) in decoding.restored_original_iter() {
525 shards[idx] = shard.to_vec();
526 }
527
528 let mut encoder = ReedSolomonEncoder::new(k, m, shard_len).map_err(Error::ReedSolomon)?;
530 for shard in shards.iter().take(k) {
531 encoder
532 .add_original_shard(shard)
533 .map_err(Error::ReedSolomon)?;
534 }
535 let encoding = encoder.encode().map_err(Error::ReedSolomon)?;
536 let recovery_shards: Vec<Vec<u8>> = encoding
537 .recovery_iter()
538 .map(|shard| shard.to_vec())
539 .collect();
540 shards.extend(recovery_shards);
541
542 let mut builder = Builder::<H>::new(n);
544 let mut hasher = H::new();
545 for shard in &shards {
546 builder.add(&{
547 hasher.update(shard);
548 hasher.finalize()
549 });
550 }
551 let computed_tree = builder.build();
552
553 if computed_tree.root() != *root {
555 return Err(Error::Inconsistent);
556 }
557
558 Ok(extract_data(shards, k))
560}
561
562#[cfg(test)]
563mod tests {
564 use super::*;
565 use commonware_codec::{Decode, Encode};
566 use commonware_cryptography::Sha256;
567
568 #[test]
569 fn test_basic() {
570 let data = b"Hello, Reed-Solomon!";
571 let total = 7u16;
572 let min = 4u16;
573
574 let (root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
576 assert_eq!(chunks.len(), total as usize);
577
578 for i in 0..total {
580 assert!(chunks[i as usize].verify(i, &root));
581 }
582
583 let minimal = chunks.into_iter().take(min as usize).collect();
585 let decoded = decode::<Sha256>(total, min, &root, minimal).unwrap();
586 assert_eq!(decoded, data);
587 }
588
589 #[test]
590 fn test_moderate() {
591 let data = b"Testing with more pieces than minimum";
592 let total = 10u16;
593 let min = 4u16;
594
595 let (root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
597
598 let minimal = chunks.into_iter().take(min as usize).collect();
600 let decoded = decode::<Sha256>(total, min, &root, minimal).unwrap();
601 assert_eq!(decoded, data);
602 }
603
604 #[test]
605 fn test_recovery() {
606 let data = b"Testing recovery pieces";
607 let total = 8u16;
608 let min = 3u16;
609
610 let (root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
612
613 let pieces: Vec<_> = vec![
615 chunks[0].clone(), chunks[4].clone(), chunks[6].clone(), ];
619
620 let decoded = decode::<Sha256>(total, min, &root, pieces).unwrap();
622 assert_eq!(decoded, data);
623 }
624
625 #[test]
626 fn test_not_enough_pieces() {
627 let data = b"Test insufficient pieces";
628 let total = 6u16;
629 let min = 4u16;
630
631 let (root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
633
634 let pieces: Vec<_> = chunks.into_iter().take(2).collect();
636
637 let result = decode::<Sha256>(total, min, &root, pieces);
639 assert!(matches!(result, Err(Error::NotEnoughChunks)));
640 }
641
642 #[test]
643 fn test_duplicate_index() {
644 let data = b"Test duplicate detection";
645 let total = 5u16;
646 let min = 3u16;
647
648 let (root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
650
651 let pieces = vec![chunks[0].clone(), chunks[0].clone(), chunks[1].clone()];
653
654 let result = decode::<Sha256>(total, min, &root, pieces);
656 assert!(matches!(result, Err(Error::DuplicateIndex(0))));
657 }
658
659 #[test]
660 fn test_invalid_index() {
661 let data = b"Test invalid index";
662 let total = 5u16;
663 let min = 3u16;
664
665 let (root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
667
668 for i in 0..total {
670 assert!(!chunks[i as usize].verify(i + 1, &root));
671 }
672 }
673
674 #[test]
675 #[should_panic(expected = "assertion failed: total > min")]
676 fn test_invalid_total() {
677 let data = b"Test parameter validation";
678
679 encode::<Sha256>(3, 3, data.to_vec()).unwrap();
681 }
682
683 #[test]
684 #[should_panic(expected = "assertion failed: min > 0")]
685 fn test_invalid_min() {
686 let data = b"Test parameter validation";
687
688 encode::<Sha256>(5, 0, data.to_vec()).unwrap();
690 }
691
692 #[test]
693 fn test_empty_data() {
694 let data = b"";
695 let total = 100u16;
696 let min = 30u16;
697
698 let (root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
700
701 let minimal = chunks.into_iter().take(min as usize).collect();
703 let decoded = decode::<Sha256>(total, min, &root, minimal).unwrap();
704 assert_eq!(decoded, data);
705 }
706
707 #[test]
708 fn test_large_data() {
709 let data = vec![42u8; 1000]; let total = 7u16;
711 let min = 4u16;
712
713 let (root, chunks) = encode::<Sha256>(total, min, data.clone()).unwrap();
715
716 let minimal = chunks.into_iter().take(min as usize).collect();
718 let decoded = decode::<Sha256>(total, min, &root, minimal).unwrap();
719 assert_eq!(decoded, data);
720 }
721
722 #[test]
723 fn test_malicious_root_detection() {
724 let data = b"Original data that should be protected";
725 let total = 7u16;
726 let min = 4u16;
727
728 let (_correct_root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
730
731 let mut hasher = Sha256::new();
733 hasher.update(b"malicious_data_that_wasnt_actually_encoded");
734 let malicious_root = hasher.finalize();
735
736 for i in 0..total {
738 assert!(!chunks[i as usize].verify(i, &malicious_root));
739 }
740
741 let minimal = chunks.into_iter().take(min as usize).collect();
743
744 let result = decode::<Sha256>(total, min, &malicious_root, minimal);
746 assert!(matches!(result, Err(Error::InvalidProof)));
747 }
748
749 #[test]
750 fn test_manipulated_chunk_detection() {
751 let data = b"Data integrity must be maintained";
752 let total = 6u16;
753 let min = 3u16;
754
755 let (root, mut chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
757
758 if !chunks[1].shard.is_empty() {
760 chunks[1].shard[0] ^= 0xFF; }
762
763 let result = decode::<Sha256>(total, min, &root, chunks);
765 assert!(matches!(result, Err(Error::InvalidProof)));
766 }
767
768 #[test]
769 fn test_inconsistent_shards() {
770 let data = b"Test data for malicious encoding";
771 let total = 5u16;
772 let min = 3u16;
773 let m = total - min;
774
775 let shards = prepare_data(data.to_vec(), min as usize, total as usize - min as usize);
777 let shard_size = shards[0].len();
778
779 let mut encoder = ReedSolomonEncoder::new(min as usize, m as usize, shard_size).unwrap();
781 for shard in &shards {
782 encoder.add_original_shard(shard).unwrap();
783 }
784 let recovery_result = encoder.encode().unwrap();
785 let mut recovery_shards: Vec<Vec<u8>> = recovery_result
786 .recovery_iter()
787 .map(|s| s.to_vec())
788 .collect();
789
790 if !recovery_shards[0].is_empty() {
792 recovery_shards[0][0] ^= 0xFF;
793 }
794
795 let mut malicious_shards = shards.clone();
797 malicious_shards.extend(recovery_shards);
798
799 let mut builder = Builder::<Sha256>::new(total as usize);
801 for shard in &malicious_shards {
802 let mut hasher = Sha256::new();
803 hasher.update(shard);
804 builder.add(&hasher.finalize());
805 }
806 let malicious_tree = builder.build();
807 let malicious_root = malicious_tree.root();
808
809 let selected_indices = vec![0, 1, 3]; let mut pieces = Vec::new();
812 for &i in &selected_indices {
813 let merkle_proof = malicious_tree.proof(i as u32).unwrap();
814 let shard = malicious_shards[i].clone();
815 let chunk = Chunk::new(shard, i as u16, merkle_proof);
816 pieces.push(chunk);
817 }
818
819 let result = decode::<Sha256>(total, min, &malicious_root, pieces);
821 assert!(matches!(result, Err(Error::Inconsistent)));
822 }
823
824 #[test]
825 fn test_odd_shard_len() {
826 let data = b"a";
827 let total = 3u16;
828 let min = 2u16;
829
830 let (root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
832
833 let pieces: Vec<_> = vec![
835 chunks[0].clone(), chunks[2].clone(), ];
838
839 let decoded = decode::<Sha256>(total, min, &root, pieces).unwrap();
841 assert_eq!(decoded, data);
842 }
843
844 #[test]
845 fn test_decode_invalid_index() {
846 let data = b"Testing recovery pieces";
847 let total = 8u16;
848 let min = 3u16;
849
850 let (root, mut chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
852
853 chunks[1].index = 8;
855 let pieces: Vec<_> = vec![
856 chunks[0].clone(), chunks[1].clone(), chunks[6].clone(), ];
860
861 let result = decode::<Sha256>(total, min, &root, pieces);
863 assert!(matches!(result, Err(Error::InvalidIndex(8))));
864 }
865
866 #[test]
867 fn test_max_chunks() {
868 let data = vec![42u8; 1000]; let total = u16::MAX;
870 let min = u16::MAX / 2;
871
872 let (root, chunks) = encode::<Sha256>(total, min, data.clone()).unwrap();
874
875 let minimal = chunks.into_iter().take(min as usize).collect();
877 let decoded = decode::<Sha256>(total, min, &root, minimal).unwrap();
878 assert_eq!(decoded, data);
879 }
880
881 #[test]
882 fn test_too_many_chunks() {
883 let data = vec![42u8; 1000]; let total = u16::MAX;
885 let min = u16::MAX / 2 - 1;
886
887 let result = encode::<Sha256>(total, min, data.clone());
889 assert!(matches!(
890 result,
891 Err(Error::ReedSolomon(
892 reed_solomon_simd::Error::UnsupportedShardCount {
893 original_count: _,
894 recovery_count: _,
895 }
896 ))
897 ));
898 }
899
900 #[test]
901 fn test_translate() {
902 let data = b"Test data for minimal proof functionality";
903 let total = 7u16;
904 let min = 4u16;
905
906 let (root, proof) = translate::<Sha256>(total, min, data.to_vec()).unwrap();
908
909 assert!(proof.verify(min, &root));
911
912 let extracted = proof.extract();
914 assert_eq!(extracted, data);
915 }
916
917 #[test]
918 fn test_translate_wrong_root() {
919 let data = b"Test data for minimal proof";
920 let total = 5u16;
921 let min = 3u16;
922
923 let (_, proof) = translate::<Sha256>(total, min, data.to_vec()).unwrap();
925
926 let mut hasher = Sha256::new();
928 hasher.update(b"wrong_root");
929 let wrong_root = hasher.finalize();
930
931 assert!(!proof.verify(min, &wrong_root));
932 }
933
934 #[test]
935 fn test_translate_serialization() {
936 let data = b"Test serialization of minimal proof";
937 let total = 5u16;
938 let min = 3u16;
939
940 let (root, proof) = translate::<Sha256>(total, min, data.to_vec()).unwrap();
942
943 let serialized = proof.encode();
945
946 let max_shard_size = proof.shards[0].len();
948 let deserialized = Data::<Sha256>::decode_cfg(serialized, &max_shard_size).unwrap();
949
950 assert!(deserialized.verify(min, &root));
952 assert_eq!(deserialized.extract(), data);
953 }
954}