Skip to main content

commonware_coding/
reed_solomon.rs

1use crate::{Config, Scheme};
2use bytes::{Buf, BufMut, Bytes};
3use commonware_codec::{BufsMut, EncodeSize, FixedSize, RangeCfg, Read, ReadExt, Write};
4use commonware_cryptography::{Digest, Hasher};
5use commonware_parallel::Strategy;
6use commonware_storage::bmt::{self, Builder};
7use commonware_utils::Cached;
8use reed_solomon_simd::{Error as RsError, ReedSolomonDecoder, ReedSolomonEncoder};
9use std::marker::PhantomData;
10use thiserror::Error;
11
12// Thread-local caches for reusing `ReedSolomonEncoder` and `ReedSolomonDecoder`
13// instances across calls. Constructing these objects is expensive because
14// the underlying engine initializes GF lookup tables. The `reset()` method
15// reconfigures the work buffers without rebuilding those tables.
16commonware_utils::thread_local_cache!(static CACHED_ENCODER: ReedSolomonEncoder);
17commonware_utils::thread_local_cache!(static CACHED_DECODER: ReedSolomonDecoder);
18
19/// Errors that can occur when interacting with the Reed-Solomon coder.
20#[derive(Error, Debug)]
21pub enum Error {
22    #[error("reed-solomon error: {0}")]
23    ReedSolomon(#[from] RsError),
24    #[error("inconsistent")]
25    Inconsistent,
26    #[error("invalid proof")]
27    InvalidProof,
28    #[error("not enough chunks")]
29    NotEnoughChunks,
30    #[error("duplicate chunk index: {0}")]
31    DuplicateIndex(u16),
32    #[error("invalid data length: {0}")]
33    InvalidDataLength(usize),
34    #[error("invalid index: {0}")]
35    InvalidIndex(u16),
36    #[error("too many total shards: {0}")]
37    TooManyTotalShards(u32),
38    #[error("checked shard commitment does not match decode commitment")]
39    CommitmentMismatch,
40}
41
42fn total_shards(config: &Config) -> Result<u16, Error> {
43    let total = config.total_shards();
44    total
45        .try_into()
46        .map_err(|_| Error::TooManyTotalShards(total))
47}
48
49/// A piece of data from a Reed-Solomon encoded object.
50#[derive(Debug, Clone)]
51pub struct Chunk<D: Digest> {
52    /// The shard of encoded data.
53    shard: Bytes,
54
55    /// The index of [`Chunk`] in the original data.
56    index: u16,
57
58    /// The multi-proof of the shard in the [`bmt`] at the given index.
59    proof: bmt::Proof<D>,
60}
61
62impl<D: Digest> Chunk<D> {
63    /// Create a new [`Chunk`] from the given shard, index, and proof.
64    const fn new(shard: Bytes, index: u16, proof: bmt::Proof<D>) -> Self {
65        Self {
66            shard,
67            index,
68            proof,
69        }
70    }
71
72    /// Verify a [`Chunk`] against the given root.
73    fn verify<H: Hasher<Digest = D>>(&self, index: u16, root: &D) -> Option<CheckedChunk<D>> {
74        // Ensure the index matches
75        if index != self.index {
76            return None;
77        }
78
79        // Compute shard digest
80        let mut hasher = H::new();
81        hasher.update(&self.shard);
82        let shard_digest = hasher.finalize();
83
84        // Verify proof
85        self.proof
86            .verify_element_inclusion(&mut hasher, &shard_digest, self.index as u32, root)
87            .ok()?;
88
89        Some(CheckedChunk::new(
90            *root,
91            self.shard.clone(),
92            self.index,
93            shard_digest,
94        ))
95    }
96}
97
98/// A shard that has been checked against a commitment.
99///
100/// This stores the shard digest computed during [`Chunk::verify`] and the
101/// commitment root it was verified against. The root is checked at decode
102/// time to prevent cross-commitment shard mixing.
103#[derive(Clone, Debug, PartialEq, Eq)]
104pub struct CheckedChunk<D: Digest> {
105    root: D,
106    shard: Bytes,
107    index: u16,
108    digest: D,
109}
110
111impl<D: Digest> CheckedChunk<D> {
112    const fn new(root: D, shard: Bytes, index: u16, digest: D) -> Self {
113        Self {
114            root,
115            shard,
116            index,
117            digest,
118        }
119    }
120}
121
122impl<D: Digest> Write for Chunk<D> {
123    fn write(&self, writer: &mut impl BufMut) {
124        self.shard.write(writer);
125        self.index.write(writer);
126        self.proof.write(writer);
127    }
128
129    fn write_bufs(&self, buf: &mut impl BufsMut) {
130        self.shard.write_bufs(buf);
131        self.index.write(buf);
132        self.proof.write(buf);
133    }
134}
135
136impl<D: Digest> Read for Chunk<D> {
137    /// The maximum size of the shard.
138    type Cfg = crate::CodecConfig;
139
140    fn read_cfg(reader: &mut impl Buf, cfg: &Self::Cfg) -> Result<Self, commonware_codec::Error> {
141        let shard = Bytes::read_cfg(reader, &RangeCfg::new(..=cfg.maximum_shard_size))?;
142        let index = u16::read(reader)?;
143        let proof = bmt::Proof::<D>::read_cfg(reader, &1)?;
144        Ok(Self {
145            shard,
146            index,
147            proof,
148        })
149    }
150}
151
152impl<D: Digest> EncodeSize for Chunk<D> {
153    fn encode_size(&self) -> usize {
154        self.shard.encode_size() + self.index.encode_size() + self.proof.encode_size()
155    }
156
157    fn encode_inline_size(&self) -> usize {
158        self.shard.encode_inline_size() + self.index.encode_size() + self.proof.encode_size()
159    }
160}
161
162impl<D: Digest> PartialEq for Chunk<D> {
163    fn eq(&self, other: &Self) -> bool {
164        self.shard == other.shard && self.index == other.index && self.proof == other.proof
165    }
166}
167
168impl<D: Digest> Eq for Chunk<D> {}
169
170#[cfg(feature = "arbitrary")]
171impl<D: Digest> arbitrary::Arbitrary<'_> for Chunk<D>
172where
173    D: for<'a> arbitrary::Arbitrary<'a>,
174{
175    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
176        Ok(Self {
177            shard: u.arbitrary::<Vec<u8>>()?.into(),
178            index: u.arbitrary()?,
179            proof: u.arbitrary()?,
180        })
181    }
182}
183
184/// Prepare data for encoding.
185///
186/// Returns a contiguous buffer of `k` padded shards and the shard length.
187/// The buffer layout is `[length_prefix | data | zero_padding]` split into
188/// `k` equal-sized shards of `shard_len` bytes each.
189fn prepare_data(mut data: impl Buf, k: usize) -> (Vec<u8>, usize) {
190    // Compute shard length
191    let data_len = data.remaining();
192    let prefixed_len = u32::SIZE + data_len;
193    let mut shard_len = prefixed_len.div_ceil(k);
194
195    // Ensure shard length is even (required for optimizations in `reed-solomon-simd`)
196    if !shard_len.is_multiple_of(2) {
197        shard_len += 1;
198    }
199
200    // Prepare data
201    let length_bytes = (data_len as u32).to_be_bytes();
202    let mut padded = vec![0u8; k * shard_len];
203    padded[..u32::SIZE].copy_from_slice(&length_bytes);
204    data.copy_to_slice(&mut padded[u32::SIZE..u32::SIZE + data_len]);
205
206    (padded, shard_len)
207}
208
209/// Extract data from encoded shards.
210///
211/// The first `k` shards, when concatenated, form `[length_prefix | data | padding]`.
212/// This function copies only the data bytes while validating trailing zero
213/// padding directly from the shard slices.
214fn extract_data(shards: &[&[u8]], k: usize) -> Result<Vec<u8>, Error> {
215    let shards = shards.get(..k).ok_or(Error::NotEnoughChunks)?;
216    let data_len = read_data_len(shards)?;
217    let mut data = Vec::with_capacity(data_len);
218    let mut prefix_bytes_left = u32::SIZE;
219    let mut data_bytes_left = data_len;
220    for shard in shards {
221        // The length prefix may straddle shard boundaries, so ignore bytes until
222        // we reach the first payload byte.
223        if prefix_bytes_left >= shard.len() {
224            prefix_bytes_left -= shard.len();
225            continue;
226        }
227
228        // Copy only the live payload bytes from this shard.
229        let payload = &shard[prefix_bytes_left..];
230        let copy_len = data_bytes_left.min(payload.len());
231        data.extend_from_slice(&payload[..copy_len]);
232        data_bytes_left -= copy_len;
233
234        // Any remaining bytes in this shard must be canonical zero padding.
235        if !payload[copy_len..].iter().all(|byte| *byte == 0) {
236            return Err(Error::Inconsistent);
237        }
238        prefix_bytes_left = 0;
239    }
240
241    // The prefix advertised more payload bytes than were present in the first
242    // `k` shards.
243    if data_bytes_left != 0 {
244        return Err(Error::Inconsistent);
245    }
246
247    Ok(data)
248}
249
250/// Read the 4-byte big-endian length prefix from `shards` and validate that
251/// the decoded length fits in the post-prefix payload region.
252fn read_data_len(shards: &[&[u8]]) -> Result<usize, Error> {
253    let total_len: usize = shards.iter().map(|s| s.len()).sum();
254    if total_len < u32::SIZE {
255        return Err(Error::Inconsistent);
256    }
257
258    // Read the length prefix, which may span multiple shards.
259    let mut prefix = [0u8; u32::SIZE];
260    let mut prefix_len = 0usize;
261    for shard in shards {
262        if prefix_len == u32::SIZE {
263            break;
264        }
265        let read = (u32::SIZE - prefix_len).min(shard.len());
266        prefix[prefix_len..prefix_len + read].copy_from_slice(&shard[..read]);
267        prefix_len += read;
268    }
269
270    let data_len = u32::from_be_bytes(prefix) as usize;
271    let payload_len = total_len - u32::SIZE;
272    if data_len > payload_len {
273        return Err(Error::Inconsistent);
274    }
275    Ok(data_len)
276}
277
278/// Type alias for the internal encoding result.
279type Encoding<D> = (D, Vec<Chunk<D>>);
280
281/// Encode data using a Reed-Solomon coder and insert it into a [`bmt`].
282///
283/// # Parameters
284///
285/// - `total`: The total number of chunks to generate.
286/// - `min`: The minimum number of chunks required to decode the data.
287/// - `data`: The data to encode.
288/// - `strategy`: The parallelism strategy to use.
289///
290/// # Returns
291///
292/// - `root`: The root of the [`bmt`].
293/// - `chunks`: [`Chunk`]s of encoded data (that can be proven against `root`).
294fn encode<H: Hasher, S: Strategy>(
295    total: u16,
296    min: u16,
297    data: impl Buf,
298    strategy: &S,
299) -> Result<Encoding<H::Digest>, Error> {
300    // Validate parameters
301    assert!(total > min);
302    assert!(min > 0);
303    let n = total as usize;
304    let k = min as usize;
305    let m = n - k;
306    let data_len = data.remaining();
307    if data_len > u32::MAX as usize {
308        return Err(Error::InvalidDataLength(data_len));
309    }
310
311    // Prepare data as a contiguous buffer of k shards
312    let (padded, shard_len) = prepare_data(data, k);
313
314    // Create or reuse encoder
315    let recovery_buf = {
316        let mut encoder = Cached::take(
317            &CACHED_ENCODER,
318            || ReedSolomonEncoder::new(k, m, shard_len),
319            |enc| enc.reset(k, m, shard_len),
320        )
321        .map_err(Error::ReedSolomon)?;
322        for shard in padded.chunks(shard_len) {
323            encoder
324                .add_original_shard(shard)
325                .map_err(Error::ReedSolomon)?;
326        }
327
328        // Compute recovery shards and collect into a contiguous buffer
329        let encoding = encoder.encode().map_err(Error::ReedSolomon)?;
330        let mut buf = Vec::with_capacity(m * shard_len);
331        for shard in encoding.recovery_iter() {
332            buf.extend_from_slice(shard);
333        }
334        buf
335    };
336
337    // Create zero-copy Bytes views into the original and recovery buffers
338    let originals: Bytes = padded.into();
339    let recoveries: Bytes = recovery_buf.into();
340
341    // Build Merkle tree
342    let mut builder = Builder::<H>::new(n);
343    let shard_slices: Vec<Bytes> = (0..k)
344        .map(|i| originals.slice(i * shard_len..(i + 1) * shard_len))
345        .chain((0..m).map(|i| recoveries.slice(i * shard_len..(i + 1) * shard_len)))
346        .collect();
347    let shard_hashes = strategy.map_init_collect_vec(&shard_slices, H::new, |hasher, shard| {
348        hasher.update(shard);
349        hasher.finalize()
350    });
351    for hash in &shard_hashes {
352        builder.add(hash);
353    }
354    let tree = builder.build();
355    let root = tree.root();
356
357    // Generate chunks with zero-copy shard views
358    let mut chunks = Vec::with_capacity(n);
359    for (i, shard) in shard_slices.into_iter().enumerate() {
360        let proof = tree.proof(i as u32).map_err(|_| Error::InvalidProof)?;
361        chunks.push(Chunk::new(shard, i as u16, proof));
362    }
363
364    Ok((root, chunks))
365}
366
367/// Decode data from a set of [`CheckedChunk`]s.
368///
369/// It is assumed that all chunks have already been verified against the given root using [`Chunk::verify`].
370///
371/// # Parameters
372///
373/// - `total`: The total number of chunks to generate.
374/// - `min`: The minimum number of chunks required to decode the data.
375/// - `root`: The root of the [`bmt`].
376/// - `chunks`: [`CheckedChunk`]s of encoded data (that can be proven against `root`)
377///
378/// # Returns
379///
380/// - `data`: The decoded data.
381fn decode<'a, H: Hasher, S: Strategy>(
382    total: u16,
383    min: u16,
384    root: &H::Digest,
385    chunks: impl Iterator<Item = &'a CheckedChunk<H::Digest>>,
386    strategy: &S,
387) -> Result<Vec<u8>, Error> {
388    // Validate parameters
389    assert!(total > min);
390    assert!(min > 0);
391    let n = total as usize;
392    let k = min as usize;
393    let m = n - k;
394    let mut chunks = chunks.peekable();
395    let Some(first) = chunks.peek() else {
396        return Err(Error::NotEnoughChunks);
397    };
398
399    // Process checked chunks
400    let shard_len = first.shard.len();
401    let mut shard_digests: Vec<Option<H::Digest>> = vec![None; n];
402    let mut provided_originals: Vec<(usize, &[u8])> = Vec::new();
403    let mut provided_recoveries: Vec<(usize, &[u8])> = Vec::new();
404    let mut provided = 0usize;
405    for chunk in chunks {
406        provided += 1;
407        if &chunk.root != root {
408            return Err(Error::CommitmentMismatch);
409        }
410        // Check for duplicate index
411        let index = chunk.index;
412        if index >= total {
413            return Err(Error::InvalidIndex(index));
414        }
415        let digest_slot = &mut shard_digests[index as usize];
416        if digest_slot.is_some() {
417            return Err(Error::DuplicateIndex(index));
418        }
419
420        // Add to provided shards and retain the checked digest for this index.
421        *digest_slot = Some(chunk.digest);
422        if index < min {
423            provided_originals.push((index as usize, chunk.shard.as_ref()));
424        } else {
425            provided_recoveries.push((index as usize - k, chunk.shard.as_ref()));
426        }
427    }
428    if provided < k {
429        return Err(Error::NotEnoughChunks);
430    }
431
432    // Decode original data
433    let mut decoder = Cached::take(
434        &CACHED_DECODER,
435        || ReedSolomonDecoder::new(k, m, shard_len),
436        |dec| dec.reset(k, m, shard_len),
437    )
438    .map_err(Error::ReedSolomon)?;
439    for (idx, shard) in &provided_originals {
440        decoder
441            .add_original_shard(*idx, shard)
442            .map_err(Error::ReedSolomon)?;
443    }
444    for (idx, shard) in &provided_recoveries {
445        decoder
446            .add_recovery_shard(*idx, shard)
447            .map_err(Error::ReedSolomon)?;
448    }
449    let decoding = decoder.decode().map_err(Error::ReedSolomon)?;
450
451    // Reconstruct all original shards
452    let mut shards = vec![Default::default(); k];
453    for (idx, shard) in provided_originals
454        .into_iter()
455        .chain(decoding.restored_original_iter())
456    {
457        shards[idx] = shard;
458    }
459
460    // Re-encode recovered data to get recovery shards
461    let mut encoder = Cached::take(
462        &CACHED_ENCODER,
463        || ReedSolomonEncoder::new(k, m, shard_len),
464        |enc| enc.reset(k, m, shard_len),
465    )
466    .map_err(Error::ReedSolomon)?;
467    for shard in shards.iter().take(k) {
468        encoder
469            .add_original_shard(shard)
470            .map_err(Error::ReedSolomon)?;
471    }
472    let encoding = encoder.encode().map_err(Error::ReedSolomon)?;
473    shards.extend(encoding.recovery_iter());
474
475    // Build Merkle tree
476    for (i, digest) in strategy.map_init_collect_vec(
477        shard_digests
478            .iter()
479            .enumerate()
480            .filter_map(|(i, digest)| digest.is_none().then_some(i)),
481        H::new,
482        |hasher, i| {
483            hasher.update(shards[i]);
484            (i, hasher.finalize())
485        },
486    ) {
487        shard_digests[i] = Some(digest);
488    }
489
490    let mut builder = Builder::<H>::new(n);
491    shard_digests
492        .into_iter()
493        .map(|digest| digest.expect("digest must be present for every shard"))
494        .for_each(|digest| {
495            builder.add(&digest);
496        });
497    let tree = builder.build();
498
499    // Confirm root is consistent
500    if tree.root() != *root {
501        return Err(Error::Inconsistent);
502    }
503
504    // Extract original data
505    extract_data(&shards, k)
506}
507
508/// A SIMD-optimized Reed-Solomon coder that emits chunks that can be proven against a [`bmt`].
509///
510/// # Behavior
511///
512/// The encoder takes input data, splits it into `k` data shards, and generates `m` recovery
513/// shards using [Reed-Solomon encoding](https://en.wikipedia.org/wiki/Reed%E2%80%93Solomon_error_correction).
514/// All `n = k + m` shards are then used to build a [`bmt`], producing a single root hash. Each shard
515/// is packaged as a chunk containing the shard data, its index, and a Merkle multi-proof against the [`bmt`] root.
516///
517/// ## Encoding
518///
519/// ```text
520///               +--------------------------------------+
521///               |         Original Data (Bytes)        |
522///               +--------------------------------------+
523///                                  |
524///                                  v
525///               +--------------------------------------+
526///               | [Length Prefix | Original Data...]   |
527///               +--------------------------------------+
528///                                  |
529///                                  v
530///              +----------+ +----------+    +-----------+
531///              |  Shard 0 | |  Shard 1 | .. | Shard k-1 |  (Data Shards)
532///              +----------+ +----------+    +-----------+
533///                     |            |             |
534///                     |            |             |
535///                     +------------+-------------+
536///                                  |
537///                                  v
538///                        +------------------+
539///                        | Reed-Solomon     |
540///                        | Encoder (k, m)   |
541///                        +------------------+
542///                                  |
543///                                  v
544///              +----------+ +----------+    +-----------+
545///              |  Shard k | | Shard k+1| .. | Shard n-1 |  (Recovery Shards)
546///              +----------+ +----------+    +-----------+
547/// ```
548///
549/// ## Merkle Tree Construction
550///
551/// All `n` shards (data and recovery) are hashed and used as leaves to build a [`bmt`].
552///
553/// ```text
554/// Shards:    [Shard 0, Shard 1, ..., Shard n-1]
555///             |        |              |
556///             v        v              v
557/// Hashes:    [H(S_0), H(S_1), ..., H(S_n-1)]
558///             \       / \       /
559///              \     /   \     /
560///               +---+     +---+
561///                 |         |
562///                 \         /
563///                  \       /
564///                   +-----+
565///                      |
566///                      v
567///                +----------+
568///                |   Root   |
569///                +----------+
570/// ```
571///
572/// The final output is the [`bmt`] root and a set of `n` chunks.
573///
574/// `(Root, [Chunk 0, Chunk 1, ..., Chunk n-1])`
575///
576/// Each chunk contains:
577/// - `shard`: The shard data (original or recovery).
578/// - `index`: The shard's original index (0 to n-1).
579/// - `proof`: A Merkle multi-proof of the shard's inclusion in the [`bmt`].
580///
581/// ## Decoding and Verification
582///
583/// The decoder requires any `k` chunks to reconstruct the original data.
584/// 1. Each chunk's Merkle multi-proof is verified against the [`bmt`] root.
585/// 2. The shards from the valid chunks are used to reconstruct the original `k` data shards.
586/// 3. To ensure consistency, the recovered data shards are re-encoded, and a new [`bmt`] root is
587///    generated. This new root MUST match the original [`bmt`] root. This prevents attacks where
588///    an adversary provides a valid set of chunks that decode to different data.
589/// 4. If the roots match, the original data is extracted from the reconstructed data shards.
590#[derive(Clone, Copy)]
591pub struct ReedSolomon<H> {
592    _marker: PhantomData<H>,
593}
594
595impl<H> std::fmt::Debug for ReedSolomon<H> {
596    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
597        f.debug_struct("ReedSolomon").finish()
598    }
599}
600
601impl<H: Hasher> Scheme for ReedSolomon<H> {
602    type Commitment = H::Digest;
603    type Shard = Chunk<H::Digest>;
604    type CheckedShard = CheckedChunk<H::Digest>;
605    type Error = Error;
606
607    fn encode(
608        config: &Config,
609        data: impl Buf,
610        strategy: &impl Strategy,
611    ) -> Result<(Self::Commitment, Vec<Self::Shard>), Self::Error> {
612        encode::<H, _>(
613            total_shards(config)?,
614            config.minimum_shards.get(),
615            data,
616            strategy,
617        )
618    }
619
620    fn check(
621        config: &Config,
622        commitment: &Self::Commitment,
623        index: u16,
624        shard: &Self::Shard,
625    ) -> Result<Self::CheckedShard, Self::Error> {
626        let total = total_shards(config)?;
627        if index >= total {
628            return Err(Error::InvalidIndex(index));
629        }
630        if shard.proof.leaf_count != u32::from(total) {
631            return Err(Error::InvalidProof);
632        }
633        if shard.index != index {
634            return Err(Error::InvalidIndex(shard.index));
635        }
636        shard
637            .verify::<H>(shard.index, commitment)
638            .ok_or(Error::InvalidProof)
639    }
640
641    fn decode<'a>(
642        config: &Config,
643        commitment: &Self::Commitment,
644        shards: impl Iterator<Item = &'a Self::CheckedShard>,
645        strategy: &impl Strategy,
646    ) -> Result<Vec<u8>, Self::Error> {
647        decode::<H, _>(
648            total_shards(config)?,
649            config.minimum_shards.get(),
650            commitment,
651            shards,
652            strategy,
653        )
654    }
655}
656
657#[cfg(test)]
658mod tests {
659    use super::*;
660    use commonware_codec::Encode;
661    use commonware_cryptography::Sha256;
662    use commonware_parallel::Sequential;
663    use commonware_runtime::{deterministic, iobuf::EncodeExt, BufferPooler, Runner};
664    use commonware_utils::NZU16;
665
666    type RS = ReedSolomon<Sha256>;
667    const STRATEGY: Sequential = Sequential;
668
669    fn checked(
670        root: <Sha256 as Hasher>::Digest,
671        chunk: Chunk<<Sha256 as Hasher>::Digest>,
672    ) -> CheckedChunk<<Sha256 as Hasher>::Digest> {
673        let Chunk { shard, index, .. } = chunk;
674        let digest = Sha256::hash(&shard);
675        CheckedChunk::new(root, shard, index, digest)
676    }
677
678    #[test]
679    fn test_recovery() {
680        let data = b"Testing recovery pieces";
681        let total = 8u16;
682        let min = 3u16;
683
684        // Encode the data
685        let (root, chunks) = encode::<Sha256, _>(total, min, data.as_slice(), &STRATEGY).unwrap();
686
687        // Use a mix of original and recovery pieces
688        let pieces: Vec<_> = vec![
689            checked(root, chunks[0].clone()), // original
690            checked(root, chunks[4].clone()), // recovery
691            checked(root, chunks[6].clone()), // recovery
692        ];
693
694        // Try to decode with a mix of original and recovery pieces
695        let decoded = decode::<Sha256, _>(total, min, &root, pieces.iter(), &STRATEGY).unwrap();
696        assert_eq!(decoded, data);
697    }
698
699    #[test]
700    fn test_not_enough_pieces() {
701        let data = b"Test insufficient pieces";
702        let total = 6u16;
703        let min = 4u16;
704
705        // Encode data
706        let (root, chunks) = encode::<Sha256, _>(total, min, data.as_slice(), &STRATEGY).unwrap();
707
708        // Try with fewer than min
709        let pieces: Vec<_> = chunks
710            .into_iter()
711            .take(2)
712            .map(|c| checked(root, c))
713            .collect();
714
715        // Fail to decode
716        let result = decode::<Sha256, _>(total, min, &root, pieces.iter(), &STRATEGY);
717        assert!(matches!(result, Err(Error::NotEnoughChunks)));
718    }
719
720    #[test]
721    fn test_duplicate_index() {
722        let data = b"Test duplicate detection";
723        let total = 5u16;
724        let min = 3u16;
725
726        // Encode data
727        let (root, chunks) = encode::<Sha256, _>(total, min, data.as_slice(), &STRATEGY).unwrap();
728
729        // Include duplicate index by cloning the first chunk
730        let pieces = [
731            checked(root, chunks[0].clone()),
732            checked(root, chunks[0].clone()),
733            checked(root, chunks[1].clone()),
734        ];
735
736        // Fail to decode
737        let result = decode::<Sha256, _>(total, min, &root, pieces.iter(), &STRATEGY);
738        assert!(matches!(result, Err(Error::DuplicateIndex(0))));
739    }
740
741    #[test]
742    fn test_invalid_index() {
743        let data = b"Test invalid index";
744        let total = 5u16;
745        let min = 3u16;
746
747        // Encode data
748        let (root, chunks) = encode::<Sha256, _>(total, min, data.as_slice(), &STRATEGY).unwrap();
749
750        // Verify all proofs at invalid index
751        for i in 0..total {
752            assert!(chunks[i as usize].verify::<Sha256>(i + 1, &root).is_none());
753        }
754    }
755
756    #[test]
757    #[should_panic(expected = "assertion failed: total > min")]
758    fn test_invalid_total() {
759        let data = b"Test parameter validation";
760
761        // total <= min should panic
762        encode::<Sha256, _>(3, 3, data.as_slice(), &STRATEGY).unwrap();
763    }
764
765    #[test]
766    #[should_panic(expected = "assertion failed: min > 0")]
767    fn test_invalid_min() {
768        let data = b"Test parameter validation";
769
770        // min = 0 should panic
771        encode::<Sha256, _>(5, 0, data.as_slice(), &STRATEGY).unwrap();
772    }
773
774    #[test]
775    fn test_empty_data() {
776        let data = b"";
777        let total = 100u16;
778        let min = 30u16;
779
780        // Encode data
781        let (root, chunks) = encode::<Sha256, _>(total, min, data.as_slice(), &STRATEGY).unwrap();
782
783        // Try to decode with min
784        let minimal = chunks
785            .into_iter()
786            .take(min as usize)
787            .map(|c| checked(root, c))
788            .collect::<Vec<_>>();
789        let decoded = decode::<Sha256, _>(total, min, &root, minimal.iter(), &STRATEGY).unwrap();
790        assert_eq!(decoded, data);
791    }
792
793    #[test]
794    fn test_large_data() {
795        let data = vec![42u8; 1000]; // 1KB of data
796        let total = 7u16;
797        let min = 4u16;
798
799        // Encode data
800        let (root, chunks) = encode::<Sha256, _>(total, min, data.as_slice(), &STRATEGY).unwrap();
801
802        // Try to decode with min
803        let minimal = chunks
804            .into_iter()
805            .take(min as usize)
806            .map(|c| checked(root, c))
807            .collect::<Vec<_>>();
808        let decoded = decode::<Sha256, _>(total, min, &root, minimal.iter(), &STRATEGY).unwrap();
809        assert_eq!(decoded, data);
810    }
811
812    #[test]
813    fn test_malicious_root_detection() {
814        let data = b"Original data that should be protected";
815        let total = 7u16;
816        let min = 4u16;
817
818        // Encode data correctly to get valid chunks
819        let (_correct_root, chunks) =
820            encode::<Sha256, _>(total, min, data.as_slice(), &STRATEGY).unwrap();
821
822        // Create a malicious/fake root (simulating a malicious encoder)
823        let mut hasher = Sha256::new();
824        hasher.update(b"malicious_data_that_wasnt_actually_encoded");
825        let malicious_root = hasher.finalize();
826
827        // Verify all proofs at incorrect root
828        for i in 0..total {
829            assert!(chunks[i as usize]
830                .clone()
831                .verify::<Sha256>(i, &malicious_root)
832                .is_none());
833        }
834
835        // Collect valid pieces (these are legitimate fragments checked against
836        // the correct root).
837        let minimal = chunks
838            .into_iter()
839            .take(min as usize)
840            .map(|c| checked(_correct_root, c))
841            .collect::<Vec<_>>();
842
843        // Attempt to decode with malicious root - rejected because checked
844        // chunks are bound to a different commitment.
845        let result = decode::<Sha256, _>(total, min, &malicious_root, minimal.iter(), &STRATEGY);
846        assert!(matches!(result, Err(Error::CommitmentMismatch)));
847    }
848
849    #[test]
850    fn test_mismatched_config_rejected_during_check() {
851        let config_expected = Config {
852            minimum_shards: NZU16!(2),
853            extra_shards: NZU16!(2),
854        };
855        let config_actual = Config {
856            minimum_shards: NZU16!(3),
857            extra_shards: NZU16!(3),
858        };
859
860        let data = b"leaf_count mismatch proof";
861        let (commitment, shards) = RS::encode(&config_actual, data.as_slice(), &STRATEGY).unwrap();
862
863        // Previously this passed because check() ignored config and only verified
864        // against commitment root. It must now fail immediately.
865        let check_result = RS::check(&config_expected, &commitment, 0, &shards[0]);
866        assert!(matches!(check_result, Err(Error::InvalidProof)));
867    }
868
869    #[test]
870    fn test_manipulated_chunk_detection() {
871        let data = b"Data integrity must be maintained";
872        let total = 6u16;
873        let min = 3u16;
874
875        // Encode data
876        let (root, chunks) = encode::<Sha256, _>(total, min, data.as_slice(), &STRATEGY).unwrap();
877        let mut pieces: Vec<_> = chunks.into_iter().map(|c| checked(root, c)).collect();
878
879        // Tamper with one of the checked chunks by modifying the shard data.
880        if !pieces[1].shard.is_empty() {
881            let mut shard = pieces[1].shard.to_vec();
882            shard[0] ^= 0xFF; // Flip bits in first byte
883            pieces[1].shard = shard.into();
884            pieces[1].digest = Sha256::hash(&pieces[1].shard);
885        }
886
887        // Try to decode with the tampered chunk
888        let result = decode::<Sha256, _>(total, min, &root, pieces.iter(), &STRATEGY);
889        assert!(matches!(result, Err(Error::Inconsistent)));
890    }
891
892    #[test]
893    fn test_inconsistent_shards() {
894        let data = b"Test data for malicious encoding";
895        let total = 5u16;
896        let min = 3u16;
897        let m = total - min;
898
899        // Compute original data encoding
900        let (padded, shard_size) = prepare_data(data.as_slice(), min as usize);
901
902        // Re-encode the data
903        let mut encoder = ReedSolomonEncoder::new(min as usize, m as usize, shard_size).unwrap();
904        for shard in padded.chunks(shard_size) {
905            encoder.add_original_shard(shard).unwrap();
906        }
907        let recovery_result = encoder.encode().unwrap();
908        let mut recovery_shards: Vec<Vec<u8>> = recovery_result
909            .recovery_iter()
910            .map(|s| s.to_vec())
911            .collect();
912
913        // Tamper with one recovery shard
914        if !recovery_shards[0].is_empty() {
915            recovery_shards[0][0] ^= 0xFF;
916        }
917
918        // Build malicious shards
919        let mut malicious_shards: Vec<Vec<u8>> =
920            padded.chunks(shard_size).map(|s| s.to_vec()).collect();
921        malicious_shards.extend(recovery_shards);
922
923        // Build malicious tree
924        let mut builder = Builder::<Sha256>::new(total as usize);
925        for shard in &malicious_shards {
926            let mut hasher = Sha256::new();
927            hasher.update(shard);
928            builder.add(&hasher.finalize());
929        }
930        let malicious_tree = builder.build();
931        let malicious_root = malicious_tree.root();
932
933        // Generate chunks for min pieces, including the tampered recovery
934        let selected_indices = vec![0, 1, 3]; // originals 0,1 and recovery 0 (index 3)
935        let mut pieces = Vec::new();
936        for &i in &selected_indices {
937            let merkle_proof = malicious_tree.proof(i as u32).unwrap();
938            let shard = malicious_shards[i].clone();
939            let chunk = Chunk::new(shard.into(), i as u16, merkle_proof);
940            pieces.push(chunk);
941        }
942        let pieces: Vec<_> = pieces
943            .into_iter()
944            .map(|c| checked(malicious_root, c))
945            .collect();
946
947        // Fail to decode
948        let result = decode::<Sha256, _>(total, min, &malicious_root, pieces.iter(), &STRATEGY);
949        assert!(matches!(result, Err(Error::Inconsistent)));
950    }
951
952    // Regression: a commitment built from shards with non-zero trailing padding
953    // used to pass decode(), even though canonical re-encoding (zero padding)
954    // produces a different root. decode() must reject such non-canonical shards.
955    #[test]
956    fn test_non_canonical_padding_rejected() {
957        let data = b"X";
958        let total = 6u16;
959        let min = 3u16;
960        let k = min as usize;
961        let m = total as usize - k;
962
963        let (mut padded, shard_len) = prepare_data(data.as_slice(), k);
964        let payload_end = u32::SIZE + data.len();
965        let total_original_len = k * shard_len;
966        assert!(payload_end < total_original_len, "test requires padding");
967
968        // Corrupt one canonical padding byte while keeping payload unchanged.
969        let pad_shard = payload_end / shard_len;
970        let pad_offset = payload_end % shard_len;
971        padded[pad_shard * shard_len + pad_offset] = 0xAA;
972
973        let mut encoder = ReedSolomonEncoder::new(k, m, shard_len).unwrap();
974        for shard in padded.chunks(shard_len) {
975            encoder.add_original_shard(shard).unwrap();
976        }
977        let recovery = encoder.encode().unwrap();
978        let mut shards: Vec<Vec<u8>> = padded.chunks(shard_len).map(|s| s.to_vec()).collect();
979        shards.extend(recovery.recovery_iter().map(|s| s.to_vec()));
980
981        let mut builder = Builder::<Sha256>::new(total as usize);
982        for shard in &shards {
983            let mut hasher = Sha256::new();
984            hasher.update(shard);
985            builder.add(&hasher.finalize());
986        }
987        let tree = builder.build();
988        let non_canonical_root = tree.root();
989
990        let mut pieces = Vec::with_capacity(k);
991        for (i, shard) in shards.iter().take(k).enumerate() {
992            let proof = tree.proof(i as u32).unwrap();
993            pieces.push(checked(
994                non_canonical_root,
995                Chunk::new(shard.clone().into(), i as u16, proof),
996            ));
997        }
998
999        let result = decode::<Sha256, _>(total, min, &non_canonical_root, pieces.iter(), &STRATEGY);
1000        assert!(matches!(result, Err(Error::Inconsistent)));
1001    }
1002
1003    #[test]
1004    fn test_decode_invalid_index() {
1005        let data = b"Testing recovery pieces";
1006        let total = 8u16;
1007        let min = 3u16;
1008
1009        // Encode the data
1010        let (root, chunks) = encode::<Sha256, _>(total, min, data.as_slice(), &STRATEGY).unwrap();
1011
1012        // Use a mix of original and recovery pieces
1013        let mut invalid = checked(root, chunks[1].clone());
1014        invalid.index = 8;
1015        let pieces: Vec<_> = vec![
1016            checked(root, chunks[0].clone()), // original
1017            invalid,                          // recovery with invalid index
1018            checked(root, chunks[6].clone()), // recovery
1019        ];
1020
1021        // Fail to decode
1022        let result = decode::<Sha256, _>(total, min, &root, pieces.iter(), &STRATEGY);
1023        assert!(matches!(result, Err(Error::InvalidIndex(8))));
1024    }
1025
1026    #[test]
1027    fn test_max_chunks() {
1028        let data = vec![42u8; 1000]; // 1KB of data
1029        let total = u16::MAX;
1030        let min = u16::MAX / 2;
1031
1032        // Encode data
1033        let (root, chunks) = encode::<Sha256, _>(total, min, data.as_slice(), &STRATEGY).unwrap();
1034
1035        // Try to decode with min
1036        let minimal = chunks
1037            .into_iter()
1038            .take(min as usize)
1039            .map(|c| checked(root, c))
1040            .collect::<Vec<_>>();
1041        let decoded = decode::<Sha256, _>(total, min, &root, minimal.iter(), &STRATEGY).unwrap();
1042        assert_eq!(decoded, data);
1043    }
1044
1045    #[test]
1046    fn test_too_many_chunks() {
1047        let data = vec![42u8; 1000]; // 1KB of data
1048        let total = u16::MAX;
1049        let min = u16::MAX / 2 - 1;
1050
1051        // Encode data
1052        let result = encode::<Sha256, _>(total, min, data.as_slice(), &STRATEGY);
1053        assert!(matches!(
1054            result,
1055            Err(Error::ReedSolomon(
1056                reed_solomon_simd::Error::UnsupportedShardCount {
1057                    original_count: _,
1058                    recovery_count: _,
1059                }
1060            ))
1061        ));
1062    }
1063
1064    #[test]
1065    fn test_too_many_total_shards() {
1066        assert!(RS::encode(
1067            &Config {
1068                minimum_shards: NZU16!(u16::MAX / 2 + 1),
1069                extra_shards: NZU16!(u16::MAX),
1070            },
1071            [].as_slice(),
1072            &STRATEGY,
1073        )
1074        .is_err())
1075    }
1076
1077    #[test]
1078    fn test_chunk_encode_with_pool_matches_encode() {
1079        let executor = deterministic::Runner::default();
1080        executor.start(|context| async move {
1081            let pool = context.network_buffer_pool();
1082
1083            let data = b"pool encoding test";
1084            let (_root, chunks) = encode::<Sha256, _>(5, 3, data.as_slice(), &STRATEGY).unwrap();
1085            let chunk = &chunks[0];
1086
1087            let encoded = chunk.encode();
1088            let mut encoded_pool = chunk.encode_with_pool(pool);
1089            let mut encoded_pool_bytes = vec![0u8; encoded_pool.remaining()];
1090            encoded_pool.copy_to_slice(&mut encoded_pool_bytes);
1091            assert_eq!(encoded_pool_bytes, encoded.as_ref());
1092        });
1093    }
1094
1095    #[cfg(feature = "arbitrary")]
1096    mod conformance {
1097        use super::*;
1098        use commonware_codec::conformance::CodecConformance;
1099        use commonware_cryptography::sha256::Digest as Sha256Digest;
1100
1101        commonware_conformance::conformance_tests! {
1102            CodecConformance<Chunk<Sha256Digest>>,
1103        }
1104    }
1105}