1use crate::{Config, Scheme};
2use bytes::{Buf, BufMut};
3use commonware_codec::{EncodeSize, FixedSize, Read, ReadExt, ReadRangeExt, Write};
4use commonware_cryptography::Hasher;
5use commonware_storage::bmt::{self, Builder};
6use reed_solomon_simd::{Error as RsError, ReedSolomonDecoder, ReedSolomonEncoder};
7use std::{collections::HashSet, marker::PhantomData};
8use thiserror::Error;
9
10#[derive(Error, Debug)]
12pub enum Error {
13 #[error("reed-solomon error: {0}")]
14 ReedSolomon(#[from] RsError),
15 #[error("inconsistent")]
16 Inconsistent,
17 #[error("invalid proof")]
18 InvalidProof,
19 #[error("not enough chunks")]
20 NotEnoughChunks,
21 #[error("duplicate chunk index: {0}")]
22 DuplicateIndex(u16),
23 #[error("invalid data length: {0}")]
24 InvalidDataLength(usize),
25 #[error("invalid index: {0}")]
26 InvalidIndex(u16),
27 #[error("wrong index: {0}")]
28 WrongIndex(u16),
29}
30
31#[derive(Debug, Clone)]
33pub struct Chunk<H: Hasher> {
34 shard: Vec<u8>,
36
37 index: u16,
39
40 proof: bmt::Proof<H>,
42}
43
44impl<H: Hasher> Chunk<H> {
45 fn new(shard: Vec<u8>, index: u16, proof: bmt::Proof<H>) -> Self {
47 Self {
48 shard,
49 index,
50 proof,
51 }
52 }
53
54 fn verify(&self, index: u16, root: &H::Digest) -> bool {
56 if index != self.index {
58 return false;
59 }
60
61 let mut hasher = H::new();
63 hasher.update(&self.shard);
64 let shard_digest = hasher.finalize();
65
66 self.proof
68 .verify(&mut hasher, &shard_digest, self.index as u32, root)
69 .is_ok()
70 }
71}
72
73impl<H: Hasher> Write for Chunk<H> {
74 fn write(&self, writer: &mut impl BufMut) {
75 self.shard.write(writer);
76 self.index.write(writer);
77 self.proof.write(writer);
78 }
79}
80
81impl<H: Hasher> Read for Chunk<H> {
82 type Cfg = usize;
84
85 fn read_cfg(reader: &mut impl Buf, cfg: &Self::Cfg) -> Result<Self, commonware_codec::Error> {
86 let shard = Vec::<u8>::read_range(reader, ..=*cfg)?;
87 let index = u16::read(reader)?;
88 let proof = bmt::Proof::<H>::read(reader)?;
89 Ok(Self {
90 shard,
91 index,
92 proof,
93 })
94 }
95}
96
97impl<H: Hasher> EncodeSize for Chunk<H> {
98 fn encode_size(&self) -> usize {
99 self.shard.encode_size() + self.index.encode_size() + self.proof.encode_size()
100 }
101}
102
103impl<H: Hasher> PartialEq for Chunk<H> {
104 fn eq(&self, other: &Self) -> bool {
105 self.shard == other.shard && self.index == other.index && self.proof == other.proof
106 }
107}
108
109impl<H: Hasher> Eq for Chunk<H> {}
110
111fn prepare_data(data: Vec<u8>, k: usize, m: usize) -> Vec<Vec<u8>> {
113 let data_len = data.len();
115 let prefixed_len = u32::SIZE + data_len;
116 let mut shard_len = prefixed_len.div_ceil(k);
117
118 if !shard_len.is_multiple_of(2) {
120 shard_len += 1;
121 }
122
123 let length_bytes = (data_len as u32).to_be_bytes();
125 let mut src = length_bytes.into_iter().chain(data);
126 let mut shards = Vec::with_capacity(k + m); for _ in 0..k {
128 let mut shard = Vec::with_capacity(shard_len);
129 for _ in 0..shard_len {
130 shard.push(src.next().unwrap_or(0));
131 }
132 shards.push(shard);
133 }
134 shards
135}
136
137fn extract_data(shards: Vec<Vec<u8>>, k: usize) -> Vec<u8> {
139 let mut data = shards.into_iter().take(k).flatten();
141
142 let data_len = (&mut data)
144 .take(u32::SIZE)
145 .collect::<Vec<_>>()
146 .try_into()
147 .expect("insufficient data");
148 let data_len = u32::from_be_bytes(data_len) as usize;
149
150 data.take(data_len).collect()
152}
153
154type Encoding<H> = (bmt::Tree<H>, Vec<Vec<u8>>);
156
157fn encode_inner<H: Hasher>(total: u16, min: u16, data: Vec<u8>) -> Result<Encoding<H>, Error> {
159 assert!(total > min);
161 assert!(min > 0);
162 let n = total as usize;
163 let k = min as usize;
164 let m = n - k;
165 if data.len() > u32::MAX as usize {
166 return Err(Error::InvalidDataLength(data.len()));
167 }
168
169 let mut shards = prepare_data(data, k, m);
171 let shard_len = shards[0].len();
172
173 let mut encoder = ReedSolomonEncoder::new(k, m, shard_len).map_err(Error::ReedSolomon)?;
175 for shard in &shards {
176 encoder
177 .add_original_shard(shard)
178 .map_err(Error::ReedSolomon)?;
179 }
180
181 let encoding = encoder.encode().map_err(Error::ReedSolomon)?;
183 let recovery_shards: Vec<Vec<u8>> = encoding
184 .recovery_iter()
185 .map(|shard| shard.to_vec())
186 .collect();
187 shards.extend(recovery_shards);
188
189 let mut builder = Builder::<H>::new(n);
191 let mut hasher = H::new();
192 for shard in &shards {
193 builder.add(&{
194 hasher.update(shard);
195 hasher.finalize()
196 });
197 }
198 let tree = builder.build();
199
200 Ok((tree, shards))
201}
202
203fn encode<H: Hasher>(
216 total: u16,
217 min: u16,
218 data: Vec<u8>,
219) -> Result<(H::Digest, Vec<Chunk<H>>), Error> {
220 let (tree, shards) = encode_inner::<H>(total, min, data)?;
222 let root = tree.root();
223 let n = total as usize;
224
225 let mut chunks = Vec::with_capacity(n);
227 for (i, shard) in shards.into_iter().enumerate() {
228 let proof = tree.proof(i as u32).map_err(|_| Error::InvalidProof)?;
229 chunks.push(Chunk::new(shard, i as u16, proof));
230 }
231
232 Ok((root, chunks))
233}
234
235fn decode<H: Hasher>(
250 total: u16,
251 min: u16,
252 root: &H::Digest,
253 chunks: &[Chunk<H>],
254) -> Result<Vec<u8>, Error> {
255 assert!(total > min);
257 assert!(min > 0);
258 let n = total as usize;
259 let k = min as usize;
260 let m = n - k;
261 if chunks.len() < k {
262 return Err(Error::NotEnoughChunks);
263 }
264
265 let shard_len = chunks[0].shard.len();
267 let mut seen = HashSet::new();
268 let mut provided_originals: Vec<(usize, Vec<u8>)> = Vec::new();
269 let mut provided_recoveries: Vec<(usize, Vec<u8>)> = Vec::new();
270 for chunk in chunks {
271 let index = chunk.index;
273 if index >= total {
274 return Err(Error::InvalidIndex(index));
275 }
276 if seen.contains(&index) {
277 return Err(Error::DuplicateIndex(index));
278 }
279 seen.insert(index);
280
281 if index < min {
283 provided_originals.push((index as usize, chunk.shard.clone()));
284 } else {
285 provided_recoveries.push((index as usize - k, chunk.shard.clone()));
286 }
287 }
288
289 let mut decoder = ReedSolomonDecoder::new(k, m, shard_len).map_err(Error::ReedSolomon)?;
291 for (idx, ref shard) in &provided_originals {
292 decoder
293 .add_original_shard(*idx, shard)
294 .map_err(Error::ReedSolomon)?;
295 }
296 for (idx, ref shard) in &provided_recoveries {
297 decoder
298 .add_recovery_shard(*idx, shard)
299 .map_err(Error::ReedSolomon)?;
300 }
301 let decoding = decoder.decode().map_err(Error::ReedSolomon)?;
302
303 let mut shards = Vec::with_capacity(n);
305 shards.resize(k, Vec::new());
306 for (idx, shard) in provided_originals {
307 shards[idx] = shard;
308 }
309 for (idx, shard) in decoding.restored_original_iter() {
310 shards[idx] = shard.to_vec();
311 }
312
313 let mut encoder = ReedSolomonEncoder::new(k, m, shard_len).map_err(Error::ReedSolomon)?;
315 for shard in shards.iter().take(k) {
316 encoder
317 .add_original_shard(shard)
318 .map_err(Error::ReedSolomon)?;
319 }
320 let encoding = encoder.encode().map_err(Error::ReedSolomon)?;
321 let recovery_shards: Vec<Vec<u8>> = encoding
322 .recovery_iter()
323 .map(|shard| shard.to_vec())
324 .collect();
325 shards.extend(recovery_shards);
326
327 let mut builder = Builder::<H>::new(n);
329 let mut hasher = H::new();
330 for shard in &shards {
331 builder.add(&{
332 hasher.update(shard);
333 hasher.finalize()
334 });
335 }
336 let computed_tree = builder.build();
337
338 if computed_tree.root() != *root {
340 return Err(Error::Inconsistent);
341 }
342
343 Ok(extract_data(shards, k))
345}
346
347#[derive(Clone, Copy)]
430pub struct ReedSolomon<H> {
431 _marker: PhantomData<H>,
432}
433
434impl<H> std::fmt::Debug for ReedSolomon<H> {
435 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
436 f.debug_struct("ReedSolomon").finish()
437 }
438}
439
440impl<H: Hasher> Scheme for ReedSolomon<H> {
441 type Commitment = H::Digest;
442
443 type Shard = Chunk<H>;
444 type ReShard = Chunk<H>;
445 type CheckedShard = Chunk<H>;
446 type CheckingData = ();
447
448 type Error = Error;
449
450 fn encode(
451 config: &Config,
452 mut data: impl Buf,
453 ) -> Result<(Self::Commitment, Vec<Self::Shard>), Self::Error> {
454 let data: Vec<u8> = data.copy_to_bytes(data.remaining()).to_vec();
455 encode(
456 config.minimum_shards + config.extra_shards,
457 config.minimum_shards,
458 data,
459 )
460 }
461
462 fn reshard(
463 _config: &Config,
464 commitment: &Self::Commitment,
465 index: u16,
466 shard: Self::Shard,
467 ) -> Result<(Self::CheckingData, Self::CheckedShard, Self::ReShard), Self::Error> {
468 if shard.index != index {
469 return Err(Error::WrongIndex(index));
470 }
471 if shard.verify(shard.index, commitment) {
472 Ok(((), shard.clone(), shard))
473 } else {
474 Err(Error::InvalidProof)
475 }
476 }
477
478 fn check(
479 _config: &Config,
480 commitment: &Self::Commitment,
481 _checking_data: &Self::CheckingData,
482 index: u16,
483 reshard: Self::ReShard,
484 ) -> Result<Self::CheckedShard, Self::Error> {
485 if reshard.index != index {
486 return Err(Error::WrongIndex(reshard.index));
487 }
488 if !reshard.verify(reshard.index, commitment) {
489 return Err(Error::InvalidProof);
490 }
491 Ok(reshard)
492 }
493
494 fn decode(
495 config: &Config,
496 commitment: &Self::Commitment,
497 _checking_data: Self::CheckingData,
498 shards: &[Self::CheckedShard],
499 ) -> Result<Vec<u8>, Self::Error> {
500 decode(
501 config.minimum_shards + config.extra_shards,
502 config.minimum_shards,
503 commitment,
504 shards,
505 )
506 }
507}
508
509#[cfg(test)]
510mod tests {
511 use super::*;
512 use commonware_cryptography::Sha256;
513
514 #[test]
515 fn test_recovery() {
516 let data = b"Testing recovery pieces";
517 let total = 8u16;
518 let min = 3u16;
519
520 let (root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
522
523 let pieces: Vec<_> = vec![
525 chunks[0].clone(), chunks[4].clone(), chunks[6].clone(), ];
529
530 let decoded = decode::<Sha256>(total, min, &root, &pieces).unwrap();
532 assert_eq!(decoded, data);
533 }
534
535 #[test]
536 fn test_not_enough_pieces() {
537 let data = b"Test insufficient pieces";
538 let total = 6u16;
539 let min = 4u16;
540
541 let (root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
543
544 let pieces: Vec<_> = chunks.into_iter().take(2).collect();
546
547 let result = decode::<Sha256>(total, min, &root, &pieces);
549 assert!(matches!(result, Err(Error::NotEnoughChunks)));
550 }
551
552 #[test]
553 fn test_duplicate_index() {
554 let data = b"Test duplicate detection";
555 let total = 5u16;
556 let min = 3u16;
557
558 let (root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
560
561 let pieces = vec![chunks[0].clone(), chunks[0].clone(), chunks[1].clone()];
563
564 let result = decode::<Sha256>(total, min, &root, &pieces);
566 assert!(matches!(result, Err(Error::DuplicateIndex(0))));
567 }
568
569 #[test]
570 fn test_invalid_index() {
571 let data = b"Test invalid index";
572 let total = 5u16;
573 let min = 3u16;
574
575 let (root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
577
578 for i in 0..total {
580 assert!(!chunks[i as usize].verify(i + 1, &root));
581 }
582 }
583
584 #[test]
585 #[should_panic(expected = "assertion failed: total > min")]
586 fn test_invalid_total() {
587 let data = b"Test parameter validation";
588
589 encode::<Sha256>(3, 3, data.to_vec()).unwrap();
591 }
592
593 #[test]
594 #[should_panic(expected = "assertion failed: min > 0")]
595 fn test_invalid_min() {
596 let data = b"Test parameter validation";
597
598 encode::<Sha256>(5, 0, data.to_vec()).unwrap();
600 }
601
602 #[test]
603 fn test_empty_data() {
604 let data = b"";
605 let total = 100u16;
606 let min = 30u16;
607
608 let (root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
610
611 let minimal = chunks.into_iter().take(min as usize).collect::<Vec<_>>();
613 let decoded = decode::<Sha256>(total, min, &root, &minimal).unwrap();
614 assert_eq!(decoded, data);
615 }
616
617 #[test]
618 fn test_large_data() {
619 let data = vec![42u8; 1000]; let total = 7u16;
621 let min = 4u16;
622
623 let (root, chunks) = encode::<Sha256>(total, min, data.clone()).unwrap();
625
626 let minimal = chunks.into_iter().take(min as usize).collect::<Vec<_>>();
628 let decoded = decode::<Sha256>(total, min, &root, &minimal).unwrap();
629 assert_eq!(decoded, data);
630 }
631
632 #[test]
633 fn test_malicious_root_detection() {
634 let data = b"Original data that should be protected";
635 let total = 7u16;
636 let min = 4u16;
637
638 let (_correct_root, chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
640
641 let mut hasher = Sha256::new();
643 hasher.update(b"malicious_data_that_wasnt_actually_encoded");
644 let malicious_root = hasher.finalize();
645
646 for i in 0..total {
648 assert!(!chunks[i as usize].verify(i, &malicious_root));
649 }
650
651 let minimal = chunks.into_iter().take(min as usize).collect::<Vec<_>>();
653
654 let result = decode::<Sha256>(total, min, &malicious_root, &minimal);
656 assert!(matches!(result, Err(Error::Inconsistent)));
657 }
658
659 #[test]
660 fn test_manipulated_chunk_detection() {
661 let data = b"Data integrity must be maintained";
662 let total = 6u16;
663 let min = 3u16;
664
665 let (root, mut chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
667
668 if !chunks[1].shard.is_empty() {
670 chunks[1].shard[0] ^= 0xFF; }
672
673 let result = decode::<Sha256>(total, min, &root, &chunks);
675 assert!(matches!(result, Err(Error::Inconsistent)));
676 }
677
678 #[test]
679 fn test_inconsistent_shards() {
680 let data = b"Test data for malicious encoding";
681 let total = 5u16;
682 let min = 3u16;
683 let m = total - min;
684
685 let shards = prepare_data(data.to_vec(), min as usize, total as usize - min as usize);
687 let shard_size = shards[0].len();
688
689 let mut encoder = ReedSolomonEncoder::new(min as usize, m as usize, shard_size).unwrap();
691 for shard in &shards {
692 encoder.add_original_shard(shard).unwrap();
693 }
694 let recovery_result = encoder.encode().unwrap();
695 let mut recovery_shards: Vec<Vec<u8>> = recovery_result
696 .recovery_iter()
697 .map(|s| s.to_vec())
698 .collect();
699
700 if !recovery_shards[0].is_empty() {
702 recovery_shards[0][0] ^= 0xFF;
703 }
704
705 let mut malicious_shards = shards.clone();
707 malicious_shards.extend(recovery_shards);
708
709 let mut builder = Builder::<Sha256>::new(total as usize);
711 for shard in &malicious_shards {
712 let mut hasher = Sha256::new();
713 hasher.update(shard);
714 builder.add(&hasher.finalize());
715 }
716 let malicious_tree = builder.build();
717 let malicious_root = malicious_tree.root();
718
719 let selected_indices = vec![0, 1, 3]; let mut pieces = Vec::new();
722 for &i in &selected_indices {
723 let merkle_proof = malicious_tree.proof(i as u32).unwrap();
724 let shard = malicious_shards[i].clone();
725 let chunk = Chunk::new(shard, i as u16, merkle_proof);
726 pieces.push(chunk);
727 }
728
729 let result = decode::<Sha256>(total, min, &malicious_root, &pieces);
731 assert!(matches!(result, Err(Error::Inconsistent)));
732 }
733
734 #[test]
735 fn test_decode_invalid_index() {
736 let data = b"Testing recovery pieces";
737 let total = 8u16;
738 let min = 3u16;
739
740 let (root, mut chunks) = encode::<Sha256>(total, min, data.to_vec()).unwrap();
742
743 chunks[1].index = 8;
745 let pieces: Vec<_> = vec![
746 chunks[0].clone(), chunks[1].clone(), chunks[6].clone(), ];
750
751 let result = decode::<Sha256>(total, min, &root, &pieces);
753 assert!(matches!(result, Err(Error::InvalidIndex(8))));
754 }
755
756 #[test]
757 fn test_max_chunks() {
758 let data = vec![42u8; 1000]; let total = u16::MAX;
760 let min = u16::MAX / 2;
761
762 let (root, chunks) = encode::<Sha256>(total, min, data.clone()).unwrap();
764
765 let minimal = chunks.into_iter().take(min as usize).collect::<Vec<_>>();
767 let decoded = decode::<Sha256>(total, min, &root, &minimal).unwrap();
768 assert_eq!(decoded, data);
769 }
770
771 #[test]
772 fn test_too_many_chunks() {
773 let data = vec![42u8; 1000]; let total = u16::MAX;
775 let min = u16::MAX / 2 - 1;
776
777 let result = encode::<Sha256>(total, min, data.clone());
779 assert!(matches!(
780 result,
781 Err(Error::ReedSolomon(
782 reed_solomon_simd::Error::UnsupportedShardCount {
783 original_count: _,
784 recovery_count: _,
785 }
786 ))
787 ));
788 }
789}