commonware_coding/reed_solomon/
mod.rs

1//! A SIMD-optimized Reed-Solomon coder that emits [Chunk]s that can be proven against a [bmt].
2//!
3//! # Behavior
4//!
5//! The encoder takes input data, splits it into `k` data shards, and generates `m` recovery
6//! shards using [Reed-Solomon encoding](https://en.wikipedia.org/wiki/Reed%E2%80%93Solomon_error_correction).
7//! All `n = k + m` shards are then used to build a [bmt], producing a single root hash. Each shard
8//! is packaged as a [Chunk] containing the shard data, its index, and a Merkle proof against the [bmt] root.
9//!
10//! ## Encoding
11//!
12//! ```text
13//!               +--------------------------------------+
14//!               |         Original Data (Bytes)        |
15//!               +--------------------------------------+
16//!                                  |
17//!                                  v
18//!               +--------------------------------------+
19//!               | [Length Prefix | Original Data...]   |
20//!               +--------------------------------------+
21//!                                  |
22//!                                  v
23//!              +----------+ +----------+    +-----------+
24//!              |  Shard 0 | |  Shard 1 | .. | Shard k-1 |  (Data Shards)
25//!              +----------+ +----------+    +-----------+
26//!                     |            |             |
27//!                     |            |             |
28//!                     +------------+-------------+
29//!                                  |
30//!                                  v
31//!                        +------------------+
32//!                        | Reed-Solomon     |
33//!                        | Encoder (k, m)   |
34//!                        +------------------+
35//!                                  |
36//!                                  v
37//!              +----------+ +----------+    +-----------+
38//!              |  Shard k | | Shard k+1| .. | Shard n-1 |  (Recovery Shards)
39//!              +----------+ +----------+    +-----------+
40//! ```
41//!
42//! ## Merkle Tree Construction
43//!
44//! All `n` shards (data and recovery) are hashed and used as leaves to build a [bmt].
45//!
46//! ```text
47//! Shards:    [Shard 0, Shard 1, ..., Shard n-1]
48//!             |        |              |
49//!             v        v              v
50//! Hashes:    [H(S_0), H(S_1), ..., H(S_n-1)]
51//!             \       / \       /
52//!              \     /   \     /
53//!               +---+     +---+
54//!                 |         |
55//!                 \         /
56//!                  \       /
57//!                   +-----+
58//!                      |
59//!                      v
60//!                +----------+
61//!                |   Root   |
62//!                +----------+
63//! ```
64//!
65//! The final output is the [bmt] root and a set of `n` [Chunk]s.
66//!
67//! `(Root, [Chunk 0, Chunk 1, ..., Chunk n-1])`
68//!
69//! Each [Chunk] contains:
70//! - `shard`: The shard data (original or recovery).
71//! - `index`: The shard's original index (0 to n-1).
72//! - `proof`: A Merkle proof of the shard's inclusion in the [bmt].
73//!
74//! ## Decoding and Verification
75//!
76//! The decoder requires any `k` [Chunk]s to reconstruct the original data.
77//! 1. Each [Chunk]'s Merkle proof is verified against the [bmt] root.
78//! 2. The shards from the valid [Chunk]s are used to reconstruct the original `k` data shards.
79//! 3. To ensure consistency, the recovered data shards are re-encoded, and a new [bmt] root is
80//!    generated. This new root MUST match the original [bmt] root. This prevents attacks where
81//!    an adversary provides a valid set of chunks that decode to different data.
82//! 4. If the roots match, the original data is extracted from the reconstructed data shards.
83//!
84//! # Example
85//!
86//! ```rust
87//! use commonware_coding::reed_solomon::{encode, decode};
88//! use commonware_cryptography::Sha256;
89//!
90//! // Generate data to encode
91//! let data = b"Hello, world! This is a test of Reed-Solomon encoding.";
92//!
93//! // Configure the encoder to generate 7 total chunks, with a minimum of 4 required for decoding
94//! let total = 7u16;
95//! let min = 4u16;
96//!
97//! // Encode the data
98//! let (root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
99//!
100//! // Pick a few chunks to recover from (a mix of original and recovery shards)
101//! let some_chunks = vec![
102//!     chunks[0].clone(), // original
103//!     chunks[2].clone(), // original
104//!     chunks[5].clone(), // recovery
105//!     chunks[6].clone(), // recovery
106//! ];
107//!
108//! // Decode the data from the subset of chunks
109//! let decoded_data = decode::<Sha256>(total, min, &root, some_chunks).unwrap();
110//!
111//! // Verify that the decoded data matches the original data
112//! assert_eq!(decoded_data, data);
113//! ```
114
115use bytes::{Buf, BufMut};
116use commonware_codec::{EncodeSize, FixedSize, RangeCfg, Read, ReadExt, ReadRangeExt, Write};
117use commonware_cryptography::Hasher;
118use commonware_storage::bmt::{self, Builder};
119use reed_solomon_simd::{Error as RsError, ReedSolomonDecoder, ReedSolomonEncoder};
120use std::collections::HashSet;
121use thiserror::Error;
122
123/// Errors that can occur when interacting with the Reed-Solomon coder.
124#[derive(Error, Debug)]
125pub enum Error {
126    #[error("reed-solomon error: {0}")]
127    ReedSolomon(#[from] RsError),
128    #[error("inconsistent")]
129    Inconsistent,
130    #[error("invalid proof")]
131    InvalidProof,
132    #[error("not enough chunks")]
133    NotEnoughChunks,
134    #[error("duplicate chunk index: {0}")]
135    DuplicateIndex(u16),
136    #[error("invalid data length: {0}")]
137    InvalidDataLength(usize),
138    #[error("invalid index: {0}")]
139    InvalidIndex(u16),
140}
141
142/// A piece of data from a Reed-Solomon encoded object.
143#[derive(Clone)]
144pub struct Chunk<H: Hasher> {
145    /// The shard of encoded data.
146    pub shard: Vec<u8>,
147
148    /// The index of [Chunk] in the original data.
149    pub index: u16,
150
151    /// The proof of the shard in the [bmt] at the given index.
152    pub proof: bmt::Proof<H>,
153}
154
155impl<H: Hasher> Chunk<H> {
156    /// Create a new [Chunk] from the given shard, index, and proof.
157    pub fn new(shard: Vec<u8>, index: u16, proof: bmt::Proof<H>) -> Self {
158        Self {
159            shard,
160            index,
161            proof,
162        }
163    }
164
165    /// Verify a [Chunk] against the given root.
166    pub fn verify(&self, index: u16, root: &H::Digest) -> bool {
167        // Ensure the index matches
168        if index != self.index {
169            return false;
170        }
171
172        // Compute shard digest
173        let mut hasher = H::new();
174        hasher.update(&self.shard);
175        let shard_digest = hasher.finalize();
176
177        // Verify proof
178        self.proof
179            .verify(&mut hasher, &shard_digest, self.index as u32, root)
180            .is_ok()
181    }
182}
183
184impl<H: Hasher> Write for Chunk<H> {
185    fn write(&self, writer: &mut impl BufMut) {
186        self.shard.write(writer);
187        self.index.write(writer);
188        self.proof.write(writer);
189    }
190}
191
192impl<H: Hasher> Read for Chunk<H> {
193    /// The maximum size of the shard.
194    type Cfg = usize;
195
196    fn read_cfg(reader: &mut impl Buf, cfg: &Self::Cfg) -> Result<Self, commonware_codec::Error> {
197        let shard = Vec::<u8>::read_range(reader, ..=*cfg)?;
198        let index = u16::read(reader)?;
199        let proof = bmt::Proof::<H>::read(reader)?;
200        Ok(Self {
201            shard,
202            index,
203            proof,
204        })
205    }
206}
207
208impl<H: Hasher> EncodeSize for Chunk<H> {
209    fn encode_size(&self) -> usize {
210        self.shard.encode_size() + self.index.encode_size() + self.proof.encode_size()
211    }
212}
213
214/// All data shards from a Reed-Solomon encoded object.
215///
216/// This is more efficient than individual chunks when transmitting complete blocks.
217/// Includes a [bmt::RangeProof] to prove the data shards are part of the commitment.
218#[derive(Clone)]
219pub struct Data<H: Hasher> {
220    /// The data shards.
221    pub shards: Vec<Vec<u8>>,
222
223    /// Range proof for indices 0..k (proving all data shards are part of the commitment).
224    pub proof: bmt::RangeProof<H>,
225}
226
227impl<H: Hasher> Data<H> {
228    /// Create a new [Data] from data shards and range proof.
229    pub fn new(shards: Vec<Vec<u8>>, proof: bmt::RangeProof<H>) -> Self {
230        Self { shards, proof }
231    }
232
233    /// Verify the [Data] against a known root.
234    ///
235    /// This verifies that the data shards are part of the [bmt] commitment.
236    pub fn verify(&self, min: u16, root: &H::Digest) -> bool {
237        // Ensure we have the correct number of data shards
238        let k = min as usize;
239        if self.shards.len() != k {
240            return false;
241        }
242
243        // Hash the data shards
244        let mut hasher = H::new();
245        let mut shards = Vec::with_capacity(k);
246        for shard in &self.shards {
247            hasher.update(shard);
248            shards.push(hasher.finalize());
249        }
250
251        // Verify the range proof for indices 0..k
252        self.proof.verify(&mut hasher, 0, &shards, root).is_ok()
253    }
254
255    /// Extract the original data from the [Data].
256    pub fn extract(self) -> Vec<u8> {
257        let k = self.shards.len();
258        extract_data(self.shards, k)
259    }
260}
261
262impl<H: Hasher> Write for Data<H> {
263    fn write(&self, writer: &mut impl BufMut) {
264        self.shards.write(writer);
265        self.proof.write(writer);
266    }
267}
268
269impl<H: Hasher> Read for Data<H> {
270    /// The maximum size of each shard.
271    type Cfg = usize;
272
273    fn read_cfg(reader: &mut impl Buf, cfg: &Self::Cfg) -> Result<Self, commonware_codec::Error> {
274        let shard_count: RangeCfg = (1..=u16::MAX as usize).into();
275        let shard_size: RangeCfg = (..=*cfg).into();
276        let shards = Vec::<Vec<u8>>::read_cfg(reader, &(shard_count, (shard_size, ())))?;
277        let proof = bmt::RangeProof::<H>::read(reader)?;
278
279        Ok(Self { shards, proof })
280    }
281}
282
283impl<H: Hasher> EncodeSize for Data<H> {
284    fn encode_size(&self) -> usize {
285        self.shards.encode_size() + self.proof.encode_size()
286    }
287}
288
289/// Prepare data for encoding.
290fn prepare_data(data: Vec<u8>, k: usize, m: usize) -> Vec<Vec<u8>> {
291    // Compute shard length
292    let data_len = data.len();
293    let prefixed_len = u32::SIZE + data_len;
294    let mut shard_len = prefixed_len.div_ceil(k);
295
296    // Ensure shard length is even (required for optimizations in `reed-solomon-simd`)
297    if shard_len % 2 != 0 {
298        shard_len += 1;
299    }
300
301    // Prepare data
302    let length_bytes = (data_len as u32).to_be_bytes();
303    let mut src = length_bytes.into_iter().chain(data);
304    let mut shards = Vec::with_capacity(k + m); // assume recovery shards will be added later
305    for _ in 0..k {
306        let mut shard = Vec::with_capacity(shard_len);
307        for _ in 0..shard_len {
308            shard.push(src.next().unwrap_or(0));
309        }
310        shards.push(shard);
311    }
312    shards
313}
314
315/// Extract data from encoded shards.
316fn extract_data(shards: Vec<Vec<u8>>, k: usize) -> Vec<u8> {
317    // Concatenate shards
318    let mut data = shards.into_iter().take(k).flatten();
319
320    // Extract length prefix
321    let data_len = (&mut data)
322        .take(u32::SIZE)
323        .collect::<Vec<_>>()
324        .try_into()
325        .expect("insufficient data");
326    let data_len = u32::from_be_bytes(data_len) as usize;
327
328    // Extract data
329    data.take(data_len).collect()
330}
331
332/// Type alias for the internal encoding result.
333pub type Encoding<H> = (bmt::Tree<H>, Vec<Vec<u8>>);
334
335/// Common encoding logic shared by [encode()] and [translate()].
336pub fn encode_inner<H: Hasher>(total: u16, min: u16, data: Vec<u8>) -> Result<Encoding<H>, Error> {
337    // Validate parameters
338    assert!(total > min);
339    assert!(min > 0);
340    let n = total as usize;
341    let k = min as usize;
342    let m = n - k;
343    if data.len() > u32::MAX as usize {
344        return Err(Error::InvalidDataLength(data.len()));
345    }
346
347    // Prepare data
348    let mut shards = prepare_data(data, k, m);
349    let shard_len = shards[0].len();
350
351    // Create encoder
352    let mut encoder = ReedSolomonEncoder::new(k, m, shard_len).map_err(Error::ReedSolomon)?;
353    for shard in &shards {
354        encoder
355            .add_original_shard(shard)
356            .map_err(Error::ReedSolomon)?;
357    }
358
359    // Compute recovery shards
360    let encoding = encoder.encode().map_err(Error::ReedSolomon)?;
361    let recovery_shards: Vec<Vec<u8>> = encoding
362        .recovery_iter()
363        .map(|shard| shard.to_vec())
364        .collect();
365    shards.extend(recovery_shards);
366
367    // Build Merkle tree
368    let mut builder = Builder::<H>::new(n);
369    let mut hasher = H::new();
370    for shard in &shards {
371        builder.add(&{
372            hasher.update(shard);
373            hasher.finalize()
374        });
375    }
376    let tree = builder.build();
377
378    Ok((tree, shards))
379}
380
381/// Encode data using a Reed-Solomon coder and insert it into a [bmt].
382///
383/// # Parameters
384///
385/// - `total`: The total number of chunks to generate.
386/// - `min`: The minimum number of chunks required to decode the data.
387/// - `data`: The data to encode.
388///
389/// # Returns
390///
391/// - `root`: The root of the [bmt].
392/// - `chunks`: [Chunk]s of encoded data (that can be proven against `root`).
393pub fn encode<H: Hasher>(
394    total: u16,
395    min: u16,
396    data: Vec<u8>,
397) -> Result<(H::Digest, Vec<Chunk<H>>), Error> {
398    // Encode data
399    let (tree, shards) = encode_inner::<H>(total, min, data)?;
400    let root = tree.root();
401    let n = total as usize;
402
403    // Generate chunks
404    let mut chunks = Vec::with_capacity(n);
405    for (i, shard) in shards.into_iter().enumerate() {
406        let proof = tree.proof(i as u32).map_err(|_| Error::InvalidProof)?;
407        chunks.push(Chunk::new(shard, i as u16, proof));
408    }
409
410    Ok((root, chunks))
411}
412
413/// Translate data into a single proof over all data shards.
414///
415/// # Parameters
416///
417/// - `total`: The total number of chunks to generate.
418/// - `min`: The minimum number of chunks required to decode the data.
419/// - `data`: The data to encode.
420///
421/// # Returns
422///
423/// - `root`: The root of the [bmt].
424/// - `data`: All data shards from the encoded object (and a range proof that they are part of the `root`).
425pub fn translate<H: Hasher>(
426    total: u16,
427    min: u16,
428    data: Vec<u8>,
429) -> Result<(H::Digest, Data<H>), Error> {
430    // Encode data
431    let (tree, mut shards) = encode_inner::<H>(total, min, data)?;
432    let root = tree.root();
433    let k = min as usize;
434
435    // Truncate to data shards only
436    shards.truncate(k);
437
438    // Generate range proof for minimal proof
439    let proof = tree
440        .range_proof(0, (min - 1) as u32)
441        .map_err(|_| Error::InvalidProof)?;
442    let proof = Data::new(shards, proof);
443
444    Ok((root, proof))
445}
446
447/// Decode data from a set of [Chunk]s.
448///
449/// # Parameters
450///
451/// - `total`: The total number of chunks to generate.
452/// - `min`: The minimum number of chunks required to decode the data.
453/// - `root`: The root of the [bmt].
454/// - `chunks`: [Chunk]s of encoded data (that can be proven against `root`).
455///
456/// # Returns
457///
458/// - `data`: The decoded data.
459pub fn decode<H: Hasher>(
460    total: u16,
461    min: u16,
462    root: &H::Digest,
463    chunks: Vec<Chunk<H>>,
464) -> Result<Vec<u8>, Error> {
465    // Validate parameters
466    assert!(total > min);
467    assert!(min > 0);
468    let n = total as usize;
469    let k = min as usize;
470    let m = n - k;
471    if chunks.len() < k {
472        return Err(Error::NotEnoughChunks);
473    }
474
475    // Verify chunks
476    let shard_len = chunks[0].shard.len();
477    let mut seen = HashSet::new();
478    let mut provided_originals: Vec<(usize, Vec<u8>)> = Vec::new();
479    let mut provided_recoveries: Vec<(usize, Vec<u8>)> = Vec::new();
480    for chunk in chunks {
481        // Check for duplicate index
482        let index = chunk.index;
483        if index >= total {
484            return Err(Error::InvalidIndex(index));
485        }
486        if seen.contains(&index) {
487            return Err(Error::DuplicateIndex(index));
488        }
489        seen.insert(index);
490
491        // Verify Merkle proof
492        if !chunk.verify(chunk.index, root) {
493            return Err(Error::InvalidProof);
494        }
495
496        // Add to provided shards
497        if index < min {
498            provided_originals.push((index as usize, chunk.shard));
499        } else {
500            provided_recoveries.push((index as usize - k, chunk.shard));
501        }
502    }
503
504    // Decode original data
505    let mut decoder = ReedSolomonDecoder::new(k, m, shard_len).map_err(Error::ReedSolomon)?;
506    for (idx, ref shard) in &provided_originals {
507        decoder
508            .add_original_shard(*idx, shard)
509            .map_err(Error::ReedSolomon)?;
510    }
511    for (idx, ref shard) in &provided_recoveries {
512        decoder
513            .add_recovery_shard(*idx, shard)
514            .map_err(Error::ReedSolomon)?;
515    }
516    let decoding = decoder.decode().map_err(Error::ReedSolomon)?;
517
518    // Reconstruct all original shards
519    let mut shards = Vec::with_capacity(n);
520    shards.resize(k, Vec::new());
521    for (idx, shard) in provided_originals {
522        shards[idx] = shard;
523    }
524    for (idx, shard) in decoding.restored_original_iter() {
525        shards[idx] = shard.to_vec();
526    }
527
528    // Encode recovered data to get recovery shards
529    let mut encoder = ReedSolomonEncoder::new(k, m, shard_len).map_err(Error::ReedSolomon)?;
530    for shard in shards.iter().take(k) {
531        encoder
532            .add_original_shard(shard)
533            .map_err(Error::ReedSolomon)?;
534    }
535    let encoding = encoder.encode().map_err(Error::ReedSolomon)?;
536    let recovery_shards: Vec<Vec<u8>> = encoding
537        .recovery_iter()
538        .map(|shard| shard.to_vec())
539        .collect();
540    shards.extend(recovery_shards);
541
542    // Build Merkle tree
543    let mut builder = Builder::<H>::new(n);
544    let mut hasher = H::new();
545    for shard in &shards {
546        builder.add(&{
547            hasher.update(shard);
548            hasher.finalize()
549        });
550    }
551    let computed_tree = builder.build();
552
553    // Confirm root is consistent
554    if computed_tree.root() != *root {
555        return Err(Error::Inconsistent);
556    }
557
558    // Extract original data
559    Ok(extract_data(shards, k))
560}
561
562#[cfg(test)]
563mod tests {
564    use super::*;
565    use commonware_codec::{Decode, Encode};
566    use commonware_cryptography::Sha256;
567
568    #[test]
569    fn test_basic() {
570        let data = b"Hello, Reed-Solomon!";
571        let total = 7u16;
572        let min = 4u16;
573
574        // Encode the data
575        let (root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
576        assert_eq!(chunks.len(), total as usize);
577
578        // Verify all chunks
579        for i in 0..total {
580            assert!(chunks[i as usize].verify(i, &root));
581        }
582
583        // Try to decode with exactly min (all original shards)
584        let minimal = chunks.into_iter().take(min as usize).collect();
585        let decoded = decode::<Sha256>(total, min, &root, minimal).unwrap();
586        assert_eq!(decoded, data);
587    }
588
589    #[test]
590    fn test_moderate() {
591        let data = b"Testing with more pieces than minimum";
592        let total = 10u16;
593        let min = 4u16;
594
595        // Encode the data
596        let (root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
597
598        // Try to decode with min (all original shards)
599        let minimal = chunks.into_iter().take(min as usize).collect();
600        let decoded = decode::<Sha256>(total, min, &root, minimal).unwrap();
601        assert_eq!(decoded, data);
602    }
603
604    #[test]
605    fn test_recovery() {
606        let data = b"Testing recovery pieces";
607        let total = 8u16;
608        let min = 3u16;
609
610        // Encode the data
611        let (root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
612
613        // Use a mix of original and recovery pieces
614        let pieces: Vec<_> = vec![
615            chunks[0].clone(), // original
616            chunks[4].clone(), // recovery
617            chunks[6].clone(), // recovery
618        ];
619
620        // Try to decode with a mix of original and recovery pieces
621        let decoded = decode::<Sha256>(total, min, &root, pieces).unwrap();
622        assert_eq!(decoded, data);
623    }
624
625    #[test]
626    fn test_not_enough_pieces() {
627        let data = b"Test insufficient pieces";
628        let total = 6u16;
629        let min = 4u16;
630
631        // Encode data
632        let (root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
633
634        // Try with fewer than min
635        let pieces: Vec<_> = chunks.into_iter().take(2).collect();
636
637        // Fail to decode
638        let result = decode::<Sha256>(total, min, &root, pieces);
639        assert!(matches!(result, Err(Error::NotEnoughChunks)));
640    }
641
642    #[test]
643    fn test_duplicate_index() {
644        let data = b"Test duplicate detection";
645        let total = 5u16;
646        let min = 3u16;
647
648        // Encode data
649        let (root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
650
651        // Include duplicate index by cloning the first chunk
652        let pieces = vec![chunks[0].clone(), chunks[0].clone(), chunks[1].clone()];
653
654        // Fail to decode
655        let result = decode::<Sha256>(total, min, &root, pieces);
656        assert!(matches!(result, Err(Error::DuplicateIndex(0))));
657    }
658
659    #[test]
660    fn test_invalid_index() {
661        let data = b"Test invalid index";
662        let total = 5u16;
663        let min = 3u16;
664
665        // Encode data
666        let (root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
667
668        // Verify all proofs at invalid index
669        for i in 0..total {
670            assert!(!chunks[i as usize].verify(i + 1, &root));
671        }
672    }
673
674    #[test]
675    #[should_panic(expected = "assertion failed: total > min")]
676    fn test_invalid_total() {
677        let data = b"Test parameter validation";
678
679        // total <= min should panic
680        encode::<Sha256>(3, 3, data.to_vec()).unwrap();
681    }
682
683    #[test]
684    #[should_panic(expected = "assertion failed: min > 0")]
685    fn test_invalid_min() {
686        let data = b"Test parameter validation";
687
688        // min = 0 should panic
689        encode::<Sha256>(5, 0, data.to_vec()).unwrap();
690    }
691
692    #[test]
693    fn test_empty_data() {
694        let data = b"";
695        let total = 100u16;
696        let min = 30u16;
697
698        // Encode data
699        let (root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
700
701        // Try to decode with min
702        let minimal = chunks.into_iter().take(min as usize).collect();
703        let decoded = decode::<Sha256>(total, min, &root, minimal).unwrap();
704        assert_eq!(decoded, data);
705    }
706
707    #[test]
708    fn test_large_data() {
709        let data = vec![42u8; 1000]; // 1KB of data
710        let total = 7u16;
711        let min = 4u16;
712
713        // Encode data
714        let (root, chunks) = encode::<Sha256>(total, min, data.clone()).unwrap();
715
716        // Try to decode with min
717        let minimal = chunks.into_iter().take(min as usize).collect();
718        let decoded = decode::<Sha256>(total, min, &root, minimal).unwrap();
719        assert_eq!(decoded, data);
720    }
721
722    #[test]
723    fn test_malicious_root_detection() {
724        let data = b"Original data that should be protected";
725        let total = 7u16;
726        let min = 4u16;
727
728        // Encode data correctly to get valid chunks
729        let (_correct_root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
730
731        // Create a malicious/fake root (simulating a malicious encoder)
732        let mut hasher = Sha256::new();
733        hasher.update(b"malicious_data_that_wasnt_actually_encoded");
734        let malicious_root = hasher.finalize();
735
736        // Verify all proofs at incorrect root
737        for i in 0..total {
738            assert!(!chunks[i as usize].verify(i, &malicious_root));
739        }
740
741        // Collect valid pieces (these are legitimate fragments)
742        let minimal = chunks.into_iter().take(min as usize).collect();
743
744        // Attempt to decode with malicious root
745        let result = decode::<Sha256>(total, min, &malicious_root, minimal);
746        assert!(matches!(result, Err(Error::InvalidProof)));
747    }
748
749    #[test]
750    fn test_manipulated_chunk_detection() {
751        let data = b"Data integrity must be maintained";
752        let total = 6u16;
753        let min = 3u16;
754
755        // Encode data
756        let (root, mut chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
757
758        // Tamper with one of the chunks by modifying the shard data
759        if !chunks[1].shard.is_empty() {
760            chunks[1].shard[0] ^= 0xFF; // Flip bits in first byte
761        }
762
763        // Try to decode with the tampered chunk
764        let result = decode::<Sha256>(total, min, &root, chunks);
765        assert!(matches!(result, Err(Error::InvalidProof)));
766    }
767
768    #[test]
769    fn test_inconsistent_shards() {
770        let data = b"Test data for malicious encoding";
771        let total = 5u16;
772        let min = 3u16;
773        let m = total - min;
774
775        // Compute original data encoding
776        let shards = prepare_data(data.to_vec(), min as usize, total as usize - min as usize);
777        let shard_size = shards[0].len();
778
779        // Re-encode the data
780        let mut encoder = ReedSolomonEncoder::new(min as usize, m as usize, shard_size).unwrap();
781        for shard in &shards {
782            encoder.add_original_shard(shard).unwrap();
783        }
784        let recovery_result = encoder.encode().unwrap();
785        let mut recovery_shards: Vec<Vec<u8>> = recovery_result
786            .recovery_iter()
787            .map(|s| s.to_vec())
788            .collect();
789
790        // Tamper with one recovery shard
791        if !recovery_shards[0].is_empty() {
792            recovery_shards[0][0] ^= 0xFF;
793        }
794
795        // Build malicious shards
796        let mut malicious_shards = shards.clone();
797        malicious_shards.extend(recovery_shards);
798
799        // Build malicious tree
800        let mut builder = Builder::<Sha256>::new(total as usize);
801        for shard in &malicious_shards {
802            let mut hasher = Sha256::new();
803            hasher.update(shard);
804            builder.add(&hasher.finalize());
805        }
806        let malicious_tree = builder.build();
807        let malicious_root = malicious_tree.root();
808
809        // Generate chunks for min pieces, including the tampered recovery
810        let selected_indices = vec![0, 1, 3]; // originals 0,1 and recovery 0 (index 3)
811        let mut pieces = Vec::new();
812        for &i in &selected_indices {
813            let merkle_proof = malicious_tree.proof(i as u32).unwrap();
814            let shard = malicious_shards[i].clone();
815            let chunk = Chunk::new(shard, i as u16, merkle_proof);
816            pieces.push(chunk);
817        }
818
819        // Fail to decode
820        let result = decode::<Sha256>(total, min, &malicious_root, pieces);
821        assert!(matches!(result, Err(Error::Inconsistent)));
822    }
823
824    #[test]
825    fn test_odd_shard_len() {
826        let data = b"a";
827        let total = 3u16;
828        let min = 2u16;
829
830        // Encode the data
831        let (root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
832
833        // Use a mix of original and recovery pieces
834        let pieces: Vec<_> = vec![
835            chunks[0].clone(), // original
836            chunks[2].clone(), // recovery
837        ];
838
839        // Try to decode with a mix of original and recovery pieces
840        let decoded = decode::<Sha256>(total, min, &root, pieces).unwrap();
841        assert_eq!(decoded, data);
842    }
843
844    #[test]
845    fn test_decode_invalid_index() {
846        let data = b"Testing recovery pieces";
847        let total = 8u16;
848        let min = 3u16;
849
850        // Encode the data
851        let (root, mut chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
852
853        // Use a mix of original and recovery pieces
854        chunks[1].index = 8;
855        let pieces: Vec<_> = vec![
856            chunks[0].clone(), // original
857            chunks[1].clone(), // recovery with invalid index
858            chunks[6].clone(), // recovery
859        ];
860
861        // Fail to decode
862        let result = decode::<Sha256>(total, min, &root, pieces);
863        assert!(matches!(result, Err(Error::InvalidIndex(8))));
864    }
865
866    #[test]
867    fn test_max_chunks() {
868        let data = vec![42u8; 1000]; // 1KB of data
869        let total = u16::MAX;
870        let min = u16::MAX / 2;
871
872        // Encode data
873        let (root, chunks) = encode::<Sha256>(total, min, data.clone()).unwrap();
874
875        // Try to decode with min
876        let minimal = chunks.into_iter().take(min as usize).collect();
877        let decoded = decode::<Sha256>(total, min, &root, minimal).unwrap();
878        assert_eq!(decoded, data);
879    }
880
881    #[test]
882    fn test_too_many_chunks() {
883        let data = vec![42u8; 1000]; // 1KB of data
884        let total = u16::MAX;
885        let min = u16::MAX / 2 - 1;
886
887        // Encode data
888        let result = encode::<Sha256>(total, min, data.clone());
889        assert!(matches!(
890            result,
891            Err(Error::ReedSolomon(
892                reed_solomon_simd::Error::UnsupportedShardCount {
893                    original_count: _,
894                    recovery_count: _,
895                }
896            ))
897        ));
898    }
899
900    #[test]
901    fn test_translate() {
902        let data = b"Test data for minimal proof functionality";
903        let total = 7u16;
904        let min = 4u16;
905
906        // Encode with minimal proof
907        let (root, proof) = translate::<Sha256>(total, min, data.to_vec()).unwrap();
908
909        // Verify minimal proof
910        assert!(proof.verify(min, &root));
911
912        // Extract data from minimal proof
913        let extracted = proof.extract();
914        assert_eq!(extracted, data);
915    }
916
917    #[test]
918    fn test_translate_wrong_root() {
919        let data = b"Test data for minimal proof";
920        let total = 5u16;
921        let min = 3u16;
922
923        // Encode with minimal proof
924        let (_, proof) = translate::<Sha256>(total, min, data.to_vec()).unwrap();
925
926        // Try to verify with wrong root
927        let mut hasher = Sha256::new();
928        hasher.update(b"wrong_root");
929        let wrong_root = hasher.finalize();
930
931        assert!(!proof.verify(min, &wrong_root));
932    }
933
934    #[test]
935    fn test_translate_serialization() {
936        let data = b"Test serialization of minimal proof";
937        let total = 5u16;
938        let min = 3u16;
939
940        // Encode with minimal proof
941        let (root, proof) = translate::<Sha256>(total, min, data.to_vec()).unwrap();
942
943        // Serialize
944        let serialized = proof.encode();
945
946        // Deserialize
947        let max_shard_size = proof.shards[0].len();
948        let deserialized = Data::<Sha256>::decode_cfg(serialized, &max_shard_size).unwrap();
949
950        // Verify deserialized proof
951        assert!(deserialized.verify(min, &root));
952        assert_eq!(deserialized.extract(), data);
953    }
954}