1use crate::{Config, Scheme};
2use bytes::{Buf, BufMut};
3use commonware_codec::{EncodeSize, FixedSize, Read, ReadExt, ReadRangeExt, Write};
4use commonware_cryptography::{Digest, Hasher};
5use commonware_parallel::Strategy;
6use commonware_storage::bmt::{self, Builder};
7use reed_solomon_simd::{Error as RsError, ReedSolomonDecoder, ReedSolomonEncoder};
8use std::{collections::HashSet, marker::PhantomData};
9use thiserror::Error;
10
11#[derive(Error, Debug)]
13pub enum Error {
14 #[error("reed-solomon error: {0}")]
15 ReedSolomon(#[from] RsError),
16 #[error("inconsistent")]
17 Inconsistent,
18 #[error("invalid proof")]
19 InvalidProof,
20 #[error("not enough chunks")]
21 NotEnoughChunks,
22 #[error("duplicate chunk index: {0}")]
23 DuplicateIndex(u16),
24 #[error("invalid data length: {0}")]
25 InvalidDataLength(usize),
26 #[error("invalid index: {0}")]
27 InvalidIndex(u16),
28 #[error("wrong index: {0}")]
29 WrongIndex(u16),
30 #[error("too many total shards: {0}")]
31 TooManyTotalShards(u32),
32}
33
34fn total_shards(config: &Config) -> Result<u16, Error> {
35 let total = config.total_shards();
36 total
37 .try_into()
38 .map_err(|_| Error::TooManyTotalShards(total))
39}
40
41#[derive(Debug, Clone)]
43pub struct Chunk<D: Digest> {
44 shard: Vec<u8>,
46
47 index: u16,
49
50 proof: bmt::Proof<D>,
52}
53
54impl<D: Digest> Chunk<D> {
55 const fn new(shard: Vec<u8>, index: u16, proof: bmt::Proof<D>) -> Self {
57 Self {
58 shard,
59 index,
60 proof,
61 }
62 }
63
64 fn verify<H: Hasher<Digest = D>>(&self, index: u16, root: &D) -> bool {
66 if index != self.index {
68 return false;
69 }
70
71 let mut hasher = H::new();
73 hasher.update(&self.shard);
74 let shard_digest = hasher.finalize();
75
76 self.proof
78 .verify_element_inclusion(&mut hasher, &shard_digest, self.index as u32, root)
79 .is_ok()
80 }
81}
82
83impl<D: Digest> Write for Chunk<D> {
84 fn write(&self, writer: &mut impl BufMut) {
85 self.shard.write(writer);
86 self.index.write(writer);
87 self.proof.write(writer);
88 }
89}
90
91impl<D: Digest> Read for Chunk<D> {
92 type Cfg = crate::CodecConfig;
94
95 fn read_cfg(reader: &mut impl Buf, cfg: &Self::Cfg) -> Result<Self, commonware_codec::Error> {
96 let shard = Vec::<u8>::read_range(reader, ..=cfg.maximum_shard_size)?;
97 let index = u16::read(reader)?;
98 let proof = bmt::Proof::<D>::read_cfg(reader, &1)?;
99 Ok(Self {
100 shard,
101 index,
102 proof,
103 })
104 }
105}
106
107impl<D: Digest> EncodeSize for Chunk<D> {
108 fn encode_size(&self) -> usize {
109 self.shard.encode_size() + self.index.encode_size() + self.proof.encode_size()
110 }
111}
112
113impl<D: Digest> PartialEq for Chunk<D> {
114 fn eq(&self, other: &Self) -> bool {
115 self.shard == other.shard && self.index == other.index && self.proof == other.proof
116 }
117}
118
119impl<D: Digest> Eq for Chunk<D> {}
120
121#[cfg(feature = "arbitrary")]
122impl<D: Digest> arbitrary::Arbitrary<'_> for Chunk<D>
123where
124 D: for<'a> arbitrary::Arbitrary<'a>,
125{
126 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
127 Ok(Self {
128 shard: u.arbitrary()?,
129 index: u.arbitrary()?,
130 proof: u.arbitrary()?,
131 })
132 }
133}
134
135fn prepare_data(data: Vec<u8>, k: usize, m: usize) -> Vec<Vec<u8>> {
137 let data_len = data.len();
139 let prefixed_len = u32::SIZE + data_len;
140 let mut shard_len = prefixed_len.div_ceil(k);
141
142 if !shard_len.is_multiple_of(2) {
144 shard_len += 1;
145 }
146
147 let length_bytes = (data_len as u32).to_be_bytes();
149 let mut padded = vec![0u8; k * shard_len];
150 padded[..u32::SIZE].copy_from_slice(&length_bytes);
151 padded[u32::SIZE..u32::SIZE + data_len].copy_from_slice(&data);
152
153 let mut shards = Vec::with_capacity(k + m); for chunk in padded.chunks(shard_len) {
155 shards.push(chunk.to_vec());
156 }
157 shards
158}
159
160fn extract_data(shards: Vec<&[u8]>, k: usize) -> Vec<u8> {
162 let mut data = shards.into_iter().take(k).flatten();
164
165 let data_len = (&mut data)
167 .take(u32::SIZE)
168 .copied()
169 .collect::<Vec<_>>()
170 .try_into()
171 .expect("insufficient data");
172 let data_len = u32::from_be_bytes(data_len) as usize;
173
174 data.take(data_len).copied().collect()
176}
177
178type Encoding<D> = (bmt::Tree<D>, Vec<Vec<u8>>);
180
181fn encode_inner<H: Hasher, S: Strategy>(
183 total: u16,
184 min: u16,
185 data: Vec<u8>,
186 strategy: &S,
187) -> Result<Encoding<H::Digest>, Error> {
188 assert!(total > min);
190 assert!(min > 0);
191 let n = total as usize;
192 let k = min as usize;
193 let m = n - k;
194 if data.len() > u32::MAX as usize {
195 return Err(Error::InvalidDataLength(data.len()));
196 }
197
198 let mut shards = prepare_data(data, k, m);
200 let shard_len = shards[0].len();
201
202 let mut encoder = ReedSolomonEncoder::new(k, m, shard_len).map_err(Error::ReedSolomon)?;
204 for shard in &shards {
205 encoder
206 .add_original_shard(shard)
207 .map_err(Error::ReedSolomon)?;
208 }
209
210 let encoding = encoder.encode().map_err(Error::ReedSolomon)?;
212 let recovery_shards: Vec<Vec<u8>> = encoding
213 .recovery_iter()
214 .map(|shard| shard.to_vec())
215 .collect();
216 shards.extend(recovery_shards);
217
218 let mut builder = Builder::<H>::new(n);
220 let shard_hashes = strategy.map_init_collect_vec(&shards, H::new, |hasher, shard| {
221 hasher.update(shard);
222 hasher.finalize()
223 });
224 for hash in &shard_hashes {
225 builder.add(hash);
226 }
227 let tree = builder.build();
228
229 Ok((tree, shards))
230}
231
232#[allow(clippy::type_complexity)]
246fn encode<H: Hasher, S: Strategy>(
247 total: u16,
248 min: u16,
249 data: Vec<u8>,
250 strategy: &S,
251) -> Result<(H::Digest, Vec<Chunk<H::Digest>>), Error> {
252 let (tree, shards) = encode_inner::<H, _>(total, min, data, strategy)?;
254 let root = tree.root();
255 let n = total as usize;
256
257 let mut chunks = Vec::with_capacity(n);
259 for (i, shard) in shards.into_iter().enumerate() {
260 let proof = tree.proof(i as u32).map_err(|_| Error::InvalidProof)?;
261 chunks.push(Chunk::new(shard, i as u16, proof));
262 }
263
264 Ok((root, chunks))
265}
266
267fn decode<H: Hasher, S: Strategy>(
283 total: u16,
284 min: u16,
285 root: &H::Digest,
286 chunks: &[Chunk<H::Digest>],
287 strategy: &S,
288) -> Result<Vec<u8>, Error> {
289 assert!(total > min);
291 assert!(min > 0);
292 let n = total as usize;
293 let k = min as usize;
294 let m = n - k;
295 if chunks.len() < k {
296 return Err(Error::NotEnoughChunks);
297 }
298
299 let shard_len = chunks[0].shard.len();
301 let mut seen = HashSet::new();
302 let mut provided_originals: Vec<(usize, &[u8])> = Vec::new();
303 let mut provided_recoveries: Vec<(usize, &[u8])> = Vec::new();
304 for chunk in chunks {
305 let index = chunk.index;
307 if index >= total {
308 return Err(Error::InvalidIndex(index));
309 }
310 if seen.contains(&index) {
311 return Err(Error::DuplicateIndex(index));
312 }
313 seen.insert(index);
314
315 if index < min {
317 provided_originals.push((index as usize, chunk.shard.as_slice()));
318 } else {
319 provided_recoveries.push((index as usize - k, chunk.shard.as_slice()));
320 }
321 }
322
323 let mut decoder = ReedSolomonDecoder::new(k, m, shard_len).map_err(Error::ReedSolomon)?;
325 for (idx, ref shard) in &provided_originals {
326 decoder
327 .add_original_shard(*idx, shard)
328 .map_err(Error::ReedSolomon)?;
329 }
330 for (idx, ref shard) in &provided_recoveries {
331 decoder
332 .add_recovery_shard(*idx, shard)
333 .map_err(Error::ReedSolomon)?;
334 }
335 let decoding = decoder.decode().map_err(Error::ReedSolomon)?;
336
337 let mut shards = Vec::with_capacity(n);
339 shards.resize(k, Default::default());
340 for (idx, shard) in provided_originals {
341 shards[idx] = shard;
342 }
343 for (idx, shard) in decoding.restored_original_iter() {
344 shards[idx] = shard;
345 }
346
347 let mut encoder = ReedSolomonEncoder::new(k, m, shard_len).map_err(Error::ReedSolomon)?;
349 for shard in shards.iter().take(k) {
350 encoder
351 .add_original_shard(shard)
352 .map_err(Error::ReedSolomon)?;
353 }
354 let encoding = encoder.encode().map_err(Error::ReedSolomon)?;
355 let recovery_shards: Vec<&[u8]> = encoding.recovery_iter().collect();
356 shards.extend(recovery_shards);
357
358 let mut builder = Builder::<H>::new(n);
360 let shard_hashes = strategy.map_init_collect_vec(&shards, H::new, |hasher, shard| {
361 hasher.update(shard);
362 hasher.finalize()
363 });
364 for hash in &shard_hashes {
365 builder.add(hash);
366 }
367 let tree = builder.build();
368
369 if tree.root() != *root {
371 return Err(Error::Inconsistent);
372 }
373
374 Ok(extract_data(shards, k))
376}
377
378#[derive(Clone, Copy)]
461pub struct ReedSolomon<H> {
462 _marker: PhantomData<H>,
463}
464
465impl<H> std::fmt::Debug for ReedSolomon<H> {
466 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
467 f.debug_struct("ReedSolomon").finish()
468 }
469}
470
471impl<H: Hasher> Scheme for ReedSolomon<H> {
472 type Commitment = H::Digest;
473
474 type Shard = Chunk<H::Digest>;
475 type ReShard = Chunk<H::Digest>;
476 type CheckedShard = Chunk<H::Digest>;
477 type CheckingData = ();
478
479 type Error = Error;
480
481 fn encode(
482 config: &Config,
483 mut data: impl Buf,
484 strategy: &impl Strategy,
485 ) -> Result<(Self::Commitment, Vec<Self::Shard>), Self::Error> {
486 let data: Vec<u8> = data.copy_to_bytes(data.remaining()).to_vec();
487 encode::<H, _>(total_shards(config)?, config.minimum_shards, data, strategy)
488 }
489
490 fn reshard(
491 _config: &Config,
492 commitment: &Self::Commitment,
493 index: u16,
494 shard: Self::Shard,
495 ) -> Result<(Self::CheckingData, Self::CheckedShard, Self::ReShard), Self::Error> {
496 if shard.index != index {
497 return Err(Error::WrongIndex(index));
498 }
499 if shard.verify::<H>(shard.index, commitment) {
500 Ok(((), shard.clone(), shard))
501 } else {
502 Err(Error::InvalidProof)
503 }
504 }
505
506 fn check(
507 _config: &Config,
508 commitment: &Self::Commitment,
509 _checking_data: &Self::CheckingData,
510 index: u16,
511 reshard: Self::ReShard,
512 ) -> Result<Self::CheckedShard, Self::Error> {
513 if reshard.index != index {
514 return Err(Error::WrongIndex(reshard.index));
515 }
516 if !reshard.verify::<H>(reshard.index, commitment) {
517 return Err(Error::InvalidProof);
518 }
519 Ok(reshard)
520 }
521
522 fn decode(
523 config: &Config,
524 commitment: &Self::Commitment,
525 _checking_data: Self::CheckingData,
526 shards: &[Self::CheckedShard],
527 strategy: &impl Strategy,
528 ) -> Result<Vec<u8>, Self::Error> {
529 decode::<H, _>(
530 total_shards(config)?,
531 config.minimum_shards,
532 commitment,
533 shards,
534 strategy,
535 )
536 }
537}
538
539#[cfg(test)]
540mod tests {
541 use super::*;
542 use commonware_cryptography::Sha256;
543 use commonware_parallel::Sequential;
544
545 const STRATEGY: Sequential = Sequential;
546
547 #[test]
548 fn test_recovery() {
549 let data = b"Testing recovery pieces";
550 let total = 8u16;
551 let min = 3u16;
552
553 let (root, chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
555
556 let pieces: Vec<_> = vec![
558 chunks[0].clone(), chunks[4].clone(), chunks[6].clone(), ];
562
563 let decoded = decode::<Sha256, _>(total, min, &root, &pieces, &STRATEGY).unwrap();
565 assert_eq!(decoded, data);
566 }
567
568 #[test]
569 fn test_not_enough_pieces() {
570 let data = b"Test insufficient pieces";
571 let total = 6u16;
572 let min = 4u16;
573
574 let (root, chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
576
577 let pieces: Vec<_> = chunks.into_iter().take(2).collect();
579
580 let result = decode::<Sha256, _>(total, min, &root, &pieces, &STRATEGY);
582 assert!(matches!(result, Err(Error::NotEnoughChunks)));
583 }
584
585 #[test]
586 fn test_duplicate_index() {
587 let data = b"Test duplicate detection";
588 let total = 5u16;
589 let min = 3u16;
590
591 let (root, chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
593
594 let pieces = vec![chunks[0].clone(), chunks[0].clone(), chunks[1].clone()];
596
597 let result = decode::<Sha256, _>(total, min, &root, &pieces, &STRATEGY);
599 assert!(matches!(result, Err(Error::DuplicateIndex(0))));
600 }
601
602 #[test]
603 fn test_invalid_index() {
604 let data = b"Test invalid index";
605 let total = 5u16;
606 let min = 3u16;
607
608 let (root, chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
610
611 for i in 0..total {
613 assert!(!chunks[i as usize].verify::<Sha256>(i + 1, &root));
614 }
615 }
616
617 #[test]
618 #[should_panic(expected = "assertion failed: total > min")]
619 fn test_invalid_total() {
620 let data = b"Test parameter validation";
621
622 encode::<Sha256, _>(3, 3, data.to_vec(), &STRATEGY).unwrap();
624 }
625
626 #[test]
627 #[should_panic(expected = "assertion failed: min > 0")]
628 fn test_invalid_min() {
629 let data = b"Test parameter validation";
630
631 encode::<Sha256, _>(5, 0, data.to_vec(), &STRATEGY).unwrap();
633 }
634
635 #[test]
636 fn test_empty_data() {
637 let data = b"";
638 let total = 100u16;
639 let min = 30u16;
640
641 let (root, chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
643
644 let minimal = chunks.into_iter().take(min as usize).collect::<Vec<_>>();
646 let decoded = decode::<Sha256, _>(total, min, &root, &minimal, &STRATEGY).unwrap();
647 assert_eq!(decoded, data);
648 }
649
650 #[test]
651 fn test_large_data() {
652 let data = vec![42u8; 1000]; let total = 7u16;
654 let min = 4u16;
655
656 let (root, chunks) = encode::<Sha256, _>(total, min, data.clone(), &STRATEGY).unwrap();
658
659 let minimal = chunks.into_iter().take(min as usize).collect::<Vec<_>>();
661 let decoded = decode::<Sha256, _>(total, min, &root, &minimal, &STRATEGY).unwrap();
662 assert_eq!(decoded, data);
663 }
664
665 #[test]
666 fn test_malicious_root_detection() {
667 let data = b"Original data that should be protected";
668 let total = 7u16;
669 let min = 4u16;
670
671 let (_correct_root, chunks) =
673 encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
674
675 let mut hasher = Sha256::new();
677 hasher.update(b"malicious_data_that_wasnt_actually_encoded");
678 let malicious_root = hasher.finalize();
679
680 for i in 0..total {
682 assert!(!chunks[i as usize].verify::<Sha256>(i, &malicious_root));
683 }
684
685 let minimal = chunks.into_iter().take(min as usize).collect::<Vec<_>>();
687
688 let result = decode::<Sha256, _>(total, min, &malicious_root, &minimal, &STRATEGY);
690 assert!(matches!(result, Err(Error::Inconsistent)));
691 }
692
693 #[test]
694 fn test_manipulated_chunk_detection() {
695 let data = b"Data integrity must be maintained";
696 let total = 6u16;
697 let min = 3u16;
698
699 let (root, mut chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
701
702 if !chunks[1].shard.is_empty() {
704 chunks[1].shard[0] ^= 0xFF; }
706
707 let result = decode::<Sha256, _>(total, min, &root, &chunks, &STRATEGY);
709 assert!(matches!(result, Err(Error::Inconsistent)));
710 }
711
712 #[test]
713 fn test_inconsistent_shards() {
714 let data = b"Test data for malicious encoding";
715 let total = 5u16;
716 let min = 3u16;
717 let m = total - min;
718
719 let shards = prepare_data(data.to_vec(), min as usize, total as usize - min as usize);
721 let shard_size = shards[0].len();
722
723 let mut encoder = ReedSolomonEncoder::new(min as usize, m as usize, shard_size).unwrap();
725 for shard in &shards {
726 encoder.add_original_shard(shard).unwrap();
727 }
728 let recovery_result = encoder.encode().unwrap();
729 let mut recovery_shards: Vec<Vec<u8>> = recovery_result
730 .recovery_iter()
731 .map(|s| s.to_vec())
732 .collect();
733
734 if !recovery_shards[0].is_empty() {
736 recovery_shards[0][0] ^= 0xFF;
737 }
738
739 let mut malicious_shards = shards.clone();
741 malicious_shards.extend(recovery_shards);
742
743 let mut builder = Builder::<Sha256>::new(total as usize);
745 for shard in &malicious_shards {
746 let mut hasher = Sha256::new();
747 hasher.update(shard);
748 builder.add(&hasher.finalize());
749 }
750 let malicious_tree = builder.build();
751 let malicious_root = malicious_tree.root();
752
753 let selected_indices = vec![0, 1, 3]; let mut pieces = Vec::new();
756 for &i in &selected_indices {
757 let merkle_proof = malicious_tree.proof(i as u32).unwrap();
758 let shard = malicious_shards[i].clone();
759 let chunk = Chunk::new(shard, i as u16, merkle_proof);
760 pieces.push(chunk);
761 }
762
763 let result = decode::<Sha256, _>(total, min, &malicious_root, &pieces, &STRATEGY);
765 assert!(matches!(result, Err(Error::Inconsistent)));
766 }
767
768 #[test]
769 fn test_decode_invalid_index() {
770 let data = b"Testing recovery pieces";
771 let total = 8u16;
772 let min = 3u16;
773
774 let (root, mut chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
776
777 chunks[1].index = 8;
779 let pieces: Vec<_> = vec![
780 chunks[0].clone(), chunks[1].clone(), chunks[6].clone(), ];
784
785 let result = decode::<Sha256, _>(total, min, &root, &pieces, &STRATEGY);
787 assert!(matches!(result, Err(Error::InvalidIndex(8))));
788 }
789
790 #[test]
791 fn test_max_chunks() {
792 let data = vec![42u8; 1000]; let total = u16::MAX;
794 let min = u16::MAX / 2;
795
796 let (root, chunks) = encode::<Sha256, _>(total, min, data.clone(), &STRATEGY).unwrap();
798
799 let minimal = chunks.into_iter().take(min as usize).collect::<Vec<_>>();
801 let decoded = decode::<Sha256, _>(total, min, &root, &minimal, &STRATEGY).unwrap();
802 assert_eq!(decoded, data);
803 }
804
805 #[test]
806 fn test_too_many_chunks() {
807 let data = vec![42u8; 1000]; let total = u16::MAX;
809 let min = u16::MAX / 2 - 1;
810
811 let result = encode::<Sha256, _>(total, min, data, &STRATEGY);
813 assert!(matches!(
814 result,
815 Err(Error::ReedSolomon(
816 reed_solomon_simd::Error::UnsupportedShardCount {
817 original_count: _,
818 recovery_count: _,
819 }
820 ))
821 ));
822 }
823
824 #[test]
825 fn test_too_many_total_shards() {
826 assert!(ReedSolomon::<Sha256>::encode(
827 &Config {
828 minimum_shards: u16::MAX / 2 + 1,
829 extra_shards: u16::MAX,
830 },
831 [].as_slice(),
832 &STRATEGY,
833 )
834 .is_err())
835 }
836
837 #[cfg(feature = "arbitrary")]
838 mod conformance {
839 use super::*;
840 use commonware_codec::conformance::CodecConformance;
841 use commonware_cryptography::sha256::Digest as Sha256Digest;
842
843 commonware_conformance::conformance_tests! {
844 CodecConformance<Chunk<Sha256Digest>>,
845 }
846 }
847}