1use crate::{Config, Scheme};
2use bytes::{Buf, BufMut};
3use commonware_codec::{EncodeSize, FixedSize, Read, ReadExt, ReadRangeExt, Write};
4use commonware_cryptography::Hasher;
5use commonware_parallel::Strategy;
6use commonware_storage::bmt::{self, Builder};
7use reed_solomon_simd::{Error as RsError, ReedSolomonDecoder, ReedSolomonEncoder};
8use std::{collections::HashSet, marker::PhantomData};
9use thiserror::Error;
10
11#[derive(Error, Debug)]
13pub enum Error {
14 #[error("reed-solomon error: {0}")]
15 ReedSolomon(#[from] RsError),
16 #[error("inconsistent")]
17 Inconsistent,
18 #[error("invalid proof")]
19 InvalidProof,
20 #[error("not enough chunks")]
21 NotEnoughChunks,
22 #[error("duplicate chunk index: {0}")]
23 DuplicateIndex(u16),
24 #[error("invalid data length: {0}")]
25 InvalidDataLength(usize),
26 #[error("invalid index: {0}")]
27 InvalidIndex(u16),
28 #[error("wrong index: {0}")]
29 WrongIndex(u16),
30 #[error("too many total shards: {0}")]
31 TooManyTotalShards(u32),
32}
33
34fn total_shards(config: &Config) -> Result<u16, Error> {
35 let total = config.total_shards();
36 total
37 .try_into()
38 .map_err(|_| Error::TooManyTotalShards(total))
39}
40
41#[derive(Debug, Clone)]
43pub struct Chunk<H: Hasher> {
44 shard: Vec<u8>,
46
47 index: u16,
49
50 proof: bmt::Proof<H::Digest>,
52}
53
54impl<H: Hasher> Chunk<H> {
55 const fn new(shard: Vec<u8>, index: u16, proof: bmt::Proof<H::Digest>) -> Self {
57 Self {
58 shard,
59 index,
60 proof,
61 }
62 }
63
64 fn verify(&self, index: u16, root: &H::Digest) -> bool {
66 if index != self.index {
68 return false;
69 }
70
71 let mut hasher = H::new();
73 hasher.update(&self.shard);
74 let shard_digest = hasher.finalize();
75
76 self.proof
78 .verify_element_inclusion(&mut hasher, &shard_digest, self.index as u32, root)
79 .is_ok()
80 }
81}
82
83impl<H: Hasher> Write for Chunk<H> {
84 fn write(&self, writer: &mut impl BufMut) {
85 self.shard.write(writer);
86 self.index.write(writer);
87 self.proof.write(writer);
88 }
89}
90
91impl<H: Hasher> Read for Chunk<H> {
92 type Cfg = crate::CodecConfig;
94
95 fn read_cfg(reader: &mut impl Buf, cfg: &Self::Cfg) -> Result<Self, commonware_codec::Error> {
96 let shard = Vec::<u8>::read_range(reader, ..=cfg.maximum_shard_size)?;
97 let index = u16::read(reader)?;
98 let proof = bmt::Proof::<H::Digest>::read_cfg(reader, &1)?;
99 Ok(Self {
100 shard,
101 index,
102 proof,
103 })
104 }
105}
106
107impl<H: Hasher> EncodeSize for Chunk<H> {
108 fn encode_size(&self) -> usize {
109 self.shard.encode_size() + self.index.encode_size() + self.proof.encode_size()
110 }
111}
112
113impl<H: Hasher> PartialEq for Chunk<H> {
114 fn eq(&self, other: &Self) -> bool {
115 self.shard == other.shard && self.index == other.index && self.proof == other.proof
116 }
117}
118
119impl<H: Hasher> Eq for Chunk<H> {}
120
121#[cfg(feature = "arbitrary")]
122impl<H: Hasher> arbitrary::Arbitrary<'_> for Chunk<H>
123where
124 H::Digest: for<'a> arbitrary::Arbitrary<'a>,
125{
126 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
127 Ok(Self {
128 shard: u.arbitrary()?,
129 index: u.arbitrary()?,
130 proof: u.arbitrary()?,
131 })
132 }
133}
134
135fn prepare_data(data: Vec<u8>, k: usize, m: usize) -> Vec<Vec<u8>> {
137 let data_len = data.len();
139 let prefixed_len = u32::SIZE + data_len;
140 let mut shard_len = prefixed_len.div_ceil(k);
141
142 if !shard_len.is_multiple_of(2) {
144 shard_len += 1;
145 }
146
147 let length_bytes = (data_len as u32).to_be_bytes();
149 let mut padded = vec![0u8; k * shard_len];
150 padded[..u32::SIZE].copy_from_slice(&length_bytes);
151 padded[u32::SIZE..u32::SIZE + data_len].copy_from_slice(&data);
152
153 let mut shards = Vec::with_capacity(k + m); for chunk in padded.chunks(shard_len) {
155 shards.push(chunk.to_vec());
156 }
157 shards
158}
159
160fn extract_data(shards: Vec<&[u8]>, k: usize) -> Vec<u8> {
162 let mut data = shards.into_iter().take(k).flatten();
164
165 let data_len = (&mut data)
167 .take(u32::SIZE)
168 .copied()
169 .collect::<Vec<_>>()
170 .try_into()
171 .expect("insufficient data");
172 let data_len = u32::from_be_bytes(data_len) as usize;
173
174 data.take(data_len).copied().collect()
176}
177
178type Encoding<H> = (bmt::Tree<H>, Vec<Vec<u8>>);
180
181fn encode_inner<H: Hasher, S: Strategy>(
183 total: u16,
184 min: u16,
185 data: Vec<u8>,
186 strategy: &S,
187) -> Result<Encoding<H>, Error> {
188 assert!(total > min);
190 assert!(min > 0);
191 let n = total as usize;
192 let k = min as usize;
193 let m = n - k;
194 if data.len() > u32::MAX as usize {
195 return Err(Error::InvalidDataLength(data.len()));
196 }
197
198 let mut shards = prepare_data(data, k, m);
200 let shard_len = shards[0].len();
201
202 let mut encoder = ReedSolomonEncoder::new(k, m, shard_len).map_err(Error::ReedSolomon)?;
204 for shard in &shards {
205 encoder
206 .add_original_shard(shard)
207 .map_err(Error::ReedSolomon)?;
208 }
209
210 let encoding = encoder.encode().map_err(Error::ReedSolomon)?;
212 let recovery_shards: Vec<Vec<u8>> = encoding
213 .recovery_iter()
214 .map(|shard| shard.to_vec())
215 .collect();
216 shards.extend(recovery_shards);
217
218 let mut builder = Builder::<H>::new(n);
220 let shard_hashes = strategy.map_init_collect_vec(&shards, H::new, |hasher, shard| {
221 hasher.update(shard);
222 hasher.finalize()
223 });
224 for hash in &shard_hashes {
225 builder.add(hash);
226 }
227 let tree = builder.build();
228
229 Ok((tree, shards))
230}
231
232fn encode<H: Hasher, S: Strategy>(
246 total: u16,
247 min: u16,
248 data: Vec<u8>,
249 strategy: &S,
250) -> Result<(H::Digest, Vec<Chunk<H>>), Error> {
251 let (tree, shards) = encode_inner::<H, _>(total, min, data, strategy)?;
253 let root = tree.root();
254 let n = total as usize;
255
256 let mut chunks = Vec::with_capacity(n);
258 for (i, shard) in shards.into_iter().enumerate() {
259 let proof = tree.proof(i as u32).map_err(|_| Error::InvalidProof)?;
260 chunks.push(Chunk::new(shard, i as u16, proof));
261 }
262
263 Ok((root, chunks))
264}
265
266fn decode<H: Hasher, S: Strategy>(
282 total: u16,
283 min: u16,
284 root: &H::Digest,
285 chunks: &[Chunk<H>],
286 strategy: &S,
287) -> Result<Vec<u8>, Error> {
288 assert!(total > min);
290 assert!(min > 0);
291 let n = total as usize;
292 let k = min as usize;
293 let m = n - k;
294 if chunks.len() < k {
295 return Err(Error::NotEnoughChunks);
296 }
297
298 let shard_len = chunks[0].shard.len();
300 let mut seen = HashSet::new();
301 let mut provided_originals: Vec<(usize, &[u8])> = Vec::new();
302 let mut provided_recoveries: Vec<(usize, &[u8])> = Vec::new();
303 for chunk in chunks {
304 let index = chunk.index;
306 if index >= total {
307 return Err(Error::InvalidIndex(index));
308 }
309 if seen.contains(&index) {
310 return Err(Error::DuplicateIndex(index));
311 }
312 seen.insert(index);
313
314 if index < min {
316 provided_originals.push((index as usize, chunk.shard.as_slice()));
317 } else {
318 provided_recoveries.push((index as usize - k, chunk.shard.as_slice()));
319 }
320 }
321
322 let mut decoder = ReedSolomonDecoder::new(k, m, shard_len).map_err(Error::ReedSolomon)?;
324 for (idx, ref shard) in &provided_originals {
325 decoder
326 .add_original_shard(*idx, shard)
327 .map_err(Error::ReedSolomon)?;
328 }
329 for (idx, ref shard) in &provided_recoveries {
330 decoder
331 .add_recovery_shard(*idx, shard)
332 .map_err(Error::ReedSolomon)?;
333 }
334 let decoding = decoder.decode().map_err(Error::ReedSolomon)?;
335
336 let mut shards = Vec::with_capacity(n);
338 shards.resize(k, Default::default());
339 for (idx, shard) in provided_originals {
340 shards[idx] = shard;
341 }
342 for (idx, shard) in decoding.restored_original_iter() {
343 shards[idx] = shard;
344 }
345
346 let mut encoder = ReedSolomonEncoder::new(k, m, shard_len).map_err(Error::ReedSolomon)?;
348 for shard in shards.iter().take(k) {
349 encoder
350 .add_original_shard(shard)
351 .map_err(Error::ReedSolomon)?;
352 }
353 let encoding = encoder.encode().map_err(Error::ReedSolomon)?;
354 let recovery_shards: Vec<&[u8]> = encoding.recovery_iter().collect();
355 shards.extend(recovery_shards);
356
357 let mut builder = Builder::<H>::new(n);
359 let shard_hashes = strategy.map_init_collect_vec(&shards, H::new, |hasher, shard| {
360 hasher.update(shard);
361 hasher.finalize()
362 });
363 for hash in &shard_hashes {
364 builder.add(hash);
365 }
366 let tree = builder.build();
367
368 if tree.root() != *root {
370 return Err(Error::Inconsistent);
371 }
372
373 Ok(extract_data(shards, k))
375}
376
377#[derive(Clone, Copy)]
460pub struct ReedSolomon<H> {
461 _marker: PhantomData<H>,
462}
463
464impl<H> std::fmt::Debug for ReedSolomon<H> {
465 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
466 f.debug_struct("ReedSolomon").finish()
467 }
468}
469
470impl<H: Hasher> Scheme for ReedSolomon<H> {
471 type Commitment = H::Digest;
472
473 type Shard = Chunk<H>;
474 type ReShard = Chunk<H>;
475 type CheckedShard = Chunk<H>;
476 type CheckingData = ();
477
478 type Error = Error;
479
480 fn encode(
481 config: &Config,
482 mut data: impl Buf,
483 strategy: &impl Strategy,
484 ) -> Result<(Self::Commitment, Vec<Self::Shard>), Self::Error> {
485 let data: Vec<u8> = data.copy_to_bytes(data.remaining()).to_vec();
486 encode(total_shards(config)?, config.minimum_shards, data, strategy)
487 }
488
489 fn reshard(
490 _config: &Config,
491 commitment: &Self::Commitment,
492 index: u16,
493 shard: Self::Shard,
494 ) -> Result<(Self::CheckingData, Self::CheckedShard, Self::ReShard), Self::Error> {
495 if shard.index != index {
496 return Err(Error::WrongIndex(index));
497 }
498 if shard.verify(shard.index, commitment) {
499 Ok(((), shard.clone(), shard))
500 } else {
501 Err(Error::InvalidProof)
502 }
503 }
504
505 fn check(
506 _config: &Config,
507 commitment: &Self::Commitment,
508 _checking_data: &Self::CheckingData,
509 index: u16,
510 reshard: Self::ReShard,
511 ) -> Result<Self::CheckedShard, Self::Error> {
512 if reshard.index != index {
513 return Err(Error::WrongIndex(reshard.index));
514 }
515 if !reshard.verify(reshard.index, commitment) {
516 return Err(Error::InvalidProof);
517 }
518 Ok(reshard)
519 }
520
521 fn decode(
522 config: &Config,
523 commitment: &Self::Commitment,
524 _checking_data: Self::CheckingData,
525 shards: &[Self::CheckedShard],
526 strategy: &impl Strategy,
527 ) -> Result<Vec<u8>, Self::Error> {
528 decode(
529 total_shards(config)?,
530 config.minimum_shards,
531 commitment,
532 shards,
533 strategy,
534 )
535 }
536}
537
538#[cfg(test)]
539mod tests {
540 use super::*;
541 use commonware_cryptography::Sha256;
542 use commonware_parallel::Sequential;
543
544 const STRATEGY: Sequential = Sequential;
545
546 #[test]
547 fn test_recovery() {
548 let data = b"Testing recovery pieces";
549 let total = 8u16;
550 let min = 3u16;
551
552 let (root, chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
554
555 let pieces: Vec<_> = vec![
557 chunks[0].clone(), chunks[4].clone(), chunks[6].clone(), ];
561
562 let decoded = decode::<Sha256, _>(total, min, &root, &pieces, &STRATEGY).unwrap();
564 assert_eq!(decoded, data);
565 }
566
567 #[test]
568 fn test_not_enough_pieces() {
569 let data = b"Test insufficient pieces";
570 let total = 6u16;
571 let min = 4u16;
572
573 let (root, chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
575
576 let pieces: Vec<_> = chunks.into_iter().take(2).collect();
578
579 let result = decode::<Sha256, _>(total, min, &root, &pieces, &STRATEGY);
581 assert!(matches!(result, Err(Error::NotEnoughChunks)));
582 }
583
584 #[test]
585 fn test_duplicate_index() {
586 let data = b"Test duplicate detection";
587 let total = 5u16;
588 let min = 3u16;
589
590 let (root, chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
592
593 let pieces = vec![chunks[0].clone(), chunks[0].clone(), chunks[1].clone()];
595
596 let result = decode::<Sha256, _>(total, min, &root, &pieces, &STRATEGY);
598 assert!(matches!(result, Err(Error::DuplicateIndex(0))));
599 }
600
601 #[test]
602 fn test_invalid_index() {
603 let data = b"Test invalid index";
604 let total = 5u16;
605 let min = 3u16;
606
607 let (root, chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
609
610 for i in 0..total {
612 assert!(!chunks[i as usize].verify(i + 1, &root));
613 }
614 }
615
616 #[test]
617 #[should_panic(expected = "assertion failed: total > min")]
618 fn test_invalid_total() {
619 let data = b"Test parameter validation";
620
621 encode::<Sha256, _>(3, 3, data.to_vec(), &STRATEGY).unwrap();
623 }
624
625 #[test]
626 #[should_panic(expected = "assertion failed: min > 0")]
627 fn test_invalid_min() {
628 let data = b"Test parameter validation";
629
630 encode::<Sha256, _>(5, 0, data.to_vec(), &STRATEGY).unwrap();
632 }
633
634 #[test]
635 fn test_empty_data() {
636 let data = b"";
637 let total = 100u16;
638 let min = 30u16;
639
640 let (root, chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
642
643 let minimal = chunks.into_iter().take(min as usize).collect::<Vec<_>>();
645 let decoded = decode::<Sha256, _>(total, min, &root, &minimal, &STRATEGY).unwrap();
646 assert_eq!(decoded, data);
647 }
648
649 #[test]
650 fn test_large_data() {
651 let data = vec![42u8; 1000]; let total = 7u16;
653 let min = 4u16;
654
655 let (root, chunks) = encode::<Sha256, _>(total, min, data.clone(), &STRATEGY).unwrap();
657
658 let minimal = chunks.into_iter().take(min as usize).collect::<Vec<_>>();
660 let decoded = decode::<Sha256, _>(total, min, &root, &minimal, &STRATEGY).unwrap();
661 assert_eq!(decoded, data);
662 }
663
664 #[test]
665 fn test_malicious_root_detection() {
666 let data = b"Original data that should be protected";
667 let total = 7u16;
668 let min = 4u16;
669
670 let (_correct_root, chunks) =
672 encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
673
674 let mut hasher = Sha256::new();
676 hasher.update(b"malicious_data_that_wasnt_actually_encoded");
677 let malicious_root = hasher.finalize();
678
679 for i in 0..total {
681 assert!(!chunks[i as usize].verify(i, &malicious_root));
682 }
683
684 let minimal = chunks.into_iter().take(min as usize).collect::<Vec<_>>();
686
687 let result = decode::<Sha256, _>(total, min, &malicious_root, &minimal, &STRATEGY);
689 assert!(matches!(result, Err(Error::Inconsistent)));
690 }
691
692 #[test]
693 fn test_manipulated_chunk_detection() {
694 let data = b"Data integrity must be maintained";
695 let total = 6u16;
696 let min = 3u16;
697
698 let (root, mut chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
700
701 if !chunks[1].shard.is_empty() {
703 chunks[1].shard[0] ^= 0xFF; }
705
706 let result = decode::<Sha256, _>(total, min, &root, &chunks, &STRATEGY);
708 assert!(matches!(result, Err(Error::Inconsistent)));
709 }
710
711 #[test]
712 fn test_inconsistent_shards() {
713 let data = b"Test data for malicious encoding";
714 let total = 5u16;
715 let min = 3u16;
716 let m = total - min;
717
718 let shards = prepare_data(data.to_vec(), min as usize, total as usize - min as usize);
720 let shard_size = shards[0].len();
721
722 let mut encoder = ReedSolomonEncoder::new(min as usize, m as usize, shard_size).unwrap();
724 for shard in &shards {
725 encoder.add_original_shard(shard).unwrap();
726 }
727 let recovery_result = encoder.encode().unwrap();
728 let mut recovery_shards: Vec<Vec<u8>> = recovery_result
729 .recovery_iter()
730 .map(|s| s.to_vec())
731 .collect();
732
733 if !recovery_shards[0].is_empty() {
735 recovery_shards[0][0] ^= 0xFF;
736 }
737
738 let mut malicious_shards = shards.clone();
740 malicious_shards.extend(recovery_shards);
741
742 let mut builder = Builder::<Sha256>::new(total as usize);
744 for shard in &malicious_shards {
745 let mut hasher = Sha256::new();
746 hasher.update(shard);
747 builder.add(&hasher.finalize());
748 }
749 let malicious_tree = builder.build();
750 let malicious_root = malicious_tree.root();
751
752 let selected_indices = vec![0, 1, 3]; let mut pieces = Vec::new();
755 for &i in &selected_indices {
756 let merkle_proof = malicious_tree.proof(i as u32).unwrap();
757 let shard = malicious_shards[i].clone();
758 let chunk = Chunk::new(shard, i as u16, merkle_proof);
759 pieces.push(chunk);
760 }
761
762 let result = decode::<Sha256, _>(total, min, &malicious_root, &pieces, &STRATEGY);
764 assert!(matches!(result, Err(Error::Inconsistent)));
765 }
766
767 #[test]
768 fn test_decode_invalid_index() {
769 let data = b"Testing recovery pieces";
770 let total = 8u16;
771 let min = 3u16;
772
773 let (root, mut chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
775
776 chunks[1].index = 8;
778 let pieces: Vec<_> = vec![
779 chunks[0].clone(), chunks[1].clone(), chunks[6].clone(), ];
783
784 let result = decode::<Sha256, _>(total, min, &root, &pieces, &STRATEGY);
786 assert!(matches!(result, Err(Error::InvalidIndex(8))));
787 }
788
789 #[test]
790 fn test_max_chunks() {
791 let data = vec![42u8; 1000]; let total = u16::MAX;
793 let min = u16::MAX / 2;
794
795 let (root, chunks) = encode::<Sha256, _>(total, min, data.clone(), &STRATEGY).unwrap();
797
798 let minimal = chunks.into_iter().take(min as usize).collect::<Vec<_>>();
800 let decoded = decode::<Sha256, _>(total, min, &root, &minimal, &STRATEGY).unwrap();
801 assert_eq!(decoded, data);
802 }
803
804 #[test]
805 fn test_too_many_chunks() {
806 let data = vec![42u8; 1000]; let total = u16::MAX;
808 let min = u16::MAX / 2 - 1;
809
810 let result = encode::<Sha256, _>(total, min, data, &STRATEGY);
812 assert!(matches!(
813 result,
814 Err(Error::ReedSolomon(
815 reed_solomon_simd::Error::UnsupportedShardCount {
816 original_count: _,
817 recovery_count: _,
818 }
819 ))
820 ));
821 }
822
823 #[test]
824 fn test_too_many_total_shards() {
825 assert!(ReedSolomon::<Sha256>::encode(
826 &Config {
827 minimum_shards: u16::MAX / 2 + 1,
828 extra_shards: u16::MAX,
829 },
830 [].as_slice(),
831 &STRATEGY,
832 )
833 .is_err())
834 }
835
836 #[cfg(feature = "arbitrary")]
837 mod conformance {
838 use super::*;
839 use commonware_codec::conformance::CodecConformance;
840
841 commonware_conformance::conformance_tests! {
842 CodecConformance<Chunk<Sha256>>,
843 }
844 }
845}