1use crate::{Config, Scheme};
2use bytes::{Buf, BufMut, Bytes};
3use commonware_codec::{EncodeSize, FixedSize, RangeCfg, Read, ReadExt, Write};
4use commonware_cryptography::{Digest, Hasher};
5use commonware_parallel::Strategy;
6use commonware_storage::bmt::{self, Builder};
7use commonware_utils::Cached;
8use reed_solomon_simd::{Error as RsError, ReedSolomonDecoder, ReedSolomonEncoder};
9use std::marker::PhantomData;
10use thiserror::Error;
11
12commonware_utils::thread_local_cache!(static CACHED_ENCODER: ReedSolomonEncoder);
17commonware_utils::thread_local_cache!(static CACHED_DECODER: ReedSolomonDecoder);
18
19#[derive(Error, Debug)]
21pub enum Error {
22 #[error("reed-solomon error: {0}")]
23 ReedSolomon(#[from] RsError),
24 #[error("inconsistent")]
25 Inconsistent,
26 #[error("invalid proof")]
27 InvalidProof,
28 #[error("not enough chunks")]
29 NotEnoughChunks,
30 #[error("duplicate chunk index: {0}")]
31 DuplicateIndex(u16),
32 #[error("invalid data length: {0}")]
33 InvalidDataLength(usize),
34 #[error("invalid index: {0}")]
35 InvalidIndex(u16),
36 #[error("too many total shards: {0}")]
37 TooManyTotalShards(u32),
38}
39
40fn total_shards(config: &Config) -> Result<u16, Error> {
41 let total = config.total_shards();
42 total
43 .try_into()
44 .map_err(|_| Error::TooManyTotalShards(total))
45}
46
47#[derive(Debug, Clone)]
49pub struct Chunk<D: Digest> {
50 shard: Bytes,
52
53 index: u16,
55
56 proof: bmt::Proof<D>,
58}
59
60impl<D: Digest> Chunk<D> {
61 const fn new(shard: Bytes, index: u16, proof: bmt::Proof<D>) -> Self {
63 Self {
64 shard,
65 index,
66 proof,
67 }
68 }
69
70 fn verify<H: Hasher<Digest = D>>(&self, index: u16, root: &D) -> Option<CheckedChunk<D>> {
72 if index != self.index {
74 return None;
75 }
76
77 let mut hasher = H::new();
79 hasher.update(&self.shard);
80 let shard_digest = hasher.finalize();
81
82 self.proof
84 .verify_element_inclusion(&mut hasher, &shard_digest, self.index as u32, root)
85 .ok()?;
86
87 Some(CheckedChunk::new(
88 self.shard.clone(),
89 self.index,
90 shard_digest,
91 ))
92 }
93}
94
95#[derive(Clone, Debug, PartialEq, Eq)]
100pub struct CheckedChunk<D: Digest> {
101 shard: Bytes,
102 index: u16,
103 digest: D,
104}
105
106impl<D: Digest> CheckedChunk<D> {
107 const fn new(shard: Bytes, index: u16, digest: D) -> Self {
108 Self {
109 shard,
110 index,
111 digest,
112 }
113 }
114}
115
116impl<D: Digest> Write for Chunk<D> {
117 fn write(&self, writer: &mut impl BufMut) {
118 self.shard.write(writer);
119 self.index.write(writer);
120 self.proof.write(writer);
121 }
122}
123
124impl<D: Digest> Read for Chunk<D> {
125 type Cfg = crate::CodecConfig;
127
128 fn read_cfg(reader: &mut impl Buf, cfg: &Self::Cfg) -> Result<Self, commonware_codec::Error> {
129 let shard = Bytes::read_cfg(reader, &RangeCfg::new(..=cfg.maximum_shard_size))?;
130 let index = u16::read(reader)?;
131 let proof = bmt::Proof::<D>::read_cfg(reader, &1)?;
132 Ok(Self {
133 shard,
134 index,
135 proof,
136 })
137 }
138}
139
140impl<D: Digest> EncodeSize for Chunk<D> {
141 fn encode_size(&self) -> usize {
142 self.shard.encode_size() + self.index.encode_size() + self.proof.encode_size()
143 }
144}
145
146impl<D: Digest> PartialEq for Chunk<D> {
147 fn eq(&self, other: &Self) -> bool {
148 self.shard == other.shard && self.index == other.index && self.proof == other.proof
149 }
150}
151
152impl<D: Digest> Eq for Chunk<D> {}
153
154#[cfg(feature = "arbitrary")]
155impl<D: Digest> arbitrary::Arbitrary<'_> for Chunk<D>
156where
157 D: for<'a> arbitrary::Arbitrary<'a>,
158{
159 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
160 Ok(Self {
161 shard: u.arbitrary::<Vec<u8>>()?.into(),
162 index: u.arbitrary()?,
163 proof: u.arbitrary()?,
164 })
165 }
166}
167
168fn prepare_data(data: &[u8], k: usize) -> (Vec<u8>, usize) {
174 let data_len = data.len();
176 let prefixed_len = u32::SIZE + data_len;
177 let mut shard_len = prefixed_len.div_ceil(k);
178
179 if !shard_len.is_multiple_of(2) {
181 shard_len += 1;
182 }
183
184 let length_bytes = (data_len as u32).to_be_bytes();
186 let mut padded = vec![0u8; k * shard_len];
187 padded[..u32::SIZE].copy_from_slice(&length_bytes);
188 padded[u32::SIZE..u32::SIZE + data_len].copy_from_slice(data);
189
190 (padded, shard_len)
191}
192
193fn extract_data(shards: &[&[u8]], k: usize) -> Result<Vec<u8>, Error> {
198 let shards = shards.get(..k).ok_or(Error::NotEnoughChunks)?;
199 let (data_len, payload_len) = read_prefix_and_payload_len(shards)?;
200 let mut payload = copy_payload_after_prefix(shards, payload_len);
201 validate_zero_padding(&payload, data_len)?;
202 payload.truncate(data_len);
203 Ok(payload)
204}
205
206fn read_prefix_and_payload_len(shards: &[&[u8]]) -> Result<(usize, usize), Error> {
209 let total_len: usize = shards.iter().map(|s| s.len()).sum();
210 if total_len < u32::SIZE {
211 return Err(Error::Inconsistent);
212 }
213
214 let mut prefix = [0u8; u32::SIZE];
216 let mut prefix_len = 0usize;
217 for shard in shards {
218 if prefix_len == u32::SIZE {
219 break;
220 }
221 let read = (u32::SIZE - prefix_len).min(shard.len());
222 prefix[prefix_len..prefix_len + read].copy_from_slice(&shard[..read]);
223 prefix_len += read;
224 }
225
226 let data_len = u32::from_be_bytes(prefix) as usize;
227 let payload_len = total_len - u32::SIZE;
228 if data_len > payload_len {
229 return Err(Error::Inconsistent);
230 }
231 Ok((data_len, payload_len))
232}
233
234fn copy_payload_after_prefix(shards: &[&[u8]], payload_len: usize) -> Vec<u8> {
237 let mut payload = Vec::with_capacity(payload_len);
238 let mut prefix_bytes_left = u32::SIZE;
239 for shard in shards {
240 if prefix_bytes_left >= shard.len() {
241 prefix_bytes_left -= shard.len();
242 continue;
243 }
244 payload.extend_from_slice(&shard[prefix_bytes_left..]);
245 prefix_bytes_left = 0;
246 }
247 payload
248}
249
250fn validate_zero_padding(payload: &[u8], data_len: usize) -> Result<(), Error> {
253 if !payload[data_len..].iter().all(|byte| *byte == 0) {
255 return Err(Error::Inconsistent);
256 }
257 Ok(())
258}
259
260type Encoding<D> = (D, Vec<Chunk<D>>);
262
263fn encode<H: Hasher, S: Strategy>(
277 total: u16,
278 min: u16,
279 data: Vec<u8>,
280 strategy: &S,
281) -> Result<Encoding<H::Digest>, Error> {
282 assert!(total > min);
284 assert!(min > 0);
285 let n = total as usize;
286 let k = min as usize;
287 let m = n - k;
288 if data.len() > u32::MAX as usize {
289 return Err(Error::InvalidDataLength(data.len()));
290 }
291
292 let (padded, shard_len) = prepare_data(&data, k);
294
295 let recovery_buf = {
297 let mut encoder = Cached::take(
298 &CACHED_ENCODER,
299 || ReedSolomonEncoder::new(k, m, shard_len),
300 |enc| enc.reset(k, m, shard_len),
301 )
302 .map_err(Error::ReedSolomon)?;
303 for shard in padded.chunks(shard_len) {
304 encoder
305 .add_original_shard(shard)
306 .map_err(Error::ReedSolomon)?;
307 }
308
309 let encoding = encoder.encode().map_err(Error::ReedSolomon)?;
311 let mut buf = Vec::with_capacity(m * shard_len);
312 for shard in encoding.recovery_iter() {
313 buf.extend_from_slice(shard);
314 }
315 buf
316 };
317
318 let originals: Bytes = padded.into();
320 let recoveries: Bytes = recovery_buf.into();
321
322 let mut builder = Builder::<H>::new(n);
324 let shard_slices: Vec<Bytes> = (0..k)
325 .map(|i| originals.slice(i * shard_len..(i + 1) * shard_len))
326 .chain((0..m).map(|i| recoveries.slice(i * shard_len..(i + 1) * shard_len)))
327 .collect();
328 let shard_hashes = strategy.map_init_collect_vec(&shard_slices, H::new, |hasher, shard| {
329 hasher.update(shard);
330 hasher.finalize()
331 });
332 for hash in &shard_hashes {
333 builder.add(hash);
334 }
335 let tree = builder.build();
336 let root = tree.root();
337
338 let mut chunks = Vec::with_capacity(n);
340 for (i, shard) in shard_slices.into_iter().enumerate() {
341 let proof = tree.proof(i as u32).map_err(|_| Error::InvalidProof)?;
342 chunks.push(Chunk::new(shard, i as u16, proof));
343 }
344
345 Ok((root, chunks))
346}
347
348fn decode<H: Hasher, S: Strategy>(
363 total: u16,
364 min: u16,
365 root: &H::Digest,
366 chunks: &[CheckedChunk<H::Digest>],
367 strategy: &S,
368) -> Result<Vec<u8>, Error> {
369 assert!(total > min);
371 assert!(min > 0);
372 let n = total as usize;
373 let k = min as usize;
374 let m = n - k;
375 if chunks.len() < k {
376 return Err(Error::NotEnoughChunks);
377 }
378
379 let shard_len = chunks[0].shard.len();
381 let mut shard_digests: Vec<Option<H::Digest>> = vec![None; n];
382 let mut provided_originals: Vec<(usize, &[u8])> = Vec::new();
383 let mut provided_recoveries: Vec<(usize, &[u8])> = Vec::new();
384 for chunk in chunks {
385 let index = chunk.index;
387 if index >= total {
388 return Err(Error::InvalidIndex(index));
389 }
390 let digest_slot = &mut shard_digests[index as usize];
391 if digest_slot.is_some() {
392 return Err(Error::DuplicateIndex(index));
393 }
394
395 *digest_slot = Some(chunk.digest);
397 if index < min {
398 provided_originals.push((index as usize, chunk.shard.as_ref()));
399 } else {
400 provided_recoveries.push((index as usize - k, chunk.shard.as_ref()));
401 }
402 }
403
404 let mut decoder = Cached::take(
406 &CACHED_DECODER,
407 || ReedSolomonDecoder::new(k, m, shard_len),
408 |dec| dec.reset(k, m, shard_len),
409 )
410 .map_err(Error::ReedSolomon)?;
411 for (idx, ref shard) in &provided_originals {
412 decoder
413 .add_original_shard(*idx, shard)
414 .map_err(Error::ReedSolomon)?;
415 }
416 for (idx, ref shard) in &provided_recoveries {
417 decoder
418 .add_recovery_shard(*idx, shard)
419 .map_err(Error::ReedSolomon)?;
420 }
421 let decoding = decoder.decode().map_err(Error::ReedSolomon)?;
422
423 let mut shards = vec![Default::default(); k];
425 for (idx, shard) in provided_originals
426 .into_iter()
427 .chain(decoding.restored_original_iter())
428 {
429 shards[idx] = shard;
430 }
431
432 let mut encoder = Cached::take(
434 &CACHED_ENCODER,
435 || ReedSolomonEncoder::new(k, m, shard_len),
436 |enc| enc.reset(k, m, shard_len),
437 )
438 .map_err(Error::ReedSolomon)?;
439 for shard in shards.iter().take(k) {
440 encoder
441 .add_original_shard(shard)
442 .map_err(Error::ReedSolomon)?;
443 }
444 let encoding = encoder.encode().map_err(Error::ReedSolomon)?;
445 shards.extend(encoding.recovery_iter());
446
447 for (i, digest) in strategy.map_init_collect_vec(
449 shard_digests
450 .iter()
451 .enumerate()
452 .filter_map(|(i, digest)| digest.is_none().then_some(i)),
453 H::new,
454 |hasher, i| {
455 hasher.update(shards[i]);
456 (i, hasher.finalize())
457 },
458 ) {
459 shard_digests[i] = Some(digest);
460 }
461
462 let mut builder = Builder::<H>::new(n);
463 shard_digests
464 .into_iter()
465 .map(|digest| digest.expect("digest must be present for every shard"))
466 .for_each(|digest| {
467 builder.add(&digest);
468 });
469 let tree = builder.build();
470
471 if tree.root() != *root {
473 return Err(Error::Inconsistent);
474 }
475
476 extract_data(&shards, k)
478}
479
480#[derive(Clone, Copy)]
563pub struct ReedSolomon<H> {
564 _marker: PhantomData<H>,
565}
566
567impl<H> std::fmt::Debug for ReedSolomon<H> {
568 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
569 f.debug_struct("ReedSolomon").finish()
570 }
571}
572
573impl<H: Hasher> Scheme for ReedSolomon<H> {
574 type Commitment = H::Digest;
575
576 type StrongShard = Chunk<H::Digest>;
577 type WeakShard = Chunk<H::Digest>;
578 type CheckedShard = CheckedChunk<H::Digest>;
579 type CheckingData = ();
580
581 type Error = Error;
582
583 fn encode(
584 config: &Config,
585 mut data: impl Buf,
586 strategy: &impl Strategy,
587 ) -> Result<(Self::Commitment, Vec<Self::StrongShard>), Self::Error> {
588 let data: Vec<u8> = data.copy_to_bytes(data.remaining()).to_vec();
589 encode::<H, _>(
590 total_shards(config)?,
591 config.minimum_shards.get(),
592 data,
593 strategy,
594 )
595 }
596
597 fn weaken(
598 config: &Config,
599 commitment: &Self::Commitment,
600 index: u16,
601 shard: Self::StrongShard,
602 ) -> Result<(Self::CheckingData, Self::CheckedShard, Self::WeakShard), Self::Error> {
603 let total = total_shards(config)?;
604 if index >= total {
605 return Err(Error::InvalidIndex(index));
606 }
607 if shard.proof.leaf_count != u32::from(total) {
608 return Err(Error::InvalidProof);
609 }
610 if shard.index != index {
611 return Err(Error::InvalidIndex(index));
612 }
613 let checked_shard = shard
614 .verify::<H>(shard.index, commitment)
615 .ok_or(Error::InvalidProof)?;
616 Ok(((), checked_shard, shard))
617 }
618
619 fn check(
620 config: &Config,
621 commitment: &Self::Commitment,
622 _checking_data: &Self::CheckingData,
623 index: u16,
624 weak_shard: Self::WeakShard,
625 ) -> Result<Self::CheckedShard, Self::Error> {
626 let total = total_shards(config)?;
627 if index >= total {
628 return Err(Error::InvalidIndex(index));
629 }
630 if weak_shard.proof.leaf_count != u32::from(total) {
631 return Err(Error::InvalidProof);
632 }
633 if weak_shard.index != index {
634 return Err(Error::InvalidIndex(weak_shard.index));
635 }
636 weak_shard
637 .verify::<H>(weak_shard.index, commitment)
638 .ok_or(Error::InvalidProof)
639 }
640
641 fn decode(
642 config: &Config,
643 commitment: &Self::Commitment,
644 _checking_data: Self::CheckingData,
645 shards: &[Self::CheckedShard],
646 strategy: &impl Strategy,
647 ) -> Result<Vec<u8>, Self::Error> {
648 decode::<H, _>(
649 total_shards(config)?,
650 config.minimum_shards.get(),
651 commitment,
652 shards,
653 strategy,
654 )
655 }
656}
657
658#[cfg(test)]
659mod tests {
660 use super::*;
661 use commonware_cryptography::Sha256;
662 use commonware_parallel::Sequential;
663 use commonware_utils::NZU16;
664
665 type RS = ReedSolomon<Sha256>;
666 const STRATEGY: Sequential = Sequential;
667
668 fn checked(
669 chunk: Chunk<<Sha256 as Hasher>::Digest>,
670 ) -> CheckedChunk<<Sha256 as Hasher>::Digest> {
671 let Chunk { shard, index, .. } = chunk;
672 let digest = Sha256::hash(&shard);
673 CheckedChunk::new(shard, index, digest)
674 }
675
676 #[test]
677 fn test_recovery() {
678 let data = b"Testing recovery pieces";
679 let total = 8u16;
680 let min = 3u16;
681
682 let (root, chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
684
685 let pieces: Vec<_> = vec![
687 checked(chunks[0].clone()), checked(chunks[4].clone()), checked(chunks[6].clone()), ];
691
692 let decoded = decode::<Sha256, _>(total, min, &root, &pieces, &STRATEGY).unwrap();
694 assert_eq!(decoded, data);
695 }
696
697 #[test]
698 fn test_not_enough_pieces() {
699 let data = b"Test insufficient pieces";
700 let total = 6u16;
701 let min = 4u16;
702
703 let (root, chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
705
706 let pieces: Vec<_> = chunks.into_iter().take(2).map(checked).collect();
708
709 let result = decode::<Sha256, _>(total, min, &root, &pieces, &STRATEGY);
711 assert!(matches!(result, Err(Error::NotEnoughChunks)));
712 }
713
714 #[test]
715 fn test_duplicate_index() {
716 let data = b"Test duplicate detection";
717 let total = 5u16;
718 let min = 3u16;
719
720 let (root, chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
722
723 let pieces = vec![
725 checked(chunks[0].clone()),
726 checked(chunks[0].clone()),
727 checked(chunks[1].clone()),
728 ];
729
730 let result = decode::<Sha256, _>(total, min, &root, &pieces, &STRATEGY);
732 assert!(matches!(result, Err(Error::DuplicateIndex(0))));
733 }
734
735 #[test]
736 fn test_invalid_index() {
737 let data = b"Test invalid index";
738 let total = 5u16;
739 let min = 3u16;
740
741 let (root, chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
743
744 for i in 0..total {
746 assert!(chunks[i as usize].verify::<Sha256>(i + 1, &root).is_none());
747 }
748 }
749
750 #[test]
751 #[should_panic(expected = "assertion failed: total > min")]
752 fn test_invalid_total() {
753 let data = b"Test parameter validation";
754
755 encode::<Sha256, _>(3, 3, data.to_vec(), &STRATEGY).unwrap();
757 }
758
759 #[test]
760 #[should_panic(expected = "assertion failed: min > 0")]
761 fn test_invalid_min() {
762 let data = b"Test parameter validation";
763
764 encode::<Sha256, _>(5, 0, data.to_vec(), &STRATEGY).unwrap();
766 }
767
768 #[test]
769 fn test_empty_data() {
770 let data = b"";
771 let total = 100u16;
772 let min = 30u16;
773
774 let (root, chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
776
777 let minimal = chunks
779 .into_iter()
780 .take(min as usize)
781 .map(checked)
782 .collect::<Vec<_>>();
783 let decoded = decode::<Sha256, _>(total, min, &root, &minimal, &STRATEGY).unwrap();
784 assert_eq!(decoded, data);
785 }
786
787 #[test]
788 fn test_large_data() {
789 let data = vec![42u8; 1000]; let total = 7u16;
791 let min = 4u16;
792
793 let (root, chunks) = encode::<Sha256, _>(total, min, data.clone(), &STRATEGY).unwrap();
795
796 let minimal = chunks
798 .into_iter()
799 .take(min as usize)
800 .map(checked)
801 .collect::<Vec<_>>();
802 let decoded = decode::<Sha256, _>(total, min, &root, &minimal, &STRATEGY).unwrap();
803 assert_eq!(decoded, data);
804 }
805
806 #[test]
807 fn test_malicious_root_detection() {
808 let data = b"Original data that should be protected";
809 let total = 7u16;
810 let min = 4u16;
811
812 let (_correct_root, chunks) =
814 encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
815
816 let mut hasher = Sha256::new();
818 hasher.update(b"malicious_data_that_wasnt_actually_encoded");
819 let malicious_root = hasher.finalize();
820
821 for i in 0..total {
823 assert!(chunks[i as usize]
824 .clone()
825 .verify::<Sha256>(i, &malicious_root)
826 .is_none());
827 }
828
829 let minimal = chunks
831 .into_iter()
832 .take(min as usize)
833 .map(checked)
834 .collect::<Vec<_>>();
835
836 let result = decode::<Sha256, _>(total, min, &malicious_root, &minimal, &STRATEGY);
838 assert!(matches!(result, Err(Error::Inconsistent)));
839 }
840
841 #[test]
842 fn test_mismatched_config_rejected_during_weaken_and_check() {
843 let config_expected = Config {
844 minimum_shards: NZU16!(2),
845 extra_shards: NZU16!(2),
846 };
847 let config_actual = Config {
848 minimum_shards: NZU16!(3),
849 extra_shards: NZU16!(3),
850 };
851
852 let data = b"leaf_count mismatch proof";
853 let (commitment, shards) = RS::encode(&config_actual, data.as_slice(), &STRATEGY).unwrap();
854
855 let strong = shards[0].clone();
858 let weaken_result = RS::weaken(&config_expected, &commitment, 0, strong);
859 assert!(matches!(weaken_result, Err(Error::InvalidProof)));
860
861 let (checking_data, _, weak) =
864 RS::weaken(&config_actual, &commitment, 1, shards[1].clone()).unwrap();
865 let check_result = RS::check(&config_expected, &commitment, &checking_data, 1, weak);
866 assert!(matches!(check_result, Err(Error::InvalidProof)));
867 }
868
869 #[test]
870 fn test_manipulated_chunk_detection() {
871 let data = b"Data integrity must be maintained";
872 let total = 6u16;
873 let min = 3u16;
874
875 let (root, chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
877 let mut pieces: Vec<_> = chunks.into_iter().map(checked).collect();
878
879 if !pieces[1].shard.is_empty() {
881 let mut shard = pieces[1].shard.to_vec();
882 shard[0] ^= 0xFF; pieces[1].shard = shard.into();
884 pieces[1].digest = Sha256::hash(&pieces[1].shard);
885 }
886
887 let result = decode::<Sha256, _>(total, min, &root, &pieces, &STRATEGY);
889 assert!(matches!(result, Err(Error::Inconsistent)));
890 }
891
892 #[test]
893 fn test_inconsistent_shards() {
894 let data = b"Test data for malicious encoding";
895 let total = 5u16;
896 let min = 3u16;
897 let m = total - min;
898
899 let (padded, shard_size) = prepare_data(data, min as usize);
901
902 let mut encoder = ReedSolomonEncoder::new(min as usize, m as usize, shard_size).unwrap();
904 for shard in padded.chunks(shard_size) {
905 encoder.add_original_shard(shard).unwrap();
906 }
907 let recovery_result = encoder.encode().unwrap();
908 let mut recovery_shards: Vec<Vec<u8>> = recovery_result
909 .recovery_iter()
910 .map(|s| s.to_vec())
911 .collect();
912
913 if !recovery_shards[0].is_empty() {
915 recovery_shards[0][0] ^= 0xFF;
916 }
917
918 let mut malicious_shards: Vec<Vec<u8>> =
920 padded.chunks(shard_size).map(|s| s.to_vec()).collect();
921 malicious_shards.extend(recovery_shards);
922
923 let mut builder = Builder::<Sha256>::new(total as usize);
925 for shard in &malicious_shards {
926 let mut hasher = Sha256::new();
927 hasher.update(shard);
928 builder.add(&hasher.finalize());
929 }
930 let malicious_tree = builder.build();
931 let malicious_root = malicious_tree.root();
932
933 let selected_indices = vec![0, 1, 3]; let mut pieces = Vec::new();
936 for &i in &selected_indices {
937 let merkle_proof = malicious_tree.proof(i as u32).unwrap();
938 let shard = malicious_shards[i].clone();
939 let chunk = Chunk::new(shard.into(), i as u16, merkle_proof);
940 pieces.push(chunk);
941 }
942 let pieces: Vec<_> = pieces.into_iter().map(checked).collect();
943
944 let result = decode::<Sha256, _>(total, min, &malicious_root, &pieces, &STRATEGY);
946 assert!(matches!(result, Err(Error::Inconsistent)));
947 }
948
949 #[test]
953 fn test_non_canonical_padding_rejected() {
954 let data = b"X";
955 let total = 6u16;
956 let min = 3u16;
957 let k = min as usize;
958 let m = total as usize - k;
959
960 let (mut padded, shard_len) = prepare_data(data, k);
961 let payload_end = u32::SIZE + data.len();
962 let total_original_len = k * shard_len;
963 assert!(payload_end < total_original_len, "test requires padding");
964
965 let pad_shard = payload_end / shard_len;
967 let pad_offset = payload_end % shard_len;
968 padded[pad_shard * shard_len + pad_offset] = 0xAA;
969
970 let mut encoder = ReedSolomonEncoder::new(k, m, shard_len).unwrap();
971 for shard in padded.chunks(shard_len) {
972 encoder.add_original_shard(shard).unwrap();
973 }
974 let recovery = encoder.encode().unwrap();
975 let mut shards: Vec<Vec<u8>> = padded.chunks(shard_len).map(|s| s.to_vec()).collect();
976 shards.extend(recovery.recovery_iter().map(|s| s.to_vec()));
977
978 let mut builder = Builder::<Sha256>::new(total as usize);
979 for shard in &shards {
980 let mut hasher = Sha256::new();
981 hasher.update(shard);
982 builder.add(&hasher.finalize());
983 }
984 let tree = builder.build();
985 let non_canonical_root = tree.root();
986
987 let mut pieces = Vec::with_capacity(k);
988 for (i, shard) in shards.iter().take(k).enumerate() {
989 let proof = tree.proof(i as u32).unwrap();
990 pieces.push(checked(Chunk::new(shard.clone().into(), i as u16, proof)));
991 }
992
993 let result = decode::<Sha256, _>(total, min, &non_canonical_root, &pieces, &STRATEGY);
994 assert!(matches!(result, Err(Error::Inconsistent)));
995 }
996
997 #[test]
998 fn test_decode_invalid_index() {
999 let data = b"Testing recovery pieces";
1000 let total = 8u16;
1001 let min = 3u16;
1002
1003 let (root, chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
1005
1006 let mut invalid = checked(chunks[1].clone());
1008 invalid.index = 8;
1009 let pieces: Vec<_> = vec![
1010 checked(chunks[0].clone()), invalid, checked(chunks[6].clone()), ];
1014
1015 let result = decode::<Sha256, _>(total, min, &root, &pieces, &STRATEGY);
1017 assert!(matches!(result, Err(Error::InvalidIndex(8))));
1018 }
1019
1020 #[test]
1021 fn test_max_chunks() {
1022 let data = vec![42u8; 1000]; let total = u16::MAX;
1024 let min = u16::MAX / 2;
1025
1026 let (root, chunks) = encode::<Sha256, _>(total, min, data.clone(), &STRATEGY).unwrap();
1028
1029 let minimal = chunks
1031 .into_iter()
1032 .take(min as usize)
1033 .map(checked)
1034 .collect::<Vec<_>>();
1035 let decoded = decode::<Sha256, _>(total, min, &root, &minimal, &STRATEGY).unwrap();
1036 assert_eq!(decoded, data);
1037 }
1038
1039 #[test]
1040 fn test_too_many_chunks() {
1041 let data = vec![42u8; 1000]; let total = u16::MAX;
1043 let min = u16::MAX / 2 - 1;
1044
1045 let result = encode::<Sha256, _>(total, min, data, &STRATEGY);
1047 assert!(matches!(
1048 result,
1049 Err(Error::ReedSolomon(
1050 reed_solomon_simd::Error::UnsupportedShardCount {
1051 original_count: _,
1052 recovery_count: _,
1053 }
1054 ))
1055 ));
1056 }
1057
1058 #[test]
1059 fn test_too_many_total_shards() {
1060 assert!(RS::encode(
1061 &Config {
1062 minimum_shards: NZU16!(u16::MAX / 2 + 1),
1063 extra_shards: NZU16!(u16::MAX),
1064 },
1065 [].as_slice(),
1066 &STRATEGY,
1067 )
1068 .is_err())
1069 }
1070
1071 #[cfg(feature = "arbitrary")]
1072 mod conformance {
1073 use super::*;
1074 use commonware_codec::conformance::CodecConformance;
1075 use commonware_cryptography::sha256::Digest as Sha256Digest;
1076
1077 commonware_conformance::conformance_tests! {
1078 CodecConformance<Chunk<Sha256Digest>>,
1079 }
1080 }
1081}