Skip to main content

commonware_coding/
reed_solomon.rs

1use crate::{Config, Scheme};
2use bytes::{Buf, BufMut, Bytes};
3use commonware_codec::{EncodeSize, FixedSize, RangeCfg, Read, ReadExt, Write};
4use commonware_cryptography::{Digest, Hasher};
5use commonware_parallel::Strategy;
6use commonware_storage::bmt::{self, Builder};
7use commonware_utils::Cached;
8use reed_solomon_simd::{Error as RsError, ReedSolomonDecoder, ReedSolomonEncoder};
9use std::marker::PhantomData;
10use thiserror::Error;
11
12// Thread-local caches for reusing `ReedSolomonEncoder` and `ReedSolomonDecoder`
13// instances across calls. Constructing these objects is expensive because
14// the underlying engine initializes GF lookup tables. The `reset()` method
15// reconfigures the work buffers without rebuilding those tables.
16commonware_utils::thread_local_cache!(static CACHED_ENCODER: ReedSolomonEncoder);
17commonware_utils::thread_local_cache!(static CACHED_DECODER: ReedSolomonDecoder);
18
19/// Errors that can occur when interacting with the Reed-Solomon coder.
20#[derive(Error, Debug)]
21pub enum Error {
22    #[error("reed-solomon error: {0}")]
23    ReedSolomon(#[from] RsError),
24    #[error("inconsistent")]
25    Inconsistent,
26    #[error("invalid proof")]
27    InvalidProof,
28    #[error("not enough chunks")]
29    NotEnoughChunks,
30    #[error("duplicate chunk index: {0}")]
31    DuplicateIndex(u16),
32    #[error("invalid data length: {0}")]
33    InvalidDataLength(usize),
34    #[error("invalid index: {0}")]
35    InvalidIndex(u16),
36    #[error("too many total shards: {0}")]
37    TooManyTotalShards(u32),
38}
39
40fn total_shards(config: &Config) -> Result<u16, Error> {
41    let total = config.total_shards();
42    total
43        .try_into()
44        .map_err(|_| Error::TooManyTotalShards(total))
45}
46
47/// A piece of data from a Reed-Solomon encoded object.
48#[derive(Debug, Clone)]
49pub struct Chunk<D: Digest> {
50    /// The shard of encoded data.
51    shard: Bytes,
52
53    /// The index of [Chunk] in the original data.
54    index: u16,
55
56    /// The multi-proof of the shard in the [bmt] at the given index.
57    proof: bmt::Proof<D>,
58}
59
60impl<D: Digest> Chunk<D> {
61    /// Create a new [Chunk] from the given shard, index, and proof.
62    const fn new(shard: Bytes, index: u16, proof: bmt::Proof<D>) -> Self {
63        Self {
64            shard,
65            index,
66            proof,
67        }
68    }
69
70    /// Verify a [Chunk] against the given root.
71    fn verify<H: Hasher<Digest = D>>(&self, index: u16, root: &D) -> Option<CheckedChunk<D>> {
72        // Ensure the index matches
73        if index != self.index {
74            return None;
75        }
76
77        // Compute shard digest
78        let mut hasher = H::new();
79        hasher.update(&self.shard);
80        let shard_digest = hasher.finalize();
81
82        // Verify proof
83        self.proof
84            .verify_element_inclusion(&mut hasher, &shard_digest, self.index as u32, root)
85            .ok()?;
86
87        Some(CheckedChunk::new(
88            self.shard.clone(),
89            self.index,
90            shard_digest,
91        ))
92    }
93}
94
95/// A shard that has been checked against a commitment.
96///
97/// This stores the shard digest computed during [Chunk::verify], so decode
98/// can reuse it without hashing the same shard again.
99#[derive(Clone, Debug, PartialEq, Eq)]
100pub struct CheckedChunk<D: Digest> {
101    shard: Bytes,
102    index: u16,
103    digest: D,
104}
105
106impl<D: Digest> CheckedChunk<D> {
107    const fn new(shard: Bytes, index: u16, digest: D) -> Self {
108        Self {
109            shard,
110            index,
111            digest,
112        }
113    }
114}
115
116impl<D: Digest> Write for Chunk<D> {
117    fn write(&self, writer: &mut impl BufMut) {
118        self.shard.write(writer);
119        self.index.write(writer);
120        self.proof.write(writer);
121    }
122}
123
124impl<D: Digest> Read for Chunk<D> {
125    /// The maximum size of the shard.
126    type Cfg = crate::CodecConfig;
127
128    fn read_cfg(reader: &mut impl Buf, cfg: &Self::Cfg) -> Result<Self, commonware_codec::Error> {
129        let shard = Bytes::read_cfg(reader, &RangeCfg::new(..=cfg.maximum_shard_size))?;
130        let index = u16::read(reader)?;
131        let proof = bmt::Proof::<D>::read_cfg(reader, &1)?;
132        Ok(Self {
133            shard,
134            index,
135            proof,
136        })
137    }
138}
139
140impl<D: Digest> EncodeSize for Chunk<D> {
141    fn encode_size(&self) -> usize {
142        self.shard.encode_size() + self.index.encode_size() + self.proof.encode_size()
143    }
144}
145
146impl<D: Digest> PartialEq for Chunk<D> {
147    fn eq(&self, other: &Self) -> bool {
148        self.shard == other.shard && self.index == other.index && self.proof == other.proof
149    }
150}
151
152impl<D: Digest> Eq for Chunk<D> {}
153
154#[cfg(feature = "arbitrary")]
155impl<D: Digest> arbitrary::Arbitrary<'_> for Chunk<D>
156where
157    D: for<'a> arbitrary::Arbitrary<'a>,
158{
159    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
160        Ok(Self {
161            shard: u.arbitrary::<Vec<u8>>()?.into(),
162            index: u.arbitrary()?,
163            proof: u.arbitrary()?,
164        })
165    }
166}
167
168/// Prepare data for encoding.
169///
170/// Returns a contiguous buffer of `k` padded shards and the shard length.
171/// The buffer layout is `[length_prefix | data | zero_padding]` split into
172/// `k` equal-sized shards of `shard_len` bytes each.
173fn prepare_data(data: &[u8], k: usize) -> (Vec<u8>, usize) {
174    // Compute shard length
175    let data_len = data.len();
176    let prefixed_len = u32::SIZE + data_len;
177    let mut shard_len = prefixed_len.div_ceil(k);
178
179    // Ensure shard length is even (required for optimizations in `reed-solomon-simd`)
180    if !shard_len.is_multiple_of(2) {
181        shard_len += 1;
182    }
183
184    // Prepare data
185    let length_bytes = (data_len as u32).to_be_bytes();
186    let mut padded = vec![0u8; k * shard_len];
187    padded[..u32::SIZE].copy_from_slice(&length_bytes);
188    padded[u32::SIZE..u32::SIZE + data_len].copy_from_slice(data);
189
190    (padded, shard_len)
191}
192
193/// Extract data from encoded shards.
194///
195/// The first `k` shards, when concatenated, form `[length_prefix | data | padding]`.
196/// This function bulk-copies shard slices while skipping the 4-byte prefix.
197fn extract_data(shards: &[&[u8]], k: usize) -> Result<Vec<u8>, Error> {
198    let shards = shards.get(..k).ok_or(Error::NotEnoughChunks)?;
199    let (data_len, payload_len) = read_prefix_and_payload_len(shards)?;
200    let mut payload = copy_payload_after_prefix(shards, payload_len);
201    validate_zero_padding(&payload, data_len)?;
202    payload.truncate(data_len);
203    Ok(payload)
204}
205
206/// Read the 4-byte big-endian length prefix from `shards` and validate that
207/// the decoded length fits in the post-prefix payload region.
208fn read_prefix_and_payload_len(shards: &[&[u8]]) -> Result<(usize, usize), Error> {
209    let total_len: usize = shards.iter().map(|s| s.len()).sum();
210    if total_len < u32::SIZE {
211        return Err(Error::Inconsistent);
212    }
213
214    // Read the length prefix, which may span multiple shards.
215    let mut prefix = [0u8; u32::SIZE];
216    let mut prefix_len = 0usize;
217    for shard in shards {
218        if prefix_len == u32::SIZE {
219            break;
220        }
221        let read = (u32::SIZE - prefix_len).min(shard.len());
222        prefix[prefix_len..prefix_len + read].copy_from_slice(&shard[..read]);
223        prefix_len += read;
224    }
225
226    let data_len = u32::from_be_bytes(prefix) as usize;
227    let payload_len = total_len - u32::SIZE;
228    if data_len > payload_len {
229        return Err(Error::Inconsistent);
230    }
231    Ok((data_len, payload_len))
232}
233
234/// Bulk-copy bytes after the 4-byte prefix from `shards` into a contiguous
235/// payload buffer.
236fn copy_payload_after_prefix(shards: &[&[u8]], payload_len: usize) -> Vec<u8> {
237    let mut payload = Vec::with_capacity(payload_len);
238    let mut prefix_bytes_left = u32::SIZE;
239    for shard in shards {
240        if prefix_bytes_left >= shard.len() {
241            prefix_bytes_left -= shard.len();
242            continue;
243        }
244        payload.extend_from_slice(&shard[prefix_bytes_left..]);
245        prefix_bytes_left = 0;
246    }
247    payload
248}
249
250/// Validate canonical encoding by requiring trailing bytes after `data_len`
251/// to be zero.
252fn validate_zero_padding(payload: &[u8], data_len: usize) -> Result<(), Error> {
253    // Canonical encoding requires all trailing bytes to be zero.
254    if !payload[data_len..].iter().all(|byte| *byte == 0) {
255        return Err(Error::Inconsistent);
256    }
257    Ok(())
258}
259
260/// Type alias for the internal encoding result.
261type Encoding<D> = (D, Vec<Chunk<D>>);
262
263/// Encode data using a Reed-Solomon coder and insert it into a [bmt].
264///
265/// # Parameters
266///
267/// - `total`: The total number of chunks to generate.
268/// - `min`: The minimum number of chunks required to decode the data.
269/// - `data`: The data to encode.
270/// - `strategy`: The parallelism strategy to use.
271///
272/// # Returns
273///
274/// - `root`: The root of the [bmt].
275/// - `chunks`: [Chunk]s of encoded data (that can be proven against `root`).
276fn encode<H: Hasher, S: Strategy>(
277    total: u16,
278    min: u16,
279    data: Vec<u8>,
280    strategy: &S,
281) -> Result<Encoding<H::Digest>, Error> {
282    // Validate parameters
283    assert!(total > min);
284    assert!(min > 0);
285    let n = total as usize;
286    let k = min as usize;
287    let m = n - k;
288    if data.len() > u32::MAX as usize {
289        return Err(Error::InvalidDataLength(data.len()));
290    }
291
292    // Prepare data as a contiguous buffer of k shards
293    let (padded, shard_len) = prepare_data(&data, k);
294
295    // Create or reuse encoder
296    let recovery_buf = {
297        let mut encoder = Cached::take(
298            &CACHED_ENCODER,
299            || ReedSolomonEncoder::new(k, m, shard_len),
300            |enc| enc.reset(k, m, shard_len),
301        )
302        .map_err(Error::ReedSolomon)?;
303        for shard in padded.chunks(shard_len) {
304            encoder
305                .add_original_shard(shard)
306                .map_err(Error::ReedSolomon)?;
307        }
308
309        // Compute recovery shards and collect into a contiguous buffer
310        let encoding = encoder.encode().map_err(Error::ReedSolomon)?;
311        let mut buf = Vec::with_capacity(m * shard_len);
312        for shard in encoding.recovery_iter() {
313            buf.extend_from_slice(shard);
314        }
315        buf
316    };
317
318    // Create zero-copy Bytes views into the original and recovery buffers
319    let originals: Bytes = padded.into();
320    let recoveries: Bytes = recovery_buf.into();
321
322    // Build Merkle tree
323    let mut builder = Builder::<H>::new(n);
324    let shard_slices: Vec<Bytes> = (0..k)
325        .map(|i| originals.slice(i * shard_len..(i + 1) * shard_len))
326        .chain((0..m).map(|i| recoveries.slice(i * shard_len..(i + 1) * shard_len)))
327        .collect();
328    let shard_hashes = strategy.map_init_collect_vec(&shard_slices, H::new, |hasher, shard| {
329        hasher.update(shard);
330        hasher.finalize()
331    });
332    for hash in &shard_hashes {
333        builder.add(hash);
334    }
335    let tree = builder.build();
336    let root = tree.root();
337
338    // Generate chunks with zero-copy shard views
339    let mut chunks = Vec::with_capacity(n);
340    for (i, shard) in shard_slices.into_iter().enumerate() {
341        let proof = tree.proof(i as u32).map_err(|_| Error::InvalidProof)?;
342        chunks.push(Chunk::new(shard, i as u16, proof));
343    }
344
345    Ok((root, chunks))
346}
347
348/// Decode data from a set of [CheckedChunk]s.
349///
350/// It is assumed that all chunks have already been verified against the given root using [Chunk::verify].
351///
352/// # Parameters
353///
354/// - `total`: The total number of chunks to generate.
355/// - `min`: The minimum number of chunks required to decode the data.
356/// - `root`: The root of the [bmt].
357/// - `chunks`: [CheckedChunk]s of encoded data (that can be proven against `root`)
358///
359/// # Returns
360///
361/// - `data`: The decoded data.
362fn decode<H: Hasher, S: Strategy>(
363    total: u16,
364    min: u16,
365    root: &H::Digest,
366    chunks: &[CheckedChunk<H::Digest>],
367    strategy: &S,
368) -> Result<Vec<u8>, Error> {
369    // Validate parameters
370    assert!(total > min);
371    assert!(min > 0);
372    let n = total as usize;
373    let k = min as usize;
374    let m = n - k;
375    if chunks.len() < k {
376        return Err(Error::NotEnoughChunks);
377    }
378
379    // Process checked chunks
380    let shard_len = chunks[0].shard.len();
381    let mut shard_digests: Vec<Option<H::Digest>> = vec![None; n];
382    let mut provided_originals: Vec<(usize, &[u8])> = Vec::new();
383    let mut provided_recoveries: Vec<(usize, &[u8])> = Vec::new();
384    for chunk in chunks {
385        // Check for duplicate index
386        let index = chunk.index;
387        if index >= total {
388            return Err(Error::InvalidIndex(index));
389        }
390        let digest_slot = &mut shard_digests[index as usize];
391        if digest_slot.is_some() {
392            return Err(Error::DuplicateIndex(index));
393        }
394
395        // Add to provided shards and retain the checked digest for this index.
396        *digest_slot = Some(chunk.digest);
397        if index < min {
398            provided_originals.push((index as usize, chunk.shard.as_ref()));
399        } else {
400            provided_recoveries.push((index as usize - k, chunk.shard.as_ref()));
401        }
402    }
403
404    // Decode original data
405    let mut decoder = Cached::take(
406        &CACHED_DECODER,
407        || ReedSolomonDecoder::new(k, m, shard_len),
408        |dec| dec.reset(k, m, shard_len),
409    )
410    .map_err(Error::ReedSolomon)?;
411    for (idx, ref shard) in &provided_originals {
412        decoder
413            .add_original_shard(*idx, shard)
414            .map_err(Error::ReedSolomon)?;
415    }
416    for (idx, ref shard) in &provided_recoveries {
417        decoder
418            .add_recovery_shard(*idx, shard)
419            .map_err(Error::ReedSolomon)?;
420    }
421    let decoding = decoder.decode().map_err(Error::ReedSolomon)?;
422
423    // Reconstruct all original shards
424    let mut shards = vec![Default::default(); k];
425    for (idx, shard) in provided_originals
426        .into_iter()
427        .chain(decoding.restored_original_iter())
428    {
429        shards[idx] = shard;
430    }
431
432    // Re-encode recovered data to get recovery shards
433    let mut encoder = Cached::take(
434        &CACHED_ENCODER,
435        || ReedSolomonEncoder::new(k, m, shard_len),
436        |enc| enc.reset(k, m, shard_len),
437    )
438    .map_err(Error::ReedSolomon)?;
439    for shard in shards.iter().take(k) {
440        encoder
441            .add_original_shard(shard)
442            .map_err(Error::ReedSolomon)?;
443    }
444    let encoding = encoder.encode().map_err(Error::ReedSolomon)?;
445    shards.extend(encoding.recovery_iter());
446
447    // Build Merkle tree
448    for (i, digest) in strategy.map_init_collect_vec(
449        shard_digests
450            .iter()
451            .enumerate()
452            .filter_map(|(i, digest)| digest.is_none().then_some(i)),
453        H::new,
454        |hasher, i| {
455            hasher.update(shards[i]);
456            (i, hasher.finalize())
457        },
458    ) {
459        shard_digests[i] = Some(digest);
460    }
461
462    let mut builder = Builder::<H>::new(n);
463    shard_digests
464        .into_iter()
465        .map(|digest| digest.expect("digest must be present for every shard"))
466        .for_each(|digest| {
467            builder.add(&digest);
468        });
469    let tree = builder.build();
470
471    // Confirm root is consistent
472    if tree.root() != *root {
473        return Err(Error::Inconsistent);
474    }
475
476    // Extract original data
477    extract_data(&shards, k)
478}
479
480/// A SIMD-optimized Reed-Solomon coder that emits chunks that can be proven against a [bmt].
481///
482/// # Behavior
483///
484/// The encoder takes input data, splits it into `k` data shards, and generates `m` recovery
485/// shards using [Reed-Solomon encoding](https://en.wikipedia.org/wiki/Reed%E2%80%93Solomon_error_correction).
486/// All `n = k + m` shards are then used to build a [bmt], producing a single root hash. Each shard
487/// is packaged as a chunk containing the shard data, its index, and a Merkle multi-proof against the [bmt] root.
488///
489/// ## Encoding
490///
491/// ```text
492///               +--------------------------------------+
493///               |         Original Data (Bytes)        |
494///               +--------------------------------------+
495///                                  |
496///                                  v
497///               +--------------------------------------+
498///               | [Length Prefix | Original Data...]   |
499///               +--------------------------------------+
500///                                  |
501///                                  v
502///              +----------+ +----------+    +-----------+
503///              |  Shard 0 | |  Shard 1 | .. | Shard k-1 |  (Data Shards)
504///              +----------+ +----------+    +-----------+
505///                     |            |             |
506///                     |            |             |
507///                     +------------+-------------+
508///                                  |
509///                                  v
510///                        +------------------+
511///                        | Reed-Solomon     |
512///                        | Encoder (k, m)   |
513///                        +------------------+
514///                                  |
515///                                  v
516///              +----------+ +----------+    +-----------+
517///              |  Shard k | | Shard k+1| .. | Shard n-1 |  (Recovery Shards)
518///              +----------+ +----------+    +-----------+
519/// ```
520///
521/// ## Merkle Tree Construction
522///
523/// All `n` shards (data and recovery) are hashed and used as leaves to build a [bmt].
524///
525/// ```text
526/// Shards:    [Shard 0, Shard 1, ..., Shard n-1]
527///             |        |              |
528///             v        v              v
529/// Hashes:    [H(S_0), H(S_1), ..., H(S_n-1)]
530///             \       / \       /
531///              \     /   \     /
532///               +---+     +---+
533///                 |         |
534///                 \         /
535///                  \       /
536///                   +-----+
537///                      |
538///                      v
539///                +----------+
540///                |   Root   |
541///                +----------+
542/// ```
543///
544/// The final output is the [bmt] root and a set of `n` chunks.
545///
546/// `(Root, [Chunk 0, Chunk 1, ..., Chunk n-1])`
547///
548/// Each chunk contains:
549/// - `shard`: The shard data (original or recovery).
550/// - `index`: The shard's original index (0 to n-1).
551/// - `proof`: A Merkle multi-proof of the shard's inclusion in the [bmt].
552///
553/// ## Decoding and Verification
554///
555/// The decoder requires any `k` chunks to reconstruct the original data.
556/// 1. Each chunk's Merkle multi-proof is verified against the [bmt] root.
557/// 2. The shards from the valid chunks are used to reconstruct the original `k` data shards.
558/// 3. To ensure consistency, the recovered data shards are re-encoded, and a new [bmt] root is
559///    generated. This new root MUST match the original [bmt] root. This prevents attacks where
560///    an adversary provides a valid set of chunks that decode to different data.
561/// 4. If the roots match, the original data is extracted from the reconstructed data shards.
562#[derive(Clone, Copy)]
563pub struct ReedSolomon<H> {
564    _marker: PhantomData<H>,
565}
566
567impl<H> std::fmt::Debug for ReedSolomon<H> {
568    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
569        f.debug_struct("ReedSolomon").finish()
570    }
571}
572
573impl<H: Hasher> Scheme for ReedSolomon<H> {
574    type Commitment = H::Digest;
575
576    type StrongShard = Chunk<H::Digest>;
577    type WeakShard = Chunk<H::Digest>;
578    type CheckedShard = CheckedChunk<H::Digest>;
579    type CheckingData = ();
580
581    type Error = Error;
582
583    fn encode(
584        config: &Config,
585        mut data: impl Buf,
586        strategy: &impl Strategy,
587    ) -> Result<(Self::Commitment, Vec<Self::StrongShard>), Self::Error> {
588        let data: Vec<u8> = data.copy_to_bytes(data.remaining()).to_vec();
589        encode::<H, _>(
590            total_shards(config)?,
591            config.minimum_shards.get(),
592            data,
593            strategy,
594        )
595    }
596
597    fn weaken(
598        config: &Config,
599        commitment: &Self::Commitment,
600        index: u16,
601        shard: Self::StrongShard,
602    ) -> Result<(Self::CheckingData, Self::CheckedShard, Self::WeakShard), Self::Error> {
603        let total = total_shards(config)?;
604        if index >= total {
605            return Err(Error::InvalidIndex(index));
606        }
607        if shard.proof.leaf_count != u32::from(total) {
608            return Err(Error::InvalidProof);
609        }
610        if shard.index != index {
611            return Err(Error::InvalidIndex(index));
612        }
613        let checked_shard = shard
614            .verify::<H>(shard.index, commitment)
615            .ok_or(Error::InvalidProof)?;
616        Ok(((), checked_shard, shard))
617    }
618
619    fn check(
620        config: &Config,
621        commitment: &Self::Commitment,
622        _checking_data: &Self::CheckingData,
623        index: u16,
624        weak_shard: Self::WeakShard,
625    ) -> Result<Self::CheckedShard, Self::Error> {
626        let total = total_shards(config)?;
627        if index >= total {
628            return Err(Error::InvalidIndex(index));
629        }
630        if weak_shard.proof.leaf_count != u32::from(total) {
631            return Err(Error::InvalidProof);
632        }
633        if weak_shard.index != index {
634            return Err(Error::InvalidIndex(weak_shard.index));
635        }
636        weak_shard
637            .verify::<H>(weak_shard.index, commitment)
638            .ok_or(Error::InvalidProof)
639    }
640
641    fn decode(
642        config: &Config,
643        commitment: &Self::Commitment,
644        _checking_data: Self::CheckingData,
645        shards: &[Self::CheckedShard],
646        strategy: &impl Strategy,
647    ) -> Result<Vec<u8>, Self::Error> {
648        decode::<H, _>(
649            total_shards(config)?,
650            config.minimum_shards.get(),
651            commitment,
652            shards,
653            strategy,
654        )
655    }
656}
657
658#[cfg(test)]
659mod tests {
660    use super::*;
661    use commonware_cryptography::Sha256;
662    use commonware_parallel::Sequential;
663    use commonware_utils::NZU16;
664
665    type RS = ReedSolomon<Sha256>;
666    const STRATEGY: Sequential = Sequential;
667
668    fn checked(
669        chunk: Chunk<<Sha256 as Hasher>::Digest>,
670    ) -> CheckedChunk<<Sha256 as Hasher>::Digest> {
671        let Chunk { shard, index, .. } = chunk;
672        let digest = Sha256::hash(&shard);
673        CheckedChunk::new(shard, index, digest)
674    }
675
676    #[test]
677    fn test_recovery() {
678        let data = b"Testing recovery pieces";
679        let total = 8u16;
680        let min = 3u16;
681
682        // Encode the data
683        let (root, chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
684
685        // Use a mix of original and recovery pieces
686        let pieces: Vec<_> = vec![
687            checked(chunks[0].clone()), // original
688            checked(chunks[4].clone()), // recovery
689            checked(chunks[6].clone()), // recovery
690        ];
691
692        // Try to decode with a mix of original and recovery pieces
693        let decoded = decode::<Sha256, _>(total, min, &root, &pieces, &STRATEGY).unwrap();
694        assert_eq!(decoded, data);
695    }
696
697    #[test]
698    fn test_not_enough_pieces() {
699        let data = b"Test insufficient pieces";
700        let total = 6u16;
701        let min = 4u16;
702
703        // Encode data
704        let (root, chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
705
706        // Try with fewer than min
707        let pieces: Vec<_> = chunks.into_iter().take(2).map(checked).collect();
708
709        // Fail to decode
710        let result = decode::<Sha256, _>(total, min, &root, &pieces, &STRATEGY);
711        assert!(matches!(result, Err(Error::NotEnoughChunks)));
712    }
713
714    #[test]
715    fn test_duplicate_index() {
716        let data = b"Test duplicate detection";
717        let total = 5u16;
718        let min = 3u16;
719
720        // Encode data
721        let (root, chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
722
723        // Include duplicate index by cloning the first chunk
724        let pieces = vec![
725            checked(chunks[0].clone()),
726            checked(chunks[0].clone()),
727            checked(chunks[1].clone()),
728        ];
729
730        // Fail to decode
731        let result = decode::<Sha256, _>(total, min, &root, &pieces, &STRATEGY);
732        assert!(matches!(result, Err(Error::DuplicateIndex(0))));
733    }
734
735    #[test]
736    fn test_invalid_index() {
737        let data = b"Test invalid index";
738        let total = 5u16;
739        let min = 3u16;
740
741        // Encode data
742        let (root, chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
743
744        // Verify all proofs at invalid index
745        for i in 0..total {
746            assert!(chunks[i as usize].verify::<Sha256>(i + 1, &root).is_none());
747        }
748    }
749
750    #[test]
751    #[should_panic(expected = "assertion failed: total > min")]
752    fn test_invalid_total() {
753        let data = b"Test parameter validation";
754
755        // total <= min should panic
756        encode::<Sha256, _>(3, 3, data.to_vec(), &STRATEGY).unwrap();
757    }
758
759    #[test]
760    #[should_panic(expected = "assertion failed: min > 0")]
761    fn test_invalid_min() {
762        let data = b"Test parameter validation";
763
764        // min = 0 should panic
765        encode::<Sha256, _>(5, 0, data.to_vec(), &STRATEGY).unwrap();
766    }
767
768    #[test]
769    fn test_empty_data() {
770        let data = b"";
771        let total = 100u16;
772        let min = 30u16;
773
774        // Encode data
775        let (root, chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
776
777        // Try to decode with min
778        let minimal = chunks
779            .into_iter()
780            .take(min as usize)
781            .map(checked)
782            .collect::<Vec<_>>();
783        let decoded = decode::<Sha256, _>(total, min, &root, &minimal, &STRATEGY).unwrap();
784        assert_eq!(decoded, data);
785    }
786
787    #[test]
788    fn test_large_data() {
789        let data = vec![42u8; 1000]; // 1KB of data
790        let total = 7u16;
791        let min = 4u16;
792
793        // Encode data
794        let (root, chunks) = encode::<Sha256, _>(total, min, data.clone(), &STRATEGY).unwrap();
795
796        // Try to decode with min
797        let minimal = chunks
798            .into_iter()
799            .take(min as usize)
800            .map(checked)
801            .collect::<Vec<_>>();
802        let decoded = decode::<Sha256, _>(total, min, &root, &minimal, &STRATEGY).unwrap();
803        assert_eq!(decoded, data);
804    }
805
806    #[test]
807    fn test_malicious_root_detection() {
808        let data = b"Original data that should be protected";
809        let total = 7u16;
810        let min = 4u16;
811
812        // Encode data correctly to get valid chunks
813        let (_correct_root, chunks) =
814            encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
815
816        // Create a malicious/fake root (simulating a malicious encoder)
817        let mut hasher = Sha256::new();
818        hasher.update(b"malicious_data_that_wasnt_actually_encoded");
819        let malicious_root = hasher.finalize();
820
821        // Verify all proofs at incorrect root
822        for i in 0..total {
823            assert!(chunks[i as usize]
824                .clone()
825                .verify::<Sha256>(i, &malicious_root)
826                .is_none());
827        }
828
829        // Collect valid pieces (these are legitimate fragments)
830        let minimal = chunks
831            .into_iter()
832            .take(min as usize)
833            .map(checked)
834            .collect::<Vec<_>>();
835
836        // Attempt to decode with malicious root
837        let result = decode::<Sha256, _>(total, min, &malicious_root, &minimal, &STRATEGY);
838        assert!(matches!(result, Err(Error::Inconsistent)));
839    }
840
841    #[test]
842    fn test_mismatched_config_rejected_during_weaken_and_check() {
843        let config_expected = Config {
844            minimum_shards: NZU16!(2),
845            extra_shards: NZU16!(2),
846        };
847        let config_actual = Config {
848            minimum_shards: NZU16!(3),
849            extra_shards: NZU16!(3),
850        };
851
852        let data = b"leaf_count mismatch proof";
853        let (commitment, shards) = RS::encode(&config_actual, data.as_slice(), &STRATEGY).unwrap();
854
855        // Previously this passed because weaken() ignored config and only verified
856        // against commitment root. It must now fail immediately.
857        let strong = shards[0].clone();
858        let weaken_result = RS::weaken(&config_expected, &commitment, 0, strong);
859        assert!(matches!(weaken_result, Err(Error::InvalidProof)));
860
861        // Produce a valid weak shard under the actual config, then ensure check()
862        // also rejects it when validated under the wrong config.
863        let (checking_data, _, weak) =
864            RS::weaken(&config_actual, &commitment, 1, shards[1].clone()).unwrap();
865        let check_result = RS::check(&config_expected, &commitment, &checking_data, 1, weak);
866        assert!(matches!(check_result, Err(Error::InvalidProof)));
867    }
868
869    #[test]
870    fn test_manipulated_chunk_detection() {
871        let data = b"Data integrity must be maintained";
872        let total = 6u16;
873        let min = 3u16;
874
875        // Encode data
876        let (root, chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
877        let mut pieces: Vec<_> = chunks.into_iter().map(checked).collect();
878
879        // Tamper with one of the checked chunks by modifying the shard data.
880        if !pieces[1].shard.is_empty() {
881            let mut shard = pieces[1].shard.to_vec();
882            shard[0] ^= 0xFF; // Flip bits in first byte
883            pieces[1].shard = shard.into();
884            pieces[1].digest = Sha256::hash(&pieces[1].shard);
885        }
886
887        // Try to decode with the tampered chunk
888        let result = decode::<Sha256, _>(total, min, &root, &pieces, &STRATEGY);
889        assert!(matches!(result, Err(Error::Inconsistent)));
890    }
891
892    #[test]
893    fn test_inconsistent_shards() {
894        let data = b"Test data for malicious encoding";
895        let total = 5u16;
896        let min = 3u16;
897        let m = total - min;
898
899        // Compute original data encoding
900        let (padded, shard_size) = prepare_data(data, min as usize);
901
902        // Re-encode the data
903        let mut encoder = ReedSolomonEncoder::new(min as usize, m as usize, shard_size).unwrap();
904        for shard in padded.chunks(shard_size) {
905            encoder.add_original_shard(shard).unwrap();
906        }
907        let recovery_result = encoder.encode().unwrap();
908        let mut recovery_shards: Vec<Vec<u8>> = recovery_result
909            .recovery_iter()
910            .map(|s| s.to_vec())
911            .collect();
912
913        // Tamper with one recovery shard
914        if !recovery_shards[0].is_empty() {
915            recovery_shards[0][0] ^= 0xFF;
916        }
917
918        // Build malicious shards
919        let mut malicious_shards: Vec<Vec<u8>> =
920            padded.chunks(shard_size).map(|s| s.to_vec()).collect();
921        malicious_shards.extend(recovery_shards);
922
923        // Build malicious tree
924        let mut builder = Builder::<Sha256>::new(total as usize);
925        for shard in &malicious_shards {
926            let mut hasher = Sha256::new();
927            hasher.update(shard);
928            builder.add(&hasher.finalize());
929        }
930        let malicious_tree = builder.build();
931        let malicious_root = malicious_tree.root();
932
933        // Generate chunks for min pieces, including the tampered recovery
934        let selected_indices = vec![0, 1, 3]; // originals 0,1 and recovery 0 (index 3)
935        let mut pieces = Vec::new();
936        for &i in &selected_indices {
937            let merkle_proof = malicious_tree.proof(i as u32).unwrap();
938            let shard = malicious_shards[i].clone();
939            let chunk = Chunk::new(shard.into(), i as u16, merkle_proof);
940            pieces.push(chunk);
941        }
942        let pieces: Vec<_> = pieces.into_iter().map(checked).collect();
943
944        // Fail to decode
945        let result = decode::<Sha256, _>(total, min, &malicious_root, &pieces, &STRATEGY);
946        assert!(matches!(result, Err(Error::Inconsistent)));
947    }
948
949    // Regression: a commitment built from shards with non-zero trailing padding
950    // used to pass decode(), even though canonical re-encoding (zero padding)
951    // produces a different root. decode() must reject such non-canonical shards.
952    #[test]
953    fn test_non_canonical_padding_rejected() {
954        let data = b"X";
955        let total = 6u16;
956        let min = 3u16;
957        let k = min as usize;
958        let m = total as usize - k;
959
960        let (mut padded, shard_len) = prepare_data(data, k);
961        let payload_end = u32::SIZE + data.len();
962        let total_original_len = k * shard_len;
963        assert!(payload_end < total_original_len, "test requires padding");
964
965        // Corrupt one canonical padding byte while keeping payload unchanged.
966        let pad_shard = payload_end / shard_len;
967        let pad_offset = payload_end % shard_len;
968        padded[pad_shard * shard_len + pad_offset] = 0xAA;
969
970        let mut encoder = ReedSolomonEncoder::new(k, m, shard_len).unwrap();
971        for shard in padded.chunks(shard_len) {
972            encoder.add_original_shard(shard).unwrap();
973        }
974        let recovery = encoder.encode().unwrap();
975        let mut shards: Vec<Vec<u8>> = padded.chunks(shard_len).map(|s| s.to_vec()).collect();
976        shards.extend(recovery.recovery_iter().map(|s| s.to_vec()));
977
978        let mut builder = Builder::<Sha256>::new(total as usize);
979        for shard in &shards {
980            let mut hasher = Sha256::new();
981            hasher.update(shard);
982            builder.add(&hasher.finalize());
983        }
984        let tree = builder.build();
985        let non_canonical_root = tree.root();
986
987        let mut pieces = Vec::with_capacity(k);
988        for (i, shard) in shards.iter().take(k).enumerate() {
989            let proof = tree.proof(i as u32).unwrap();
990            pieces.push(checked(Chunk::new(shard.clone().into(), i as u16, proof)));
991        }
992
993        let result = decode::<Sha256, _>(total, min, &non_canonical_root, &pieces, &STRATEGY);
994        assert!(matches!(result, Err(Error::Inconsistent)));
995    }
996
997    #[test]
998    fn test_decode_invalid_index() {
999        let data = b"Testing recovery pieces";
1000        let total = 8u16;
1001        let min = 3u16;
1002
1003        // Encode the data
1004        let (root, chunks) = encode::<Sha256, _>(total, min, data.to_vec(), &STRATEGY).unwrap();
1005
1006        // Use a mix of original and recovery pieces
1007        let mut invalid = checked(chunks[1].clone());
1008        invalid.index = 8;
1009        let pieces: Vec<_> = vec![
1010            checked(chunks[0].clone()), // original
1011            invalid,                    // recovery with invalid index
1012            checked(chunks[6].clone()), // recovery
1013        ];
1014
1015        // Fail to decode
1016        let result = decode::<Sha256, _>(total, min, &root, &pieces, &STRATEGY);
1017        assert!(matches!(result, Err(Error::InvalidIndex(8))));
1018    }
1019
1020    #[test]
1021    fn test_max_chunks() {
1022        let data = vec![42u8; 1000]; // 1KB of data
1023        let total = u16::MAX;
1024        let min = u16::MAX / 2;
1025
1026        // Encode data
1027        let (root, chunks) = encode::<Sha256, _>(total, min, data.clone(), &STRATEGY).unwrap();
1028
1029        // Try to decode with min
1030        let minimal = chunks
1031            .into_iter()
1032            .take(min as usize)
1033            .map(checked)
1034            .collect::<Vec<_>>();
1035        let decoded = decode::<Sha256, _>(total, min, &root, &minimal, &STRATEGY).unwrap();
1036        assert_eq!(decoded, data);
1037    }
1038
1039    #[test]
1040    fn test_too_many_chunks() {
1041        let data = vec![42u8; 1000]; // 1KB of data
1042        let total = u16::MAX;
1043        let min = u16::MAX / 2 - 1;
1044
1045        // Encode data
1046        let result = encode::<Sha256, _>(total, min, data, &STRATEGY);
1047        assert!(matches!(
1048            result,
1049            Err(Error::ReedSolomon(
1050                reed_solomon_simd::Error::UnsupportedShardCount {
1051                    original_count: _,
1052                    recovery_count: _,
1053                }
1054            ))
1055        ));
1056    }
1057
1058    #[test]
1059    fn test_too_many_total_shards() {
1060        assert!(RS::encode(
1061            &Config {
1062                minimum_shards: NZU16!(u16::MAX / 2 + 1),
1063                extra_shards: NZU16!(u16::MAX),
1064            },
1065            [].as_slice(),
1066            &STRATEGY,
1067        )
1068        .is_err())
1069    }
1070
1071    #[cfg(feature = "arbitrary")]
1072    mod conformance {
1073        use super::*;
1074        use commonware_codec::conformance::CodecConformance;
1075        use commonware_cryptography::sha256::Digest as Sha256Digest;
1076
1077        commonware_conformance::conformance_tests! {
1078            CodecConformance<Chunk<Sha256Digest>>,
1079        }
1080    }
1081}