commonware_coding/
reed_solomon.rs

1use crate::{Config, Scheme};
2use bytes::{Buf, BufMut};
3use commonware_codec::{EncodeSize, FixedSize, Read, ReadExt, ReadRangeExt, Write};
4use commonware_cryptography::Hasher;
5use commonware_parallel::Strategy;
6use commonware_storage::bmt::{self, Builder};
7use reed_solomon_simd::{Error as RsError, ReedSolomonDecoder, ReedSolomonEncoder};
8use std::{collections::HashSet, marker::PhantomData};
9use thiserror::Error;
10
11/// Errors that can occur when interacting with the Reed-Solomon coder.
12#[derive(Error, Debug)]
13pub enum Error {
14    #[error("reed-solomon error: {0}")]
15    ReedSolomon(#[from] RsError),
16    #[error("inconsistent")]
17    Inconsistent,
18    #[error("invalid proof")]
19    InvalidProof,
20    #[error("not enough chunks")]
21    NotEnoughChunks,
22    #[error("duplicate chunk index: {0}")]
23    DuplicateIndex(u16),
24    #[error("invalid data length: {0}")]
25    InvalidDataLength(usize),
26    #[error("invalid index: {0}")]
27    InvalidIndex(u16),
28    #[error("wrong index: {0}")]
29    WrongIndex(u16),
30    #[error("too many total shards: {0}")]
31    TooManyTotalShards(u32),
32}
33
34fn total_shards(config: &Config) -> Result<u16, Error> {
35    let total = config.total_shards();
36    total
37        .try_into()
38        .map_err(|_| Error::TooManyTotalShards(total))
39}
40
41/// A piece of data from a Reed-Solomon encoded object.
42#[derive(Debug, Clone)]
43pub struct Chunk<H: Hasher> {
44    /// The shard of encoded data.
45    shard: Vec<u8>,
46
47    /// The index of [Chunk] in the original data.
48    index: u16,
49
50    /// The multi-proof of the shard in the [bmt] at the given index.
51    proof: bmt::Proof<H::Digest>,
52}
53
54impl<H: Hasher> Chunk<H> {
55    /// Create a new [Chunk] from the given shard, index, and proof.
56    const fn new(shard: Vec<u8>, index: u16, proof: bmt::Proof<H::Digest>) -> Self {
57        Self {
58            shard,
59            index,
60            proof,
61        }
62    }
63
64    /// Verify a [Chunk] against the given root.
65    fn verify(&self, index: u16, root: &H::Digest) -> bool {
66        // Ensure the index matches
67        if index != self.index {
68            return false;
69        }
70
71        // Compute shard digest
72        let mut hasher = H::new();
73        hasher.update(&self.shard);
74        let shard_digest = hasher.finalize();
75
76        // Verify proof
77        self.proof
78            .verify_element_inclusion(&mut hasher, &shard_digest, self.index as u32, root)
79            .is_ok()
80    }
81}
82
83impl<H: Hasher> Write for Chunk<H> {
84    fn write(&self, writer: &mut impl BufMut) {
85        self.shard.write(writer);
86        self.index.write(writer);
87        self.proof.write(writer);
88    }
89}
90
91impl<H: Hasher> Read for Chunk<H> {
92    /// The maximum size of the shard.
93    type Cfg = crate::CodecConfig;
94
95    fn read_cfg(reader: &mut impl Buf, cfg: &Self::Cfg) -> Result<Self, commonware_codec::Error> {
96        let shard = Vec::<u8>::read_range(reader, ..=cfg.maximum_shard_size)?;
97        let index = u16::read(reader)?;
98        let proof = bmt::Proof::<H::Digest>::read_cfg(reader, &1)?;
99        Ok(Self {
100            shard,
101            index,
102            proof,
103        })
104    }
105}
106
107impl<H: Hasher> EncodeSize for Chunk<H> {
108    fn encode_size(&self) -> usize {
109        self.shard.encode_size() + self.index.encode_size() + self.proof.encode_size()
110    }
111}
112
113impl<H: Hasher> PartialEq for Chunk<H> {
114    fn eq(&self, other: &Self) -> bool {
115        self.shard == other.shard && self.index == other.index && self.proof == other.proof
116    }
117}
118
119impl<H: Hasher> Eq for Chunk<H> {}
120
121#[cfg(feature = "arbitrary")]
122impl<H: Hasher> arbitrary::Arbitrary<'_> for Chunk<H>
123where
124    H::Digest: for<'a> arbitrary::Arbitrary<'a>,
125{
126    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
127        Ok(Self {
128            shard: u.arbitrary()?,
129            index: u.arbitrary()?,
130            proof: u.arbitrary()?,
131        })
132    }
133}
134
135/// Prepare data for encoding.
136fn prepare_data(data: Vec<u8>, k: usize, m: usize) -> Vec<Vec<u8>> {
137    // Compute shard length
138    let data_len = data.len();
139    let prefixed_len = u32::SIZE + data_len;
140    let mut shard_len = prefixed_len.div_ceil(k);
141
142    // Ensure shard length is even (required for optimizations in `reed-solomon-simd`)
143    if !shard_len.is_multiple_of(2) {
144        shard_len += 1;
145    }
146
147    // Prepare data
148    let length_bytes = (data_len as u32).to_be_bytes();
149    let mut padded = vec![0u8; k * shard_len];
150    padded[..u32::SIZE].copy_from_slice(&length_bytes);
151    padded[u32::SIZE..u32::SIZE + data_len].copy_from_slice(&data);
152
153    let mut shards = Vec::with_capacity(k + m); // assume recovery shards will be added later
154    for chunk in padded.chunks(shard_len) {
155        shards.push(chunk.to_vec());
156    }
157    shards
158}
159
160/// Extract data from encoded shards.
161fn extract_data(shards: Vec<&[u8]>, k: usize) -> Vec<u8> {
162    // Concatenate shards
163    let mut data = shards.into_iter().take(k).flatten();
164
165    // Extract length prefix
166    let data_len = (&mut data)
167        .take(u32::SIZE)
168        .copied()
169        .collect::<Vec<_>>()
170        .try_into()
171        .expect("insufficient data");
172    let data_len = u32::from_be_bytes(data_len) as usize;
173
174    // Extract data
175    data.take(data_len).copied().collect()
176}
177
178/// Type alias for the internal encoding result.
179type Encoding<H> = (bmt::Tree<H>, Vec<Vec<u8>>);
180
181/// Inner logic for [encode()]
182fn encode_inner<H: Hasher, S: Strategy>(
183    total: u16,
184    min: u16,
185    data: Vec<u8>,
186    strategy: &S,
187) -> Result<Encoding<H>, Error> {
188    // Validate parameters
189    assert!(total > min);
190    assert!(min > 0);
191    let n = total as usize;
192    let k = min as usize;
193    let m = n - k;
194    if data.len() > u32::MAX as usize {
195        return Err(Error::InvalidDataLength(data.len()));
196    }
197
198    // Prepare data
199    let mut shards = prepare_data(data, k, m);
200    let shard_len = shards[0].len();
201
202    // Create encoder
203    let mut encoder = ReedSolomonEncoder::new(k, m, shard_len).map_err(Error::ReedSolomon)?;
204    for shard in &shards {
205        encoder
206            .add_original_shard(shard)
207            .map_err(Error::ReedSolomon)?;
208    }
209
210    // Compute recovery shards
211    let encoding = encoder.encode().map_err(Error::ReedSolomon)?;
212    let recovery_shards: Vec<Vec<u8>> = encoding
213        .recovery_iter()
214        .map(|shard| shard.to_vec())
215        .collect();
216    shards.extend(recovery_shards);
217
218    // Build Merkle tree
219    let mut builder = Builder::<H>::new(n);
220    let shard_hashes = strategy.map_init_collect_vec(&shards, H::new, |hasher, shard| {
221        hasher.update(shard);
222        hasher.finalize()
223    });
224    for hash in &shard_hashes {
225        builder.add(hash);
226    }
227    let tree = builder.build();
228
229    Ok((tree, shards))
230}
231
232/// Encode data using a Reed-Solomon coder and insert it into a [bmt].
233///
234/// # Parameters
235///
236/// - `total`: The total number of chunks to generate.
237/// - `min`: The minimum number of chunks required to decode the data.
238/// - `data`: The data to encode.
239/// - `concurrency`: The level of concurrency to use.
240///
241/// # Returns
242///
243/// - `root`: The root of the [bmt].
244/// - `chunks`: [Chunk]s of encoded data (that can be proven against `root`).
245fn encode<H: Hasher, S: Strategy>(
246    total: u16,
247    min: u16,
248    data: Vec<u8>,
249    strategy: &S,
250) -> Result<(H::Digest, Vec<Chunk<H>>), Error> {
251    // Encode data
252    let (tree, shards) = encode_inner::<H, _>(total, min, data, strategy)?;
253    let root = tree.root();
254    let n = total as usize;
255
256    // Generate chunks
257    let mut chunks = Vec::with_capacity(n);
258    for (i, shard) in shards.into_iter().enumerate() {
259        let proof = tree.proof(i as u32).map_err(|_| Error::InvalidProof)?;
260        chunks.push(Chunk::new(shard, i as u16, proof));
261    }
262
263    Ok((root, chunks))
264}
265
266/// Decode data from a set of [Chunk]s.
267///
268/// It is assumed that all [Chunk]s have already been verified against the given root using [Chunk::verify].
269///
270/// # Parameters
271///
272/// - `total`: The total number of chunks to generate.
273/// - `min`: The minimum number of chunks required to decode the data.
274/// - `root`: The root of the [bmt].
275/// - `chunks`: [Chunk]s of encoded data (that can be proven against `root`).
276/// - `concurrency`: The level of concurrency to use.
277///
278/// # Returns
279///
280/// - `data`: The decoded data.
281fn decode<H: Hasher, S: Strategy>(
282    total: u16,
283    min: u16,
284    root: &H::Digest,
285    chunks: &[Chunk<H>],
286    strategy: &S,
287) -> Result<Vec<u8>, Error> {
288    // Validate parameters
289    assert!(total > min);
290    assert!(min > 0);
291    let n = total as usize;
292    let k = min as usize;
293    let m = n - k;
294    if chunks.len() < k {
295        return Err(Error::NotEnoughChunks);
296    }
297
298    // Verify chunks
299    let shard_len = chunks[0].shard.len();
300    let mut seen = HashSet::new();
301    let mut provided_originals: Vec<(usize, &[u8])> = Vec::new();
302    let mut provided_recoveries: Vec<(usize, &[u8])> = Vec::new();
303    for chunk in chunks {
304        // Check for duplicate index
305        let index = chunk.index;
306        if index >= total {
307            return Err(Error::InvalidIndex(index));
308        }
309        if seen.contains(&index) {
310            return Err(Error::DuplicateIndex(index));
311        }
312        seen.insert(index);
313
314        // Add to provided shards
315        if index < min {
316            provided_originals.push((index as usize, chunk.shard.as_slice()));
317        } else {
318            provided_recoveries.push((index as usize - k, chunk.shard.as_slice()));
319        }
320    }
321
322    // Decode original data
323    let mut decoder = ReedSolomonDecoder::new(k, m, shard_len).map_err(Error::ReedSolomon)?;
324    for (idx, ref shard) in &provided_originals {
325        decoder
326            .add_original_shard(*idx, shard)
327            .map_err(Error::ReedSolomon)?;
328    }
329    for (idx, ref shard) in &provided_recoveries {
330        decoder
331            .add_recovery_shard(*idx, shard)
332            .map_err(Error::ReedSolomon)?;
333    }
334    let decoding = decoder.decode().map_err(Error::ReedSolomon)?;
335
336    // Reconstruct all original shards
337    let mut shards = Vec::with_capacity(n);
338    shards.resize(k, Default::default());
339    for (idx, shard) in provided_originals {
340        shards[idx] = shard;
341    }
342    for (idx, shard) in decoding.restored_original_iter() {
343        shards[idx] = shard;
344    }
345
346    // Encode recovered data to get recovery shards
347    let mut encoder = ReedSolomonEncoder::new(k, m, shard_len).map_err(Error::ReedSolomon)?;
348    for shard in shards.iter().take(k) {
349        encoder
350            .add_original_shard(shard)
351            .map_err(Error::ReedSolomon)?;
352    }
353    let encoding = encoder.encode().map_err(Error::ReedSolomon)?;
354    let recovery_shards: Vec<&[u8]> = encoding.recovery_iter().collect();
355    shards.extend(recovery_shards);
356
357    // Build Merkle tree
358    let mut builder = Builder::<H>::new(n);
359    let shard_hashes = strategy.map_init_collect_vec(&shards, H::new, |hasher, shard| {
360        hasher.update(shard);
361        hasher.finalize()
362    });
363    for hash in &shard_hashes {
364        builder.add(hash);
365    }
366    let tree = builder.build();
367
368    // Confirm root is consistent
369    if tree.root() != *root {
370        return Err(Error::Inconsistent);
371    }
372
373    // Extract original data
374    Ok(extract_data(shards, k))
375}
376
377/// A SIMD-optimized Reed-Solomon coder that emits chunks that can be proven against a [bmt].
378///
379/// # Behavior
380///
381/// The encoder takes input data, splits it into `k` data shards, and generates `m` recovery
382/// shards using [Reed-Solomon encoding](https://en.wikipedia.org/wiki/Reed%E2%80%93Solomon_error_correction).
383/// All `n = k + m` shards are then used to build a [bmt], producing a single root hash. Each shard
384/// is packaged as a chunk containing the shard data, its index, and a Merkle multi-proof against the [bmt] root.
385///
386/// ## Encoding
387///
388/// ```text
389///               +--------------------------------------+
390///               |         Original Data (Bytes)        |
391///               +--------------------------------------+
392///                                  |
393///                                  v
394///               +--------------------------------------+
395///               | [Length Prefix | Original Data...]   |
396///               +--------------------------------------+
397///                                  |
398///                                  v
399///              +----------+ +----------+    +-----------+
400///              |  Shard 0 | |  Shard 1 | .. | Shard k-1 |  (Data Shards)
401///              +----------+ +----------+    +-----------+
402///                     |            |             |
403///                     |            |             |
404///                     +------------+-------------+
405///                                  |
406///                                  v
407///                        +------------------+
408///                        | Reed-Solomon     |
409///                        | Encoder (k, m)   |
410///                        +------------------+
411///                                  |
412///                                  v
413///              +----------+ +----------+    +-----------+
414///              |  Shard k | | Shard k+1| .. | Shard n-1 |  (Recovery Shards)
415///              +----------+ +----------+    +-----------+
416/// ```
417///
418/// ## Merkle Tree Construction
419///
420/// All `n` shards (data and recovery) are hashed and used as leaves to build a [bmt].
421///
422/// ```text
423/// Shards:    [Shard 0, Shard 1, ..., Shard n-1]
424///             |        |              |
425///             v        v              v
426/// Hashes:    [H(S_0), H(S_1), ..., H(S_n-1)]
427///             \       / \       /
428///              \     /   \     /
429///               +---+     +---+
430///                 |         |
431///                 \         /
432///                  \       /
433///                   +-----+
434///                      |
435///                      v
436///                +----------+
437///                |   Root   |
438///                +----------+
439/// ```
440///
441/// The final output is the [bmt] root and a set of `n` chunks.
442///
443/// `(Root, [Chunk 0, Chunk 1, ..., Chunk n-1])`
444///
445/// Each chunk contains:
446/// - `shard`: The shard data (original or recovery).
447/// - `index`: The shard's original index (0 to n-1).
448/// - `proof`: A Merkle multi-proof of the shard's inclusion in the [bmt].
449///
450/// ## Decoding and Verification
451///
452/// The decoder requires any `k` chunks to reconstruct the original data.
453/// 1. Each chunk's Merkle multi-proof is verified against the [bmt] root.
454/// 2. The shards from the valid chunks are used to reconstruct the original `k` data shards.
455/// 3. To ensure consistency, the recovered data shards are re-encoded, and a new [bmt] root is
456///    generated. This new root MUST match the original [bmt] root. This prevents attacks where
457///    an adversary provides a valid set of chunks that decode to different data.
458/// 4. If the roots match, the original data is extracted from the reconstructed data shards.
459#[derive(Clone, Copy)]
460pub struct ReedSolomon<H> {
461    _marker: PhantomData<H>,
462}
463
464impl<H> std::fmt::Debug for ReedSolomon<H> {
465    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
466        f.debug_struct("ReedSolomon").finish()
467    }
468}
469
470impl<H: Hasher> Scheme for ReedSolomon<H> {
471    type Commitment = H::Digest;
472
473    type Shard = Chunk<H>;
474    type ReShard = Chunk<H>;
475    type CheckedShard = Chunk<H>;
476    type CheckingData = ();
477
478    type Error = Error;
479
480    fn encode(
481        config: &Config,
482        mut data: impl Buf,
483        strategy: &impl Strategy,
484    ) -> Result<(Self::Commitment, Vec<Self::Shard>), Self::Error> {
485        let data: Vec<u8> = data.copy_to_bytes(data.remaining()).to_vec();
486        encode(total_shards(config)?, config.minimum_shards, data, strategy)
487    }
488
489    fn reshard(
490        _config: &Config,
491        commitment: &Self::Commitment,
492        index: u16,
493        shard: Self::Shard,
494    ) -> Result<(Self::CheckingData, Self::CheckedShard, Self::ReShard), Self::Error> {
495        if shard.index != index {
496            return Err(Error::WrongIndex(index));
497        }
498        if shard.verify(shard.index, commitment) {
499            Ok(((), shard.clone(), shard))
500        } else {
501            Err(Error::InvalidProof)
502        }
503    }
504
505    fn check(
506        _config: &Config,
507        commitment: &Self::Commitment,
508        _checking_data: &Self::CheckingData,
509        index: u16,
510        reshard: Self::ReShard,
511    ) -> Result<Self::CheckedShard, Self::Error> {
512        if reshard.index != index {
513            return Err(Error::WrongIndex(reshard.index));
514        }
515        if !reshard.verify(reshard.index, commitment) {
516            return Err(Error::InvalidProof);
517        }
518        Ok(reshard)
519    }
520
521    fn decode(
522        config: &Config,
523        commitment: &Self::Commitment,
524        _checking_data: Self::CheckingData,
525        shards: &[Self::CheckedShard],
526        strategy: &impl Strategy,
527    ) -> Result<Vec<u8>, Self::Error> {
528        decode(
529            total_shards(config)?,
530            config.minimum_shards,
531            commitment,
532            shards,
533            strategy,
534        )
535    }
536}
537
538#[cfg(test)]
539mod tests {
540    use super::*;
541    use commonware_cryptography::Sha256;
542    use commonware_parallel::Sequential;
543
544    const STRATEGY: Sequential = Sequential;
545
546    #[test]
547    fn test_recovery() {
548        let data = b"Testing recovery pieces";
549        let total = 8u16;
550        let min = 3u16;
551
552        // Encode the data
553        let (root, chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
554
555        // Use a mix of original and recovery pieces
556        let pieces: Vec<_> = vec![
557            chunks[0].clone(), // original
558            chunks[4].clone(), // recovery
559            chunks[6].clone(), // recovery
560        ];
561
562        // Try to decode with a mix of original and recovery pieces
563        let decoded = decode::<Sha256, _>(total, min, &root, &pieces, &STRATEGY).unwrap();
564        assert_eq!(decoded, data);
565    }
566
567    #[test]
568    fn test_not_enough_pieces() {
569        let data = b"Test insufficient pieces";
570        let total = 6u16;
571        let min = 4u16;
572
573        // Encode data
574        let (root, chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
575
576        // Try with fewer than min
577        let pieces: Vec<_> = chunks.into_iter().take(2).collect();
578
579        // Fail to decode
580        let result = decode::<Sha256, _>(total, min, &root, &pieces, &STRATEGY);
581        assert!(matches!(result, Err(Error::NotEnoughChunks)));
582    }
583
584    #[test]
585    fn test_duplicate_index() {
586        let data = b"Test duplicate detection";
587        let total = 5u16;
588        let min = 3u16;
589
590        // Encode data
591        let (root, chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
592
593        // Include duplicate index by cloning the first chunk
594        let pieces = vec![chunks[0].clone(), chunks[0].clone(), chunks[1].clone()];
595
596        // Fail to decode
597        let result = decode::<Sha256, _>(total, min, &root, &pieces, &STRATEGY);
598        assert!(matches!(result, Err(Error::DuplicateIndex(0))));
599    }
600
601    #[test]
602    fn test_invalid_index() {
603        let data = b"Test invalid index";
604        let total = 5u16;
605        let min = 3u16;
606
607        // Encode data
608        let (root, chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
609
610        // Verify all proofs at invalid index
611        for i in 0..total {
612            assert!(!chunks[i as usize].verify(i + 1, &root));
613        }
614    }
615
616    #[test]
617    #[should_panic(expected = "assertion failed: total > min")]
618    fn test_invalid_total() {
619        let data = b"Test parameter validation";
620
621        // total <= min should panic
622        encode::<Sha256, _>(3, 3, data.to_vec(), &STRATEGY).unwrap();
623    }
624
625    #[test]
626    #[should_panic(expected = "assertion failed: min > 0")]
627    fn test_invalid_min() {
628        let data = b"Test parameter validation";
629
630        // min = 0 should panic
631        encode::<Sha256, _>(5, 0, data.to_vec(), &STRATEGY).unwrap();
632    }
633
634    #[test]
635    fn test_empty_data() {
636        let data = b"";
637        let total = 100u16;
638        let min = 30u16;
639
640        // Encode data
641        let (root, chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
642
643        // Try to decode with min
644        let minimal = chunks.into_iter().take(min as usize).collect::<Vec<_>>();
645        let decoded = decode::<Sha256, _>(total, min, &root, &minimal, &STRATEGY).unwrap();
646        assert_eq!(decoded, data);
647    }
648
649    #[test]
650    fn test_large_data() {
651        let data = vec![42u8; 1000]; // 1KB of data
652        let total = 7u16;
653        let min = 4u16;
654
655        // Encode data
656        let (root, chunks) = encode::<Sha256, _>(total, min, data.clone(), &STRATEGY).unwrap();
657
658        // Try to decode with min
659        let minimal = chunks.into_iter().take(min as usize).collect::<Vec<_>>();
660        let decoded = decode::<Sha256, _>(total, min, &root, &minimal, &STRATEGY).unwrap();
661        assert_eq!(decoded, data);
662    }
663
664    #[test]
665    fn test_malicious_root_detection() {
666        let data = b"Original data that should be protected";
667        let total = 7u16;
668        let min = 4u16;
669
670        // Encode data correctly to get valid chunks
671        let (_correct_root, chunks) =
672            encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
673
674        // Create a malicious/fake root (simulating a malicious encoder)
675        let mut hasher = Sha256::new();
676        hasher.update(b"malicious_data_that_wasnt_actually_encoded");
677        let malicious_root = hasher.finalize();
678
679        // Verify all proofs at incorrect root
680        for i in 0..total {
681            assert!(!chunks[i as usize].verify(i, &malicious_root));
682        }
683
684        // Collect valid pieces (these are legitimate fragments)
685        let minimal = chunks.into_iter().take(min as usize).collect::<Vec<_>>();
686
687        // Attempt to decode with malicious root
688        let result = decode::<Sha256, _>(total, min, &malicious_root, &minimal, &STRATEGY);
689        assert!(matches!(result, Err(Error::Inconsistent)));
690    }
691
692    #[test]
693    fn test_manipulated_chunk_detection() {
694        let data = b"Data integrity must be maintained";
695        let total = 6u16;
696        let min = 3u16;
697
698        // Encode data
699        let (root, mut chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
700
701        // Tamper with one of the chunks by modifying the shard data
702        if !chunks[1].shard.is_empty() {
703            chunks[1].shard[0] ^= 0xFF; // Flip bits in first byte
704        }
705
706        // Try to decode with the tampered chunk
707        let result = decode::<Sha256, _>(total, min, &root, &chunks, &STRATEGY);
708        assert!(matches!(result, Err(Error::Inconsistent)));
709    }
710
711    #[test]
712    fn test_inconsistent_shards() {
713        let data = b"Test data for malicious encoding";
714        let total = 5u16;
715        let min = 3u16;
716        let m = total - min;
717
718        // Compute original data encoding
719        let shards = prepare_data(data.to_vec(), min as usize, total as usize - min as usize);
720        let shard_size = shards[0].len();
721
722        // Re-encode the data
723        let mut encoder = ReedSolomonEncoder::new(min as usize, m as usize, shard_size).unwrap();
724        for shard in &shards {
725            encoder.add_original_shard(shard).unwrap();
726        }
727        let recovery_result = encoder.encode().unwrap();
728        let mut recovery_shards: Vec<Vec<u8>> = recovery_result
729            .recovery_iter()
730            .map(|s| s.to_vec())
731            .collect();
732
733        // Tamper with one recovery shard
734        if !recovery_shards[0].is_empty() {
735            recovery_shards[0][0] ^= 0xFF;
736        }
737
738        // Build malicious shards
739        let mut malicious_shards = shards.clone();
740        malicious_shards.extend(recovery_shards);
741
742        // Build malicious tree
743        let mut builder = Builder::<Sha256>::new(total as usize);
744        for shard in &malicious_shards {
745            let mut hasher = Sha256::new();
746            hasher.update(shard);
747            builder.add(&hasher.finalize());
748        }
749        let malicious_tree = builder.build();
750        let malicious_root = malicious_tree.root();
751
752        // Generate chunks for min pieces, including the tampered recovery
753        let selected_indices = vec![0, 1, 3]; // originals 0,1 and recovery 0 (index 3)
754        let mut pieces = Vec::new();
755        for &i in &selected_indices {
756            let merkle_proof = malicious_tree.proof(i as u32).unwrap();
757            let shard = malicious_shards[i].clone();
758            let chunk = Chunk::new(shard, i as u16, merkle_proof);
759            pieces.push(chunk);
760        }
761
762        // Fail to decode
763        let result = decode::<Sha256, _>(total, min, &malicious_root, &pieces, &STRATEGY);
764        assert!(matches!(result, Err(Error::Inconsistent)));
765    }
766
767    #[test]
768    fn test_decode_invalid_index() {
769        let data = b"Testing recovery pieces";
770        let total = 8u16;
771        let min = 3u16;
772
773        // Encode the data
774        let (root, mut chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
775
776        // Use a mix of original and recovery pieces
777        chunks[1].index = 8;
778        let pieces: Vec<_> = vec![
779            chunks[0].clone(), // original
780            chunks[1].clone(), // recovery with invalid index
781            chunks[6].clone(), // recovery
782        ];
783
784        // Fail to decode
785        let result = decode::<Sha256, _>(total, min, &root, &pieces, &STRATEGY);
786        assert!(matches!(result, Err(Error::InvalidIndex(8))));
787    }
788
789    #[test]
790    fn test_max_chunks() {
791        let data = vec![42u8; 1000]; // 1KB of data
792        let total = u16::MAX;
793        let min = u16::MAX / 2;
794
795        // Encode data
796        let (root, chunks) = encode::<Sha256, _>(total, min, data.clone(), &STRATEGY).unwrap();
797
798        // Try to decode with min
799        let minimal = chunks.into_iter().take(min as usize).collect::<Vec<_>>();
800        let decoded = decode::<Sha256, _>(total, min, &root, &minimal, &STRATEGY).unwrap();
801        assert_eq!(decoded, data);
802    }
803
804    #[test]
805    fn test_too_many_chunks() {
806        let data = vec![42u8; 1000]; // 1KB of data
807        let total = u16::MAX;
808        let min = u16::MAX / 2 - 1;
809
810        // Encode data
811        let result = encode::<Sha256, _>(total, min, data, &STRATEGY);
812        assert!(matches!(
813            result,
814            Err(Error::ReedSolomon(
815                reed_solomon_simd::Error::UnsupportedShardCount {
816                    original_count: _,
817                    recovery_count: _,
818                }
819            ))
820        ));
821    }
822
823    #[test]
824    fn test_too_many_total_shards() {
825        assert!(ReedSolomon::<Sha256>::encode(
826            &Config {
827                minimum_shards: u16::MAX / 2 + 1,
828                extra_shards: u16::MAX,
829            },
830            [].as_slice(),
831            &STRATEGY,
832        )
833        .is_err())
834    }
835
836    #[cfg(feature = "arbitrary")]
837    mod conformance {
838        use super::*;
839        use commonware_codec::conformance::CodecConformance;
840
841        commonware_conformance::conformance_tests! {
842            CodecConformance<Chunk<Sha256>>,
843        }
844    }
845}