1#[cfg(any(feature = "std", test))]
8use crate::mmr::iterator::nodes_to_pin;
9use crate::mmr::{
10 hasher::Hasher,
11 iterator::{PathIterator, PeakIterator},
12 Error, Location, Position,
13};
14use alloc::{
15 collections::{btree_map::BTreeMap, btree_set::BTreeSet},
16 vec,
17 vec::Vec,
18};
19use bytes::{Buf, BufMut};
20use commonware_codec::{varint::UInt, EncodeSize, Read, ReadExt, ReadRangeExt, Write};
21use commonware_cryptography::{Digest, Hasher as CHasher};
22use core::ops::Range;
23#[cfg(feature = "std")]
24use tracing::debug;
25
26#[derive(Error, Debug)]
28pub enum ReconstructionError {
29 #[error("missing digests in proof")]
30 MissingDigests,
31 #[error("extra digests in proof")]
32 ExtraDigests,
33 #[error("start location is out of bounds")]
34 InvalidStartLoc,
35 #[error("end location is out of bounds")]
36 InvalidEndLoc,
37 #[error("missing elements")]
38 MissingElements,
39 #[error("invalid size")]
40 InvalidSize,
41}
42
43#[derive(Clone, Debug, Eq)]
54pub struct Proof<D: Digest> {
55 pub size: Position,
60 pub digests: Vec<D>,
63}
64
65impl<D: Digest> PartialEq for Proof<D> {
66 fn eq(&self, other: &Self) -> bool {
67 self.size == other.size && self.digests == other.digests
68 }
69}
70
71impl<D: Digest> EncodeSize for Proof<D> {
72 fn encode_size(&self) -> usize {
73 UInt(*self.size).encode_size() + self.digests.encode_size()
74 }
75}
76
77impl<D: Digest> Write for Proof<D> {
78 fn write(&self, buf: &mut impl BufMut) {
79 UInt(*self.size).write(buf);
81
82 self.digests.write(buf);
84 }
85}
86
87impl<D: Digest> Read for Proof<D> {
88 type Cfg = usize;
90
91 fn read_cfg(buf: &mut impl Buf, max_len: &Self::Cfg) -> Result<Self, commonware_codec::Error> {
92 let size = Position::new(UInt::<u64>::read(buf)?.into());
94
95 let range = ..=max_len;
97 let digests = Vec::<D>::read_range(buf, range)?;
98 Ok(Proof { size, digests })
99 }
100}
101
102impl<D: Digest> Default for Proof<D> {
103 fn default() -> Self {
106 Self {
107 size: Position::new(0),
108 digests: vec![],
109 }
110 }
111}
112
113impl<D: Digest> Proof<D> {
114 pub fn verify_element_inclusion<I, H>(
117 &self,
118 hasher: &mut H,
119 element: &[u8],
120 loc: Location,
121 root: &D,
122 ) -> bool
123 where
124 I: CHasher<Digest = D>,
125 H: Hasher<I>,
126 {
127 self.verify_range_inclusion(hasher, &[element], loc, root)
128 }
129
130 pub fn verify_range_inclusion<I, H, E>(
134 &self,
135 hasher: &mut H,
136 elements: &[E],
137 start_loc: Location,
138 root: &D,
139 ) -> bool
140 where
141 I: CHasher<Digest = D>,
142 H: Hasher<I>,
143 E: AsRef<[u8]>,
144 {
145 if !self.size.is_mmr_size() {
146 #[cfg(feature = "std")]
147 debug!(size = ?self.size, "invalid proof size");
148 return false;
149 }
150
151 match self.reconstruct_root(hasher, elements, start_loc) {
152 Ok(reconstructed_root) => *root == reconstructed_root,
153 Err(_error) => {
154 #[cfg(feature = "std")]
155 tracing::debug!(error = ?_error, "invalid proof input");
156 false
157 }
158 }
159 }
160
161 pub fn verify_multi_inclusion<I, H, E>(
166 &self,
167 hasher: &mut H,
168 elements: &[(E, Location)],
169 root: &D,
170 ) -> bool
171 where
172 I: CHasher<Digest = D>,
173 H: Hasher<I>,
174 E: AsRef<[u8]>,
175 {
176 if elements.is_empty() {
178 return self.size == Position::new(0)
179 && *root == hasher.root(Position::new(0), core::iter::empty());
180 }
181 if !self.size.is_mmr_size() {
182 return false;
183 }
184
185 let mut node_positions = BTreeSet::new();
187 let mut nodes_required = BTreeMap::new();
188
189 for (_, loc) in elements {
190 if !loc.is_valid() {
191 return false;
192 }
193 let Ok(required) = nodes_required_for_range_proof(self.size, *loc..*loc + 1) else {
195 return false;
196 };
197 for req_pos in &required {
198 node_positions.insert(*req_pos);
199 }
200 nodes_required.insert(*loc, required);
201 }
202
203 if node_positions.len() != self.digests.len() {
205 return false;
206 }
207
208 let node_digests: BTreeMap<Position, D> = node_positions
210 .iter()
211 .zip(self.digests.iter())
212 .map(|(&pos, digest)| (pos, *digest))
213 .collect();
214
215 for (element, loc) in elements {
217 let required = &nodes_required[loc];
219
220 let mut digests = Vec::with_capacity(required.len());
222 for req_pos in required {
223 let digest = node_digests
226 .get(req_pos)
227 .expect("must exist by construction of node_digests");
228 digests.push(*digest);
229 }
230 let proof = Proof {
231 size: self.size,
232 digests,
233 };
234
235 if !proof.verify_element_inclusion(hasher, element.as_ref(), *loc, root) {
237 return false;
238 }
239 }
240
241 true
242 }
243
244 #[cfg(any(feature = "std", test))]
266 pub(crate) fn extract_pinned_nodes(
267 &self,
268 range: std::ops::Range<Location>,
269 ) -> Result<Vec<D>, Error> {
270 let start_pos = Position::try_from(range.start)?;
272 let pinned_positions: Vec<Position> = nodes_to_pin(start_pos).collect();
273
274 let required_positions = nodes_required_for_range_proof(self.size, range)?;
276
277 if required_positions.len() != self.digests.len() {
278 #[cfg(feature = "std")]
279 debug!(
280 digests_len = self.digests.len(),
281 required_positions_len = required_positions.len(),
282 "Proof digest count doesn't match required positions",
283 );
284 return Err(Error::InvalidProofLength);
285 }
286
287 if pinned_positions
290 == required_positions[required_positions.len() - pinned_positions.len()..]
291 {
292 return Ok(self.digests[required_positions.len() - pinned_positions.len()..].to_vec());
293 }
294
295 let position_to_digest: BTreeMap<Position, D> = required_positions
297 .iter()
298 .zip(self.digests.iter())
299 .map(|(&pos, &digest)| (pos, digest))
300 .collect();
301
302 let mut result = Vec::with_capacity(pinned_positions.len());
304 for pinned_pos in pinned_positions {
305 let Some(&digest) = position_to_digest.get(&pinned_pos) else {
306 #[cfg(feature = "std")]
307 debug!(?pinned_pos, "Pinned node not found in proof");
308 return Err(Error::MissingDigest(pinned_pos));
309 };
310 result.push(digest);
311 }
312 Ok(result)
313 }
314
315 pub fn verify_range_inclusion_and_extract_digests<I, H, E>(
321 &self,
322 hasher: &mut H,
323 elements: &[E],
324 start_loc: Location,
325 root: &I::Digest,
326 ) -> Result<Vec<(Position, D)>, super::Error>
327 where
328 I: CHasher<Digest = D>,
329 H: Hasher<I>,
330 E: AsRef<[u8]>,
331 {
332 let mut collected_digests = Vec::new();
333 let Ok(peak_digests) = self.reconstruct_peak_digests(
334 hasher,
335 elements,
336 start_loc,
337 Some(&mut collected_digests),
338 ) else {
339 return Err(Error::InvalidProof);
340 };
341
342 if hasher.root(self.size, peak_digests.iter()) != *root {
343 return Err(Error::RootMismatch);
344 }
345
346 Ok(collected_digests)
347 }
348
349 pub fn reconstruct_root<I, H, E>(
352 &self,
353 hasher: &mut H,
354 elements: &[E],
355 start_loc: Location,
356 ) -> Result<I::Digest, ReconstructionError>
357 where
358 I: CHasher<Digest = D>,
359 H: Hasher<I>,
360 E: AsRef<[u8]>,
361 {
362 if !self.size.is_mmr_size() {
363 return Err(ReconstructionError::InvalidSize);
364 }
365
366 let peak_digests = self.reconstruct_peak_digests(hasher, elements, start_loc, None)?;
367
368 Ok(hasher.root(self.size, peak_digests.iter()))
369 }
370
371 pub fn reconstruct_peak_digests<I, H, E>(
375 &self,
376 hasher: &mut H,
377 elements: &[E],
378 start_loc: Location,
379 mut collected_digests: Option<&mut Vec<(Position, I::Digest)>>,
380 ) -> Result<Vec<D>, ReconstructionError>
381 where
382 I: CHasher<Digest = D>,
383 H: Hasher<I>,
384 E: AsRef<[u8]>,
385 {
386 if elements.is_empty() {
387 if start_loc == 0 {
388 return Ok(vec![]);
389 }
390 return Err(ReconstructionError::MissingElements);
391 }
392 let start_element_pos =
393 Position::try_from(start_loc).map_err(|_| ReconstructionError::InvalidStartLoc)?;
394 let end_element_pos = if elements.len() == 1 {
395 start_element_pos
396 } else {
397 let end_loc = start_loc
398 .checked_add(elements.len() as u64 - 1)
399 .ok_or(ReconstructionError::InvalidEndLoc)?;
400 Position::try_from(end_loc).map_err(|_| ReconstructionError::InvalidEndLoc)?
401 };
402 if end_element_pos >= self.size {
403 return Err(ReconstructionError::InvalidEndLoc);
404 }
405
406 let mut proof_digests_iter = self.digests.iter();
407 let mut siblings_iter = self.digests.iter().rev();
408
409 let mut peak_digests: Vec<D> = Vec::new();
412 let mut proof_digests_used = 0;
413 let mut elements_iter = elements.iter();
414 if !Position::is_mmr_size(self.size) {
415 return Err(ReconstructionError::InvalidSize);
416 }
417 for (peak_pos, height) in PeakIterator::new(self.size) {
418 let leftmost_pos = peak_pos + 2 - (1 << (height + 1));
419 if peak_pos >= start_element_pos && leftmost_pos <= end_element_pos {
420 let hash = peak_digest_from_range(
421 hasher,
422 RangeInfo {
423 pos: peak_pos,
424 two_h: 1 << height,
425 leftmost_pos: start_element_pos,
426 rightmost_pos: end_element_pos,
427 },
428 &mut elements_iter,
429 &mut siblings_iter,
430 collected_digests.as_deref_mut(),
431 )?;
432 peak_digests.push(hash);
433 if let Some(ref mut collected_digests) = collected_digests {
434 collected_digests.push((peak_pos, hash));
435 }
436 } else if let Some(hash) = proof_digests_iter.next() {
437 proof_digests_used += 1;
438 peak_digests.push(*hash);
439 if let Some(ref mut collected_digests) = collected_digests {
440 collected_digests.push((peak_pos, *hash));
441 }
442 } else {
443 return Err(ReconstructionError::MissingDigests);
444 }
445 }
446
447 if elements_iter.next().is_some() {
448 return Err(ReconstructionError::ExtraDigests);
449 }
450 if let Some(next_sibling) = siblings_iter.next() {
451 if proof_digests_used == 0 || *next_sibling != self.digests[proof_digests_used - 1] {
452 return Err(ReconstructionError::ExtraDigests);
453 }
454 }
455
456 Ok(peak_digests)
457 }
458}
459
460pub(crate) fn nodes_required_for_range_proof(
471 size: Position,
472 range: Range<Location>,
473) -> Result<Vec<Position>, Error> {
474 if !size.is_mmr_size() {
475 return Err(Error::InvalidSize(*size));
476 }
477 if range.is_empty() {
478 return Err(Error::Empty);
479 }
480 let start_element_pos = Position::try_from(range.start)?;
481 let end_minus_one = range
482 .end
483 .checked_sub(1)
484 .expect("can't underflow because range is non-empty");
485 let end_element_pos = Position::try_from(end_minus_one)?;
486 if end_element_pos >= size {
487 return Err(Error::RangeOutOfBounds(range.end));
488 }
489
490 let mut start_tree_with_element: Option<(Position, u32)> = None;
493 let mut end_tree_with_element: Option<(Position, u32)> = None;
494 let mut peak_iterator = PeakIterator::new(size);
495 let mut positions = Vec::new();
496 while let Some(peak) = peak_iterator.next() {
497 if start_tree_with_element.is_none() && peak.0 >= start_element_pos {
498 start_tree_with_element = Some(peak);
500 if peak.0 >= end_element_pos {
501 end_tree_with_element = Some(peak);
503 continue;
504 }
505 for peak in peak_iterator.by_ref() {
506 if peak.0 >= end_element_pos {
507 end_tree_with_element = Some(peak);
509 break;
510 }
511 }
512 } else {
513 positions.push(peak.0);
515 }
516 }
517
518 let (start_tree_peak, start_tree_height) =
521 start_tree_with_element.expect("start_tree_with_element is Some");
522 let (end_tree_peak, end_tree_height) =
523 end_tree_with_element.expect("end_tree_with_element is Some");
524
525 let left_path_iter = PathIterator::new(start_element_pos, start_tree_peak, start_tree_height);
529
530 let mut siblings = Vec::new();
531 if start_element_pos == end_element_pos {
532 siblings.extend(left_path_iter);
535 } else {
536 let right_path_iter = PathIterator::new(end_element_pos, end_tree_peak, end_tree_height);
537 siblings.extend(right_path_iter.filter(|(parent_pos, pos)| *parent_pos == *pos + 1));
539 siblings.extend(left_path_iter.filter(|(parent_pos, pos)| *parent_pos != *pos + 1));
541
542 if start_tree_peak == end_tree_peak {
545 siblings.sort_by(|a, b| b.0.cmp(&a.0));
546 }
547 }
548 positions.extend(siblings.into_iter().map(|(_, pos)| pos));
549 Ok(positions)
550}
551
552#[cfg(any(feature = "std", test))]
564pub(crate) fn nodes_required_for_multi_proof(
565 size: Position,
566 locations: &[Location],
567) -> Result<BTreeSet<Position>, Error> {
568 if !size.is_mmr_size() {
569 return Err(Error::InvalidSize(*size));
570 }
571 if locations.is_empty() {
575 return Err(Error::Empty);
576 }
577 locations.iter().try_fold(BTreeSet::new(), |mut acc, loc| {
578 if !loc.is_valid() {
579 return Err(Error::LocationOverflow(*loc));
580 }
581 let positions = nodes_required_for_range_proof(size, *loc..*loc + 1)?;
583 acc.extend(positions);
584 Ok(acc)
585 })
586}
587
588struct RangeInfo {
590 pos: Position, two_h: u64, leftmost_pos: Position, rightmost_pos: Position, }
595
596fn peak_digest_from_range<'a, I, H, E, S>(
597 hasher: &mut H,
598 range_info: RangeInfo,
599 elements: &mut E,
600 sibling_digests: &mut S,
601 mut collected_digests: Option<&mut Vec<(Position, I::Digest)>>,
602) -> Result<I::Digest, ReconstructionError>
603where
604 I: CHasher,
605 H: Hasher<I>,
606 E: Iterator<Item: AsRef<[u8]>>,
607 S: Iterator<Item = &'a I::Digest>,
608{
609 assert_ne!(range_info.two_h, 0);
610 if range_info.two_h == 1 {
611 match elements.next() {
612 Some(element) => return Ok(hasher.leaf_digest(range_info.pos, element.as_ref())),
613 None => return Err(ReconstructionError::MissingDigests),
614 }
615 }
616
617 let mut left_digest: Option<I::Digest> = None;
618 let mut right_digest: Option<I::Digest> = None;
619
620 let left_pos = range_info.pos - range_info.two_h;
621 let right_pos = left_pos + range_info.two_h - 1;
622 if left_pos >= range_info.leftmost_pos {
623 let digest = peak_digest_from_range(
625 hasher,
626 RangeInfo {
627 pos: left_pos,
628 two_h: range_info.two_h >> 1,
629 leftmost_pos: range_info.leftmost_pos,
630 rightmost_pos: range_info.rightmost_pos,
631 },
632 elements,
633 sibling_digests,
634 collected_digests.as_deref_mut(),
635 )?;
636 left_digest = Some(digest);
637 }
638 if left_pos < range_info.rightmost_pos {
639 let digest = peak_digest_from_range(
641 hasher,
642 RangeInfo {
643 pos: right_pos,
644 two_h: range_info.two_h >> 1,
645 leftmost_pos: range_info.leftmost_pos,
646 rightmost_pos: range_info.rightmost_pos,
647 },
648 elements,
649 sibling_digests,
650 collected_digests.as_deref_mut(),
651 )?;
652 right_digest = Some(digest);
653 }
654
655 if left_digest.is_none() {
656 match sibling_digests.next() {
657 Some(hash) => left_digest = Some(*hash),
658 None => return Err(ReconstructionError::MissingDigests),
659 }
660 }
661 if right_digest.is_none() {
662 match sibling_digests.next() {
663 Some(hash) => right_digest = Some(*hash),
664 None => return Err(ReconstructionError::MissingDigests),
665 }
666 }
667
668 if let Some(ref mut collected_digests) = collected_digests {
669 collected_digests.push((
670 left_pos,
671 left_digest.expect("left_digest guaranteed to be Some after checks above"),
672 ));
673 collected_digests.push((
674 right_pos,
675 right_digest.expect("right_digest guaranteed to be Some after checks above"),
676 ));
677 }
678
679 Ok(hasher.node_digest(
680 range_info.pos,
681 &left_digest.expect("left_digest guaranteed to be Some after checks above"),
682 &right_digest.expect("right_digest guaranteed to be Some after checks above"),
683 ))
684}
685
686#[cfg(test)]
687mod tests {
688 use super::*;
689 use crate::mmr::{hasher::Standard, location::LocationRangeExt as _, mem::Mmr, MAX_LOCATION};
690 use bytes::Bytes;
691 use commonware_codec::{Decode, Encode};
692 use commonware_cryptography::{sha256::Digest, Sha256};
693 use commonware_macros::test_traced;
694
695 fn test_digest(v: u8) -> Digest {
696 Sha256::hash(&[v])
697 }
698
699 #[test]
700 fn test_proving_proof() {
701 let mmr = Mmr::new();
703 let mut hasher: Standard<Sha256> = Standard::new();
704 let root = mmr.root(&mut hasher);
705 let proof = Proof::default();
706 assert!(proof.verify_range_inclusion(
707 &mut hasher,
708 &[] as &[Digest],
709 Location::new_unchecked(0),
710 &root
711 ));
712
713 assert!(!proof.verify_range_inclusion(
715 &mut hasher,
716 &[] as &[Digest],
717 Location::new_unchecked(1),
718 &root
719 ));
720
721 let test_digest = test_digest(0);
723 assert!(!proof.verify_range_inclusion(
724 &mut hasher,
725 &[] as &[Digest],
726 Location::new_unchecked(0),
727 &test_digest
728 ));
729
730 assert!(!proof.verify_range_inclusion(
732 &mut hasher,
733 &[test_digest],
734 Location::new_unchecked(0),
735 &root
736 ));
737 }
738
739 #[test]
740 fn test_proving_verify_element() {
741 let mut mmr = Mmr::new();
743 let element = Digest::from(*b"01234567012345670123456701234567");
744 let mut hasher: Standard<Sha256> = Standard::new();
745 for _ in 0..11 {
746 mmr.add(&mut hasher, &element);
747 }
748
749 let root = mmr.root(&mut hasher);
750
751 for leaf in 0u64..11 {
753 let leaf = Location::new_unchecked(leaf);
754 let proof: Proof<Digest> = mmr.proof(leaf).unwrap();
755 assert!(
756 proof.verify_element_inclusion(&mut hasher, &element, leaf, &root),
757 "valid proof should verify successfully"
758 );
759 }
760
761 const LEAF: Location = Location::new_unchecked(10);
764 let proof = mmr.proof(LEAF).unwrap();
765 assert!(
766 proof.verify_element_inclusion(&mut hasher, &element, LEAF, &root),
767 "proof verification should be successful"
768 );
769 let wrong_sizes = [0, 16, 17, 18, 20, u64::MAX - 100];
770 for sz in wrong_sizes {
771 let mut wrong_size_proof = proof.clone();
772 wrong_size_proof.size = Position::new(sz);
773 assert!(
774 !wrong_size_proof.verify_element_inclusion(&mut hasher, &element, LEAF, &root),
775 "proof with wrong size should fail verification"
776 );
777 }
778 assert!(
779 !proof.verify_element_inclusion(&mut hasher, &element, LEAF + 1, &root),
780 "proof verification should fail with incorrect element position"
781 );
782 assert!(
783 !proof.verify_element_inclusion(&mut hasher, &element, LEAF - 1, &root),
784 "proof verification should fail with incorrect element position 2"
785 );
786 assert!(
787 !proof.verify_element_inclusion(&mut hasher, &test_digest(0), LEAF, &root),
788 "proof verification should fail with mangled element"
789 );
790 let root2 = test_digest(0);
791 assert!(
792 !proof.verify_element_inclusion(&mut hasher, &element, LEAF, &root2),
793 "proof verification should fail with mangled root"
794 );
795 let mut proof2 = proof.clone();
796 proof2.digests[0] = test_digest(0);
797 assert!(
798 !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, &root),
799 "proof verification should fail with mangled proof hash"
800 );
801 proof2 = proof.clone();
802 proof2.size = Position::new(10);
803 assert!(
804 !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, &root),
805 "proof verification should fail with incorrect size"
806 );
807 proof2 = proof.clone();
808 proof2.digests.push(test_digest(0));
809 assert!(
810 !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, &root),
811 "proof verification should fail with extra hash"
812 );
813 proof2 = proof.clone();
814 while !proof2.digests.is_empty() {
815 proof2.digests.pop();
816 assert!(
817 !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, &root),
818 "proof verification should fail with missing digests"
819 );
820 }
821 proof2 = proof.clone();
822 proof2.digests.clear();
823 const PEAK_COUNT: usize = 3;
824 proof2
825 .digests
826 .extend(proof.digests[0..PEAK_COUNT - 1].iter().cloned());
827 proof2.digests.push(test_digest(0));
830 proof2
831 .digests
832 .extend(proof.digests[PEAK_COUNT - 1..].iter().cloned());
833 assert!(
834 !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, &root),
835 "proof verification should fail with extra hash even if it's unused by the computation"
836 );
837 }
838
839 #[test]
840 fn test_proving_verify_range() {
841 let mut mmr = Mmr::default();
843 let mut elements = Vec::new();
844 let mut hasher: Standard<Sha256> = Standard::new();
845 for i in 0..49 {
846 elements.push(test_digest(i));
847 mmr.add(&mut hasher, elements.last().unwrap());
848 }
849 let root = mmr.root(&mut hasher);
851
852 for i in 0..elements.len() {
853 for j in i + 1..elements.len() {
854 let range = Location::new_unchecked(i as u64)..Location::new_unchecked(j as u64);
855 let range_proof = mmr.range_proof(range.clone()).unwrap();
856 assert!(
857 range_proof.verify_range_inclusion(
858 &mut hasher,
859 &elements[range.to_usize_range()],
860 range.start,
861 &root,
862 ),
863 "valid range proof should verify successfully {i}:{j}",
864 );
865 }
866 }
867
868 let range = Location::new_unchecked(33)..Location::new_unchecked(40);
871 let range_proof = mmr.range_proof(range.clone()).unwrap();
872 let valid_elements = &elements[range.to_usize_range()];
873 assert!(
874 range_proof.verify_range_inclusion(&mut hasher, valid_elements, range.start, &root),
875 "valid range proof should verify successfully"
876 );
877 let mut invalid_proof = range_proof.clone();
880 for _i in 0..range_proof.digests.len() {
881 invalid_proof.digests.remove(0);
882 assert!(
883 !invalid_proof.verify_range_inclusion(
884 &mut hasher,
885 valid_elements,
886 range.start,
887 &root,
888 ),
889 "range proof with removed elements should fail"
890 );
891 }
892 for i in 0..elements.len() {
895 for j in i + 1..elements.len() {
896 if Location::from(i) == range.start && Location::from(j) == range.end {
897 continue;
899 }
900 assert!(
901 !range_proof.verify_range_inclusion(
902 &mut hasher,
903 &elements[i..j],
904 range.start,
905 &root,
906 ),
907 "range proof with invalid element range should fail {i}:{j}",
908 );
909 }
910 }
911 let invalid_root = test_digest(1);
913 assert!(
914 !range_proof.verify_range_inclusion(
915 &mut hasher,
916 valid_elements,
917 range.start,
918 &invalid_root,
919 ),
920 "range proof with invalid root should fail"
921 );
922 for i in 0..range_proof.digests.len() {
924 let mut invalid_proof = range_proof.clone();
925 invalid_proof.digests[i] = test_digest(0);
926
927 assert!(
928 !invalid_proof.verify_range_inclusion(
929 &mut hasher,
930 valid_elements,
931 range.start,
932 &root,
933 ),
934 "mangled range proof should fail verification"
935 );
936 }
937 for i in 0..range_proof.digests.len() {
939 let mut invalid_proof = range_proof.clone();
940 invalid_proof.digests.insert(i, test_digest(0));
941 assert!(
942 !invalid_proof.verify_range_inclusion(
943 &mut hasher,
944 valid_elements,
945 range.start,
946 &root,
947 ),
948 "mangled range proof should fail verification. inserted element at: {i}",
949 );
950 }
951 for loc in 0..elements.len() {
953 let loc = Location::new_unchecked(loc as u64);
954 if loc == range.start {
955 continue;
956 }
957 assert!(
958 !range_proof.verify_range_inclusion(&mut hasher, valid_elements, loc, &root),
959 "bad start_loc should fail verification {loc}",
960 );
961 }
962 }
963
964 #[test_traced]
965 fn test_proving_retained_nodes_provable_after_pruning() {
966 let mut mmr = Mmr::default();
968 let mut elements = Vec::new();
969 let mut hasher: Standard<Sha256> = Standard::new();
970 for i in 0..49 {
971 elements.push(test_digest(i));
972 mmr.add(&mut hasher, elements.last().unwrap());
973 }
974
975 let root = mmr.root(&mut hasher);
977 for i in 1..*mmr.size() {
978 mmr.prune_to_pos(Position::new(i));
979 let pruned_root = mmr.root(&mut hasher);
980 assert_eq!(root, pruned_root);
981 for loc in 0..elements.len() {
982 let loc = Location::new_unchecked(loc as u64);
983 let proof = mmr.proof(loc);
984 if Position::try_from(loc).unwrap() < Position::new(i) {
985 continue;
986 }
987 assert!(proof.is_ok());
988 assert!(proof.unwrap().verify_element_inclusion(
989 &mut hasher,
990 &elements[*loc as usize],
991 loc,
992 &root
993 ));
994 }
995 }
996 }
997
998 #[test]
999 fn test_proving_ranges_provable_after_pruning() {
1000 let mut mmr = Mmr::default();
1002 let mut elements = Vec::new();
1003 let mut hasher: Standard<Sha256> = Standard::new();
1004 for i in 0..49 {
1005 elements.push(test_digest(i));
1006 mmr.add(&mut hasher, elements.last().unwrap());
1007 }
1008
1009 const PRUNE_POS: Position = Position::new(62);
1011 mmr.prune_to_pos(PRUNE_POS);
1012 assert_eq!(mmr.oldest_retained_pos().unwrap(), PRUNE_POS);
1013
1014 let root = mmr.root(&mut hasher);
1016 for i in 0..elements.len() - 1 {
1017 if Position::try_from(Location::new_unchecked(i as u64)).unwrap() < PRUNE_POS {
1018 continue;
1019 }
1020 for j in (i + 2)..elements.len() {
1021 let range = Location::new_unchecked(i as u64)..Location::new_unchecked(j as u64);
1022 let range_proof = mmr.range_proof(range.clone()).unwrap();
1023 assert!(
1024 range_proof.verify_range_inclusion(
1025 &mut hasher,
1026 &elements[range.to_usize_range()],
1027 range.start,
1028 &root,
1029 ),
1030 "valid range proof over remaining elements should verify successfully",
1031 );
1032 }
1033 }
1034
1035 for i in 0..37 {
1038 elements.push(test_digest(i));
1039 mmr.add(&mut hasher, elements.last().unwrap());
1040 }
1041 mmr.prune_to_pos(Position::new(130)); assert_eq!(mmr.oldest_retained_pos().unwrap(), 130);
1043
1044 let updated_root = mmr.root(&mut hasher);
1045 let range = Location::new_unchecked(elements.len() as u64 - 10)
1046 ..Location::new_unchecked(elements.len() as u64);
1047 let range_proof = mmr.range_proof(range.clone()).unwrap();
1048 assert!(
1049 range_proof.verify_range_inclusion(
1050 &mut hasher,
1051 &elements[range.to_usize_range()],
1052 range.start,
1053 &updated_root,
1054 ),
1055 "valid range proof over remaining elements after 2 pruning rounds should verify successfully",
1056 );
1057 }
1058
1059 #[test]
1060 fn test_proving_proof_serialization() {
1061 let mut mmr = Mmr::default();
1063 let mut elements = Vec::new();
1064 let mut hasher: Standard<Sha256> = Standard::new();
1065 for i in 0..25 {
1066 elements.push(test_digest(i));
1067 mmr.add(&mut hasher, elements.last().unwrap());
1068 }
1069
1070 for i in 0..elements.len() {
1073 for j in i + 1..elements.len() {
1074 let range = Location::new_unchecked(i as u64)..Location::new_unchecked(j as u64);
1075 let proof = mmr.range_proof(range).unwrap();
1076
1077 let expected_size = proof.encode_size();
1078 let serialized_proof = proof.encode().freeze();
1079 assert_eq!(
1080 serialized_proof.len(),
1081 expected_size,
1082 "serialized proof should have expected size"
1083 );
1084 let max_digests = proof.digests.len();
1085 let deserialized_proof = Proof::decode_cfg(serialized_proof, &max_digests).unwrap();
1086 assert_eq!(
1087 proof, deserialized_proof,
1088 "deserialized proof should match source proof"
1089 );
1090
1091 let serialized_proof = proof.encode().freeze();
1094 let serialized_proof: Bytes = serialized_proof.slice(0..serialized_proof.len() - 1);
1095 assert!(
1096 Proof::<Digest>::decode_cfg(serialized_proof, &max_digests).is_err(),
1097 "proof should not deserialize with truncated data"
1098 );
1099
1100 let mut serialized_proof = proof.encode();
1103 serialized_proof.extend_from_slice(&[0; 10]);
1104 let serialized_proof = serialized_proof.freeze();
1105
1106 assert!(
1107 Proof::<Digest>::decode_cfg(serialized_proof, &max_digests).is_err(),
1108 "proof should not deserialize with extra data"
1109 );
1110
1111 if max_digests > 0 {
1113 let serialized_proof = proof.encode().freeze();
1114 assert!(
1115 Proof::<Digest>::decode_cfg(serialized_proof, &(max_digests - 1)).is_err(),
1116 "proof should not deserialize with max length exceeded"
1117 );
1118 }
1119 }
1120 }
1121 }
1122
1123 #[test_traced]
1124 fn test_proving_extract_pinned_nodes() {
1125 for num_elements in 1u64..255 {
1127 let mut mmr = Mmr::new();
1129 let mut hasher: Standard<Sha256> = Standard::new();
1130
1131 for i in 0..num_elements {
1132 let digest = test_digest(i as u8);
1133 mmr.add(&mut hasher, &digest);
1134 }
1135
1136 for leaf in 0..num_elements {
1138 let test_end_locs = if num_elements == 1 {
1140 vec![leaf + 1]
1142 } else {
1143 let mut ends = vec![leaf + 1]; if leaf + 2 <= num_elements {
1148 ends.push(leaf + 2);
1149 }
1150 if leaf + 3 <= num_elements {
1151 ends.push(leaf + 3);
1152 }
1153 if ends.last().unwrap() != &num_elements {
1155 ends.push(num_elements);
1156 }
1157
1158 ends.into_iter()
1159 .collect::<BTreeSet<_>>()
1160 .into_iter()
1161 .collect()
1162 };
1163
1164 for end_loc in test_end_locs {
1165 let range = Location::new_unchecked(leaf)..Location::new_unchecked(end_loc);
1167 let proof_result = mmr.range_proof(range.clone());
1168 let proof = proof_result.unwrap();
1169
1170 let extract_result = proof.extract_pinned_nodes(range.clone());
1172 assert!(
1173 extract_result.is_ok(),
1174 "Failed to extract pinned nodes for {num_elements} elements, boundary={leaf}, range={}..{}", range.start, range.end
1175 );
1176
1177 let pinned_nodes = extract_result.unwrap();
1178 let leaf_loc = Location::new_unchecked(leaf);
1179 let leaf_pos = Position::try_from(leaf_loc).unwrap();
1180 let expected_pinned: Vec<Position> = nodes_to_pin(leaf_pos).collect();
1181
1182 assert_eq!(
1184 pinned_nodes.len(),
1185 expected_pinned.len(),
1186 "Pinned node count mismatch for {num_elements} elements, boundary={leaf}, range=[{leaf}, {end_loc}]"
1187 );
1188
1189 for (i, &expected_pos) in expected_pinned.iter().enumerate() {
1192 let extracted_hash = pinned_nodes[i];
1193 let actual_hash = mmr.get_node(expected_pos).unwrap();
1194 assert_eq!(
1195 extracted_hash, actual_hash,
1196 "Hash mismatch at position {expected_pos} (index {i}) for {num_elements} elements, boundary={leaf}, range=[{leaf}, {end_loc}]"
1197 );
1198 }
1199 }
1200 }
1201 }
1202 }
1203
1204 #[test]
1205 fn test_proving_extract_pinned_nodes_invalid_size() {
1206 let mut mmr = Mmr::new();
1208 let mut hasher: Standard<Sha256> = Standard::new();
1209
1210 for i in 0..10 {
1212 let digest = test_digest(i);
1213 mmr.add(&mut hasher, &digest);
1214 }
1215
1216 let range = Location::new_unchecked(5)..Location::new_unchecked(8);
1218 let mut proof = mmr.range_proof(range.clone()).unwrap();
1219
1220 assert!(proof.extract_pinned_nodes(range.clone()).is_ok());
1222
1223 const INVALID_SIZES: [u64; 5] = [2, 5, 6, 9, u64::MAX];
1225 for invalid_size in INVALID_SIZES {
1226 proof.size = Position::new(invalid_size);
1227 let result = proof.extract_pinned_nodes(range.clone());
1228 assert!(matches!(result, Err(Error::InvalidSize(s)) if s == invalid_size));
1229 }
1230
1231 proof.size = Position::new(1u64 << 63);
1233 let result = proof.extract_pinned_nodes(range);
1234 assert!(matches!(result, Err(Error::InvalidSize(s)) if s == 1u64 << 63));
1235 }
1236
1237 #[test]
1238 fn test_proving_digests_from_range() {
1239 let mut mmr = Mmr::default();
1241 let mut elements = Vec::new();
1242 let mut element_positions = Vec::new();
1243 let mut hasher: Standard<Sha256> = Standard::new();
1244 for i in 0..49 {
1245 elements.push(test_digest(i));
1246 element_positions.push(mmr.add(&mut hasher, elements.last().unwrap()));
1247 }
1248 let root = mmr.root(&mut hasher);
1249
1250 let proof = mmr
1253 .range_proof(Location::new_unchecked(0)..mmr.leaves())
1254 .unwrap();
1255 let mut node_digests = proof
1256 .verify_range_inclusion_and_extract_digests(
1257 &mut hasher,
1258 &elements,
1259 Location::new_unchecked(0),
1260 &root,
1261 )
1262 .unwrap();
1263 assert_eq!(node_digests.len() as u64, mmr.size());
1264 node_digests.sort_by_key(|(pos, _)| *pos);
1265 for (i, (pos, d)) in node_digests.into_iter().enumerate() {
1266 assert_eq!(pos, i as u64);
1267 assert_eq!(mmr.get_node(pos).unwrap(), d);
1268 }
1269 let wrong_root = elements[0]; assert!(matches!(
1272 proof.verify_range_inclusion_and_extract_digests(
1273 &mut hasher,
1274 &elements,
1275 Location::new_unchecked(0),
1276 &wrong_root
1277 ),
1278 Err(Error::RootMismatch)
1279 ));
1280
1281 let range = Location::new_unchecked(0)..Location::new_unchecked(1);
1283 let single_proof = mmr.range_proof(range.clone()).unwrap();
1284 let range_start = range.start;
1285 let single_digests = single_proof
1286 .verify_range_inclusion_and_extract_digests(
1287 &mut hasher,
1288 &elements[range.to_usize_range()],
1289 range_start,
1290 &root,
1291 )
1292 .unwrap();
1293 assert!(single_digests.len() > 1);
1294
1295 let mid_idx = 24;
1297 let range = Location::new_unchecked(mid_idx)..Location::new_unchecked(mid_idx + 1);
1298 let range_start = range.start;
1299 let mid_proof = mmr.range_proof(range.clone()).unwrap();
1300 let mid_digests = mid_proof
1301 .verify_range_inclusion_and_extract_digests(
1302 &mut hasher,
1303 &elements[range.to_usize_range()],
1304 range_start,
1305 &root,
1306 )
1307 .unwrap();
1308 assert!(mid_digests.len() > 1);
1309
1310 let last_idx = elements.len() as u64 - 1;
1312 let range = Location::new_unchecked(last_idx)..Location::new_unchecked(last_idx + 1);
1313 let range_start = range.start;
1314 let last_proof = mmr.range_proof(range.clone()).unwrap();
1315 let last_digests = last_proof
1316 .verify_range_inclusion_and_extract_digests(
1317 &mut hasher,
1318 &elements[range.to_usize_range()],
1319 range_start,
1320 &root,
1321 )
1322 .unwrap();
1323 assert!(last_digests.len() > 1);
1324
1325 let range = Location::new_unchecked(0)..Location::new_unchecked(5);
1327 let range_start = range.start;
1328 let small_proof = mmr.range_proof(range.clone()).unwrap();
1329 let small_digests = small_proof
1330 .verify_range_inclusion_and_extract_digests(
1331 &mut hasher,
1332 &elements[range.to_usize_range()],
1333 range_start,
1334 &root,
1335 )
1336 .unwrap();
1337 assert!(small_digests.len() > 5);
1339
1340 let range = Location::new_unchecked(10)..Location::new_unchecked(31);
1342 let range_start = range.start;
1343 let mid_range_proof = mmr.range_proof(range.clone()).unwrap();
1344 let mid_range_digests = mid_range_proof
1345 .verify_range_inclusion_and_extract_digests(
1346 &mut hasher,
1347 &elements[range.to_usize_range()],
1348 range_start,
1349 &root,
1350 )
1351 .unwrap();
1352 let num_elements = range.end - range.start;
1353 assert!(mid_range_digests.len() as u64 > num_elements);
1354 }
1355
1356 #[test]
1357 fn test_proving_multi_proof_generation_and_verify() {
1358 let mut mmr = Mmr::new();
1360 let mut hasher: Standard<Sha256> = Standard::new();
1361 let mut elements = Vec::new();
1362
1363 for i in 0..20 {
1364 elements.push(test_digest(i));
1365 mmr.add(&mut hasher, &elements[i as usize]);
1366 }
1367
1368 let root = mmr.root(&mut hasher);
1369
1370 let locations = &[
1372 Location::new_unchecked(0),
1373 Location::new_unchecked(5),
1374 Location::new_unchecked(10),
1375 ];
1376 let nodes_for_multi_proof =
1377 nodes_required_for_multi_proof(mmr.size(), locations).expect("test locations valid");
1378 let digests = nodes_for_multi_proof
1379 .into_iter()
1380 .map(|pos| mmr.get_node(pos).unwrap())
1381 .collect();
1382 let multi_proof = Proof {
1383 size: mmr.size(),
1384 digests,
1385 };
1386
1387 assert_eq!(multi_proof.size, mmr.size());
1388
1389 assert!(multi_proof.verify_multi_inclusion(
1391 &mut hasher,
1392 &[
1393 (elements[0], Location::new_unchecked(0)),
1394 (elements[5], Location::new_unchecked(5)),
1395 (elements[10], Location::new_unchecked(10)),
1396 ],
1397 &root
1398 ));
1399
1400 assert!(multi_proof.verify_multi_inclusion(
1402 &mut hasher,
1403 &[
1404 (elements[10], Location::new_unchecked(10)),
1405 (elements[5], Location::new_unchecked(5)),
1406 (elements[0], Location::new_unchecked(0)),
1407 ],
1408 &root
1409 ));
1410
1411 assert!(multi_proof.verify_multi_inclusion(
1413 &mut hasher,
1414 &[
1415 (elements[0], Location::new_unchecked(0)),
1416 (elements[0], Location::new_unchecked(0)),
1417 (elements[10], Location::new_unchecked(10)),
1418 (elements[5], Location::new_unchecked(5)),
1419 ],
1420 &root
1421 ));
1422
1423 let wrong_sizes = [
1426 Position::new(16),
1427 Position::new(17),
1428 Position::new(u64::MAX - 100),
1429 ];
1430 for sz in wrong_sizes {
1431 let mut wrong_size_proof = multi_proof.clone();
1432 wrong_size_proof.size = sz;
1433 assert!(!wrong_size_proof.verify_multi_inclusion(
1434 &mut hasher,
1435 &[
1436 (elements[0], Location::new_unchecked(0)),
1437 (elements[5], Location::new_unchecked(5)),
1438 (elements[10], Location::new_unchecked(10)),
1439 ],
1440 &root,
1441 ));
1442 }
1443
1444 assert!(!multi_proof.verify_multi_inclusion(
1446 &mut hasher,
1447 &[
1448 (elements[0], Location::new_unchecked(1)),
1449 (elements[5], Location::new_unchecked(6)),
1450 (elements[10], Location::new_unchecked(11)),
1451 ],
1452 &root,
1453 ));
1454
1455 let wrong_elements = [
1457 vec![255u8, 254u8, 253u8],
1458 vec![252u8, 251u8, 250u8],
1459 vec![249u8, 248u8, 247u8],
1460 ];
1461 let wrong_verification = multi_proof.verify_multi_inclusion(
1462 &mut hasher,
1463 &[
1464 (wrong_elements[0].as_slice(), Location::new_unchecked(0)),
1465 (wrong_elements[1].as_slice(), Location::new_unchecked(5)),
1466 (wrong_elements[2].as_slice(), Location::new_unchecked(10)),
1467 ],
1468 &root,
1469 );
1470 assert!(!wrong_verification, "Should fail with wrong elements");
1471
1472 let wrong_verification = multi_proof.verify_multi_inclusion(
1474 &mut hasher,
1475 &[
1476 (elements[0], Location::new_unchecked(0)),
1477 (elements[5], Location::new_unchecked(5)),
1478 (elements[10], Location::new_unchecked(1000)),
1479 ],
1480 &root,
1481 );
1482 assert!(
1483 !wrong_verification,
1484 "Should fail with out of range elements"
1485 );
1486
1487 let wrong_root = test_digest(99);
1489 assert!(!multi_proof.verify_multi_inclusion(
1490 &mut hasher,
1491 &[
1492 (elements[0], Location::new_unchecked(0)),
1493 (elements[5], Location::new_unchecked(5)),
1494 (elements[10], Location::new_unchecked(10)),
1495 ],
1496 &wrong_root
1497 ));
1498
1499 let empty_mmr = Mmr::new();
1501 let empty_root = empty_mmr.root(&mut hasher);
1502 let empty_proof = Proof::default();
1503 assert!(empty_proof.verify_multi_inclusion(
1504 &mut hasher,
1505 &[] as &[(Digest, Location)],
1506 &empty_root
1507 ));
1508 }
1509
1510 #[test]
1511 fn test_proving_multi_proof_deduplication() {
1512 let mut mmr = Mmr::new();
1513 let mut hasher: Standard<Sha256> = Standard::new();
1514 let mut elements = Vec::new();
1515
1516 for i in 0..30 {
1518 elements.push(test_digest(i));
1519 mmr.add(&mut hasher, &elements[i as usize]);
1520 }
1521
1522 let proof1 = mmr.proof(Location::new_unchecked(0)).unwrap();
1524 let proof2 = mmr.proof(Location::new_unchecked(1)).unwrap();
1525 let total_digests_separate = proof1.digests.len() + proof2.digests.len();
1526
1527 let locations = &[Location::new_unchecked(0), Location::new_unchecked(1)];
1529 let multi_proof =
1530 nodes_required_for_multi_proof(mmr.size(), locations).expect("test locations valid");
1531 let digests = multi_proof
1532 .into_iter()
1533 .map(|pos| mmr.get_node(pos).unwrap())
1534 .collect();
1535 let multi_proof = Proof {
1536 size: mmr.size(),
1537 digests,
1538 };
1539
1540 assert!(multi_proof.digests.len() < total_digests_separate);
1542
1543 let root = mmr.root(&mut hasher);
1545 assert!(multi_proof.verify_multi_inclusion(
1546 &mut hasher,
1547 &[
1548 (elements[0], Location::new_unchecked(0)),
1549 (elements[1], Location::new_unchecked(1))
1550 ],
1551 &root
1552 ));
1553 }
1554
1555 #[test]
1556 fn test_max_location_is_provable() {
1557 let max_mmr_size = Position::new((1u64 << 63) - 1);
1560
1561 let max_loc = Location::new_unchecked(MAX_LOCATION);
1562 let max_loc_plus_1 = Location::new_unchecked(MAX_LOCATION + 1);
1563 let max_loc_plus_2 = Location::new_unchecked(MAX_LOCATION + 2);
1564
1565 let result = nodes_required_for_range_proof(max_mmr_size, max_loc..max_loc_plus_1);
1568
1569 assert!(result.is_ok(), "Should be able to prove MAX_LOCATION");
1571
1572 let result_overflow =
1574 nodes_required_for_range_proof(max_mmr_size, max_loc_plus_1..max_loc_plus_2);
1575 assert!(
1576 result_overflow.is_err(),
1577 "Should reject location > MAX_LOCATION"
1578 );
1579 matches!(result_overflow, Err(Error::LocationOverflow(_)));
1580 }
1581
1582 #[test]
1583 fn test_max_location_multi_proof() {
1584 let max_mmr_size = Position::new((1u64 << 63) - 1);
1586 let max_loc = Location::new_unchecked(MAX_LOCATION);
1587
1588 let result = nodes_required_for_multi_proof(max_mmr_size, &[max_loc]);
1590 assert!(
1591 result.is_ok(),
1592 "Should be able to generate multi-proof for MAX_LOCATION"
1593 );
1594
1595 let invalid_loc = Location::new_unchecked(MAX_LOCATION + 1);
1597 let result_overflow = nodes_required_for_multi_proof(max_mmr_size, &[invalid_loc]);
1598 assert!(
1599 result_overflow.is_err(),
1600 "Should reject location > MAX_LOCATION in multi-proof"
1601 );
1602 }
1603
1604 #[test]
1605 fn test_invalid_size_validation() {
1606 let loc = Location::new_unchecked(0);
1608 let range = loc..loc + 1;
1609
1610 const INVALID_SIZES: [u64; 5] = [2, 5, 6, 9, u64::MAX];
1611 for size in INVALID_SIZES {
1612 let result = nodes_required_for_range_proof(Position::new(size), range.clone());
1613 assert!(matches!(result, Err(Error::InvalidSize(s)) if s == size));
1614
1615 let result = nodes_required_for_multi_proof(Position::new(size), &[loc]);
1616 assert!(matches!(result, Err(Error::InvalidSize(s)) if s == size));
1617 }
1618 }
1619}