1use crate::{Config, Scheme};
2use bytes::{Buf, BufMut, Bytes};
3use commonware_codec::{BufsMut, EncodeSize, FixedSize, RangeCfg, Read, ReadExt, Write};
4use commonware_cryptography::{Digest, Hasher};
5use commonware_parallel::Strategy;
6use commonware_storage::bmt::{self, Builder};
7use commonware_utils::Cached;
8use reed_solomon_simd::{Error as RsError, ReedSolomonDecoder, ReedSolomonEncoder};
9use std::marker::PhantomData;
10use thiserror::Error;
11
12commonware_utils::thread_local_cache!(static CACHED_ENCODER: ReedSolomonEncoder);
17commonware_utils::thread_local_cache!(static CACHED_DECODER: ReedSolomonDecoder);
18
19#[derive(Error, Debug)]
21pub enum Error {
22 #[error("reed-solomon error: {0}")]
23 ReedSolomon(#[from] RsError),
24 #[error("inconsistent")]
25 Inconsistent,
26 #[error("invalid proof")]
27 InvalidProof,
28 #[error("not enough chunks")]
29 NotEnoughChunks,
30 #[error("duplicate chunk index: {0}")]
31 DuplicateIndex(u16),
32 #[error("invalid data length: {0}")]
33 InvalidDataLength(usize),
34 #[error("invalid index: {0}")]
35 InvalidIndex(u16),
36 #[error("too many total shards: {0}")]
37 TooManyTotalShards(u32),
38 #[error("checked shard commitment does not match decode commitment")]
39 CommitmentMismatch,
40}
41
42fn total_shards(config: &Config) -> Result<u16, Error> {
43 let total = config.total_shards();
44 total
45 .try_into()
46 .map_err(|_| Error::TooManyTotalShards(total))
47}
48
49#[derive(Debug, Clone)]
51pub struct Chunk<D: Digest> {
52 shard: Bytes,
54
55 index: u16,
57
58 proof: bmt::Proof<D>,
60}
61
62impl<D: Digest> Chunk<D> {
63 const fn new(shard: Bytes, index: u16, proof: bmt::Proof<D>) -> Self {
65 Self {
66 shard,
67 index,
68 proof,
69 }
70 }
71
72 fn verify<H: Hasher<Digest = D>>(&self, index: u16, root: &D) -> Option<CheckedChunk<D>> {
74 if index != self.index {
76 return None;
77 }
78
79 let mut hasher = H::new();
81 hasher.update(&self.shard);
82 let shard_digest = hasher.finalize();
83
84 self.proof
86 .verify_element_inclusion(&mut hasher, &shard_digest, self.index as u32, root)
87 .ok()?;
88
89 Some(CheckedChunk::new(
90 *root,
91 self.shard.clone(),
92 self.index,
93 shard_digest,
94 ))
95 }
96}
97
98#[derive(Clone, Debug, PartialEq, Eq)]
104pub struct CheckedChunk<D: Digest> {
105 root: D,
106 shard: Bytes,
107 index: u16,
108 digest: D,
109}
110
111impl<D: Digest> CheckedChunk<D> {
112 const fn new(root: D, shard: Bytes, index: u16, digest: D) -> Self {
113 Self {
114 root,
115 shard,
116 index,
117 digest,
118 }
119 }
120}
121
122impl<D: Digest> Write for Chunk<D> {
123 fn write(&self, writer: &mut impl BufMut) {
124 self.shard.write(writer);
125 self.index.write(writer);
126 self.proof.write(writer);
127 }
128
129 fn write_bufs(&self, buf: &mut impl BufsMut) {
130 self.shard.write_bufs(buf);
131 self.index.write(buf);
132 self.proof.write(buf);
133 }
134}
135
136impl<D: Digest> Read for Chunk<D> {
137 type Cfg = crate::CodecConfig;
139
140 fn read_cfg(reader: &mut impl Buf, cfg: &Self::Cfg) -> Result<Self, commonware_codec::Error> {
141 let shard = Bytes::read_cfg(reader, &RangeCfg::new(..=cfg.maximum_shard_size))?;
142 let index = u16::read(reader)?;
143 let proof = bmt::Proof::<D>::read_cfg(reader, &1)?;
144 Ok(Self {
145 shard,
146 index,
147 proof,
148 })
149 }
150}
151
152impl<D: Digest> EncodeSize for Chunk<D> {
153 fn encode_size(&self) -> usize {
154 self.shard.encode_size() + self.index.encode_size() + self.proof.encode_size()
155 }
156
157 fn encode_inline_size(&self) -> usize {
158 self.shard.encode_inline_size() + self.index.encode_size() + self.proof.encode_size()
159 }
160}
161
162impl<D: Digest> PartialEq for Chunk<D> {
163 fn eq(&self, other: &Self) -> bool {
164 self.shard == other.shard && self.index == other.index && self.proof == other.proof
165 }
166}
167
168impl<D: Digest> Eq for Chunk<D> {}
169
170#[cfg(feature = "arbitrary")]
171impl<D: Digest> arbitrary::Arbitrary<'_> for Chunk<D>
172where
173 D: for<'a> arbitrary::Arbitrary<'a>,
174{
175 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
176 Ok(Self {
177 shard: u.arbitrary::<Vec<u8>>()?.into(),
178 index: u.arbitrary()?,
179 proof: u.arbitrary()?,
180 })
181 }
182}
183
184fn prepare_data(mut data: impl Buf, k: usize) -> (Vec<u8>, usize) {
190 let data_len = data.remaining();
192 let prefixed_len = u32::SIZE + data_len;
193 let mut shard_len = prefixed_len.div_ceil(k);
194
195 if !shard_len.is_multiple_of(2) {
197 shard_len += 1;
198 }
199
200 let length_bytes = (data_len as u32).to_be_bytes();
202 let mut padded = vec![0u8; k * shard_len];
203 padded[..u32::SIZE].copy_from_slice(&length_bytes);
204 data.copy_to_slice(&mut padded[u32::SIZE..u32::SIZE + data_len]);
205
206 (padded, shard_len)
207}
208
209fn extract_data(shards: &[&[u8]], k: usize) -> Result<Vec<u8>, Error> {
215 let shards = shards.get(..k).ok_or(Error::NotEnoughChunks)?;
216 let data_len = read_data_len(shards)?;
217 let mut data = Vec::with_capacity(data_len);
218 let mut prefix_bytes_left = u32::SIZE;
219 let mut data_bytes_left = data_len;
220 for shard in shards {
221 if prefix_bytes_left >= shard.len() {
224 prefix_bytes_left -= shard.len();
225 continue;
226 }
227
228 let payload = &shard[prefix_bytes_left..];
230 let copy_len = data_bytes_left.min(payload.len());
231 data.extend_from_slice(&payload[..copy_len]);
232 data_bytes_left -= copy_len;
233
234 if !payload[copy_len..].iter().all(|byte| *byte == 0) {
236 return Err(Error::Inconsistent);
237 }
238 prefix_bytes_left = 0;
239 }
240
241 if data_bytes_left != 0 {
244 return Err(Error::Inconsistent);
245 }
246
247 Ok(data)
248}
249
250fn read_data_len(shards: &[&[u8]]) -> Result<usize, Error> {
253 let total_len: usize = shards.iter().map(|s| s.len()).sum();
254 if total_len < u32::SIZE {
255 return Err(Error::Inconsistent);
256 }
257
258 let mut prefix = [0u8; u32::SIZE];
260 let mut prefix_len = 0usize;
261 for shard in shards {
262 if prefix_len == u32::SIZE {
263 break;
264 }
265 let read = (u32::SIZE - prefix_len).min(shard.len());
266 prefix[prefix_len..prefix_len + read].copy_from_slice(&shard[..read]);
267 prefix_len += read;
268 }
269
270 let data_len = u32::from_be_bytes(prefix) as usize;
271 let payload_len = total_len - u32::SIZE;
272 if data_len > payload_len {
273 return Err(Error::Inconsistent);
274 }
275 Ok(data_len)
276}
277
278type Encoding<D> = (D, Vec<Chunk<D>>);
280
281fn encode<H: Hasher, S: Strategy>(
295 total: u16,
296 min: u16,
297 data: impl Buf,
298 strategy: &S,
299) -> Result<Encoding<H::Digest>, Error> {
300 assert!(total > min);
302 assert!(min > 0);
303 let n = total as usize;
304 let k = min as usize;
305 let m = n - k;
306 let data_len = data.remaining();
307 if data_len > u32::MAX as usize {
308 return Err(Error::InvalidDataLength(data_len));
309 }
310
311 let (padded, shard_len) = prepare_data(data, k);
313
314 let recovery_buf = {
316 let mut encoder = Cached::take(
317 &CACHED_ENCODER,
318 || ReedSolomonEncoder::new(k, m, shard_len),
319 |enc| enc.reset(k, m, shard_len),
320 )
321 .map_err(Error::ReedSolomon)?;
322 for shard in padded.chunks(shard_len) {
323 encoder
324 .add_original_shard(shard)
325 .map_err(Error::ReedSolomon)?;
326 }
327
328 let encoding = encoder.encode().map_err(Error::ReedSolomon)?;
330 let mut buf = Vec::with_capacity(m * shard_len);
331 for shard in encoding.recovery_iter() {
332 buf.extend_from_slice(shard);
333 }
334 buf
335 };
336
337 let originals: Bytes = padded.into();
339 let recoveries: Bytes = recovery_buf.into();
340
341 let mut builder = Builder::<H>::new(n);
343 let shard_slices: Vec<Bytes> = (0..k)
344 .map(|i| originals.slice(i * shard_len..(i + 1) * shard_len))
345 .chain((0..m).map(|i| recoveries.slice(i * shard_len..(i + 1) * shard_len)))
346 .collect();
347 let shard_hashes = strategy.map_init_collect_vec(&shard_slices, H::new, |hasher, shard| {
348 hasher.update(shard);
349 hasher.finalize()
350 });
351 for hash in &shard_hashes {
352 builder.add(hash);
353 }
354 let tree = builder.build();
355 let root = tree.root();
356
357 let mut chunks = Vec::with_capacity(n);
359 for (i, shard) in shard_slices.into_iter().enumerate() {
360 let proof = tree.proof(i as u32).map_err(|_| Error::InvalidProof)?;
361 chunks.push(Chunk::new(shard, i as u16, proof));
362 }
363
364 Ok((root, chunks))
365}
366
367fn decode<'a, H: Hasher, S: Strategy>(
382 total: u16,
383 min: u16,
384 root: &H::Digest,
385 chunks: impl Iterator<Item = &'a CheckedChunk<H::Digest>>,
386 strategy: &S,
387) -> Result<Vec<u8>, Error> {
388 assert!(total > min);
390 assert!(min > 0);
391 let n = total as usize;
392 let k = min as usize;
393 let m = n - k;
394 let mut chunks = chunks.peekable();
395 let Some(first) = chunks.peek() else {
396 return Err(Error::NotEnoughChunks);
397 };
398
399 let shard_len = first.shard.len();
401 let mut shard_digests: Vec<Option<H::Digest>> = vec![None; n];
402 let mut provided_originals: Vec<(usize, &[u8])> = Vec::new();
403 let mut provided_recoveries: Vec<(usize, &[u8])> = Vec::new();
404 let mut provided = 0usize;
405 for chunk in chunks {
406 provided += 1;
407 if &chunk.root != root {
408 return Err(Error::CommitmentMismatch);
409 }
410 let index = chunk.index;
412 if index >= total {
413 return Err(Error::InvalidIndex(index));
414 }
415 let digest_slot = &mut shard_digests[index as usize];
416 if digest_slot.is_some() {
417 return Err(Error::DuplicateIndex(index));
418 }
419
420 *digest_slot = Some(chunk.digest);
422 if index < min {
423 provided_originals.push((index as usize, chunk.shard.as_ref()));
424 } else {
425 provided_recoveries.push((index as usize - k, chunk.shard.as_ref()));
426 }
427 }
428 if provided < k {
429 return Err(Error::NotEnoughChunks);
430 }
431
432 let mut decoder = Cached::take(
434 &CACHED_DECODER,
435 || ReedSolomonDecoder::new(k, m, shard_len),
436 |dec| dec.reset(k, m, shard_len),
437 )
438 .map_err(Error::ReedSolomon)?;
439 for (idx, shard) in &provided_originals {
440 decoder
441 .add_original_shard(*idx, shard)
442 .map_err(Error::ReedSolomon)?;
443 }
444 for (idx, shard) in &provided_recoveries {
445 decoder
446 .add_recovery_shard(*idx, shard)
447 .map_err(Error::ReedSolomon)?;
448 }
449 let decoding = decoder.decode().map_err(Error::ReedSolomon)?;
450
451 let mut shards = vec![Default::default(); k];
453 for (idx, shard) in provided_originals
454 .into_iter()
455 .chain(decoding.restored_original_iter())
456 {
457 shards[idx] = shard;
458 }
459
460 let mut encoder = Cached::take(
462 &CACHED_ENCODER,
463 || ReedSolomonEncoder::new(k, m, shard_len),
464 |enc| enc.reset(k, m, shard_len),
465 )
466 .map_err(Error::ReedSolomon)?;
467 for shard in shards.iter().take(k) {
468 encoder
469 .add_original_shard(shard)
470 .map_err(Error::ReedSolomon)?;
471 }
472 let encoding = encoder.encode().map_err(Error::ReedSolomon)?;
473 shards.extend(encoding.recovery_iter());
474
475 for (i, digest) in strategy.map_init_collect_vec(
477 shard_digests
478 .iter()
479 .enumerate()
480 .filter_map(|(i, digest)| digest.is_none().then_some(i)),
481 H::new,
482 |hasher, i| {
483 hasher.update(shards[i]);
484 (i, hasher.finalize())
485 },
486 ) {
487 shard_digests[i] = Some(digest);
488 }
489
490 let mut builder = Builder::<H>::new(n);
491 shard_digests
492 .into_iter()
493 .map(|digest| digest.expect("digest must be present for every shard"))
494 .for_each(|digest| {
495 builder.add(&digest);
496 });
497 let tree = builder.build();
498
499 if tree.root() != *root {
501 return Err(Error::Inconsistent);
502 }
503
504 extract_data(&shards, k)
506}
507
508#[derive(Clone, Copy)]
591pub struct ReedSolomon<H> {
592 _marker: PhantomData<H>,
593}
594
595impl<H> std::fmt::Debug for ReedSolomon<H> {
596 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
597 f.debug_struct("ReedSolomon").finish()
598 }
599}
600
601impl<H: Hasher> Scheme for ReedSolomon<H> {
602 type Commitment = H::Digest;
603 type Shard = Chunk<H::Digest>;
604 type CheckedShard = CheckedChunk<H::Digest>;
605 type Error = Error;
606
607 fn encode(
608 config: &Config,
609 data: impl Buf,
610 strategy: &impl Strategy,
611 ) -> Result<(Self::Commitment, Vec<Self::Shard>), Self::Error> {
612 encode::<H, _>(
613 total_shards(config)?,
614 config.minimum_shards.get(),
615 data,
616 strategy,
617 )
618 }
619
620 fn check(
621 config: &Config,
622 commitment: &Self::Commitment,
623 index: u16,
624 shard: &Self::Shard,
625 ) -> Result<Self::CheckedShard, Self::Error> {
626 let total = total_shards(config)?;
627 if index >= total {
628 return Err(Error::InvalidIndex(index));
629 }
630 if shard.proof.leaf_count != u32::from(total) {
631 return Err(Error::InvalidProof);
632 }
633 if shard.index != index {
634 return Err(Error::InvalidIndex(shard.index));
635 }
636 shard
637 .verify::<H>(shard.index, commitment)
638 .ok_or(Error::InvalidProof)
639 }
640
641 fn decode<'a>(
642 config: &Config,
643 commitment: &Self::Commitment,
644 shards: impl Iterator<Item = &'a Self::CheckedShard>,
645 strategy: &impl Strategy,
646 ) -> Result<Vec<u8>, Self::Error> {
647 decode::<H, _>(
648 total_shards(config)?,
649 config.minimum_shards.get(),
650 commitment,
651 shards,
652 strategy,
653 )
654 }
655}
656
657#[cfg(test)]
658mod tests {
659 use super::*;
660 use commonware_codec::Encode;
661 use commonware_cryptography::Sha256;
662 use commonware_parallel::Sequential;
663 use commonware_runtime::{deterministic, iobuf::EncodeExt, BufferPooler, Runner};
664 use commonware_utils::NZU16;
665
666 type RS = ReedSolomon<Sha256>;
667 const STRATEGY: Sequential = Sequential;
668
669 fn checked(
670 root: <Sha256 as Hasher>::Digest,
671 chunk: Chunk<<Sha256 as Hasher>::Digest>,
672 ) -> CheckedChunk<<Sha256 as Hasher>::Digest> {
673 let Chunk { shard, index, .. } = chunk;
674 let digest = Sha256::hash(&shard);
675 CheckedChunk::new(root, shard, index, digest)
676 }
677
678 #[test]
679 fn test_recovery() {
680 let data = b"Testing recovery pieces";
681 let total = 8u16;
682 let min = 3u16;
683
684 let (root, chunks) = encode::<Sha256, _>(total, min, data.as_slice(), &STRATEGY).unwrap();
686
687 let pieces: Vec<_> = vec![
689 checked(root, chunks[0].clone()), checked(root, chunks[4].clone()), checked(root, chunks[6].clone()), ];
693
694 let decoded = decode::<Sha256, _>(total, min, &root, pieces.iter(), &STRATEGY).unwrap();
696 assert_eq!(decoded, data);
697 }
698
699 #[test]
700 fn test_not_enough_pieces() {
701 let data = b"Test insufficient pieces";
702 let total = 6u16;
703 let min = 4u16;
704
705 let (root, chunks) = encode::<Sha256, _>(total, min, data.as_slice(), &STRATEGY).unwrap();
707
708 let pieces: Vec<_> = chunks
710 .into_iter()
711 .take(2)
712 .map(|c| checked(root, c))
713 .collect();
714
715 let result = decode::<Sha256, _>(total, min, &root, pieces.iter(), &STRATEGY);
717 assert!(matches!(result, Err(Error::NotEnoughChunks)));
718 }
719
720 #[test]
721 fn test_duplicate_index() {
722 let data = b"Test duplicate detection";
723 let total = 5u16;
724 let min = 3u16;
725
726 let (root, chunks) = encode::<Sha256, _>(total, min, data.as_slice(), &STRATEGY).unwrap();
728
729 let pieces = [
731 checked(root, chunks[0].clone()),
732 checked(root, chunks[0].clone()),
733 checked(root, chunks[1].clone()),
734 ];
735
736 let result = decode::<Sha256, _>(total, min, &root, pieces.iter(), &STRATEGY);
738 assert!(matches!(result, Err(Error::DuplicateIndex(0))));
739 }
740
741 #[test]
742 fn test_invalid_index() {
743 let data = b"Test invalid index";
744 let total = 5u16;
745 let min = 3u16;
746
747 let (root, chunks) = encode::<Sha256, _>(total, min, data.as_slice(), &STRATEGY).unwrap();
749
750 for i in 0..total {
752 assert!(chunks[i as usize].verify::<Sha256>(i + 1, &root).is_none());
753 }
754 }
755
756 #[test]
757 #[should_panic(expected = "assertion failed: total > min")]
758 fn test_invalid_total() {
759 let data = b"Test parameter validation";
760
761 encode::<Sha256, _>(3, 3, data.as_slice(), &STRATEGY).unwrap();
763 }
764
765 #[test]
766 #[should_panic(expected = "assertion failed: min > 0")]
767 fn test_invalid_min() {
768 let data = b"Test parameter validation";
769
770 encode::<Sha256, _>(5, 0, data.as_slice(), &STRATEGY).unwrap();
772 }
773
774 #[test]
775 fn test_empty_data() {
776 let data = b"";
777 let total = 100u16;
778 let min = 30u16;
779
780 let (root, chunks) = encode::<Sha256, _>(total, min, data.as_slice(), &STRATEGY).unwrap();
782
783 let minimal = chunks
785 .into_iter()
786 .take(min as usize)
787 .map(|c| checked(root, c))
788 .collect::<Vec<_>>();
789 let decoded = decode::<Sha256, _>(total, min, &root, minimal.iter(), &STRATEGY).unwrap();
790 assert_eq!(decoded, data);
791 }
792
793 #[test]
794 fn test_large_data() {
795 let data = vec![42u8; 1000]; let total = 7u16;
797 let min = 4u16;
798
799 let (root, chunks) = encode::<Sha256, _>(total, min, data.as_slice(), &STRATEGY).unwrap();
801
802 let minimal = chunks
804 .into_iter()
805 .take(min as usize)
806 .map(|c| checked(root, c))
807 .collect::<Vec<_>>();
808 let decoded = decode::<Sha256, _>(total, min, &root, minimal.iter(), &STRATEGY).unwrap();
809 assert_eq!(decoded, data);
810 }
811
812 #[test]
813 fn test_malicious_root_detection() {
814 let data = b"Original data that should be protected";
815 let total = 7u16;
816 let min = 4u16;
817
818 let (_correct_root, chunks) =
820 encode::<Sha256, _>(total, min, data.as_slice(), &STRATEGY).unwrap();
821
822 let mut hasher = Sha256::new();
824 hasher.update(b"malicious_data_that_wasnt_actually_encoded");
825 let malicious_root = hasher.finalize();
826
827 for i in 0..total {
829 assert!(chunks[i as usize]
830 .clone()
831 .verify::<Sha256>(i, &malicious_root)
832 .is_none());
833 }
834
835 let minimal = chunks
838 .into_iter()
839 .take(min as usize)
840 .map(|c| checked(_correct_root, c))
841 .collect::<Vec<_>>();
842
843 let result = decode::<Sha256, _>(total, min, &malicious_root, minimal.iter(), &STRATEGY);
846 assert!(matches!(result, Err(Error::CommitmentMismatch)));
847 }
848
849 #[test]
850 fn test_mismatched_config_rejected_during_check() {
851 let config_expected = Config {
852 minimum_shards: NZU16!(2),
853 extra_shards: NZU16!(2),
854 };
855 let config_actual = Config {
856 minimum_shards: NZU16!(3),
857 extra_shards: NZU16!(3),
858 };
859
860 let data = b"leaf_count mismatch proof";
861 let (commitment, shards) = RS::encode(&config_actual, data.as_slice(), &STRATEGY).unwrap();
862
863 let check_result = RS::check(&config_expected, &commitment, 0, &shards[0]);
866 assert!(matches!(check_result, Err(Error::InvalidProof)));
867 }
868
869 #[test]
870 fn test_manipulated_chunk_detection() {
871 let data = b"Data integrity must be maintained";
872 let total = 6u16;
873 let min = 3u16;
874
875 let (root, chunks) = encode::<Sha256, _>(total, min, data.as_slice(), &STRATEGY).unwrap();
877 let mut pieces: Vec<_> = chunks.into_iter().map(|c| checked(root, c)).collect();
878
879 if !pieces[1].shard.is_empty() {
881 let mut shard = pieces[1].shard.to_vec();
882 shard[0] ^= 0xFF; pieces[1].shard = shard.into();
884 pieces[1].digest = Sha256::hash(&pieces[1].shard);
885 }
886
887 let result = decode::<Sha256, _>(total, min, &root, pieces.iter(), &STRATEGY);
889 assert!(matches!(result, Err(Error::Inconsistent)));
890 }
891
892 #[test]
893 fn test_inconsistent_shards() {
894 let data = b"Test data for malicious encoding";
895 let total = 5u16;
896 let min = 3u16;
897 let m = total - min;
898
899 let (padded, shard_size) = prepare_data(data.as_slice(), min as usize);
901
902 let mut encoder = ReedSolomonEncoder::new(min as usize, m as usize, shard_size).unwrap();
904 for shard in padded.chunks(shard_size) {
905 encoder.add_original_shard(shard).unwrap();
906 }
907 let recovery_result = encoder.encode().unwrap();
908 let mut recovery_shards: Vec<Vec<u8>> = recovery_result
909 .recovery_iter()
910 .map(|s| s.to_vec())
911 .collect();
912
913 if !recovery_shards[0].is_empty() {
915 recovery_shards[0][0] ^= 0xFF;
916 }
917
918 let mut malicious_shards: Vec<Vec<u8>> =
920 padded.chunks(shard_size).map(|s| s.to_vec()).collect();
921 malicious_shards.extend(recovery_shards);
922
923 let mut builder = Builder::<Sha256>::new(total as usize);
925 for shard in &malicious_shards {
926 let mut hasher = Sha256::new();
927 hasher.update(shard);
928 builder.add(&hasher.finalize());
929 }
930 let malicious_tree = builder.build();
931 let malicious_root = malicious_tree.root();
932
933 let selected_indices = vec![0, 1, 3]; let mut pieces = Vec::new();
936 for &i in &selected_indices {
937 let merkle_proof = malicious_tree.proof(i as u32).unwrap();
938 let shard = malicious_shards[i].clone();
939 let chunk = Chunk::new(shard.into(), i as u16, merkle_proof);
940 pieces.push(chunk);
941 }
942 let pieces: Vec<_> = pieces
943 .into_iter()
944 .map(|c| checked(malicious_root, c))
945 .collect();
946
947 let result = decode::<Sha256, _>(total, min, &malicious_root, pieces.iter(), &STRATEGY);
949 assert!(matches!(result, Err(Error::Inconsistent)));
950 }
951
952 #[test]
956 fn test_non_canonical_padding_rejected() {
957 let data = b"X";
958 let total = 6u16;
959 let min = 3u16;
960 let k = min as usize;
961 let m = total as usize - k;
962
963 let (mut padded, shard_len) = prepare_data(data.as_slice(), k);
964 let payload_end = u32::SIZE + data.len();
965 let total_original_len = k * shard_len;
966 assert!(payload_end < total_original_len, "test requires padding");
967
968 let pad_shard = payload_end / shard_len;
970 let pad_offset = payload_end % shard_len;
971 padded[pad_shard * shard_len + pad_offset] = 0xAA;
972
973 let mut encoder = ReedSolomonEncoder::new(k, m, shard_len).unwrap();
974 for shard in padded.chunks(shard_len) {
975 encoder.add_original_shard(shard).unwrap();
976 }
977 let recovery = encoder.encode().unwrap();
978 let mut shards: Vec<Vec<u8>> = padded.chunks(shard_len).map(|s| s.to_vec()).collect();
979 shards.extend(recovery.recovery_iter().map(|s| s.to_vec()));
980
981 let mut builder = Builder::<Sha256>::new(total as usize);
982 for shard in &shards {
983 let mut hasher = Sha256::new();
984 hasher.update(shard);
985 builder.add(&hasher.finalize());
986 }
987 let tree = builder.build();
988 let non_canonical_root = tree.root();
989
990 let mut pieces = Vec::with_capacity(k);
991 for (i, shard) in shards.iter().take(k).enumerate() {
992 let proof = tree.proof(i as u32).unwrap();
993 pieces.push(checked(
994 non_canonical_root,
995 Chunk::new(shard.clone().into(), i as u16, proof),
996 ));
997 }
998
999 let result = decode::<Sha256, _>(total, min, &non_canonical_root, pieces.iter(), &STRATEGY);
1000 assert!(matches!(result, Err(Error::Inconsistent)));
1001 }
1002
1003 #[test]
1004 fn test_decode_invalid_index() {
1005 let data = b"Testing recovery pieces";
1006 let total = 8u16;
1007 let min = 3u16;
1008
1009 let (root, chunks) = encode::<Sha256, _>(total, min, data.as_slice(), &STRATEGY).unwrap();
1011
1012 let mut invalid = checked(root, chunks[1].clone());
1014 invalid.index = 8;
1015 let pieces: Vec<_> = vec![
1016 checked(root, chunks[0].clone()), invalid, checked(root, chunks[6].clone()), ];
1020
1021 let result = decode::<Sha256, _>(total, min, &root, pieces.iter(), &STRATEGY);
1023 assert!(matches!(result, Err(Error::InvalidIndex(8))));
1024 }
1025
1026 #[test]
1027 fn test_max_chunks() {
1028 let data = vec![42u8; 1000]; let total = u16::MAX;
1030 let min = u16::MAX / 2;
1031
1032 let (root, chunks) = encode::<Sha256, _>(total, min, data.as_slice(), &STRATEGY).unwrap();
1034
1035 let minimal = chunks
1037 .into_iter()
1038 .take(min as usize)
1039 .map(|c| checked(root, c))
1040 .collect::<Vec<_>>();
1041 let decoded = decode::<Sha256, _>(total, min, &root, minimal.iter(), &STRATEGY).unwrap();
1042 assert_eq!(decoded, data);
1043 }
1044
1045 #[test]
1046 fn test_too_many_chunks() {
1047 let data = vec![42u8; 1000]; let total = u16::MAX;
1049 let min = u16::MAX / 2 - 1;
1050
1051 let result = encode::<Sha256, _>(total, min, data.as_slice(), &STRATEGY);
1053 assert!(matches!(
1054 result,
1055 Err(Error::ReedSolomon(
1056 reed_solomon_simd::Error::UnsupportedShardCount {
1057 original_count: _,
1058 recovery_count: _,
1059 }
1060 ))
1061 ));
1062 }
1063
1064 #[test]
1065 fn test_too_many_total_shards() {
1066 assert!(RS::encode(
1067 &Config {
1068 minimum_shards: NZU16!(u16::MAX / 2 + 1),
1069 extra_shards: NZU16!(u16::MAX),
1070 },
1071 [].as_slice(),
1072 &STRATEGY,
1073 )
1074 .is_err())
1075 }
1076
1077 #[test]
1078 fn test_chunk_encode_with_pool_matches_encode() {
1079 let executor = deterministic::Runner::default();
1080 executor.start(|context| async move {
1081 let pool = context.network_buffer_pool();
1082
1083 let data = b"pool encoding test";
1084 let (_root, chunks) = encode::<Sha256, _>(5, 3, data.as_slice(), &STRATEGY).unwrap();
1085 let chunk = &chunks[0];
1086
1087 let encoded = chunk.encode();
1088 let mut encoded_pool = chunk.encode_with_pool(pool);
1089 let mut encoded_pool_bytes = vec![0u8; encoded_pool.remaining()];
1090 encoded_pool.copy_to_slice(&mut encoded_pool_bytes);
1091 assert_eq!(encoded_pool_bytes, encoded.as_ref());
1092 });
1093 }
1094
1095 #[cfg(feature = "arbitrary")]
1096 mod conformance {
1097 use super::*;
1098 use commonware_codec::conformance::CodecConformance;
1099 use commonware_cryptography::sha256::Digest as Sha256Digest;
1100
1101 commonware_conformance::conformance_tests! {
1102 CodecConformance<Chunk<Sha256Digest>>,
1103 }
1104 }
1105}