1use crate::{Config, Scheme};
2use bytes::{Buf, BufMut};
3use commonware_codec::{EncodeSize, FixedSize, Read, ReadExt, ReadRangeExt, Write};
4use commonware_cryptography::Hasher;
5use commonware_storage::bmt::{self, Builder};
6use rayon::{
7 iter::{IntoParallelRefIterator, ParallelIterator},
8 ThreadPoolBuilder,
9};
10use reed_solomon_simd::{Error as RsError, ReedSolomonDecoder, ReedSolomonEncoder};
11use std::{collections::HashSet, marker::PhantomData};
12use thiserror::Error;
13
14#[derive(Error, Debug)]
16pub enum Error {
17 #[error("reed-solomon error: {0}")]
18 ReedSolomon(#[from] RsError),
19 #[error("inconsistent")]
20 Inconsistent,
21 #[error("invalid proof")]
22 InvalidProof,
23 #[error("not enough chunks")]
24 NotEnoughChunks,
25 #[error("duplicate chunk index: {0}")]
26 DuplicateIndex(u16),
27 #[error("invalid data length: {0}")]
28 InvalidDataLength(usize),
29 #[error("invalid index: {0}")]
30 InvalidIndex(u16),
31 #[error("wrong index: {0}")]
32 WrongIndex(u16),
33 #[error("too many total shards: {0}")]
34 TooManyTotalShards(u32),
35}
36
37fn total_shards(config: &Config) -> Result<u16, Error> {
38 let total = config.total_shards();
39 total
40 .try_into()
41 .map_err(|_| Error::TooManyTotalShards(total))
42}
43
44#[derive(Debug, Clone)]
46pub struct Chunk<H: Hasher> {
47 shard: Vec<u8>,
49
50 index: u16,
52
53 proof: bmt::Proof<H>,
55}
56
57impl<H: Hasher> Chunk<H> {
58 const fn new(shard: Vec<u8>, index: u16, proof: bmt::Proof<H>) -> Self {
60 Self {
61 shard,
62 index,
63 proof,
64 }
65 }
66
67 fn verify(&self, index: u16, root: &H::Digest) -> bool {
69 if index != self.index {
71 return false;
72 }
73
74 let mut hasher = H::new();
76 hasher.update(&self.shard);
77 let shard_digest = hasher.finalize();
78
79 self.proof
81 .verify(&mut hasher, &shard_digest, self.index as u32, root)
82 .is_ok()
83 }
84}
85
86impl<H: Hasher> Write for Chunk<H> {
87 fn write(&self, writer: &mut impl BufMut) {
88 self.shard.write(writer);
89 self.index.write(writer);
90 self.proof.write(writer);
91 }
92}
93
94impl<H: Hasher> Read for Chunk<H> {
95 type Cfg = crate::CodecConfig;
97
98 fn read_cfg(reader: &mut impl Buf, cfg: &Self::Cfg) -> Result<Self, commonware_codec::Error> {
99 let shard = Vec::<u8>::read_range(reader, ..=cfg.maximum_shard_size)?;
100 let index = u16::read(reader)?;
101 let proof = bmt::Proof::<H>::read(reader)?;
102 Ok(Self {
103 shard,
104 index,
105 proof,
106 })
107 }
108}
109
110impl<H: Hasher> EncodeSize for Chunk<H> {
111 fn encode_size(&self) -> usize {
112 self.shard.encode_size() + self.index.encode_size() + self.proof.encode_size()
113 }
114}
115
116impl<H: Hasher> PartialEq for Chunk<H> {
117 fn eq(&self, other: &Self) -> bool {
118 self.shard == other.shard && self.index == other.index && self.proof == other.proof
119 }
120}
121
122impl<H: Hasher> Eq for Chunk<H> {}
123
124#[cfg(feature = "arbitrary")]
125impl<H: Hasher> arbitrary::Arbitrary<'_> for Chunk<H>
126where
127 H::Digest: for<'a> arbitrary::Arbitrary<'a>,
128{
129 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
130 Ok(Self {
131 shard: u.arbitrary()?,
132 index: u.arbitrary()?,
133 proof: u.arbitrary()?,
134 })
135 }
136}
137
138fn prepare_data(data: Vec<u8>, k: usize, m: usize) -> Vec<Vec<u8>> {
140 let data_len = data.len();
142 let prefixed_len = u32::SIZE + data_len;
143 let mut shard_len = prefixed_len.div_ceil(k);
144
145 if !shard_len.is_multiple_of(2) {
147 shard_len += 1;
148 }
149
150 let length_bytes = (data_len as u32).to_be_bytes();
152 let mut padded = vec![0u8; k * shard_len];
153 padded[..u32::SIZE].copy_from_slice(&length_bytes);
154 padded[u32::SIZE..u32::SIZE + data_len].copy_from_slice(&data);
155
156 let mut shards = Vec::with_capacity(k + m); for chunk in padded.chunks(shard_len) {
158 shards.push(chunk.to_vec());
159 }
160 shards
161}
162
163fn extract_data(shards: Vec<&[u8]>, k: usize) -> Vec<u8> {
165 let mut data = shards.into_iter().take(k).flatten();
167
168 let data_len = (&mut data)
170 .take(u32::SIZE)
171 .copied()
172 .collect::<Vec<_>>()
173 .try_into()
174 .expect("insufficient data");
175 let data_len = u32::from_be_bytes(data_len) as usize;
176
177 data.take(data_len).copied().collect()
179}
180
181type Encoding<H> = (bmt::Tree<H>, Vec<Vec<u8>>);
183
184fn encode_inner<H: Hasher>(
186 total: u16,
187 min: u16,
188 data: Vec<u8>,
189 concurrency: usize,
190) -> Result<Encoding<H>, Error> {
191 assert!(total > min);
193 assert!(min > 0);
194 let n = total as usize;
195 let k = min as usize;
196 let m = n - k;
197 if data.len() > u32::MAX as usize {
198 return Err(Error::InvalidDataLength(data.len()));
199 }
200
201 let mut shards = prepare_data(data, k, m);
203 let shard_len = shards[0].len();
204
205 let mut encoder = ReedSolomonEncoder::new(k, m, shard_len).map_err(Error::ReedSolomon)?;
207 for shard in &shards {
208 encoder
209 .add_original_shard(shard)
210 .map_err(Error::ReedSolomon)?;
211 }
212
213 let encoding = encoder.encode().map_err(Error::ReedSolomon)?;
215 let recovery_shards: Vec<Vec<u8>> = encoding
216 .recovery_iter()
217 .map(|shard| shard.to_vec())
218 .collect();
219 shards.extend(recovery_shards);
220
221 let mut builder = Builder::<H>::new(n);
223 if concurrency > 1 {
224 let pool = ThreadPoolBuilder::new()
225 .num_threads(concurrency)
226 .build()
227 .expect("unable to build thread pool");
228 let shard_hashes = pool.install(|| {
229 shards
230 .par_iter()
231 .map_init(H::new, |hasher, shard| {
232 hasher.update(shard);
233 hasher.finalize()
234 })
235 .collect::<Vec<_>>()
236 });
237 for hash in &shard_hashes {
238 builder.add(hash);
239 }
240 } else {
241 let mut hasher = H::new();
242 for shard in &shards {
243 hasher.update(shard);
244 builder.add(&hasher.finalize());
245 }
246 }
247 let tree = builder.build();
248
249 Ok((tree, shards))
250}
251
252fn encode<H: Hasher>(
266 total: u16,
267 min: u16,
268 data: Vec<u8>,
269 concurrency: usize,
270) -> Result<(H::Digest, Vec<Chunk<H>>), Error> {
271 let (tree, shards) = encode_inner::<H>(total, min, data, concurrency)?;
273 let root = tree.root();
274 let n = total as usize;
275
276 let mut chunks = Vec::with_capacity(n);
278 for (i, shard) in shards.into_iter().enumerate() {
279 let proof = tree.proof(i as u32).map_err(|_| Error::InvalidProof)?;
280 chunks.push(Chunk::new(shard, i as u16, proof));
281 }
282
283 Ok((root, chunks))
284}
285
286fn decode<H: Hasher>(
302 total: u16,
303 min: u16,
304 root: &H::Digest,
305 chunks: &[Chunk<H>],
306 concurrency: usize,
307) -> Result<Vec<u8>, Error> {
308 assert!(total > min);
310 assert!(min > 0);
311 let n = total as usize;
312 let k = min as usize;
313 let m = n - k;
314 if chunks.len() < k {
315 return Err(Error::NotEnoughChunks);
316 }
317
318 let shard_len = chunks[0].shard.len();
320 let mut seen = HashSet::new();
321 let mut provided_originals: Vec<(usize, &[u8])> = Vec::new();
322 let mut provided_recoveries: Vec<(usize, &[u8])> = Vec::new();
323 for chunk in chunks {
324 let index = chunk.index;
326 if index >= total {
327 return Err(Error::InvalidIndex(index));
328 }
329 if seen.contains(&index) {
330 return Err(Error::DuplicateIndex(index));
331 }
332 seen.insert(index);
333
334 if index < min {
336 provided_originals.push((index as usize, chunk.shard.as_slice()));
337 } else {
338 provided_recoveries.push((index as usize - k, chunk.shard.as_slice()));
339 }
340 }
341
342 let mut decoder = ReedSolomonDecoder::new(k, m, shard_len).map_err(Error::ReedSolomon)?;
344 for (idx, ref shard) in &provided_originals {
345 decoder
346 .add_original_shard(*idx, shard)
347 .map_err(Error::ReedSolomon)?;
348 }
349 for (idx, ref shard) in &provided_recoveries {
350 decoder
351 .add_recovery_shard(*idx, shard)
352 .map_err(Error::ReedSolomon)?;
353 }
354 let decoding = decoder.decode().map_err(Error::ReedSolomon)?;
355
356 let mut shards = Vec::with_capacity(n);
358 shards.resize(k, Default::default());
359 for (idx, shard) in provided_originals {
360 shards[idx] = shard;
361 }
362 for (idx, shard) in decoding.restored_original_iter() {
363 shards[idx] = shard;
364 }
365
366 let mut encoder = ReedSolomonEncoder::new(k, m, shard_len).map_err(Error::ReedSolomon)?;
368 for shard in shards.iter().take(k) {
369 encoder
370 .add_original_shard(shard)
371 .map_err(Error::ReedSolomon)?;
372 }
373 let encoding = encoder.encode().map_err(Error::ReedSolomon)?;
374 let recovery_shards: Vec<&[u8]> = encoding.recovery_iter().collect();
375 shards.extend(recovery_shards);
376
377 let mut builder = Builder::<H>::new(n);
379 if concurrency > 1 {
380 let pool = ThreadPoolBuilder::new()
381 .num_threads(concurrency)
382 .build()
383 .expect("unable to build thread pool");
384 let shard_hashes = pool.install(|| {
385 shards
386 .par_iter()
387 .map_init(H::new, |hasher, shard| {
388 hasher.update(shard);
389 hasher.finalize()
390 })
391 .collect::<Vec<_>>()
392 });
393
394 for hash in &shard_hashes {
395 builder.add(hash);
396 }
397 } else {
398 let mut hasher = H::new();
399 for shard in &shards {
400 hasher.update(shard);
401 builder.add(&hasher.finalize());
402 }
403 }
404 let tree = builder.build();
405
406 if tree.root() != *root {
408 return Err(Error::Inconsistent);
409 }
410
411 Ok(extract_data(shards, k))
413}
414
415#[derive(Clone, Copy)]
498pub struct ReedSolomon<H> {
499 _marker: PhantomData<H>,
500}
501
502impl<H> std::fmt::Debug for ReedSolomon<H> {
503 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
504 f.debug_struct("ReedSolomon").finish()
505 }
506}
507
508impl<H: Hasher> Scheme for ReedSolomon<H> {
509 type Commitment = H::Digest;
510
511 type Shard = Chunk<H>;
512 type ReShard = Chunk<H>;
513 type CheckedShard = Chunk<H>;
514 type CheckingData = ();
515
516 type Error = Error;
517
518 fn encode(
519 config: &Config,
520 mut data: impl Buf,
521 concurrency: usize,
522 ) -> Result<(Self::Commitment, Vec<Self::Shard>), Self::Error> {
523 let data: Vec<u8> = data.copy_to_bytes(data.remaining()).to_vec();
524 encode(
525 total_shards(config)?,
526 config.minimum_shards,
527 data,
528 concurrency,
529 )
530 }
531
532 fn reshard(
533 _config: &Config,
534 commitment: &Self::Commitment,
535 index: u16,
536 shard: Self::Shard,
537 ) -> Result<(Self::CheckingData, Self::CheckedShard, Self::ReShard), Self::Error> {
538 if shard.index != index {
539 return Err(Error::WrongIndex(index));
540 }
541 if shard.verify(shard.index, commitment) {
542 Ok(((), shard.clone(), shard))
543 } else {
544 Err(Error::InvalidProof)
545 }
546 }
547
548 fn check(
549 _config: &Config,
550 commitment: &Self::Commitment,
551 _checking_data: &Self::CheckingData,
552 index: u16,
553 reshard: Self::ReShard,
554 ) -> Result<Self::CheckedShard, Self::Error> {
555 if reshard.index != index {
556 return Err(Error::WrongIndex(reshard.index));
557 }
558 if !reshard.verify(reshard.index, commitment) {
559 return Err(Error::InvalidProof);
560 }
561 Ok(reshard)
562 }
563
564 fn decode(
565 config: &Config,
566 commitment: &Self::Commitment,
567 _checking_data: Self::CheckingData,
568 shards: &[Self::CheckedShard],
569 concurrency: usize,
570 ) -> Result<Vec<u8>, Self::Error> {
571 decode(
572 total_shards(config)?,
573 config.minimum_shards,
574 commitment,
575 shards,
576 concurrency,
577 )
578 }
579}
580
581#[cfg(test)]
582mod tests {
583 use super::*;
584 use commonware_cryptography::Sha256;
585
586 const CONCURRENCY: usize = 1;
587
588 #[test]
589 fn test_recovery() {
590 let data = b"Testing recovery pieces";
591 let total = 8u16;
592 let min = 3u16;
593
594 let (root, chunks) = encode::<Sha256>(total, min, data.to_vec(), CONCURRENCY).unwrap();
596
597 let pieces: Vec<_> = vec![
599 chunks[0].clone(), chunks[4].clone(), chunks[6].clone(), ];
603
604 let decoded = decode::<Sha256>(total, min, &root, &pieces, CONCURRENCY).unwrap();
606 assert_eq!(decoded, data);
607 }
608
609 #[test]
610 fn test_not_enough_pieces() {
611 let data = b"Test insufficient pieces";
612 let total = 6u16;
613 let min = 4u16;
614
615 let (root, chunks) = encode::<Sha256>(total, min, data.to_vec(), CONCURRENCY).unwrap();
617
618 let pieces: Vec<_> = chunks.into_iter().take(2).collect();
620
621 let result = decode::<Sha256>(total, min, &root, &pieces, CONCURRENCY);
623 assert!(matches!(result, Err(Error::NotEnoughChunks)));
624 }
625
626 #[test]
627 fn test_duplicate_index() {
628 let data = b"Test duplicate detection";
629 let total = 5u16;
630 let min = 3u16;
631
632 let (root, chunks) = encode::<Sha256>(total, min, data.to_vec(), CONCURRENCY).unwrap();
634
635 let pieces = vec![chunks[0].clone(), chunks[0].clone(), chunks[1].clone()];
637
638 let result = decode::<Sha256>(total, min, &root, &pieces, CONCURRENCY);
640 assert!(matches!(result, Err(Error::DuplicateIndex(0))));
641 }
642
643 #[test]
644 fn test_invalid_index() {
645 let data = b"Test invalid index";
646 let total = 5u16;
647 let min = 3u16;
648
649 let (root, chunks) = encode::<Sha256>(total, min, data.to_vec(), CONCURRENCY).unwrap();
651
652 for i in 0..total {
654 assert!(!chunks[i as usize].verify(i + 1, &root));
655 }
656 }
657
658 #[test]
659 #[should_panic(expected = "assertion failed: total > min")]
660 fn test_invalid_total() {
661 let data = b"Test parameter validation";
662
663 encode::<Sha256>(3, 3, data.to_vec(), CONCURRENCY).unwrap();
665 }
666
667 #[test]
668 #[should_panic(expected = "assertion failed: min > 0")]
669 fn test_invalid_min() {
670 let data = b"Test parameter validation";
671
672 encode::<Sha256>(5, 0, data.to_vec(), CONCURRENCY).unwrap();
674 }
675
676 #[test]
677 fn test_empty_data() {
678 let data = b"";
679 let total = 100u16;
680 let min = 30u16;
681
682 let (root, chunks) = encode::<Sha256>(total, min, data.to_vec(), CONCURRENCY).unwrap();
684
685 let minimal = chunks.into_iter().take(min as usize).collect::<Vec<_>>();
687 let decoded = decode::<Sha256>(total, min, &root, &minimal, CONCURRENCY).unwrap();
688 assert_eq!(decoded, data);
689 }
690
691 #[test]
692 fn test_large_data() {
693 let data = vec![42u8; 1000]; let total = 7u16;
695 let min = 4u16;
696
697 let (root, chunks) = encode::<Sha256>(total, min, data.clone(), CONCURRENCY).unwrap();
699
700 let minimal = chunks.into_iter().take(min as usize).collect::<Vec<_>>();
702 let decoded = decode::<Sha256>(total, min, &root, &minimal, CONCURRENCY).unwrap();
703 assert_eq!(decoded, data);
704 }
705
706 #[test]
707 fn test_malicious_root_detection() {
708 let data = b"Original data that should be protected";
709 let total = 7u16;
710 let min = 4u16;
711
712 let (_correct_root, chunks) =
714 encode::<Sha256>(total, min, data.to_vec(), CONCURRENCY).unwrap();
715
716 let mut hasher = Sha256::new();
718 hasher.update(b"malicious_data_that_wasnt_actually_encoded");
719 let malicious_root = hasher.finalize();
720
721 for i in 0..total {
723 assert!(!chunks[i as usize].verify(i, &malicious_root));
724 }
725
726 let minimal = chunks.into_iter().take(min as usize).collect::<Vec<_>>();
728
729 let result = decode::<Sha256>(total, min, &malicious_root, &minimal, CONCURRENCY);
731 assert!(matches!(result, Err(Error::Inconsistent)));
732 }
733
734 #[test]
735 fn test_manipulated_chunk_detection() {
736 let data = b"Data integrity must be maintained";
737 let total = 6u16;
738 let min = 3u16;
739
740 let (root, mut chunks) = encode::<Sha256>(total, min, data.to_vec(), CONCURRENCY).unwrap();
742
743 if !chunks[1].shard.is_empty() {
745 chunks[1].shard[0] ^= 0xFF; }
747
748 let result = decode::<Sha256>(total, min, &root, &chunks, CONCURRENCY);
750 assert!(matches!(result, Err(Error::Inconsistent)));
751 }
752
753 #[test]
754 fn test_inconsistent_shards() {
755 let data = b"Test data for malicious encoding";
756 let total = 5u16;
757 let min = 3u16;
758 let m = total - min;
759
760 let shards = prepare_data(data.to_vec(), min as usize, total as usize - min as usize);
762 let shard_size = shards[0].len();
763
764 let mut encoder = ReedSolomonEncoder::new(min as usize, m as usize, shard_size).unwrap();
766 for shard in &shards {
767 encoder.add_original_shard(shard).unwrap();
768 }
769 let recovery_result = encoder.encode().unwrap();
770 let mut recovery_shards: Vec<Vec<u8>> = recovery_result
771 .recovery_iter()
772 .map(|s| s.to_vec())
773 .collect();
774
775 if !recovery_shards[0].is_empty() {
777 recovery_shards[0][0] ^= 0xFF;
778 }
779
780 let mut malicious_shards = shards.clone();
782 malicious_shards.extend(recovery_shards);
783
784 let mut builder = Builder::<Sha256>::new(total as usize);
786 for shard in &malicious_shards {
787 let mut hasher = Sha256::new();
788 hasher.update(shard);
789 builder.add(&hasher.finalize());
790 }
791 let malicious_tree = builder.build();
792 let malicious_root = malicious_tree.root();
793
794 let selected_indices = vec![0, 1, 3]; let mut pieces = Vec::new();
797 for &i in &selected_indices {
798 let merkle_proof = malicious_tree.proof(i as u32).unwrap();
799 let shard = malicious_shards[i].clone();
800 let chunk = Chunk::new(shard, i as u16, merkle_proof);
801 pieces.push(chunk);
802 }
803
804 let result = decode::<Sha256>(total, min, &malicious_root, &pieces, CONCURRENCY);
806 assert!(matches!(result, Err(Error::Inconsistent)));
807 }
808
809 #[test]
810 fn test_decode_invalid_index() {
811 let data = b"Testing recovery pieces";
812 let total = 8u16;
813 let min = 3u16;
814
815 let (root, mut chunks) = encode::<Sha256>(total, min, data.to_vec(), CONCURRENCY).unwrap();
817
818 chunks[1].index = 8;
820 let pieces: Vec<_> = vec![
821 chunks[0].clone(), chunks[1].clone(), chunks[6].clone(), ];
825
826 let result = decode::<Sha256>(total, min, &root, &pieces, CONCURRENCY);
828 assert!(matches!(result, Err(Error::InvalidIndex(8))));
829 }
830
831 #[test]
832 fn test_max_chunks() {
833 let data = vec![42u8; 1000]; let total = u16::MAX;
835 let min = u16::MAX / 2;
836
837 let (root, chunks) = encode::<Sha256>(total, min, data.clone(), CONCURRENCY).unwrap();
839
840 let minimal = chunks.into_iter().take(min as usize).collect::<Vec<_>>();
842 let decoded = decode::<Sha256>(total, min, &root, &minimal, CONCURRENCY).unwrap();
843 assert_eq!(decoded, data);
844 }
845
846 #[test]
847 fn test_too_many_chunks() {
848 let data = vec![42u8; 1000]; let total = u16::MAX;
850 let min = u16::MAX / 2 - 1;
851
852 let result = encode::<Sha256>(total, min, data, CONCURRENCY);
854 assert!(matches!(
855 result,
856 Err(Error::ReedSolomon(
857 reed_solomon_simd::Error::UnsupportedShardCount {
858 original_count: _,
859 recovery_count: _,
860 }
861 ))
862 ));
863 }
864
865 #[test]
866 fn test_too_many_total_shards() {
867 assert!(ReedSolomon::<Sha256>::encode(
868 &Config {
869 minimum_shards: u16::MAX / 2 + 1,
870 extra_shards: u16::MAX,
871 },
872 [].as_slice(),
873 CONCURRENCY,
874 )
875 .is_err())
876 }
877
878 #[cfg(feature = "arbitrary")]
879 mod conformance {
880 use super::*;
881 use commonware_codec::conformance::CodecConformance;
882
883 commonware_conformance::conformance_tests! {
884 CodecConformance<Chunk<Sha256>>,
885 }
886 }
887}