1#[cfg(any(feature = "std", test))]
8use crate::mmr::iterator::nodes_to_pin;
9use crate::mmr::{
10 hasher::Hasher,
11 iterator::{PathIterator, PeakIterator},
12 Error, Location, Position,
13};
14use alloc::{
15 collections::{btree_map::BTreeMap, btree_set::BTreeSet},
16 vec,
17 vec::Vec,
18};
19use bytes::{Buf, BufMut};
20use commonware_codec::{varint::UInt, EncodeSize, Read, ReadExt, ReadRangeExt, Write};
21use commonware_cryptography::Digest;
22use core::ops::Range;
23#[cfg(feature = "std")]
24use tracing::debug;
25
26#[derive(Error, Debug)]
28pub enum ReconstructionError {
29 #[error("missing digests in proof")]
30 MissingDigests,
31 #[error("extra digests in proof")]
32 ExtraDigests,
33 #[error("start location is out of bounds")]
34 InvalidStartLoc,
35 #[error("end location is out of bounds")]
36 InvalidEndLoc,
37 #[error("missing elements")]
38 MissingElements,
39 #[error("invalid size")]
40 InvalidSize,
41}
42
43#[derive(Clone, Debug, Eq)]
54pub struct Proof<D: Digest> {
55 pub size: Position,
60 pub digests: Vec<D>,
63}
64
65impl<D: Digest> PartialEq for Proof<D> {
66 fn eq(&self, other: &Self) -> bool {
67 self.size == other.size && self.digests == other.digests
68 }
69}
70
71impl<D: Digest> EncodeSize for Proof<D> {
72 fn encode_size(&self) -> usize {
73 UInt(*self.size).encode_size() + self.digests.encode_size()
74 }
75}
76
77impl<D: Digest> Write for Proof<D> {
78 fn write(&self, buf: &mut impl BufMut) {
79 UInt(*self.size).write(buf);
81
82 self.digests.write(buf);
84 }
85}
86
87impl<D: Digest> Read for Proof<D> {
88 type Cfg = usize;
90
91 fn read_cfg(buf: &mut impl Buf, max_len: &Self::Cfg) -> Result<Self, commonware_codec::Error> {
92 let size = Position::new(UInt::<u64>::read(buf)?.into());
94
95 let range = ..=max_len;
97 let digests = Vec::<D>::read_range(buf, range)?;
98 Ok(Self { size, digests })
99 }
100}
101
102impl<D: Digest> Default for Proof<D> {
103 fn default() -> Self {
106 Self {
107 size: Position::new(0),
108 digests: vec![],
109 }
110 }
111}
112
113#[cfg(feature = "arbitrary")]
114impl<D: Digest> arbitrary::Arbitrary<'_> for Proof<D>
115where
116 D: for<'a> arbitrary::Arbitrary<'a>,
117{
118 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
119 Ok(Self {
120 size: u.arbitrary()?,
121 digests: u.arbitrary()?,
122 })
123 }
124}
125
126impl<D: Digest> Proof<D> {
127 pub fn verify_element_inclusion<H>(
130 &self,
131 hasher: &mut H,
132 element: &[u8],
133 loc: Location,
134 root: &D,
135 ) -> bool
136 where
137 H: Hasher<D>,
138 {
139 self.verify_range_inclusion(hasher, &[element], loc, root)
140 }
141
142 pub fn verify_range_inclusion<H, E>(
146 &self,
147 hasher: &mut H,
148 elements: &[E],
149 start_loc: Location,
150 root: &D,
151 ) -> bool
152 where
153 H: Hasher<D>,
154 E: AsRef<[u8]>,
155 {
156 if !self.size.is_mmr_size() {
157 #[cfg(feature = "std")]
158 debug!(size = ?self.size, "invalid proof size");
159 return false;
160 }
161
162 match self.reconstruct_root(hasher, elements, start_loc) {
163 Ok(reconstructed_root) => *root == reconstructed_root,
164 Err(_error) => {
165 #[cfg(feature = "std")]
166 tracing::debug!(error = ?_error, "invalid proof input");
167 false
168 }
169 }
170 }
171
172 pub fn verify_multi_inclusion<H, E>(
177 &self,
178 hasher: &mut H,
179 elements: &[(E, Location)],
180 root: &D,
181 ) -> bool
182 where
183 H: Hasher<D>,
184 E: AsRef<[u8]>,
185 {
186 if elements.is_empty() {
188 return self.size == Position::new(0)
189 && *root == hasher.root(Position::new(0), core::iter::empty());
190 }
191 if !self.size.is_mmr_size() {
192 return false;
193 }
194
195 let mut node_positions = BTreeSet::new();
197 let mut nodes_required = BTreeMap::new();
198
199 for (_, loc) in elements {
200 if !loc.is_valid() {
201 return false;
202 }
203 let Ok(required) = nodes_required_for_range_proof(self.size, *loc..*loc + 1) else {
205 return false;
206 };
207 for req_pos in &required {
208 node_positions.insert(*req_pos);
209 }
210 nodes_required.insert(*loc, required);
211 }
212
213 if node_positions.len() != self.digests.len() {
215 return false;
216 }
217
218 let node_digests: BTreeMap<Position, D> = node_positions
220 .iter()
221 .zip(self.digests.iter())
222 .map(|(&pos, digest)| (pos, *digest))
223 .collect();
224
225 for (element, loc) in elements {
227 let required = &nodes_required[loc];
229
230 let mut digests = Vec::with_capacity(required.len());
232 for req_pos in required {
233 let digest = node_digests
236 .get(req_pos)
237 .expect("must exist by construction of node_digests");
238 digests.push(*digest);
239 }
240 let proof = Self {
241 size: self.size,
242 digests,
243 };
244
245 if !proof.verify_element_inclusion(hasher, element.as_ref(), *loc, root) {
247 return false;
248 }
249 }
250
251 true
252 }
253
254 #[cfg(any(feature = "std", test))]
276 pub(crate) fn extract_pinned_nodes(
277 &self,
278 range: std::ops::Range<Location>,
279 ) -> Result<Vec<D>, Error> {
280 let start_pos = Position::try_from(range.start)?;
282 let pinned_positions: Vec<Position> = nodes_to_pin(start_pos).collect();
283
284 let required_positions = nodes_required_for_range_proof(self.size, range)?;
286
287 if required_positions.len() != self.digests.len() {
288 #[cfg(feature = "std")]
289 debug!(
290 digests_len = self.digests.len(),
291 required_positions_len = required_positions.len(),
292 "Proof digest count doesn't match required positions",
293 );
294 return Err(Error::InvalidProofLength);
295 }
296
297 if pinned_positions
300 == required_positions[required_positions.len() - pinned_positions.len()..]
301 {
302 return Ok(self.digests[required_positions.len() - pinned_positions.len()..].to_vec());
303 }
304
305 let position_to_digest: BTreeMap<Position, D> = required_positions
307 .iter()
308 .zip(self.digests.iter())
309 .map(|(&pos, &digest)| (pos, digest))
310 .collect();
311
312 let mut result = Vec::with_capacity(pinned_positions.len());
314 for pinned_pos in pinned_positions {
315 let Some(&digest) = position_to_digest.get(&pinned_pos) else {
316 #[cfg(feature = "std")]
317 debug!(?pinned_pos, "Pinned node not found in proof");
318 return Err(Error::MissingDigest(pinned_pos));
319 };
320 result.push(digest);
321 }
322 Ok(result)
323 }
324
325 pub fn verify_range_inclusion_and_extract_digests<H, E>(
331 &self,
332 hasher: &mut H,
333 elements: &[E],
334 start_loc: Location,
335 root: &D,
336 ) -> Result<Vec<(Position, D)>, super::Error>
337 where
338 H: Hasher<D>,
339 E: AsRef<[u8]>,
340 {
341 let mut collected_digests = Vec::new();
342 let Ok(peak_digests) = self.reconstruct_peak_digests(
343 hasher,
344 elements,
345 start_loc,
346 Some(&mut collected_digests),
347 ) else {
348 return Err(Error::InvalidProof);
349 };
350
351 if hasher.root(self.size, peak_digests.iter()) != *root {
352 return Err(Error::RootMismatch);
353 }
354
355 Ok(collected_digests)
356 }
357
358 pub fn reconstruct_root<H, E>(
361 &self,
362 hasher: &mut H,
363 elements: &[E],
364 start_loc: Location,
365 ) -> Result<D, ReconstructionError>
366 where
367 H: Hasher<D>,
368 E: AsRef<[u8]>,
369 {
370 if !self.size.is_mmr_size() {
371 return Err(ReconstructionError::InvalidSize);
372 }
373
374 let peak_digests = self.reconstruct_peak_digests(hasher, elements, start_loc, None)?;
375
376 Ok(hasher.root(self.size, peak_digests.iter()))
377 }
378
379 pub fn reconstruct_peak_digests<H, E>(
383 &self,
384 hasher: &mut H,
385 elements: &[E],
386 start_loc: Location,
387 mut collected_digests: Option<&mut Vec<(Position, D)>>,
388 ) -> Result<Vec<D>, ReconstructionError>
389 where
390 H: Hasher<D>,
391 E: AsRef<[u8]>,
392 {
393 if elements.is_empty() {
394 if start_loc == 0 {
395 return Ok(vec![]);
396 }
397 return Err(ReconstructionError::MissingElements);
398 }
399 let start_element_pos =
400 Position::try_from(start_loc).map_err(|_| ReconstructionError::InvalidStartLoc)?;
401 let end_element_pos = if elements.len() == 1 {
402 start_element_pos
403 } else {
404 let end_loc = start_loc
405 .checked_add(elements.len() as u64 - 1)
406 .ok_or(ReconstructionError::InvalidEndLoc)?;
407 Position::try_from(end_loc).map_err(|_| ReconstructionError::InvalidEndLoc)?
408 };
409 if end_element_pos >= self.size {
410 return Err(ReconstructionError::InvalidEndLoc);
411 }
412
413 let mut proof_digests_iter = self.digests.iter();
414 let mut siblings_iter = self.digests.iter().rev();
415
416 let mut peak_digests: Vec<D> = Vec::new();
419 let mut proof_digests_used = 0;
420 let mut elements_iter = elements.iter();
421 if !Position::is_mmr_size(self.size) {
422 return Err(ReconstructionError::InvalidSize);
423 }
424 for (peak_pos, height) in PeakIterator::new(self.size) {
425 let leftmost_pos = peak_pos + 2 - (1 << (height + 1));
426 if peak_pos >= start_element_pos && leftmost_pos <= end_element_pos {
427 let hash = peak_digest_from_range(
428 hasher,
429 RangeInfo {
430 pos: peak_pos,
431 two_h: 1 << height,
432 leftmost_pos: start_element_pos,
433 rightmost_pos: end_element_pos,
434 },
435 &mut elements_iter,
436 &mut siblings_iter,
437 collected_digests.as_deref_mut(),
438 )?;
439 peak_digests.push(hash);
440 if let Some(ref mut collected_digests) = collected_digests {
441 collected_digests.push((peak_pos, hash));
442 }
443 } else if let Some(hash) = proof_digests_iter.next() {
444 proof_digests_used += 1;
445 peak_digests.push(*hash);
446 if let Some(ref mut collected_digests) = collected_digests {
447 collected_digests.push((peak_pos, *hash));
448 }
449 } else {
450 return Err(ReconstructionError::MissingDigests);
451 }
452 }
453
454 if elements_iter.next().is_some() {
455 return Err(ReconstructionError::ExtraDigests);
456 }
457 if let Some(next_sibling) = siblings_iter.next() {
458 if proof_digests_used == 0 || *next_sibling != self.digests[proof_digests_used - 1] {
459 return Err(ReconstructionError::ExtraDigests);
460 }
461 }
462
463 Ok(peak_digests)
464 }
465}
466
467pub(crate) fn nodes_required_for_range_proof(
478 size: Position,
479 range: Range<Location>,
480) -> Result<Vec<Position>, Error> {
481 if !size.is_mmr_size() {
482 return Err(Error::InvalidSize(*size));
483 }
484 if range.is_empty() {
485 return Err(Error::Empty);
486 }
487 let start_element_pos = Position::try_from(range.start)?;
488 let end_minus_one = range
489 .end
490 .checked_sub(1)
491 .expect("can't underflow because range is non-empty");
492 let end_element_pos = Position::try_from(end_minus_one)?;
493 if end_element_pos >= size {
494 return Err(Error::RangeOutOfBounds(range.end));
495 }
496
497 let mut start_tree_with_element: Option<(Position, u32)> = None;
500 let mut end_tree_with_element: Option<(Position, u32)> = None;
501 let mut peak_iterator = PeakIterator::new(size);
502 let mut positions = Vec::new();
503 while let Some(peak) = peak_iterator.next() {
504 if start_tree_with_element.is_none() && peak.0 >= start_element_pos {
505 start_tree_with_element = Some(peak);
507 if peak.0 >= end_element_pos {
508 end_tree_with_element = Some(peak);
510 continue;
511 }
512 for peak in peak_iterator.by_ref() {
513 if peak.0 >= end_element_pos {
514 end_tree_with_element = Some(peak);
516 break;
517 }
518 }
519 } else {
520 positions.push(peak.0);
522 }
523 }
524
525 let (start_tree_peak, start_tree_height) =
528 start_tree_with_element.expect("start_tree_with_element is Some");
529 let (end_tree_peak, end_tree_height) =
530 end_tree_with_element.expect("end_tree_with_element is Some");
531
532 let left_path_iter = PathIterator::new(start_element_pos, start_tree_peak, start_tree_height);
536
537 let mut siblings = Vec::new();
538 if start_element_pos == end_element_pos {
539 siblings.extend(left_path_iter);
542 } else {
543 let right_path_iter = PathIterator::new(end_element_pos, end_tree_peak, end_tree_height);
544 siblings.extend(right_path_iter.filter(|(parent_pos, pos)| *parent_pos == *pos + 1));
546 siblings.extend(left_path_iter.filter(|(parent_pos, pos)| *parent_pos != *pos + 1));
548
549 if start_tree_peak == end_tree_peak {
552 siblings.sort_by(|a, b| b.0.cmp(&a.0));
553 }
554 }
555 positions.extend(siblings.into_iter().map(|(_, pos)| pos));
556 Ok(positions)
557}
558
559#[cfg(any(feature = "std", test))]
571pub(crate) fn nodes_required_for_multi_proof(
572 size: Position,
573 locations: &[Location],
574) -> Result<BTreeSet<Position>, Error> {
575 if !size.is_mmr_size() {
576 return Err(Error::InvalidSize(*size));
577 }
578 if locations.is_empty() {
582 return Err(Error::Empty);
583 }
584 locations.iter().try_fold(BTreeSet::new(), |mut acc, loc| {
585 if !loc.is_valid() {
586 return Err(Error::LocationOverflow(*loc));
587 }
588 let positions = nodes_required_for_range_proof(size, *loc..*loc + 1)?;
590 acc.extend(positions);
591 Ok(acc)
592 })
593}
594
595struct RangeInfo {
597 pos: Position, two_h: u64, leftmost_pos: Position, rightmost_pos: Position, }
602
603fn peak_digest_from_range<'a, D, H, E, S>(
604 hasher: &mut H,
605 range_info: RangeInfo,
606 elements: &mut E,
607 sibling_digests: &mut S,
608 mut collected_digests: Option<&mut Vec<(Position, D)>>,
609) -> Result<D, ReconstructionError>
610where
611 D: Digest,
612 H: Hasher<D>,
613 E: Iterator<Item: AsRef<[u8]>>,
614 S: Iterator<Item = &'a D>,
615{
616 assert_ne!(range_info.two_h, 0);
617 if range_info.two_h == 1 {
618 match elements.next() {
619 Some(element) => return Ok(hasher.leaf_digest(range_info.pos, element.as_ref())),
620 None => return Err(ReconstructionError::MissingDigests),
621 }
622 }
623
624 let mut left_digest: Option<D> = None;
625 let mut right_digest: Option<D> = None;
626
627 let left_pos = range_info.pos - range_info.two_h;
628 let right_pos = left_pos + range_info.two_h - 1;
629 if left_pos >= range_info.leftmost_pos {
630 let digest = peak_digest_from_range(
632 hasher,
633 RangeInfo {
634 pos: left_pos,
635 two_h: range_info.two_h >> 1,
636 leftmost_pos: range_info.leftmost_pos,
637 rightmost_pos: range_info.rightmost_pos,
638 },
639 elements,
640 sibling_digests,
641 collected_digests.as_deref_mut(),
642 )?;
643 left_digest = Some(digest);
644 }
645 if left_pos < range_info.rightmost_pos {
646 let digest = peak_digest_from_range(
648 hasher,
649 RangeInfo {
650 pos: right_pos,
651 two_h: range_info.two_h >> 1,
652 leftmost_pos: range_info.leftmost_pos,
653 rightmost_pos: range_info.rightmost_pos,
654 },
655 elements,
656 sibling_digests,
657 collected_digests.as_deref_mut(),
658 )?;
659 right_digest = Some(digest);
660 }
661
662 if left_digest.is_none() {
663 match sibling_digests.next() {
664 Some(hash) => left_digest = Some(*hash),
665 None => return Err(ReconstructionError::MissingDigests),
666 }
667 }
668 if right_digest.is_none() {
669 match sibling_digests.next() {
670 Some(hash) => right_digest = Some(*hash),
671 None => return Err(ReconstructionError::MissingDigests),
672 }
673 }
674
675 if let Some(ref mut collected_digests) = collected_digests {
676 collected_digests.push((
677 left_pos,
678 left_digest.expect("left_digest guaranteed to be Some after checks above"),
679 ));
680 collected_digests.push((
681 right_pos,
682 right_digest.expect("right_digest guaranteed to be Some after checks above"),
683 ));
684 }
685
686 Ok(hasher.node_digest(
687 range_info.pos,
688 &left_digest.expect("left_digest guaranteed to be Some after checks above"),
689 &right_digest.expect("right_digest guaranteed to be Some after checks above"),
690 ))
691}
692
693#[cfg(test)]
694mod tests {
695 use super::*;
696 use crate::mmr::{
697 hasher::Standard, location::LocationRangeExt as _, mem::CleanMmr, MAX_LOCATION,
698 };
699 use bytes::Bytes;
700 use commonware_codec::{Decode, Encode};
701 use commonware_cryptography::{sha256::Digest, Hasher, Sha256};
702 use commonware_macros::test_traced;
703
704 fn test_digest(v: u8) -> Digest {
705 Sha256::hash(&[v])
706 }
707
708 #[test]
709 fn test_proving_proof() {
710 let mut hasher: Standard<Sha256> = Standard::new();
712 let mmr = CleanMmr::new(&mut hasher);
713 let root = mmr.root();
714 let proof = Proof::default();
715 assert!(proof.verify_range_inclusion(
716 &mut hasher,
717 &[] as &[Digest],
718 Location::new_unchecked(0),
719 root
720 ));
721
722 assert!(!proof.verify_range_inclusion(
724 &mut hasher,
725 &[] as &[Digest],
726 Location::new_unchecked(1),
727 root
728 ));
729
730 let test_digest = test_digest(0);
732 assert!(!proof.verify_range_inclusion(
733 &mut hasher,
734 &[] as &[Digest],
735 Location::new_unchecked(0),
736 &test_digest
737 ));
738
739 assert!(!proof.verify_range_inclusion(
741 &mut hasher,
742 &[test_digest],
743 Location::new_unchecked(0),
744 root
745 ));
746 }
747
748 #[test]
749 fn test_proving_verify_element() {
750 let element = Digest::from(*b"01234567012345670123456701234567");
752 let mut hasher: Standard<Sha256> = Standard::new();
753 let mut mmr = CleanMmr::new(&mut hasher);
754 for _ in 0..11 {
755 mmr.add(&mut hasher, &element);
756 }
757 let root = mmr.root();
758
759 for leaf in 0u64..11 {
761 let leaf = Location::new_unchecked(leaf);
762 let proof: Proof<Digest> = mmr.proof(leaf).unwrap();
763 assert!(
764 proof.verify_element_inclusion(&mut hasher, &element, leaf, root),
765 "valid proof should verify successfully"
766 );
767 }
768
769 const LEAF: Location = Location::new_unchecked(10);
772 let proof = mmr.proof(LEAF).unwrap();
773 assert!(
774 proof.verify_element_inclusion(&mut hasher, &element, LEAF, root),
775 "proof verification should be successful"
776 );
777 let wrong_sizes = [0, 16, 17, 18, 20, u64::MAX - 100];
778 for sz in wrong_sizes {
779 let mut wrong_size_proof = proof.clone();
780 wrong_size_proof.size = Position::new(sz);
781 assert!(
782 !wrong_size_proof.verify_element_inclusion(&mut hasher, &element, LEAF, root),
783 "proof with wrong size should fail verification"
784 );
785 }
786 assert!(
787 !proof.verify_element_inclusion(&mut hasher, &element, LEAF + 1, root),
788 "proof verification should fail with incorrect element position"
789 );
790 assert!(
791 !proof.verify_element_inclusion(&mut hasher, &element, LEAF - 1, root),
792 "proof verification should fail with incorrect element position 2"
793 );
794 assert!(
795 !proof.verify_element_inclusion(&mut hasher, &test_digest(0), LEAF, root),
796 "proof verification should fail with mangled element"
797 );
798 let root2 = test_digest(0);
799 assert!(
800 !proof.verify_element_inclusion(&mut hasher, &element, LEAF, &root2),
801 "proof verification should fail with mangled root"
802 );
803 let mut proof2 = proof.clone();
804 proof2.digests[0] = test_digest(0);
805 assert!(
806 !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, root),
807 "proof verification should fail with mangled proof hash"
808 );
809 proof2 = proof.clone();
810 proof2.size = Position::new(10);
811 assert!(
812 !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, root),
813 "proof verification should fail with incorrect size"
814 );
815 proof2 = proof.clone();
816 proof2.digests.push(test_digest(0));
817 assert!(
818 !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, root),
819 "proof verification should fail with extra hash"
820 );
821 proof2 = proof.clone();
822 while !proof2.digests.is_empty() {
823 proof2.digests.pop();
824 assert!(
825 !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, root),
826 "proof verification should fail with missing digests"
827 );
828 }
829 proof2 = proof.clone();
830 proof2.digests.clear();
831 const PEAK_COUNT: usize = 3;
832 proof2
833 .digests
834 .extend(proof.digests[0..PEAK_COUNT - 1].iter().cloned());
835 proof2.digests.push(test_digest(0));
838 proof2
839 .digests
840 .extend(proof.digests[PEAK_COUNT - 1..].iter().cloned());
841 assert!(
842 !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, root),
843 "proof verification should fail with extra hash even if it's unused by the computation"
844 );
845 }
846
847 #[test]
848 fn test_proving_verify_range() {
849 let mut hasher: Standard<Sha256> = Standard::new();
851 let mut mmr = CleanMmr::new(&mut hasher);
852 let mut elements = Vec::new();
853 for i in 0..49 {
854 elements.push(test_digest(i));
855 mmr.add(&mut hasher, elements.last().unwrap());
856 }
857 let root = mmr.root();
859
860 for i in 0..elements.len() {
861 for j in i + 1..elements.len() {
862 let range = Location::new_unchecked(i as u64)..Location::new_unchecked(j as u64);
863 let range_proof = mmr.range_proof(range.clone()).unwrap();
864 assert!(
865 range_proof.verify_range_inclusion(
866 &mut hasher,
867 &elements[range.to_usize_range()],
868 range.start,
869 root,
870 ),
871 "valid range proof should verify successfully {i}:{j}",
872 );
873 }
874 }
875
876 let range = Location::new_unchecked(33)..Location::new_unchecked(40);
879 let range_proof = mmr.range_proof(range.clone()).unwrap();
880 let valid_elements = &elements[range.to_usize_range()];
881 assert!(
882 range_proof.verify_range_inclusion(&mut hasher, valid_elements, range.start, root),
883 "valid range proof should verify successfully"
884 );
885 let mut invalid_proof = range_proof.clone();
888 for _i in 0..range_proof.digests.len() {
889 invalid_proof.digests.remove(0);
890 assert!(
891 !invalid_proof.verify_range_inclusion(
892 &mut hasher,
893 valid_elements,
894 range.start,
895 root,
896 ),
897 "range proof with removed elements should fail"
898 );
899 }
900 for i in 0..elements.len() {
903 for j in i + 1..elements.len() {
904 if Location::from(i) == range.start && Location::from(j) == range.end {
905 continue;
907 }
908 assert!(
909 !range_proof.verify_range_inclusion(
910 &mut hasher,
911 &elements[i..j],
912 range.start,
913 root,
914 ),
915 "range proof with invalid element range should fail {i}:{j}",
916 );
917 }
918 }
919 let invalid_root = test_digest(1);
921 assert!(
922 !range_proof.verify_range_inclusion(
923 &mut hasher,
924 valid_elements,
925 range.start,
926 &invalid_root,
927 ),
928 "range proof with invalid root should fail"
929 );
930 for i in 0..range_proof.digests.len() {
932 let mut invalid_proof = range_proof.clone();
933 invalid_proof.digests[i] = test_digest(0);
934
935 assert!(
936 !invalid_proof.verify_range_inclusion(
937 &mut hasher,
938 valid_elements,
939 range.start,
940 root,
941 ),
942 "mangled range proof should fail verification"
943 );
944 }
945 for i in 0..range_proof.digests.len() {
947 let mut invalid_proof = range_proof.clone();
948 invalid_proof.digests.insert(i, test_digest(0));
949 assert!(
950 !invalid_proof.verify_range_inclusion(
951 &mut hasher,
952 valid_elements,
953 range.start,
954 root,
955 ),
956 "mangled range proof should fail verification. inserted element at: {i}",
957 );
958 }
959 for loc in 0..elements.len() {
961 let loc = Location::new_unchecked(loc as u64);
962 if loc == range.start {
963 continue;
964 }
965 assert!(
966 !range_proof.verify_range_inclusion(&mut hasher, valid_elements, loc, root),
967 "bad start_loc should fail verification {loc}",
968 );
969 }
970 }
971
972 #[test_traced]
973 fn test_proving_retained_nodes_provable_after_pruning() {
974 let mut hasher: Standard<Sha256> = Standard::new();
976 let mut mmr = CleanMmr::new(&mut hasher);
977 let mut elements = Vec::new();
978 for i in 0..49 {
979 elements.push(test_digest(i));
980 mmr.add(&mut hasher, elements.last().unwrap());
981 }
982
983 let root = *mmr.root();
985 for i in 1..*mmr.size() {
986 mmr.prune_to_pos(Position::new(i));
987 let pruned_root = mmr.root();
988 assert_eq!(root, *pruned_root);
989 for loc in 0..elements.len() {
990 let loc = Location::new_unchecked(loc as u64);
991 let proof = mmr.proof(loc);
992 if Position::try_from(loc).unwrap() < Position::new(i) {
993 continue;
994 }
995 assert!(proof.is_ok());
996 assert!(proof.unwrap().verify_element_inclusion(
997 &mut hasher,
998 &elements[*loc as usize],
999 loc,
1000 &root
1001 ));
1002 }
1003 }
1004 }
1005
1006 #[test]
1007 fn test_proving_ranges_provable_after_pruning() {
1008 let mut hasher: Standard<Sha256> = Standard::new();
1010 let mut mmr = CleanMmr::new(&mut hasher);
1011 let mut elements = Vec::new();
1012 for i in 0..49 {
1013 elements.push(test_digest(i));
1014 mmr.add(&mut hasher, elements.last().unwrap());
1015 }
1016
1017 const PRUNE_POS: Position = Position::new(62);
1019 mmr.prune_to_pos(PRUNE_POS);
1020 assert_eq!(mmr.oldest_retained_pos().unwrap(), PRUNE_POS);
1021
1022 let root = mmr.root();
1024 for i in 0..elements.len() - 1 {
1025 if Position::try_from(Location::new_unchecked(i as u64)).unwrap() < PRUNE_POS {
1026 continue;
1027 }
1028 for j in (i + 2)..elements.len() {
1029 let range = Location::new_unchecked(i as u64)..Location::new_unchecked(j as u64);
1030 let range_proof = mmr.range_proof(range.clone()).unwrap();
1031 assert!(
1032 range_proof.verify_range_inclusion(
1033 &mut hasher,
1034 &elements[range.to_usize_range()],
1035 range.start,
1036 root,
1037 ),
1038 "valid range proof over remaining elements should verify successfully",
1039 );
1040 }
1041 }
1042
1043 for i in 0..37 {
1046 elements.push(test_digest(i));
1047 mmr.add(&mut hasher, elements.last().unwrap());
1048 }
1049 mmr.prune_to_pos(Position::new(130)); assert_eq!(mmr.oldest_retained_pos().unwrap(), 130);
1051
1052 let updated_root = mmr.root();
1053 let range = Location::new_unchecked(elements.len() as u64 - 10)
1054 ..Location::new_unchecked(elements.len() as u64);
1055 let range_proof = mmr.range_proof(range.clone()).unwrap();
1056 assert!(
1057 range_proof.verify_range_inclusion(
1058 &mut hasher,
1059 &elements[range.to_usize_range()],
1060 range.start,
1061 updated_root,
1062 ),
1063 "valid range proof over remaining elements after 2 pruning rounds should verify successfully",
1064 );
1065 }
1066
1067 #[test]
1068 fn test_proving_proof_serialization() {
1069 let mut hasher: Standard<Sha256> = Standard::new();
1071 let mut mmr = CleanMmr::new(&mut hasher);
1072 let mut elements = Vec::new();
1073 for i in 0..25 {
1074 elements.push(test_digest(i));
1075 mmr.add(&mut hasher, elements.last().unwrap());
1076 }
1077
1078 for i in 0..elements.len() {
1081 for j in i + 1..elements.len() {
1082 let range = Location::new_unchecked(i as u64)..Location::new_unchecked(j as u64);
1083 let proof = mmr.range_proof(range).unwrap();
1084
1085 let expected_size = proof.encode_size();
1086 let serialized_proof = proof.encode().freeze();
1087 assert_eq!(
1088 serialized_proof.len(),
1089 expected_size,
1090 "serialized proof should have expected size"
1091 );
1092 let max_digests = proof.digests.len();
1093 let deserialized_proof = Proof::decode_cfg(serialized_proof, &max_digests).unwrap();
1094 assert_eq!(
1095 proof, deserialized_proof,
1096 "deserialized proof should match source proof"
1097 );
1098
1099 let serialized_proof = proof.encode().freeze();
1102 let serialized_proof: Bytes = serialized_proof.slice(0..serialized_proof.len() - 1);
1103 assert!(
1104 Proof::<Digest>::decode_cfg(serialized_proof, &max_digests).is_err(),
1105 "proof should not deserialize with truncated data"
1106 );
1107
1108 let mut serialized_proof = proof.encode();
1111 serialized_proof.extend_from_slice(&[0; 10]);
1112 let serialized_proof = serialized_proof.freeze();
1113
1114 assert!(
1115 Proof::<Digest>::decode_cfg(serialized_proof, &max_digests).is_err(),
1116 "proof should not deserialize with extra data"
1117 );
1118
1119 if max_digests > 0 {
1121 let serialized_proof = proof.encode().freeze();
1122 assert!(
1123 Proof::<Digest>::decode_cfg(serialized_proof, &(max_digests - 1)).is_err(),
1124 "proof should not deserialize with max length exceeded"
1125 );
1126 }
1127 }
1128 }
1129 }
1130
1131 #[test_traced]
1132 fn test_proving_extract_pinned_nodes() {
1133 for num_elements in 1u64..255 {
1135 let mut hasher: Standard<Sha256> = Standard::new();
1137 let mut mmr = CleanMmr::new(&mut hasher);
1138
1139 for i in 0..num_elements {
1140 let digest = test_digest(i as u8);
1141 mmr.add(&mut hasher, &digest);
1142 }
1143
1144 for leaf in 0..num_elements {
1146 let test_end_locs = if num_elements == 1 {
1148 vec![leaf + 1]
1150 } else {
1151 let mut ends = vec![leaf + 1]; if leaf + 2 <= num_elements {
1156 ends.push(leaf + 2);
1157 }
1158 if leaf + 3 <= num_elements {
1159 ends.push(leaf + 3);
1160 }
1161 if ends.last().unwrap() != &num_elements {
1163 ends.push(num_elements);
1164 }
1165
1166 ends.into_iter()
1167 .collect::<BTreeSet<_>>()
1168 .into_iter()
1169 .collect()
1170 };
1171
1172 for end_loc in test_end_locs {
1173 let range = Location::new_unchecked(leaf)..Location::new_unchecked(end_loc);
1175 let proof_result = mmr.range_proof(range.clone());
1176 let proof = proof_result.unwrap();
1177
1178 let extract_result = proof.extract_pinned_nodes(range.clone());
1180 assert!(
1181 extract_result.is_ok(),
1182 "Failed to extract pinned nodes for {num_elements} elements, boundary={leaf}, range={}..{}", range.start, range.end
1183 );
1184
1185 let pinned_nodes = extract_result.unwrap();
1186 let leaf_loc = Location::new_unchecked(leaf);
1187 let leaf_pos = Position::try_from(leaf_loc).unwrap();
1188 let expected_pinned: Vec<Position> = nodes_to_pin(leaf_pos).collect();
1189
1190 assert_eq!(
1192 pinned_nodes.len(),
1193 expected_pinned.len(),
1194 "Pinned node count mismatch for {num_elements} elements, boundary={leaf}, range=[{leaf}, {end_loc}]"
1195 );
1196
1197 for (i, &expected_pos) in expected_pinned.iter().enumerate() {
1200 let extracted_hash = pinned_nodes[i];
1201 let actual_hash = mmr.get_node(expected_pos).unwrap();
1202 assert_eq!(
1203 extracted_hash, actual_hash,
1204 "Hash mismatch at position {expected_pos} (index {i}) for {num_elements} elements, boundary={leaf}, range=[{leaf}, {end_loc}]"
1205 );
1206 }
1207 }
1208 }
1209 }
1210 }
1211
1212 #[test]
1213 fn test_proving_extract_pinned_nodes_invalid_size() {
1214 let mut hasher: Standard<Sha256> = Standard::new();
1216 let mut mmr = CleanMmr::new(&mut hasher);
1217
1218 for i in 0..10 {
1220 let digest = test_digest(i);
1221 mmr.add(&mut hasher, &digest);
1222 }
1223
1224 let range = Location::new_unchecked(5)..Location::new_unchecked(8);
1226 let mut proof = mmr.range_proof(range.clone()).unwrap();
1227
1228 assert!(proof.extract_pinned_nodes(range.clone()).is_ok());
1230
1231 const INVALID_SIZES: [u64; 5] = [2, 5, 6, 9, u64::MAX];
1233 for invalid_size in INVALID_SIZES {
1234 proof.size = Position::new(invalid_size);
1235 let result = proof.extract_pinned_nodes(range.clone());
1236 assert!(matches!(result, Err(Error::InvalidSize(s)) if s == invalid_size));
1237 }
1238
1239 proof.size = Position::new(1u64 << 63);
1241 let result = proof.extract_pinned_nodes(range);
1242 assert!(matches!(result, Err(Error::InvalidSize(s)) if s == 1u64 << 63));
1243 }
1244
1245 #[test]
1246 fn test_proving_digests_from_range() {
1247 let mut hasher: Standard<Sha256> = Standard::new();
1249 let mut mmr = CleanMmr::new(&mut hasher);
1250 let mut elements = Vec::new();
1251 let mut element_positions = Vec::new();
1252 for i in 0..49 {
1253 elements.push(test_digest(i));
1254 element_positions.push(mmr.add(&mut hasher, elements.last().unwrap()));
1255 }
1256 let root = mmr.root();
1257
1258 let proof = mmr
1261 .range_proof(Location::new_unchecked(0)..mmr.leaves())
1262 .unwrap();
1263 let mut node_digests = proof
1264 .verify_range_inclusion_and_extract_digests(
1265 &mut hasher,
1266 &elements,
1267 Location::new_unchecked(0),
1268 root,
1269 )
1270 .unwrap();
1271 assert_eq!(node_digests.len() as u64, mmr.size());
1272 node_digests.sort_by_key(|(pos, _)| *pos);
1273 for (i, (pos, d)) in node_digests.into_iter().enumerate() {
1274 assert_eq!(pos, i as u64);
1275 assert_eq!(mmr.get_node(pos).unwrap(), d);
1276 }
1277 let wrong_root = elements[0]; assert!(matches!(
1280 proof.verify_range_inclusion_and_extract_digests(
1281 &mut hasher,
1282 &elements,
1283 Location::new_unchecked(0),
1284 &wrong_root
1285 ),
1286 Err(Error::RootMismatch)
1287 ));
1288
1289 let range = Location::new_unchecked(0)..Location::new_unchecked(1);
1291 let single_proof = mmr.range_proof(range.clone()).unwrap();
1292 let range_start = range.start;
1293 let single_digests = single_proof
1294 .verify_range_inclusion_and_extract_digests(
1295 &mut hasher,
1296 &elements[range.to_usize_range()],
1297 range_start,
1298 root,
1299 )
1300 .unwrap();
1301 assert!(single_digests.len() > 1);
1302
1303 let mid_idx = 24;
1305 let range = Location::new_unchecked(mid_idx)..Location::new_unchecked(mid_idx + 1);
1306 let range_start = range.start;
1307 let mid_proof = mmr.range_proof(range.clone()).unwrap();
1308 let mid_digests = mid_proof
1309 .verify_range_inclusion_and_extract_digests(
1310 &mut hasher,
1311 &elements[range.to_usize_range()],
1312 range_start,
1313 root,
1314 )
1315 .unwrap();
1316 assert!(mid_digests.len() > 1);
1317
1318 let last_idx = elements.len() as u64 - 1;
1320 let range = Location::new_unchecked(last_idx)..Location::new_unchecked(last_idx + 1);
1321 let range_start = range.start;
1322 let last_proof = mmr.range_proof(range.clone()).unwrap();
1323 let last_digests = last_proof
1324 .verify_range_inclusion_and_extract_digests(
1325 &mut hasher,
1326 &elements[range.to_usize_range()],
1327 range_start,
1328 root,
1329 )
1330 .unwrap();
1331 assert!(last_digests.len() > 1);
1332
1333 let range = Location::new_unchecked(0)..Location::new_unchecked(5);
1335 let range_start = range.start;
1336 let small_proof = mmr.range_proof(range.clone()).unwrap();
1337 let small_digests = small_proof
1338 .verify_range_inclusion_and_extract_digests(
1339 &mut hasher,
1340 &elements[range.to_usize_range()],
1341 range_start,
1342 root,
1343 )
1344 .unwrap();
1345 assert!(small_digests.len() > 5);
1347
1348 let range = Location::new_unchecked(10)..Location::new_unchecked(31);
1350 let range_start = range.start;
1351 let mid_range_proof = mmr.range_proof(range.clone()).unwrap();
1352 let mid_range_digests = mid_range_proof
1353 .verify_range_inclusion_and_extract_digests(
1354 &mut hasher,
1355 &elements[range.to_usize_range()],
1356 range_start,
1357 root,
1358 )
1359 .unwrap();
1360 let num_elements = range.end - range.start;
1361 assert!(mid_range_digests.len() as u64 > num_elements);
1362 }
1363
1364 #[test]
1365 fn test_proving_multi_proof_generation_and_verify() {
1366 let mut hasher: Standard<Sha256> = Standard::new();
1368 let mut mmr = CleanMmr::new(&mut hasher);
1369 let mut elements = Vec::new();
1370
1371 for i in 0..20 {
1372 elements.push(test_digest(i));
1373 mmr.add(&mut hasher, &elements[i as usize]);
1374 }
1375
1376 let root = mmr.root();
1377
1378 let locations = &[
1380 Location::new_unchecked(0),
1381 Location::new_unchecked(5),
1382 Location::new_unchecked(10),
1383 ];
1384 let nodes_for_multi_proof =
1385 nodes_required_for_multi_proof(mmr.size(), locations).expect("test locations valid");
1386 let digests = nodes_for_multi_proof
1387 .into_iter()
1388 .map(|pos| mmr.get_node(pos).unwrap())
1389 .collect();
1390 let multi_proof = Proof {
1391 size: mmr.size(),
1392 digests,
1393 };
1394
1395 assert_eq!(multi_proof.size, mmr.size());
1396
1397 assert!(multi_proof.verify_multi_inclusion(
1399 &mut hasher,
1400 &[
1401 (elements[0], Location::new_unchecked(0)),
1402 (elements[5], Location::new_unchecked(5)),
1403 (elements[10], Location::new_unchecked(10)),
1404 ],
1405 root
1406 ));
1407
1408 assert!(multi_proof.verify_multi_inclusion(
1410 &mut hasher,
1411 &[
1412 (elements[10], Location::new_unchecked(10)),
1413 (elements[5], Location::new_unchecked(5)),
1414 (elements[0], Location::new_unchecked(0)),
1415 ],
1416 root
1417 ));
1418
1419 assert!(multi_proof.verify_multi_inclusion(
1421 &mut hasher,
1422 &[
1423 (elements[0], Location::new_unchecked(0)),
1424 (elements[0], Location::new_unchecked(0)),
1425 (elements[10], Location::new_unchecked(10)),
1426 (elements[5], Location::new_unchecked(5)),
1427 ],
1428 root
1429 ));
1430
1431 let wrong_sizes = [
1434 Position::new(16),
1435 Position::new(17),
1436 Position::new(u64::MAX - 100),
1437 ];
1438 for sz in wrong_sizes {
1439 let mut wrong_size_proof = multi_proof.clone();
1440 wrong_size_proof.size = sz;
1441 assert!(!wrong_size_proof.verify_multi_inclusion(
1442 &mut hasher,
1443 &[
1444 (elements[0], Location::new_unchecked(0)),
1445 (elements[5], Location::new_unchecked(5)),
1446 (elements[10], Location::new_unchecked(10)),
1447 ],
1448 root,
1449 ));
1450 }
1451
1452 assert!(!multi_proof.verify_multi_inclusion(
1454 &mut hasher,
1455 &[
1456 (elements[0], Location::new_unchecked(1)),
1457 (elements[5], Location::new_unchecked(6)),
1458 (elements[10], Location::new_unchecked(11)),
1459 ],
1460 root,
1461 ));
1462
1463 let wrong_elements = [
1465 vec![255u8, 254u8, 253u8],
1466 vec![252u8, 251u8, 250u8],
1467 vec![249u8, 248u8, 247u8],
1468 ];
1469 let wrong_verification = multi_proof.verify_multi_inclusion(
1470 &mut hasher,
1471 &[
1472 (wrong_elements[0].as_slice(), Location::new_unchecked(0)),
1473 (wrong_elements[1].as_slice(), Location::new_unchecked(5)),
1474 (wrong_elements[2].as_slice(), Location::new_unchecked(10)),
1475 ],
1476 root,
1477 );
1478 assert!(!wrong_verification, "Should fail with wrong elements");
1479
1480 let wrong_verification = multi_proof.verify_multi_inclusion(
1482 &mut hasher,
1483 &[
1484 (elements[0], Location::new_unchecked(0)),
1485 (elements[5], Location::new_unchecked(5)),
1486 (elements[10], Location::new_unchecked(1000)),
1487 ],
1488 root,
1489 );
1490 assert!(
1491 !wrong_verification,
1492 "Should fail with out of range elements"
1493 );
1494
1495 let wrong_root = test_digest(99);
1497 assert!(!multi_proof.verify_multi_inclusion(
1498 &mut hasher,
1499 &[
1500 (elements[0], Location::new_unchecked(0)),
1501 (elements[5], Location::new_unchecked(5)),
1502 (elements[10], Location::new_unchecked(10)),
1503 ],
1504 &wrong_root
1505 ));
1506
1507 let mut hasher: Standard<Sha256> = Standard::new();
1509 let empty_mmr = CleanMmr::new(&mut hasher);
1510 let empty_root = empty_mmr.root();
1511 let empty_proof = Proof::default();
1512 assert!(empty_proof.verify_multi_inclusion(
1513 &mut hasher,
1514 &[] as &[(Digest, Location)],
1515 empty_root
1516 ));
1517 }
1518
1519 #[test]
1520 fn test_proving_multi_proof_deduplication() {
1521 let mut hasher: Standard<Sha256> = Standard::new();
1522 let mut mmr = CleanMmr::new(&mut hasher);
1523 let mut elements = Vec::new();
1524
1525 for i in 0..30 {
1527 elements.push(test_digest(i));
1528 mmr.add(&mut hasher, &elements[i as usize]);
1529 }
1530
1531 let proof1 = mmr.proof(Location::new_unchecked(0)).unwrap();
1533 let proof2 = mmr.proof(Location::new_unchecked(1)).unwrap();
1534 let total_digests_separate = proof1.digests.len() + proof2.digests.len();
1535
1536 let locations = &[Location::new_unchecked(0), Location::new_unchecked(1)];
1538 let multi_proof =
1539 nodes_required_for_multi_proof(mmr.size(), locations).expect("test locations valid");
1540 let digests = multi_proof
1541 .into_iter()
1542 .map(|pos| mmr.get_node(pos).unwrap())
1543 .collect();
1544 let multi_proof = Proof {
1545 size: mmr.size(),
1546 digests,
1547 };
1548
1549 assert!(multi_proof.digests.len() < total_digests_separate);
1551
1552 let root = mmr.root();
1554 assert!(multi_proof.verify_multi_inclusion(
1555 &mut hasher,
1556 &[
1557 (elements[0], Location::new_unchecked(0)),
1558 (elements[1], Location::new_unchecked(1))
1559 ],
1560 root
1561 ));
1562 }
1563
1564 #[test]
1565 fn test_max_location_is_provable() {
1566 let max_mmr_size = Position::new((1u64 << 63) - 1);
1569
1570 let max_loc = Location::new_unchecked(MAX_LOCATION);
1571 let max_loc_plus_1 = Location::new_unchecked(MAX_LOCATION + 1);
1572 let max_loc_plus_2 = Location::new_unchecked(MAX_LOCATION + 2);
1573
1574 let result = nodes_required_for_range_proof(max_mmr_size, max_loc..max_loc_plus_1);
1577
1578 assert!(result.is_ok(), "Should be able to prove MAX_LOCATION");
1580
1581 let result_overflow =
1583 nodes_required_for_range_proof(max_mmr_size, max_loc_plus_1..max_loc_plus_2);
1584 assert!(
1585 result_overflow.is_err(),
1586 "Should reject location > MAX_LOCATION"
1587 );
1588 matches!(result_overflow, Err(Error::LocationOverflow(_)));
1589 }
1590
1591 #[test]
1592 fn test_max_location_multi_proof() {
1593 let max_mmr_size = Position::new((1u64 << 63) - 1);
1595 let max_loc = Location::new_unchecked(MAX_LOCATION);
1596
1597 let result = nodes_required_for_multi_proof(max_mmr_size, &[max_loc]);
1599 assert!(
1600 result.is_ok(),
1601 "Should be able to generate multi-proof for MAX_LOCATION"
1602 );
1603
1604 let invalid_loc = Location::new_unchecked(MAX_LOCATION + 1);
1606 let result_overflow = nodes_required_for_multi_proof(max_mmr_size, &[invalid_loc]);
1607 assert!(
1608 result_overflow.is_err(),
1609 "Should reject location > MAX_LOCATION in multi-proof"
1610 );
1611 }
1612
1613 #[test]
1614 fn test_invalid_size_validation() {
1615 let loc = Location::new_unchecked(0);
1617 let range = loc..loc + 1;
1618
1619 const INVALID_SIZES: [u64; 5] = [2, 5, 6, 9, u64::MAX];
1620 for size in INVALID_SIZES {
1621 let result = nodes_required_for_range_proof(Position::new(size), range.clone());
1622 assert!(matches!(result, Err(Error::InvalidSize(s)) if s == size));
1623
1624 let result = nodes_required_for_multi_proof(Position::new(size), &[loc]);
1625 assert!(matches!(result, Err(Error::InvalidSize(s)) if s == size));
1626 }
1627 }
1628
1629 #[cfg(feature = "arbitrary")]
1630 mod conformance {
1631 use super::*;
1632 use commonware_codec::conformance::CodecConformance;
1633 use commonware_cryptography::sha256::Digest as Sha256Digest;
1634
1635 commonware_conformance::conformance_tests! {
1636 CodecConformance<Proof<Sha256Digest>>,
1637 }
1638 }
1639}