commonware_coding/reed_solomon/
mod.rs

1use crate::{Config, Scheme};
2use bytes::{Buf, BufMut};
3use commonware_codec::{EncodeSize, FixedSize, Read, ReadExt, ReadRangeExt, Write};
4use commonware_cryptography::Hasher;
5use commonware_storage::bmt::{self, Builder};
6use rayon::{
7    iter::{IntoParallelRefIterator, ParallelIterator},
8    ThreadPoolBuilder,
9};
10use reed_solomon_simd::{Error as RsError, ReedSolomonDecoder, ReedSolomonEncoder};
11use std::{collections::HashSet, marker::PhantomData};
12use thiserror::Error;
13
14/// Errors that can occur when interacting with the Reed-Solomon coder.
15#[derive(Error, Debug)]
16pub enum Error {
17    #[error("reed-solomon error: {0}")]
18    ReedSolomon(#[from] RsError),
19    #[error("inconsistent")]
20    Inconsistent,
21    #[error("invalid proof")]
22    InvalidProof,
23    #[error("not enough chunks")]
24    NotEnoughChunks,
25    #[error("duplicate chunk index: {0}")]
26    DuplicateIndex(u16),
27    #[error("invalid data length: {0}")]
28    InvalidDataLength(usize),
29    #[error("invalid index: {0}")]
30    InvalidIndex(u16),
31    #[error("wrong index: {0}")]
32    WrongIndex(u16),
33    #[error("too many total shards: {0}")]
34    TooManyTotalShards(u32),
35}
36
37fn total_shards(config: &Config) -> Result<u16, Error> {
38    let total = config.total_shards();
39    total
40        .try_into()
41        .map_err(|_| Error::TooManyTotalShards(total))
42}
43
44/// A piece of data from a Reed-Solomon encoded object.
45#[derive(Debug, Clone)]
46pub struct Chunk<H: Hasher> {
47    /// The shard of encoded data.
48    shard: Vec<u8>,
49
50    /// The index of [Chunk] in the original data.
51    index: u16,
52
53    /// The proof of the shard in the [bmt] at the given index.
54    proof: bmt::Proof<H>,
55}
56
57impl<H: Hasher> Chunk<H> {
58    /// Create a new [Chunk] from the given shard, index, and proof.
59    const fn new(shard: Vec<u8>, index: u16, proof: bmt::Proof<H>) -> Self {
60        Self {
61            shard,
62            index,
63            proof,
64        }
65    }
66
67    /// Verify a [Chunk] against the given root.
68    fn verify(&self, index: u16, root: &H::Digest) -> bool {
69        // Ensure the index matches
70        if index != self.index {
71            return false;
72        }
73
74        // Compute shard digest
75        let mut hasher = H::new();
76        hasher.update(&self.shard);
77        let shard_digest = hasher.finalize();
78
79        // Verify proof
80        self.proof
81            .verify(&mut hasher, &shard_digest, self.index as u32, root)
82            .is_ok()
83    }
84}
85
86impl<H: Hasher> Write for Chunk<H> {
87    fn write(&self, writer: &mut impl BufMut) {
88        self.shard.write(writer);
89        self.index.write(writer);
90        self.proof.write(writer);
91    }
92}
93
94impl<H: Hasher> Read for Chunk<H> {
95    /// The maximum size of the shard.
96    type Cfg = crate::CodecConfig;
97
98    fn read_cfg(reader: &mut impl Buf, cfg: &Self::Cfg) -> Result<Self, commonware_codec::Error> {
99        let shard = Vec::<u8>::read_range(reader, ..=cfg.maximum_shard_size)?;
100        let index = u16::read(reader)?;
101        let proof = bmt::Proof::<H>::read(reader)?;
102        Ok(Self {
103            shard,
104            index,
105            proof,
106        })
107    }
108}
109
110impl<H: Hasher> EncodeSize for Chunk<H> {
111    fn encode_size(&self) -> usize {
112        self.shard.encode_size() + self.index.encode_size() + self.proof.encode_size()
113    }
114}
115
116impl<H: Hasher> PartialEq for Chunk<H> {
117    fn eq(&self, other: &Self) -> bool {
118        self.shard == other.shard && self.index == other.index && self.proof == other.proof
119    }
120}
121
122impl<H: Hasher> Eq for Chunk<H> {}
123
124#[cfg(feature = "arbitrary")]
125impl<H: Hasher> arbitrary::Arbitrary<'_> for Chunk<H>
126where
127    H::Digest: for<'a> arbitrary::Arbitrary<'a>,
128{
129    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
130        Ok(Self {
131            shard: u.arbitrary()?,
132            index: u.arbitrary()?,
133            proof: u.arbitrary()?,
134        })
135    }
136}
137
138/// Prepare data for encoding.
139fn prepare_data(data: Vec<u8>, k: usize, m: usize) -> Vec<Vec<u8>> {
140    // Compute shard length
141    let data_len = data.len();
142    let prefixed_len = u32::SIZE + data_len;
143    let mut shard_len = prefixed_len.div_ceil(k);
144
145    // Ensure shard length is even (required for optimizations in `reed-solomon-simd`)
146    if !shard_len.is_multiple_of(2) {
147        shard_len += 1;
148    }
149
150    // Prepare data
151    let length_bytes = (data_len as u32).to_be_bytes();
152    let mut padded = vec![0u8; k * shard_len];
153    padded[..u32::SIZE].copy_from_slice(&length_bytes);
154    padded[u32::SIZE..u32::SIZE + data_len].copy_from_slice(&data);
155
156    let mut shards = Vec::with_capacity(k + m); // assume recovery shards will be added later
157    for chunk in padded.chunks(shard_len) {
158        shards.push(chunk.to_vec());
159    }
160    shards
161}
162
163/// Extract data from encoded shards.
164fn extract_data(shards: Vec<&[u8]>, k: usize) -> Vec<u8> {
165    // Concatenate shards
166    let mut data = shards.into_iter().take(k).flatten();
167
168    // Extract length prefix
169    let data_len = (&mut data)
170        .take(u32::SIZE)
171        .copied()
172        .collect::<Vec<_>>()
173        .try_into()
174        .expect("insufficient data");
175    let data_len = u32::from_be_bytes(data_len) as usize;
176
177    // Extract data
178    data.take(data_len).copied().collect()
179}
180
181/// Type alias for the internal encoding result.
182type Encoding<H> = (bmt::Tree<H>, Vec<Vec<u8>>);
183
184/// Inner logic for [encode()]
185fn encode_inner<H: Hasher>(
186    total: u16,
187    min: u16,
188    data: Vec<u8>,
189    concurrency: usize,
190) -> Result<Encoding<H>, Error> {
191    // Validate parameters
192    assert!(total > min);
193    assert!(min > 0);
194    let n = total as usize;
195    let k = min as usize;
196    let m = n - k;
197    if data.len() > u32::MAX as usize {
198        return Err(Error::InvalidDataLength(data.len()));
199    }
200
201    // Prepare data
202    let mut shards = prepare_data(data, k, m);
203    let shard_len = shards[0].len();
204
205    // Create encoder
206    let mut encoder = ReedSolomonEncoder::new(k, m, shard_len).map_err(Error::ReedSolomon)?;
207    for shard in &shards {
208        encoder
209            .add_original_shard(shard)
210            .map_err(Error::ReedSolomon)?;
211    }
212
213    // Compute recovery shards
214    let encoding = encoder.encode().map_err(Error::ReedSolomon)?;
215    let recovery_shards: Vec<Vec<u8>> = encoding
216        .recovery_iter()
217        .map(|shard| shard.to_vec())
218        .collect();
219    shards.extend(recovery_shards);
220
221    // Build Merkle tree
222    let mut builder = Builder::<H>::new(n);
223    if concurrency > 1 {
224        let pool = ThreadPoolBuilder::new()
225            .num_threads(concurrency)
226            .build()
227            .expect("unable to build thread pool");
228        let shard_hashes = pool.install(|| {
229            shards
230                .par_iter()
231                .map_init(H::new, |hasher, shard| {
232                    hasher.update(shard);
233                    hasher.finalize()
234                })
235                .collect::<Vec<_>>()
236        });
237        for hash in &shard_hashes {
238            builder.add(hash);
239        }
240    } else {
241        let mut hasher = H::new();
242        for shard in &shards {
243            hasher.update(shard);
244            builder.add(&hasher.finalize());
245        }
246    }
247    let tree = builder.build();
248
249    Ok((tree, shards))
250}
251
252/// Encode data using a Reed-Solomon coder and insert it into a [bmt].
253///
254/// # Parameters
255///
256/// - `total`: The total number of chunks to generate.
257/// - `min`: The minimum number of chunks required to decode the data.
258/// - `data`: The data to encode.
259/// - `concurrency`: The level of concurrency to use.
260///
261/// # Returns
262///
263/// - `root`: The root of the [bmt].
264/// - `chunks`: [Chunk]s of encoded data (that can be proven against `root`).
265fn encode<H: Hasher>(
266    total: u16,
267    min: u16,
268    data: Vec<u8>,
269    concurrency: usize,
270) -> Result<(H::Digest, Vec<Chunk<H>>), Error> {
271    // Encode data
272    let (tree, shards) = encode_inner::<H>(total, min, data, concurrency)?;
273    let root = tree.root();
274    let n = total as usize;
275
276    // Generate chunks
277    let mut chunks = Vec::with_capacity(n);
278    for (i, shard) in shards.into_iter().enumerate() {
279        let proof = tree.proof(i as u32).map_err(|_| Error::InvalidProof)?;
280        chunks.push(Chunk::new(shard, i as u16, proof));
281    }
282
283    Ok((root, chunks))
284}
285
286/// Decode data from a set of [Chunk]s.
287///
288/// It is assumed that all [Chunk]s have already been verified against the given root using [Chunk::verify].
289///
290/// # Parameters
291///
292/// - `total`: The total number of chunks to generate.
293/// - `min`: The minimum number of chunks required to decode the data.
294/// - `root`: The root of the [bmt].
295/// - `chunks`: [Chunk]s of encoded data (that can be proven against `root`).
296/// - `concurrency`: The level of concurrency to use.
297///
298/// # Returns
299///
300/// - `data`: The decoded data.
301fn decode<H: Hasher>(
302    total: u16,
303    min: u16,
304    root: &H::Digest,
305    chunks: &[Chunk<H>],
306    concurrency: usize,
307) -> Result<Vec<u8>, Error> {
308    // Validate parameters
309    assert!(total > min);
310    assert!(min > 0);
311    let n = total as usize;
312    let k = min as usize;
313    let m = n - k;
314    if chunks.len() < k {
315        return Err(Error::NotEnoughChunks);
316    }
317
318    // Verify chunks
319    let shard_len = chunks[0].shard.len();
320    let mut seen = HashSet::new();
321    let mut provided_originals: Vec<(usize, &[u8])> = Vec::new();
322    let mut provided_recoveries: Vec<(usize, &[u8])> = Vec::new();
323    for chunk in chunks {
324        // Check for duplicate index
325        let index = chunk.index;
326        if index >= total {
327            return Err(Error::InvalidIndex(index));
328        }
329        if seen.contains(&index) {
330            return Err(Error::DuplicateIndex(index));
331        }
332        seen.insert(index);
333
334        // Add to provided shards
335        if index < min {
336            provided_originals.push((index as usize, chunk.shard.as_slice()));
337        } else {
338            provided_recoveries.push((index as usize - k, chunk.shard.as_slice()));
339        }
340    }
341
342    // Decode original data
343    let mut decoder = ReedSolomonDecoder::new(k, m, shard_len).map_err(Error::ReedSolomon)?;
344    for (idx, ref shard) in &provided_originals {
345        decoder
346            .add_original_shard(*idx, shard)
347            .map_err(Error::ReedSolomon)?;
348    }
349    for (idx, ref shard) in &provided_recoveries {
350        decoder
351            .add_recovery_shard(*idx, shard)
352            .map_err(Error::ReedSolomon)?;
353    }
354    let decoding = decoder.decode().map_err(Error::ReedSolomon)?;
355
356    // Reconstruct all original shards
357    let mut shards = Vec::with_capacity(n);
358    shards.resize(k, Default::default());
359    for (idx, shard) in provided_originals {
360        shards[idx] = shard;
361    }
362    for (idx, shard) in decoding.restored_original_iter() {
363        shards[idx] = shard;
364    }
365
366    // Encode recovered data to get recovery shards
367    let mut encoder = ReedSolomonEncoder::new(k, m, shard_len).map_err(Error::ReedSolomon)?;
368    for shard in shards.iter().take(k) {
369        encoder
370            .add_original_shard(shard)
371            .map_err(Error::ReedSolomon)?;
372    }
373    let encoding = encoder.encode().map_err(Error::ReedSolomon)?;
374    let recovery_shards: Vec<&[u8]> = encoding.recovery_iter().collect();
375    shards.extend(recovery_shards);
376
377    // Build Merkle tree
378    let mut builder = Builder::<H>::new(n);
379    if concurrency > 1 {
380        let pool = ThreadPoolBuilder::new()
381            .num_threads(concurrency)
382            .build()
383            .expect("unable to build thread pool");
384        let shard_hashes = pool.install(|| {
385            shards
386                .par_iter()
387                .map_init(H::new, |hasher, shard| {
388                    hasher.update(shard);
389                    hasher.finalize()
390                })
391                .collect::<Vec<_>>()
392        });
393
394        for hash in &shard_hashes {
395            builder.add(hash);
396        }
397    } else {
398        let mut hasher = H::new();
399        for shard in &shards {
400            hasher.update(shard);
401            builder.add(&hasher.finalize());
402        }
403    }
404    let tree = builder.build();
405
406    // Confirm root is consistent
407    if tree.root() != *root {
408        return Err(Error::Inconsistent);
409    }
410
411    // Extract original data
412    Ok(extract_data(shards, k))
413}
414
415/// A SIMD-optimized Reed-Solomon coder that emits chunks that can be proven against a [bmt].
416///
417/// # Behavior
418///
419/// The encoder takes input data, splits it into `k` data shards, and generates `m` recovery
420/// shards using [Reed-Solomon encoding](https://en.wikipedia.org/wiki/Reed%E2%80%93Solomon_error_correction).
421/// All `n = k + m` shards are then used to build a [bmt], producing a single root hash. Each shard
422/// is packaged as a chunk containing the shard data, its index, and a Merkle proof against the [bmt] root.
423///
424/// ## Encoding
425///
426/// ```text
427///               +--------------------------------------+
428///               |         Original Data (Bytes)        |
429///               +--------------------------------------+
430///                                  |
431///                                  v
432///               +--------------------------------------+
433///               | [Length Prefix | Original Data...]   |
434///               +--------------------------------------+
435///                                  |
436///                                  v
437///              +----------+ +----------+    +-----------+
438///              |  Shard 0 | |  Shard 1 | .. | Shard k-1 |  (Data Shards)
439///              +----------+ +----------+    +-----------+
440///                     |            |             |
441///                     |            |             |
442///                     +------------+-------------+
443///                                  |
444///                                  v
445///                        +------------------+
446///                        | Reed-Solomon     |
447///                        | Encoder (k, m)   |
448///                        +------------------+
449///                                  |
450///                                  v
451///              +----------+ +----------+    +-----------+
452///              |  Shard k | | Shard k+1| .. | Shard n-1 |  (Recovery Shards)
453///              +----------+ +----------+    +-----------+
454/// ```
455///
456/// ## Merkle Tree Construction
457///
458/// All `n` shards (data and recovery) are hashed and used as leaves to build a [bmt].
459///
460/// ```text
461/// Shards:    [Shard 0, Shard 1, ..., Shard n-1]
462///             |        |              |
463///             v        v              v
464/// Hashes:    [H(S_0), H(S_1), ..., H(S_n-1)]
465///             \       / \       /
466///              \     /   \     /
467///               +---+     +---+
468///                 |         |
469///                 \         /
470///                  \       /
471///                   +-----+
472///                      |
473///                      v
474///                +----------+
475///                |   Root   |
476///                +----------+
477/// ```
478///
479/// The final output is the [bmt] root and a set of `n` chunks.
480///
481/// `(Root, [Chunk 0, Chunk 1, ..., Chunk n-1])`
482///
483/// Each chunk contains:
484/// - `shard`: The shard data (original or recovery).
485/// - `index`: The shard's original index (0 to n-1).
486/// - `proof`: A Merkle proof of the shard's inclusion in the [bmt].
487///
488/// ## Decoding and Verification
489///
490/// The decoder requires any `k` chunks to reconstruct the original data.
491/// 1. Each chunk's Merkle proof is verified against the [bmt] root.
492/// 2. The shards from the valid chunks are used to reconstruct the original `k` data shards.
493/// 3. To ensure consistency, the recovered data shards are re-encoded, and a new [bmt] root is
494///    generated. This new root MUST match the original [bmt] root. This prevents attacks where
495///    an adversary provides a valid set of chunks that decode to different data.
496/// 4. If the roots match, the original data is extracted from the reconstructed data shards.
497#[derive(Clone, Copy)]
498pub struct ReedSolomon<H> {
499    _marker: PhantomData<H>,
500}
501
502impl<H> std::fmt::Debug for ReedSolomon<H> {
503    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
504        f.debug_struct("ReedSolomon").finish()
505    }
506}
507
508impl<H: Hasher> Scheme for ReedSolomon<H> {
509    type Commitment = H::Digest;
510
511    type Shard = Chunk<H>;
512    type ReShard = Chunk<H>;
513    type CheckedShard = Chunk<H>;
514    type CheckingData = ();
515
516    type Error = Error;
517
518    fn encode(
519        config: &Config,
520        mut data: impl Buf,
521        concurrency: usize,
522    ) -> Result<(Self::Commitment, Vec<Self::Shard>), Self::Error> {
523        let data: Vec<u8> = data.copy_to_bytes(data.remaining()).to_vec();
524        encode(
525            total_shards(config)?,
526            config.minimum_shards,
527            data,
528            concurrency,
529        )
530    }
531
532    fn reshard(
533        _config: &Config,
534        commitment: &Self::Commitment,
535        index: u16,
536        shard: Self::Shard,
537    ) -> Result<(Self::CheckingData, Self::CheckedShard, Self::ReShard), Self::Error> {
538        if shard.index != index {
539            return Err(Error::WrongIndex(index));
540        }
541        if shard.verify(shard.index, commitment) {
542            Ok(((), shard.clone(), shard))
543        } else {
544            Err(Error::InvalidProof)
545        }
546    }
547
548    fn check(
549        _config: &Config,
550        commitment: &Self::Commitment,
551        _checking_data: &Self::CheckingData,
552        index: u16,
553        reshard: Self::ReShard,
554    ) -> Result<Self::CheckedShard, Self::Error> {
555        if reshard.index != index {
556            return Err(Error::WrongIndex(reshard.index));
557        }
558        if !reshard.verify(reshard.index, commitment) {
559            return Err(Error::InvalidProof);
560        }
561        Ok(reshard)
562    }
563
564    fn decode(
565        config: &Config,
566        commitment: &Self::Commitment,
567        _checking_data: Self::CheckingData,
568        shards: &[Self::CheckedShard],
569        concurrency: usize,
570    ) -> Result<Vec<u8>, Self::Error> {
571        decode(
572            total_shards(config)?,
573            config.minimum_shards,
574            commitment,
575            shards,
576            concurrency,
577        )
578    }
579}
580
581#[cfg(test)]
582mod tests {
583    use super::*;
584    use commonware_cryptography::Sha256;
585
586    const CONCURRENCY: usize = 1;
587
588    #[test]
589    fn test_recovery() {
590        let data = b"Testing recovery pieces";
591        let total = 8u16;
592        let min = 3u16;
593
594        // Encode the data
595        let (root, chunks) = encode::<Sha256>(total, min, data.to_vec(), CONCURRENCY).unwrap();
596
597        // Use a mix of original and recovery pieces
598        let pieces: Vec<_> = vec![
599            chunks[0].clone(), // original
600            chunks[4].clone(), // recovery
601            chunks[6].clone(), // recovery
602        ];
603
604        // Try to decode with a mix of original and recovery pieces
605        let decoded = decode::<Sha256>(total, min, &root, &pieces, CONCURRENCY).unwrap();
606        assert_eq!(decoded, data);
607    }
608
609    #[test]
610    fn test_not_enough_pieces() {
611        let data = b"Test insufficient pieces";
612        let total = 6u16;
613        let min = 4u16;
614
615        // Encode data
616        let (root, chunks) = encode::<Sha256>(total, min, data.to_vec(), CONCURRENCY).unwrap();
617
618        // Try with fewer than min
619        let pieces: Vec<_> = chunks.into_iter().take(2).collect();
620
621        // Fail to decode
622        let result = decode::<Sha256>(total, min, &root, &pieces, CONCURRENCY);
623        assert!(matches!(result, Err(Error::NotEnoughChunks)));
624    }
625
626    #[test]
627    fn test_duplicate_index() {
628        let data = b"Test duplicate detection";
629        let total = 5u16;
630        let min = 3u16;
631
632        // Encode data
633        let (root, chunks) = encode::<Sha256>(total, min, data.to_vec(), CONCURRENCY).unwrap();
634
635        // Include duplicate index by cloning the first chunk
636        let pieces = vec![chunks[0].clone(), chunks[0].clone(), chunks[1].clone()];
637
638        // Fail to decode
639        let result = decode::<Sha256>(total, min, &root, &pieces, CONCURRENCY);
640        assert!(matches!(result, Err(Error::DuplicateIndex(0))));
641    }
642
643    #[test]
644    fn test_invalid_index() {
645        let data = b"Test invalid index";
646        let total = 5u16;
647        let min = 3u16;
648
649        // Encode data
650        let (root, chunks) = encode::<Sha256>(total, min, data.to_vec(), CONCURRENCY).unwrap();
651
652        // Verify all proofs at invalid index
653        for i in 0..total {
654            assert!(!chunks[i as usize].verify(i + 1, &root));
655        }
656    }
657
658    #[test]
659    #[should_panic(expected = "assertion failed: total > min")]
660    fn test_invalid_total() {
661        let data = b"Test parameter validation";
662
663        // total <= min should panic
664        encode::<Sha256>(3, 3, data.to_vec(), CONCURRENCY).unwrap();
665    }
666
667    #[test]
668    #[should_panic(expected = "assertion failed: min > 0")]
669    fn test_invalid_min() {
670        let data = b"Test parameter validation";
671
672        // min = 0 should panic
673        encode::<Sha256>(5, 0, data.to_vec(), CONCURRENCY).unwrap();
674    }
675
676    #[test]
677    fn test_empty_data() {
678        let data = b"";
679        let total = 100u16;
680        let min = 30u16;
681
682        // Encode data
683        let (root, chunks) = encode::<Sha256>(total, min, data.to_vec(), CONCURRENCY).unwrap();
684
685        // Try to decode with min
686        let minimal = chunks.into_iter().take(min as usize).collect::<Vec<_>>();
687        let decoded = decode::<Sha256>(total, min, &root, &minimal, CONCURRENCY).unwrap();
688        assert_eq!(decoded, data);
689    }
690
691    #[test]
692    fn test_large_data() {
693        let data = vec![42u8; 1000]; // 1KB of data
694        let total = 7u16;
695        let min = 4u16;
696
697        // Encode data
698        let (root, chunks) = encode::<Sha256>(total, min, data.clone(), CONCURRENCY).unwrap();
699
700        // Try to decode with min
701        let minimal = chunks.into_iter().take(min as usize).collect::<Vec<_>>();
702        let decoded = decode::<Sha256>(total, min, &root, &minimal, CONCURRENCY).unwrap();
703        assert_eq!(decoded, data);
704    }
705
706    #[test]
707    fn test_malicious_root_detection() {
708        let data = b"Original data that should be protected";
709        let total = 7u16;
710        let min = 4u16;
711
712        // Encode data correctly to get valid chunks
713        let (_correct_root, chunks) =
714            encode::<Sha256>(total, min, data.to_vec(), CONCURRENCY).unwrap();
715
716        // Create a malicious/fake root (simulating a malicious encoder)
717        let mut hasher = Sha256::new();
718        hasher.update(b"malicious_data_that_wasnt_actually_encoded");
719        let malicious_root = hasher.finalize();
720
721        // Verify all proofs at incorrect root
722        for i in 0..total {
723            assert!(!chunks[i as usize].verify(i, &malicious_root));
724        }
725
726        // Collect valid pieces (these are legitimate fragments)
727        let minimal = chunks.into_iter().take(min as usize).collect::<Vec<_>>();
728
729        // Attempt to decode with malicious root
730        let result = decode::<Sha256>(total, min, &malicious_root, &minimal, CONCURRENCY);
731        assert!(matches!(result, Err(Error::Inconsistent)));
732    }
733
734    #[test]
735    fn test_manipulated_chunk_detection() {
736        let data = b"Data integrity must be maintained";
737        let total = 6u16;
738        let min = 3u16;
739
740        // Encode data
741        let (root, mut chunks) = encode::<Sha256>(total, min, data.to_vec(), CONCURRENCY).unwrap();
742
743        // Tamper with one of the chunks by modifying the shard data
744        if !chunks[1].shard.is_empty() {
745            chunks[1].shard[0] ^= 0xFF; // Flip bits in first byte
746        }
747
748        // Try to decode with the tampered chunk
749        let result = decode::<Sha256>(total, min, &root, &chunks, CONCURRENCY);
750        assert!(matches!(result, Err(Error::Inconsistent)));
751    }
752
753    #[test]
754    fn test_inconsistent_shards() {
755        let data = b"Test data for malicious encoding";
756        let total = 5u16;
757        let min = 3u16;
758        let m = total - min;
759
760        // Compute original data encoding
761        let shards = prepare_data(data.to_vec(), min as usize, total as usize - min as usize);
762        let shard_size = shards[0].len();
763
764        // Re-encode the data
765        let mut encoder = ReedSolomonEncoder::new(min as usize, m as usize, shard_size).unwrap();
766        for shard in &shards {
767            encoder.add_original_shard(shard).unwrap();
768        }
769        let recovery_result = encoder.encode().unwrap();
770        let mut recovery_shards: Vec<Vec<u8>> = recovery_result
771            .recovery_iter()
772            .map(|s| s.to_vec())
773            .collect();
774
775        // Tamper with one recovery shard
776        if !recovery_shards[0].is_empty() {
777            recovery_shards[0][0] ^= 0xFF;
778        }
779
780        // Build malicious shards
781        let mut malicious_shards = shards.clone();
782        malicious_shards.extend(recovery_shards);
783
784        // Build malicious tree
785        let mut builder = Builder::<Sha256>::new(total as usize);
786        for shard in &malicious_shards {
787            let mut hasher = Sha256::new();
788            hasher.update(shard);
789            builder.add(&hasher.finalize());
790        }
791        let malicious_tree = builder.build();
792        let malicious_root = malicious_tree.root();
793
794        // Generate chunks for min pieces, including the tampered recovery
795        let selected_indices = vec![0, 1, 3]; // originals 0,1 and recovery 0 (index 3)
796        let mut pieces = Vec::new();
797        for &i in &selected_indices {
798            let merkle_proof = malicious_tree.proof(i as u32).unwrap();
799            let shard = malicious_shards[i].clone();
800            let chunk = Chunk::new(shard, i as u16, merkle_proof);
801            pieces.push(chunk);
802        }
803
804        // Fail to decode
805        let result = decode::<Sha256>(total, min, &malicious_root, &pieces, CONCURRENCY);
806        assert!(matches!(result, Err(Error::Inconsistent)));
807    }
808
809    #[test]
810    fn test_decode_invalid_index() {
811        let data = b"Testing recovery pieces";
812        let total = 8u16;
813        let min = 3u16;
814
815        // Encode the data
816        let (root, mut chunks) = encode::<Sha256>(total, min, data.to_vec(), CONCURRENCY).unwrap();
817
818        // Use a mix of original and recovery pieces
819        chunks[1].index = 8;
820        let pieces: Vec<_> = vec![
821            chunks[0].clone(), // original
822            chunks[1].clone(), // recovery with invalid index
823            chunks[6].clone(), // recovery
824        ];
825
826        // Fail to decode
827        let result = decode::<Sha256>(total, min, &root, &pieces, CONCURRENCY);
828        assert!(matches!(result, Err(Error::InvalidIndex(8))));
829    }
830
831    #[test]
832    fn test_max_chunks() {
833        let data = vec![42u8; 1000]; // 1KB of data
834        let total = u16::MAX;
835        let min = u16::MAX / 2;
836
837        // Encode data
838        let (root, chunks) = encode::<Sha256>(total, min, data.clone(), CONCURRENCY).unwrap();
839
840        // Try to decode with min
841        let minimal = chunks.into_iter().take(min as usize).collect::<Vec<_>>();
842        let decoded = decode::<Sha256>(total, min, &root, &minimal, CONCURRENCY).unwrap();
843        assert_eq!(decoded, data);
844    }
845
846    #[test]
847    fn test_too_many_chunks() {
848        let data = vec![42u8; 1000]; // 1KB of data
849        let total = u16::MAX;
850        let min = u16::MAX / 2 - 1;
851
852        // Encode data
853        let result = encode::<Sha256>(total, min, data, CONCURRENCY);
854        assert!(matches!(
855            result,
856            Err(Error::ReedSolomon(
857                reed_solomon_simd::Error::UnsupportedShardCount {
858                    original_count: _,
859                    recovery_count: _,
860                }
861            ))
862        ));
863    }
864
865    #[test]
866    fn test_too_many_total_shards() {
867        assert!(ReedSolomon::<Sha256>::encode(
868            &Config {
869                minimum_shards: u16::MAX / 2 + 1,
870                extra_shards: u16::MAX,
871            },
872            [].as_slice(),
873            CONCURRENCY,
874        )
875        .is_err())
876    }
877
878    #[cfg(feature = "arbitrary")]
879    mod conformance {
880        use super::*;
881        use commonware_codec::conformance::CodecConformance;
882
883        commonware_conformance::conformance_tests! {
884            CodecConformance<Chunk<Sha256>>,
885        }
886    }
887}