Skip to main content

commonware_storage/mmr/
proof.rs

1//! Defines the inclusion [Proof] structure, and functions for verifying them against a root digest.
2//!
3//! Also provides lower-level functions for building verifiers against new or extended proof types.
4//! These lower level functions are kept outside of the [Proof] structure and not re-exported by the
5//! parent module.
6
7#[cfg(any(feature = "std", test))]
8use crate::mmr::iterator::nodes_to_pin;
9use crate::mmr::{
10    hasher::Hasher,
11    iterator::{PathIterator, PeakIterator},
12    Error, Location, Position,
13};
14use alloc::{
15    collections::{btree_map::BTreeMap, btree_set::BTreeSet},
16    vec,
17    vec::Vec,
18};
19use bytes::{Buf, BufMut};
20use commonware_codec::{EncodeSize, Read, ReadExt, ReadRangeExt, Write};
21use commonware_cryptography::Digest;
22use core::{cmp::Reverse, ops::Range};
23#[cfg(feature = "std")]
24use tracing::debug;
25
26/// The maximum number of digests in a proof per element being proven.
27///
28/// This accounts for the worst case proof size, in an MMR with 62 peaks. The
29/// left-most leaf in such a tree requires 122 digests, for 61 path siblings
30/// and 61 peak digests.
31pub const MAX_PROOF_DIGESTS_PER_ELEMENT: usize = 122;
32
33/// Errors that can occur when reconstructing a digest from a proof due to invalid input.
34#[derive(Error, Debug)]
35pub enum ReconstructionError {
36    #[error("missing digests in proof")]
37    MissingDigests,
38    #[error("extra digests in proof")]
39    ExtraDigests,
40    #[error("start location is out of bounds")]
41    InvalidStartLoc,
42    #[error("end location is out of bounds")]
43    InvalidEndLoc,
44    #[error("missing elements")]
45    MissingElements,
46    #[error("invalid size")]
47    InvalidSize,
48}
49
50/// Contains the information necessary for proving the inclusion of an element, or some range of
51/// elements, in the MMR from its root digest.
52///
53/// The `digests` vector contains:
54///
55/// 1: the digests of each peak corresponding to a mountain containing no elements from the element
56/// range being proven in decreasing order of height, followed by:
57///
58/// 2: the nodes in the remaining mountains necessary for reconstructing their peak digests from the
59/// elements within the range, ordered by the position of their parent.
60#[derive(Clone, Debug, Eq)]
61pub struct Proof<D: Digest> {
62    /// The total number of leaves in the MMR for MMR proofs, though other authenticated data
63    /// structures may override the meaning of this field. For example, the authenticated
64    /// [crate::AuthenticatedBitMap] stores the number of bits in the bitmap within this field.
65    pub leaves: Location,
66    /// The digests necessary for proving the inclusion of an element, or range of elements, in the
67    /// MMR.
68    pub digests: Vec<D>,
69}
70
71impl<D: Digest> PartialEq for Proof<D> {
72    fn eq(&self, other: &Self) -> bool {
73        self.leaves == other.leaves && self.digests == other.digests
74    }
75}
76
77impl<D: Digest> EncodeSize for Proof<D> {
78    fn encode_size(&self) -> usize {
79        self.leaves.encode_size() + self.digests.encode_size()
80    }
81}
82
83impl<D: Digest> Write for Proof<D> {
84    fn write(&self, buf: &mut impl BufMut) {
85        // Write the number of leaves in the MMR
86        self.leaves.write(buf);
87
88        // Write the digests
89        self.digests.write(buf);
90    }
91}
92
93impl<D: Digest> Read for Proof<D> {
94    /// The maximum number of items being proven.
95    ///
96    /// The upper bound on digests is derived as `max_items * MAX_PROOF_DIGESTS_PER_ELEMENT`.
97    type Cfg = usize;
98
99    fn read_cfg(
100        buf: &mut impl Buf,
101        max_items: &Self::Cfg,
102    ) -> Result<Self, commonware_codec::Error> {
103        // Read the number of nodes in the MMR
104        let leaves = Location::read(buf)?;
105
106        // Read the digests
107        let max_digests = max_items.saturating_mul(MAX_PROOF_DIGESTS_PER_ELEMENT);
108        let digests = Vec::<D>::read_range(buf, ..=max_digests)?;
109
110        Ok(Self { leaves, digests })
111    }
112}
113
114impl<D: Digest> Default for Proof<D> {
115    /// Create an empty proof. The empty proof will verify only against the root digest of an empty
116    /// (`leaves == 0`) MMR.
117    fn default() -> Self {
118        Self {
119            leaves: Location::new_unchecked(0),
120            digests: vec![],
121        }
122    }
123}
124
125#[cfg(feature = "arbitrary")]
126impl<D: Digest> arbitrary::Arbitrary<'_> for Proof<D>
127where
128    D: for<'a> arbitrary::Arbitrary<'a>,
129{
130    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
131        Ok(Self {
132            leaves: u.arbitrary()?,
133            digests: u.arbitrary()?,
134        })
135    }
136}
137
138impl<D: Digest> Proof<D> {
139    /// Return true if this proof proves that `element` appears at location `loc` within the MMR
140    /// with root digest `root`.
141    pub fn verify_element_inclusion<H>(
142        &self,
143        hasher: &mut H,
144        element: &[u8],
145        loc: Location,
146        root: &D,
147    ) -> bool
148    where
149        H: Hasher<Digest = D>,
150    {
151        self.verify_range_inclusion(hasher, &[element], loc, root)
152    }
153
154    /// Return true if this proof proves that the `elements` appear consecutively starting at
155    /// position `start_loc` within the MMR with root digest `root`. A malformed proof will return
156    /// false.
157    pub fn verify_range_inclusion<H, E>(
158        &self,
159        hasher: &mut H,
160        elements: &[E],
161        start_loc: Location,
162        root: &D,
163    ) -> bool
164    where
165        H: Hasher<Digest = D>,
166        E: AsRef<[u8]>,
167    {
168        match self.reconstruct_root(hasher, elements, start_loc) {
169            Ok(reconstructed_root) => *root == reconstructed_root,
170            Err(_error) => {
171                #[cfg(feature = "std")]
172                tracing::debug!(error = ?_error, "invalid proof input");
173                false
174            }
175        }
176    }
177
178    /// Return true if this proof proves that the elements at the specified locations are included
179    /// in the MMR with the root digest `root`. A malformed proof will return false.
180    ///
181    /// The order of the elements does not affect the output.
182    pub fn verify_multi_inclusion<H, E>(
183        &self,
184        hasher: &mut H,
185        elements: &[(E, Location)],
186        root: &D,
187    ) -> bool
188    where
189        H: Hasher<Digest = D>,
190        E: AsRef<[u8]>,
191    {
192        // Empty proof is valid for an empty MMR
193        if elements.is_empty() {
194            return self.leaves == Location::new_unchecked(0)
195                && *root == hasher.root(Location::new_unchecked(0), core::iter::empty());
196        }
197
198        // Single pass to collect all required positions with deduplication
199        let mut node_positions = BTreeSet::new();
200        let mut nodes_required = BTreeMap::new();
201
202        for (_, loc) in elements {
203            if !loc.is_valid() {
204                return false;
205            }
206            // `loc` is valid so it won't overflow from +1
207            let Ok(required) = nodes_required_for_range_proof(self.leaves, *loc..*loc + 1) else {
208                return false;
209            };
210            for req_pos in &required {
211                node_positions.insert(*req_pos);
212            }
213            nodes_required.insert(*loc, required);
214        }
215
216        // Verify we have the exact number of digests needed
217        if node_positions.len() != self.digests.len() {
218            return false;
219        }
220
221        // Build position to digest mapping once
222        let node_digests: BTreeMap<Position, D> = node_positions
223            .iter()
224            .zip(self.digests.iter())
225            .map(|(&pos, digest)| (pos, *digest))
226            .collect();
227
228        // Verify each element by reconstructing its path
229        for (element, loc) in elements {
230            // Get required positions for this element
231            let required = &nodes_required[loc];
232
233            // Build proof with required digests
234            let mut digests = Vec::with_capacity(required.len());
235            for req_pos in required {
236                // There must exist a digest for each required position (by
237                // construction of `node_digests`)
238                let digest = node_digests
239                    .get(req_pos)
240                    .expect("must exist by construction of node_digests");
241                digests.push(*digest);
242            }
243            let proof = Self {
244                leaves: self.leaves,
245                digests,
246            };
247
248            // Verify the proof
249            if !proof.verify_element_inclusion(hasher, element.as_ref(), *loc, root) {
250                return false;
251            }
252        }
253
254        true
255    }
256
257    // The functions below are lower level functions that are useful to building verification
258    // functions for new or extended proof types.
259
260    /// Computes the set of pinned nodes for the pruning boundary corresponding to the start of the
261    /// given range, returning the digest of each by extracting it from the proof.
262    ///
263    /// # Arguments
264    /// * `range` - The start and end locations of the proven range, where start is also used as the
265    ///   pruning boundary.
266    ///
267    /// # Returns
268    /// A Vec of digests for all nodes in `nodes_to_pin(pruning_boundary)`, in the same order as
269    /// returned by `nodes_to_pin` (decreasing height order)
270    ///
271    /// # Errors
272    ///
273    /// Returns [Error::InvalidSize] if the proof size is not a valid MMR size.
274    /// Returns [Error::LocationOverflow] if a location in `range` > [crate::mmr::MAX_LOCATION].
275    /// Returns [Error::InvalidProofLength] if the proof digest count doesn't match the required
276    /// positions count.
277    /// Returns [Error::MissingDigest] if a pinned node is not found in the proof.
278    #[cfg(any(feature = "std", test))]
279    pub(crate) fn extract_pinned_nodes(
280        &self,
281        range: std::ops::Range<Location>,
282    ) -> Result<Vec<D>, Error> {
283        // Get the positions of all nodes that should be pinned.
284        let start_pos = Position::try_from(range.start)?;
285        let pinned_positions: Vec<Position> = nodes_to_pin(start_pos).collect();
286
287        // Get all positions required for the proof.
288        let required_positions = nodes_required_for_range_proof(self.leaves, range)?;
289
290        if required_positions.len() != self.digests.len() {
291            #[cfg(feature = "std")]
292            debug!(
293                digests_len = self.digests.len(),
294                required_positions_len = required_positions.len(),
295                "Proof digest count doesn't match required positions",
296            );
297            return Err(Error::InvalidProofLength);
298        }
299
300        // Happy path: we can extract the pinned nodes directly from the proof.
301        // This happens when the `end_element_pos` is the last element in the MMR.
302        if pinned_positions
303            == required_positions[required_positions.len() - pinned_positions.len()..]
304        {
305            return Ok(self.digests[required_positions.len() - pinned_positions.len()..].to_vec());
306        }
307
308        // Create a mapping from position to digest.
309        let position_to_digest: BTreeMap<Position, D> = required_positions
310            .iter()
311            .zip(self.digests.iter())
312            .map(|(&pos, &digest)| (pos, digest))
313            .collect();
314
315        // Extract the pinned nodes in the same order as nodes_to_pin.
316        let mut result = Vec::with_capacity(pinned_positions.len());
317        for pinned_pos in pinned_positions {
318            let Some(&digest) = position_to_digest.get(&pinned_pos) else {
319                #[cfg(feature = "std")]
320                debug!(?pinned_pos, "Pinned node not found in proof");
321                return Err(Error::MissingDigest(pinned_pos));
322            };
323            result.push(digest);
324        }
325        Ok(result)
326    }
327
328    /// Reconstructs the root digest of the MMR from the digests in the proof and the provided range
329    /// of elements, returning the (position,digest) of every node whose digest was required by the
330    /// process (including those from the proof itself). Returns a [Error::InvalidProof] if the
331    /// input data is invalid and [Error::RootMismatch] if the root does not match the computed
332    /// root.
333    pub fn verify_range_inclusion_and_extract_digests<H, E>(
334        &self,
335        hasher: &mut H,
336        elements: &[E],
337        start_loc: Location,
338        root: &D,
339    ) -> Result<Vec<(Position, D)>, super::Error>
340    where
341        H: Hasher<Digest = D>,
342        E: AsRef<[u8]>,
343    {
344        let mut collected_digests = Vec::new();
345        let Ok(peak_digests) = self.reconstruct_peak_digests(
346            hasher,
347            elements,
348            start_loc,
349            Some(&mut collected_digests),
350        ) else {
351            return Err(Error::InvalidProof);
352        };
353
354        if hasher.root(self.leaves, peak_digests.iter()) != *root {
355            return Err(Error::RootMismatch);
356        }
357
358        Ok(collected_digests)
359    }
360
361    /// Reconstructs the root digest of the MMR from the digests in the proof and the provided range
362    /// of elements, or returns a [ReconstructionError] if the input data is invalid.
363    pub fn reconstruct_root<H, E>(
364        &self,
365        hasher: &mut H,
366        elements: &[E],
367        start_loc: Location,
368    ) -> Result<D, ReconstructionError>
369    where
370        H: Hasher<Digest = D>,
371        E: AsRef<[u8]>,
372    {
373        let peak_digests = self.reconstruct_peak_digests(hasher, elements, start_loc, None)?;
374
375        Ok(hasher.root(self.leaves, peak_digests.iter()))
376    }
377
378    /// Reconstruct the peak digests of the MMR that produced this proof, returning
379    /// [ReconstructionError] if the input data is invalid.  If collected_digests is Some, then all
380    /// node digests used in the process will be added to the wrapped vector.
381    pub fn reconstruct_peak_digests<H, E>(
382        &self,
383        hasher: &mut H,
384        elements: &[E],
385        start_loc: Location,
386        mut collected_digests: Option<&mut Vec<(Position, D)>>,
387    ) -> Result<Vec<D>, ReconstructionError>
388    where
389        H: Hasher<Digest = D>,
390        E: AsRef<[u8]>,
391    {
392        if elements.is_empty() {
393            if start_loc == 0 {
394                return Ok(vec![]);
395            }
396            return Err(ReconstructionError::MissingElements);
397        }
398        let size = Position::try_from(self.leaves).map_err(|_| ReconstructionError::InvalidSize)?;
399        let start_element_pos =
400            Position::try_from(start_loc).map_err(|_| ReconstructionError::InvalidStartLoc)?;
401        let end_element_pos = if elements.len() == 1 {
402            start_element_pos
403        } else {
404            let end_loc = start_loc
405                .checked_add(elements.len() as u64 - 1)
406                .ok_or(ReconstructionError::InvalidEndLoc)?;
407            Position::try_from(end_loc).map_err(|_| ReconstructionError::InvalidEndLoc)?
408        };
409        if end_element_pos >= size {
410            return Err(ReconstructionError::InvalidEndLoc);
411        }
412
413        let mut proof_digests_iter = self.digests.iter();
414        let mut siblings_iter = self.digests.iter().rev();
415
416        // Include peak digests only for trees that have no elements from the range, and keep track
417        // of the starting and ending trees of those that do contain some.
418        let mut peak_digests: Vec<D> = Vec::new();
419        let mut proof_digests_used = 0;
420        let mut elements_iter = elements.iter();
421        for (peak_pos, height) in PeakIterator::new(size) {
422            let leftmost_pos = peak_pos + 2 - (1 << (height + 1));
423            if peak_pos >= start_element_pos && leftmost_pos <= end_element_pos {
424                let hash = peak_digest_from_range(
425                    hasher,
426                    RangeInfo {
427                        pos: peak_pos,
428                        two_h: 1 << height,
429                        leftmost_pos: start_element_pos,
430                        rightmost_pos: end_element_pos,
431                    },
432                    &mut elements_iter,
433                    &mut siblings_iter,
434                    collected_digests.as_deref_mut(),
435                )?;
436                peak_digests.push(hash);
437                if let Some(ref mut collected_digests) = collected_digests {
438                    collected_digests.push((peak_pos, hash));
439                }
440            } else if let Some(hash) = proof_digests_iter.next() {
441                proof_digests_used += 1;
442                peak_digests.push(*hash);
443                if let Some(ref mut collected_digests) = collected_digests {
444                    collected_digests.push((peak_pos, *hash));
445                }
446            } else {
447                return Err(ReconstructionError::MissingDigests);
448            }
449        }
450
451        if elements_iter.next().is_some() {
452            return Err(ReconstructionError::ExtraDigests);
453        }
454        if let Some(next_sibling) = siblings_iter.next() {
455            if proof_digests_used == 0 || *next_sibling != self.digests[proof_digests_used - 1] {
456                return Err(ReconstructionError::ExtraDigests);
457            }
458        }
459
460        Ok(peak_digests)
461    }
462}
463
464/// Return the list of node positions required by the range proof for the specified range of
465/// elements.
466///
467/// # Errors
468///
469/// Returns [Error::InvalidSize] if `size` is not a valid MMR size.
470/// Returns [Error::Empty] if the range is empty.
471/// Returns [Error::LocationOverflow] if a location in `range` > [crate::mmr::MAX_LOCATION].
472/// Returns [Error::RangeOutOfBounds] if the last element position in `range` is out of bounds
473/// (>= `size`).
474pub(crate) fn nodes_required_for_range_proof(
475    leaves: Location,
476    range: Range<Location>,
477) -> Result<Vec<Position>, Error> {
478    if range.is_empty() {
479        return Err(Error::Empty);
480    }
481    let end_minus_one = range
482        .end
483        .checked_sub(1)
484        .expect("can't underflow because range is non-empty");
485    if end_minus_one >= leaves {
486        return Err(Error::RangeOutOfBounds(range.end));
487    }
488
489    // Find the mountains that contain no elements from the range. The peaks of these mountains
490    // are required to prove the range, so they are added to the result.
491    let mut start_tree_with_element: Option<(Position, u32)> = None;
492    let mut end_tree_with_element: Option<(Position, u32)> = None;
493    let mut positions = Vec::new();
494    let size = Position::try_from(leaves)?;
495    let start_element_pos = Position::try_from(range.start)?;
496    let end_element_pos = Position::try_from(end_minus_one)?;
497
498    let mut peak_iterator = PeakIterator::new(size);
499    while let Some(peak) = peak_iterator.next() {
500        if start_tree_with_element.is_none() && peak.0 >= start_element_pos {
501            // Found the first tree to contain an element in the range
502            start_tree_with_element = Some(peak);
503            if peak.0 >= end_element_pos {
504                // Start and end tree are the same
505                end_tree_with_element = Some(peak);
506                continue;
507            }
508            for peak in peak_iterator.by_ref() {
509                if peak.0 >= end_element_pos {
510                    // Found the last tree to contain an element in the range
511                    end_tree_with_element = Some(peak);
512                    break;
513                }
514            }
515        } else {
516            // Tree is outside the range, its peak is thus required.
517            positions.push(peak.0);
518        }
519    }
520
521    // We checked above that all range elements are in this MMR, so some mountain must contain
522    // the first and last elements in the range.
523    let (start_tree_peak, start_tree_height) =
524        start_tree_with_element.expect("start_tree_with_element is Some");
525    let (end_tree_peak, end_tree_height) =
526        end_tree_with_element.expect("end_tree_with_element is Some");
527
528    // Include the positions of any left-siblings of each node on the path from peak to
529    // leftmost-leaf, and right-siblings for the path from peak to rightmost-leaf. These are
530    // added in order of decreasing parent position.
531    let left_path_iter = PathIterator::new(start_element_pos, start_tree_peak, start_tree_height);
532
533    let mut siblings = Vec::new();
534    if start_element_pos == end_element_pos {
535        // For the (common) case of a single element range, the right and left path are the
536        // same so no need to process each independently.
537        siblings.extend(left_path_iter);
538    } else {
539        let right_path_iter = PathIterator::new(end_element_pos, end_tree_peak, end_tree_height);
540        // filter the right path for right siblings only
541        siblings.extend(right_path_iter.filter(|(parent_pos, pos)| *parent_pos == *pos + 1));
542        // filter the left path for left siblings only
543        siblings.extend(left_path_iter.filter(|(parent_pos, pos)| *parent_pos != *pos + 1));
544
545        // If the range spans more than one tree, then the digests must already be in the correct
546        // order. Otherwise, we enforce the desired order through sorting.
547        if start_tree_peak == end_tree_peak {
548            siblings.sort_by_key(|a| Reverse(a.0));
549        }
550    }
551    positions.extend(siblings.into_iter().map(|(_, pos)| pos));
552
553    Ok(positions)
554}
555
556/// Returns the positions of the minimal set of nodes whose digests are required to prove the
557/// inclusion of the elements at the specified `locations`.
558///
559/// The order of positions does not affect the output (sorted internally).
560///
561/// # Errors
562///
563/// Returns [Error::InvalidSize] if `size` is not a valid MMR size.
564/// Returns [Error::Empty] if locations is empty.
565/// Returns [Error::LocationOverflow] if any location in `locations` > [crate::mmr::MAX_LOCATION].
566/// Returns [Error::RangeOutOfBounds] if any location is out of bounds for the given `size`.
567#[cfg(any(feature = "std", test))]
568pub(crate) fn nodes_required_for_multi_proof(
569    leaves: Location,
570    locations: &[Location],
571) -> Result<BTreeSet<Position>, Error> {
572    // Collect all required node positions
573    //
574    // TODO(#1472): Optimize this loop
575    if locations.is_empty() {
576        return Err(Error::Empty);
577    }
578    locations.iter().try_fold(BTreeSet::new(), |mut acc, loc| {
579        if !loc.is_valid() {
580            return Err(Error::LocationOverflow(*loc));
581        }
582        // `loc` is valid so it won't overflow from +1
583        let positions = nodes_required_for_range_proof(leaves, *loc..*loc + 1)?;
584        acc.extend(positions);
585
586        Ok(acc)
587    })
588}
589
590/// Information about the current range of nodes being traversed.
591struct RangeInfo {
592    pos: Position,           // current node position in the tree
593    two_h: u64,              // 2^height of the current node
594    leftmost_pos: Position,  // leftmost leaf in the tree to be traversed
595    rightmost_pos: Position, // rightmost leaf in the tree to be traversed
596}
597
598fn peak_digest_from_range<'a, D, H, E, S>(
599    hasher: &mut H,
600    range_info: RangeInfo,
601    elements: &mut E,
602    sibling_digests: &mut S,
603    mut collected_digests: Option<&mut Vec<(Position, D)>>,
604) -> Result<D, ReconstructionError>
605where
606    D: Digest,
607    H: Hasher<Digest = D>,
608    E: Iterator<Item: AsRef<[u8]>>,
609    S: Iterator<Item = &'a D>,
610{
611    assert_ne!(range_info.two_h, 0);
612    if range_info.two_h == 1 {
613        match elements.next() {
614            Some(element) => return Ok(hasher.leaf_digest(range_info.pos, element.as_ref())),
615            None => return Err(ReconstructionError::MissingDigests),
616        }
617    }
618
619    let mut left_digest: Option<D> = None;
620    let mut right_digest: Option<D> = None;
621
622    let left_pos = range_info.pos - range_info.two_h;
623    let right_pos = left_pos + range_info.two_h - 1;
624    if left_pos >= range_info.leftmost_pos {
625        // Descend left
626        let digest = peak_digest_from_range(
627            hasher,
628            RangeInfo {
629                pos: left_pos,
630                two_h: range_info.two_h >> 1,
631                leftmost_pos: range_info.leftmost_pos,
632                rightmost_pos: range_info.rightmost_pos,
633            },
634            elements,
635            sibling_digests,
636            collected_digests.as_deref_mut(),
637        )?;
638        left_digest = Some(digest);
639    }
640    if left_pos < range_info.rightmost_pos {
641        // Descend right
642        let digest = peak_digest_from_range(
643            hasher,
644            RangeInfo {
645                pos: right_pos,
646                two_h: range_info.two_h >> 1,
647                leftmost_pos: range_info.leftmost_pos,
648                rightmost_pos: range_info.rightmost_pos,
649            },
650            elements,
651            sibling_digests,
652            collected_digests.as_deref_mut(),
653        )?;
654        right_digest = Some(digest);
655    }
656
657    if left_digest.is_none() {
658        match sibling_digests.next() {
659            Some(hash) => left_digest = Some(*hash),
660            None => return Err(ReconstructionError::MissingDigests),
661        }
662    }
663    if right_digest.is_none() {
664        match sibling_digests.next() {
665            Some(hash) => right_digest = Some(*hash),
666            None => return Err(ReconstructionError::MissingDigests),
667        }
668    }
669
670    if let Some(ref mut collected_digests) = collected_digests {
671        collected_digests.push((
672            left_pos,
673            left_digest.expect("left_digest guaranteed to be Some after checks above"),
674        ));
675        collected_digests.push((
676            right_pos,
677            right_digest.expect("right_digest guaranteed to be Some after checks above"),
678        ));
679    }
680
681    Ok(hasher.node_digest(
682        range_info.pos,
683        &left_digest.expect("left_digest guaranteed to be Some after checks above"),
684        &right_digest.expect("right_digest guaranteed to be Some after checks above"),
685    ))
686}
687
688#[cfg(test)]
689mod tests {
690    use super::*;
691    use crate::mmr::{
692        hasher::Standard,
693        location::LocationRangeExt as _,
694        mem::{CleanMmr, DirtyMmr},
695        MAX_LOCATION,
696    };
697    use commonware_codec::{Decode, Encode};
698    use commonware_cryptography::{sha256::Digest, Hasher, Sha256};
699    use commonware_macros::test_traced;
700
701    fn test_digest(v: u8) -> Digest {
702        Sha256::hash(&[v])
703    }
704
705    #[test]
706    fn test_proving_proof() {
707        // Test that an empty proof authenticates an empty MMR.
708        let mut hasher: Standard<Sha256> = Standard::new();
709        let mmr = CleanMmr::new(&mut hasher);
710        let root = mmr.root();
711        let proof = Proof::default();
712        assert!(proof.verify_range_inclusion(
713            &mut hasher,
714            &[] as &[Digest],
715            Location::new_unchecked(0),
716            root
717        ));
718
719        // Any starting position other than 0 should fail to verify.
720        assert!(!proof.verify_range_inclusion(
721            &mut hasher,
722            &[] as &[Digest],
723            Location::new_unchecked(1),
724            root
725        ));
726
727        // Invalid root should fail to verify.
728        let test_digest = test_digest(0);
729        assert!(!proof.verify_range_inclusion(
730            &mut hasher,
731            &[] as &[Digest],
732            Location::new_unchecked(0),
733            &test_digest
734        ));
735
736        // Non-empty elements list should fail to verify.
737        assert!(!proof.verify_range_inclusion(
738            &mut hasher,
739            &[test_digest],
740            Location::new_unchecked(0),
741            root
742        ));
743    }
744
745    #[test]
746    fn test_proving_verify_element() {
747        // create an 11 element MMR over which we'll test single-element inclusion proofs
748        let element = Digest::from(*b"01234567012345670123456701234567");
749        let mut hasher: Standard<Sha256> = Standard::new();
750        let mut mmr = DirtyMmr::new();
751        for _ in 0..11 {
752            mmr.add(&mut hasher, &element);
753        }
754        let mmr = mmr.merkleize(&mut hasher, None);
755        let root = mmr.root();
756
757        // confirm the proof of inclusion for each leaf successfully verifies
758        for leaf in 0u64..11 {
759            let leaf = Location::new_unchecked(leaf);
760            let proof: Proof<Digest> = mmr.proof(leaf).unwrap();
761            assert!(
762                proof.verify_element_inclusion(&mut hasher, &element, leaf, root),
763                "valid proof should verify successfully"
764            );
765        }
766
767        // Create a valid proof, then confirm various mangling of the proof or proof args results in
768        // verification failure.
769        const LEAF: Location = Location::new_unchecked(10);
770        let proof = mmr.proof(LEAF).unwrap();
771        assert!(
772            proof.verify_element_inclusion(&mut hasher, &element, LEAF, root),
773            "proof verification should be successful"
774        );
775        assert!(
776            !proof.verify_element_inclusion(&mut hasher, &element, LEAF + 1, root),
777            "proof verification should fail with incorrect element position"
778        );
779        assert!(
780            !proof.verify_element_inclusion(&mut hasher, &element, LEAF - 1, root),
781            "proof verification should fail with incorrect element position 2"
782        );
783        assert!(
784            !proof.verify_element_inclusion(&mut hasher, &test_digest(0), LEAF, root),
785            "proof verification should fail with mangled element"
786        );
787        let root2 = test_digest(0);
788        assert!(
789            !proof.verify_element_inclusion(&mut hasher, &element, LEAF, &root2),
790            "proof verification should fail with mangled root"
791        );
792        let mut proof2 = proof.clone();
793        proof2.digests[0] = test_digest(0);
794        assert!(
795            !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, root),
796            "proof verification should fail with mangled proof hash"
797        );
798        proof2 = proof.clone();
799        proof2.leaves = Location::new_unchecked(10);
800        assert!(
801            !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, root),
802            "proof verification should fail with incorrect leaves"
803        );
804        proof2 = proof.clone();
805        proof2.digests.push(test_digest(0));
806        assert!(
807            !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, root),
808            "proof verification should fail with extra hash"
809        );
810        proof2 = proof.clone();
811        while !proof2.digests.is_empty() {
812            proof2.digests.pop();
813            assert!(
814                !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, root),
815                "proof verification should fail with missing digests"
816            );
817        }
818        proof2 = proof.clone();
819        proof2.digests.clear();
820        const PEAK_COUNT: usize = 3;
821        proof2
822            .digests
823            .extend(proof.digests[0..PEAK_COUNT - 1].iter().cloned());
824        // sneak in an extra hash that won't be used in the computation and make sure it's
825        // detected
826        proof2.digests.push(test_digest(0));
827        proof2
828            .digests
829            .extend(proof.digests[PEAK_COUNT - 1..].iter().cloned());
830        assert!(
831            !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, root),
832            "proof verification should fail with extra hash even if it's unused by the computation"
833        );
834    }
835
836    #[test]
837    fn test_proving_verify_range() {
838        // create a new MMR and add a non-trivial amount (49) of elements
839        let mut hasher: Standard<Sha256> = Standard::new();
840        let mut mmr = DirtyMmr::new();
841        let mut elements = Vec::new();
842        for i in 0..49 {
843            elements.push(test_digest(i));
844            mmr.add(&mut hasher, elements.last().unwrap());
845        }
846        let mmr = mmr.merkleize(&mut hasher, None);
847        // test range proofs over all possible ranges of at least 2 elements
848        let root = mmr.root();
849
850        for i in 0..elements.len() {
851            for j in i + 1..elements.len() {
852                let range = Location::new_unchecked(i as u64)..Location::new_unchecked(j as u64);
853                let range_proof = mmr.range_proof(range.clone()).unwrap();
854                assert!(
855                    range_proof.verify_range_inclusion(
856                        &mut hasher,
857                        &elements[range.to_usize_range()],
858                        range.start,
859                        root,
860                    ),
861                    "valid range proof should verify successfully {i}:{j}",
862                );
863            }
864        }
865
866        // Create a proof over a range of elements, confirm it verifies successfully, then mangle
867        // the proof & proof input in various ways, confirming verification fails.
868        let range = Location::new_unchecked(33)..Location::new_unchecked(40);
869        let range_proof = mmr.range_proof(range.clone()).unwrap();
870        let valid_elements = &elements[range.to_usize_range()];
871        assert!(
872            range_proof.verify_range_inclusion(&mut hasher, valid_elements, range.start, root),
873            "valid range proof should verify successfully"
874        );
875        // Remove digests from the proof until it's empty, confirming proof verification fails for
876        // each.
877        let mut invalid_proof = range_proof.clone();
878        for _i in 0..range_proof.digests.len() {
879            invalid_proof.digests.remove(0);
880            assert!(
881                !invalid_proof.verify_range_inclusion(
882                    &mut hasher,
883                    valid_elements,
884                    range.start,
885                    root,
886                ),
887                "range proof with removed elements should fail"
888            );
889        }
890        // Confirm proof verification fails when providing an element range different than the one
891        // used to generate the proof.
892        for i in 0..elements.len() {
893            for j in i + 1..elements.len() {
894                if Location::from(i) == range.start && Location::from(j) == range.end {
895                    // skip the valid range
896                    continue;
897                }
898                assert!(
899                    !range_proof.verify_range_inclusion(
900                        &mut hasher,
901                        &elements[i..j],
902                        range.start,
903                        root,
904                    ),
905                    "range proof with invalid element range should fail {i}:{j}",
906                );
907            }
908        }
909        // Confirm proof fails to verify with an invalid root.
910        let invalid_root = test_digest(1);
911        assert!(
912            !range_proof.verify_range_inclusion(
913                &mut hasher,
914                valid_elements,
915                range.start,
916                &invalid_root,
917            ),
918            "range proof with invalid root should fail"
919        );
920        // Mangle each element of the proof and confirm it fails to verify.
921        for i in 0..range_proof.digests.len() {
922            let mut invalid_proof = range_proof.clone();
923            invalid_proof.digests[i] = test_digest(0);
924
925            assert!(
926                !invalid_proof.verify_range_inclusion(
927                    &mut hasher,
928                    valid_elements,
929                    range.start,
930                    root,
931                ),
932                "mangled range proof should fail verification"
933            );
934        }
935        // Inserting elements into the proof should also cause it to fail (malleability check)
936        for i in 0..range_proof.digests.len() {
937            let mut invalid_proof = range_proof.clone();
938            invalid_proof.digests.insert(i, test_digest(0));
939            assert!(
940                !invalid_proof.verify_range_inclusion(
941                    &mut hasher,
942                    valid_elements,
943                    range.start,
944                    root,
945                ),
946                "mangled range proof should fail verification. inserted element at: {i}",
947            );
948        }
949        // Bad start_loc should cause verification to fail.
950        for loc in 0..elements.len() {
951            let loc = Location::new_unchecked(loc as u64);
952            if loc == range.start {
953                continue;
954            }
955            assert!(
956                !range_proof.verify_range_inclusion(&mut hasher, valid_elements, loc, root),
957                "bad start_loc should fail verification {loc}",
958            );
959        }
960    }
961
962    #[test_traced]
963    fn test_proving_retained_nodes_provable_after_pruning() {
964        // create a new MMR and add a non-trivial amount (49) of elements
965        let mut hasher: Standard<Sha256> = Standard::new();
966        let mut mmr = DirtyMmr::new();
967        let mut elements = Vec::new();
968        for i in 0..49 {
969            elements.push(test_digest(i));
970            mmr.add(&mut hasher, elements.last().unwrap());
971        }
972        let mut mmr = mmr.merkleize(&mut hasher, None);
973
974        // Confirm we can successfully prove all retained elements in the MMR after pruning.
975        let root = *mmr.root();
976        for i in 1..*mmr.size() {
977            mmr.prune_to_pos(Position::new(i));
978            let pruned_root = mmr.root();
979            assert_eq!(root, *pruned_root);
980            for loc in 0..elements.len() {
981                let loc = Location::new_unchecked(loc as u64);
982                let proof = mmr.proof(loc);
983                if Position::try_from(loc).unwrap() < Position::new(i) {
984                    continue;
985                }
986                assert!(proof.is_ok());
987                assert!(proof.unwrap().verify_element_inclusion(
988                    &mut hasher,
989                    &elements[*loc as usize],
990                    loc,
991                    &root
992                ));
993            }
994        }
995    }
996
997    #[test]
998    fn test_proving_ranges_provable_after_pruning() {
999        // create a new MMR and add a non-trivial amount (49) of elements
1000        let mut hasher: Standard<Sha256> = Standard::new();
1001        let mut mmr = DirtyMmr::new();
1002        let mut elements = Vec::new();
1003        for i in 0..49 {
1004            elements.push(test_digest(i));
1005            mmr.add(&mut hasher, elements.last().unwrap());
1006        }
1007        let mut mmr = mmr.merkleize(&mut hasher, None);
1008
1009        // prune up to the first peak
1010        const PRUNE_POS: Position = Position::new(62);
1011        mmr.prune_to_pos(PRUNE_POS);
1012        assert_eq!(mmr.bounds().start, PRUNE_POS);
1013
1014        // Test range proofs over all possible ranges of at least 2 elements
1015        let root = mmr.root();
1016        for i in 0..elements.len() - 1 {
1017            if Position::try_from(Location::new_unchecked(i as u64)).unwrap() < PRUNE_POS {
1018                continue;
1019            }
1020            for j in (i + 2)..elements.len() {
1021                let range = Location::new_unchecked(i as u64)..Location::new_unchecked(j as u64);
1022                let range_proof = mmr.range_proof(range.clone()).unwrap();
1023                assert!(
1024                    range_proof.verify_range_inclusion(
1025                        &mut hasher,
1026                        &elements[range.to_usize_range()],
1027                        range.start,
1028                        root,
1029                    ),
1030                    "valid range proof over remaining elements should verify successfully",
1031                );
1032            }
1033        }
1034
1035        // Add a few more nodes, prune again, and test again to make sure repeated pruning doesn't
1036        // break proof verification.
1037        let mut mmr = mmr.into_dirty();
1038        for i in 0..37 {
1039            elements.push(test_digest(i));
1040            mmr.add(&mut hasher, elements.last().unwrap());
1041        }
1042        let mut mmr = mmr.merkleize(&mut hasher, None);
1043        mmr.prune_to_pos(Position::new(130)); // a bit after the new highest peak
1044        assert_eq!(mmr.bounds().start, 130);
1045
1046        let updated_root = mmr.root();
1047        let range = Location::new_unchecked(elements.len() as u64 - 10)
1048            ..Location::new_unchecked(elements.len() as u64);
1049        let range_proof = mmr.range_proof(range.clone()).unwrap();
1050        assert!(
1051                range_proof.verify_range_inclusion(
1052                    &mut hasher,
1053                    &elements[range.to_usize_range()],
1054                    range.start,
1055                    updated_root,
1056                ),
1057                "valid range proof over remaining elements after 2 pruning rounds should verify successfully",
1058            );
1059    }
1060
1061    #[test]
1062    fn test_proving_proof_serialization() {
1063        // create a new MMR and add a non-trivial amount of elements
1064        let mut hasher: Standard<Sha256> = Standard::new();
1065        let mut mmr = DirtyMmr::new();
1066        let mut elements = Vec::new();
1067        for i in 0..25 {
1068            elements.push(test_digest(i));
1069            mmr.add(&mut hasher, elements.last().unwrap());
1070        }
1071        let mmr = mmr.merkleize(&mut hasher, None);
1072
1073        // Generate proofs over all possible ranges of elements and confirm each
1074        // serializes=>deserializes correctly.
1075        for i in 0..elements.len() {
1076            for j in i + 1..elements.len() {
1077                let range = Location::new_unchecked(i as u64)..Location::new_unchecked(j as u64);
1078                let proof = mmr.range_proof(range).unwrap();
1079
1080                let expected_size = proof.encode_size();
1081                let serialized_proof = proof.encode();
1082                assert_eq!(
1083                    serialized_proof.len(),
1084                    expected_size,
1085                    "serialized proof should have expected size"
1086                );
1087                // max_items is the number of elements in the range
1088                let max_items = j - i;
1089                let deserialized_proof = Proof::decode_cfg(serialized_proof, &max_items).unwrap();
1090                assert_eq!(
1091                    proof, deserialized_proof,
1092                    "deserialized proof should match source proof"
1093                );
1094
1095                // Remove one byte from the end of the serialized
1096                // proof and confirm it fails to deserialize.
1097                let serialized_proof = proof.encode();
1098                let serialized_proof = serialized_proof.slice(0..serialized_proof.len() - 1);
1099                assert!(
1100                    Proof::<Digest>::decode_cfg(serialized_proof, &max_items).is_err(),
1101                    "proof should not deserialize with truncated data"
1102                );
1103
1104                // Add 1 byte of extra data to the end of the serialized
1105                // proof and confirm it fails to deserialize.
1106                let mut serialized_proof = proof.encode_mut();
1107                serialized_proof.extend_from_slice(&[0; 10]);
1108                let serialized_proof = serialized_proof;
1109
1110                assert!(
1111                    Proof::<Digest>::decode_cfg(serialized_proof, &max_items).is_err(),
1112                    "proof should not deserialize with extra data"
1113                );
1114
1115                // Confirm deserialization fails when max_items is too small.
1116                let actual_digests = proof.digests.len();
1117                if actual_digests > 0 {
1118                    // Find the minimum max_items that would allow this many digests
1119                    let min_max_items = actual_digests.div_ceil(MAX_PROOF_DIGESTS_PER_ELEMENT);
1120                    // Using one less should fail
1121                    let too_small = min_max_items - 1;
1122                    let serialized_proof = proof.encode();
1123                    assert!(
1124                        Proof::<Digest>::decode_cfg(serialized_proof, &too_small).is_err(),
1125                        "proof should not deserialize with max_items too small"
1126                    );
1127                }
1128            }
1129        }
1130    }
1131
1132    #[test_traced]
1133    fn test_proving_extract_pinned_nodes() {
1134        // Test for every number of elements from 1 to 255
1135        for num_elements in 1u64..255 {
1136            // Build MMR with the specified number of elements
1137            let mut hasher: Standard<Sha256> = Standard::new();
1138            let mut mmr = DirtyMmr::new();
1139
1140            for i in 0..num_elements {
1141                let digest = test_digest(i as u8);
1142                mmr.add(&mut hasher, &digest);
1143            }
1144            let mmr = mmr.merkleize(&mut hasher, None);
1145
1146            // Test pruning to each leaf.
1147            for leaf in 0..num_elements {
1148                // Test with a few different end positions to get good coverage
1149                let test_end_locs = if num_elements == 1 {
1150                    // Single element case
1151                    vec![leaf + 1]
1152                } else {
1153                    // Multi-element case: test with various end positions
1154                    let mut ends = vec![leaf + 1]; // Single element proof
1155
1156                    // Add a few more end positions if available
1157                    if leaf + 2 <= num_elements {
1158                        ends.push(leaf + 2);
1159                    }
1160                    if leaf + 3 <= num_elements {
1161                        ends.push(leaf + 3);
1162                    }
1163                    // Always test with the last element if different
1164                    if ends.last().unwrap() != &num_elements {
1165                        ends.push(num_elements);
1166                    }
1167
1168                    ends.into_iter()
1169                        .collect::<BTreeSet<_>>()
1170                        .into_iter()
1171                        .collect()
1172                };
1173
1174                for end_loc in test_end_locs {
1175                    // Generate proof for the range
1176                    let range = Location::new_unchecked(leaf)..Location::new_unchecked(end_loc);
1177                    let proof_result = mmr.range_proof(range.clone());
1178                    let proof = proof_result.unwrap();
1179
1180                    // Extract pinned nodes
1181                    let extract_result = proof.extract_pinned_nodes(range.clone());
1182                    assert!(
1183                            extract_result.is_ok(),
1184                            "Failed to extract pinned nodes for {num_elements} elements, boundary={leaf}, range={}..{}", range.start, range.end
1185                        );
1186
1187                    let pinned_nodes = extract_result.unwrap();
1188                    let leaf_loc = Location::new_unchecked(leaf);
1189                    let leaf_pos = Position::try_from(leaf_loc).unwrap();
1190                    let expected_pinned: Vec<Position> = nodes_to_pin(leaf_pos).collect();
1191
1192                    // Verify count matches expected
1193                    assert_eq!(
1194                            pinned_nodes.len(),
1195                            expected_pinned.len(),
1196                            "Pinned node count mismatch for {num_elements} elements, boundary={leaf}, range=[{leaf}, {end_loc}]"
1197                        );
1198
1199                    // Verify extracted hashes match actual node values
1200                    // The pinned_nodes Vec is in the same order as expected_pinned
1201                    for (i, &expected_pos) in expected_pinned.iter().enumerate() {
1202                        let extracted_hash = pinned_nodes[i];
1203                        let actual_hash = mmr.get_node(expected_pos).unwrap();
1204                        assert_eq!(
1205                                extracted_hash, actual_hash,
1206                                "Hash mismatch at position {expected_pos} (index {i}) for {num_elements} elements, boundary={leaf}, range=[{leaf}, {end_loc}]"
1207                            );
1208                    }
1209                }
1210            }
1211        }
1212    }
1213
1214    #[test]
1215    fn test_proving_extract_pinned_nodes_invalid_size() {
1216        // Test that extract_pinned_nodes returns an error for invalid MMR size
1217        let mut hasher: Standard<Sha256> = Standard::new();
1218        let mut mmr = DirtyMmr::new();
1219
1220        // Build MMR with 10 elements
1221        for i in 0..10 {
1222            let digest = test_digest(i);
1223            mmr.add(&mut hasher, &digest);
1224        }
1225        let mmr = mmr.merkleize(&mut hasher, None);
1226
1227        // Generate a valid proof
1228        let range = Location::new_unchecked(5)..Location::new_unchecked(8);
1229        let mut proof = mmr.range_proof(range.clone()).unwrap();
1230
1231        // Verify the proof works with valid size
1232        assert!(proof.extract_pinned_nodes(range.clone()).is_ok());
1233
1234        // Test with invalid location.
1235        proof.leaves = Location::new_unchecked(MAX_LOCATION + 2);
1236        let result = proof.extract_pinned_nodes(range);
1237        assert!(matches!(result, Err(Error::LocationOverflow(_))));
1238    }
1239
1240    #[test]
1241    fn test_proving_digests_from_range() {
1242        // create a new MMR and add a non-trivial amount (49) of elements
1243        let mut hasher: Standard<Sha256> = Standard::new();
1244        let mut mmr = DirtyMmr::new();
1245        let mut elements = Vec::new();
1246        let mut element_positions = Vec::new();
1247        for i in 0..49 {
1248            elements.push(test_digest(i));
1249            element_positions.push(mmr.add(&mut hasher, elements.last().unwrap()));
1250        }
1251        let mmr = mmr.merkleize(&mut hasher, None);
1252        let root = mmr.root();
1253
1254        // Test 1: compute_digests over the entire range should contain a digest for every node
1255        // in the tree.
1256        let proof = mmr
1257            .range_proof(Location::new_unchecked(0)..mmr.leaves())
1258            .unwrap();
1259        let mut node_digests = proof
1260            .verify_range_inclusion_and_extract_digests(
1261                &mut hasher,
1262                &elements,
1263                Location::new_unchecked(0),
1264                root,
1265            )
1266            .unwrap();
1267        assert_eq!(node_digests.len() as u64, mmr.size());
1268        node_digests.sort_by_key(|(pos, _)| *pos);
1269        for (i, (pos, d)) in node_digests.into_iter().enumerate() {
1270            assert_eq!(pos, i as u64);
1271            assert_eq!(mmr.get_node(pos).unwrap(), d);
1272        }
1273        // Make sure the wrong root fails.
1274        let wrong_root = elements[0]; // any other digest will do
1275        assert!(matches!(
1276            proof.verify_range_inclusion_and_extract_digests(
1277                &mut hasher,
1278                &elements,
1279                Location::new_unchecked(0),
1280                &wrong_root
1281            ),
1282            Err(Error::RootMismatch)
1283        ));
1284
1285        // Test 2: Single element range (first element)
1286        let range = Location::new_unchecked(0)..Location::new_unchecked(1);
1287        let single_proof = mmr.range_proof(range.clone()).unwrap();
1288        let range_start = range.start;
1289        let single_digests = single_proof
1290            .verify_range_inclusion_and_extract_digests(
1291                &mut hasher,
1292                &elements[range.to_usize_range()],
1293                range_start,
1294                root,
1295            )
1296            .unwrap();
1297        assert!(single_digests.len() > 1);
1298
1299        // Test 3: Single element range (middle element)
1300        let mid_idx = 24;
1301        let range = Location::new_unchecked(mid_idx)..Location::new_unchecked(mid_idx + 1);
1302        let range_start = range.start;
1303        let mid_proof = mmr.range_proof(range.clone()).unwrap();
1304        let mid_digests = mid_proof
1305            .verify_range_inclusion_and_extract_digests(
1306                &mut hasher,
1307                &elements[range.to_usize_range()],
1308                range_start,
1309                root,
1310            )
1311            .unwrap();
1312        assert!(mid_digests.len() > 1);
1313
1314        // Test 4: Single element range (last element)
1315        let last_idx = elements.len() as u64 - 1;
1316        let range = Location::new_unchecked(last_idx)..Location::new_unchecked(last_idx + 1);
1317        let range_start = range.start;
1318        let last_proof = mmr.range_proof(range.clone()).unwrap();
1319        let last_digests = last_proof
1320            .verify_range_inclusion_and_extract_digests(
1321                &mut hasher,
1322                &elements[range.to_usize_range()],
1323                range_start,
1324                root,
1325            )
1326            .unwrap();
1327        assert!(last_digests.len() > 1);
1328
1329        // Test 5: Small range at the beginning
1330        let range = Location::new_unchecked(0)..Location::new_unchecked(5);
1331        let range_start = range.start;
1332        let small_proof = mmr.range_proof(range.clone()).unwrap();
1333        let small_digests = small_proof
1334            .verify_range_inclusion_and_extract_digests(
1335                &mut hasher,
1336                &elements[range.to_usize_range()],
1337                range_start,
1338                root,
1339            )
1340            .unwrap();
1341        // Verify that we get digests for the range elements and their ancestors
1342        assert!(small_digests.len() > 5);
1343
1344        // Test 6: Medium range in the middle
1345        let range = Location::new_unchecked(10)..Location::new_unchecked(31);
1346        let range_start = range.start;
1347        let mid_range_proof = mmr.range_proof(range.clone()).unwrap();
1348        let mid_range_digests = mid_range_proof
1349            .verify_range_inclusion_and_extract_digests(
1350                &mut hasher,
1351                &elements[range.to_usize_range()],
1352                range_start,
1353                root,
1354            )
1355            .unwrap();
1356        let num_elements = range.end - range.start;
1357        assert!(mid_range_digests.len() as u64 > num_elements);
1358    }
1359
1360    #[test]
1361    fn test_proving_multi_proof_generation_and_verify() {
1362        // Create an MMR with multiple elements
1363        let mut hasher: Standard<Sha256> = Standard::new();
1364        let mut dirty_mmr = DirtyMmr::new();
1365        let mut elements = Vec::new();
1366
1367        for i in 0..20 {
1368            elements.push(test_digest(i));
1369            dirty_mmr.add(&mut hasher, &elements[i as usize]);
1370        }
1371        let mmr = dirty_mmr.merkleize(&mut hasher, None);
1372
1373        let root = mmr.root();
1374
1375        // Generate proof for non-contiguous single elements
1376        let locations = &[
1377            Location::new_unchecked(0),
1378            Location::new_unchecked(5),
1379            Location::new_unchecked(10),
1380        ];
1381        let nodes_for_multi_proof =
1382            nodes_required_for_multi_proof(mmr.leaves(), locations).expect("test locations valid");
1383        let digests = nodes_for_multi_proof
1384            .into_iter()
1385            .map(|pos| mmr.get_node(pos).unwrap())
1386            .collect();
1387        let multi_proof = Proof {
1388            leaves: mmr.leaves(),
1389            digests,
1390        };
1391
1392        assert_eq!(multi_proof.leaves, mmr.leaves());
1393
1394        // Verify the proof
1395        assert!(multi_proof.verify_multi_inclusion(
1396            &mut hasher,
1397            &[
1398                (elements[0], Location::new_unchecked(0)),
1399                (elements[5], Location::new_unchecked(5)),
1400                (elements[10], Location::new_unchecked(10)),
1401            ],
1402            root
1403        ));
1404
1405        // Verify in different order
1406        assert!(multi_proof.verify_multi_inclusion(
1407            &mut hasher,
1408            &[
1409                (elements[10], Location::new_unchecked(10)),
1410                (elements[5], Location::new_unchecked(5)),
1411                (elements[0], Location::new_unchecked(0)),
1412            ],
1413            root
1414        ));
1415
1416        // Verify with duplicate items
1417        assert!(multi_proof.verify_multi_inclusion(
1418            &mut hasher,
1419            &[
1420                (elements[0], Location::new_unchecked(0)),
1421                (elements[0], Location::new_unchecked(0)),
1422                (elements[10], Location::new_unchecked(10)),
1423                (elements[5], Location::new_unchecked(5)),
1424            ],
1425            root
1426        ));
1427
1428        // Verify mangling the location to something invalid should fail.
1429        let mut wrong_size_proof = multi_proof.clone();
1430        wrong_size_proof.leaves = Location::new_unchecked(MAX_LOCATION + 2);
1431        assert!(!wrong_size_proof.verify_multi_inclusion(
1432            &mut hasher,
1433            &[
1434                (elements[0], Location::new_unchecked(0)),
1435                (elements[5], Location::new_unchecked(5)),
1436                (elements[10], Location::new_unchecked(10)),
1437            ],
1438            root,
1439        ));
1440
1441        // Verify with wrong positions
1442        assert!(!multi_proof.verify_multi_inclusion(
1443            &mut hasher,
1444            &[
1445                (elements[0], Location::new_unchecked(1)),
1446                (elements[5], Location::new_unchecked(6)),
1447                (elements[10], Location::new_unchecked(11)),
1448            ],
1449            root,
1450        ));
1451
1452        // Verify with wrong elements
1453        let wrong_elements = [
1454            vec![255u8, 254u8, 253u8],
1455            vec![252u8, 251u8, 250u8],
1456            vec![249u8, 248u8, 247u8],
1457        ];
1458        let wrong_verification = multi_proof.verify_multi_inclusion(
1459            &mut hasher,
1460            &[
1461                (wrong_elements[0].as_slice(), Location::new_unchecked(0)),
1462                (wrong_elements[1].as_slice(), Location::new_unchecked(5)),
1463                (wrong_elements[2].as_slice(), Location::new_unchecked(10)),
1464            ],
1465            root,
1466        );
1467        assert!(!wrong_verification, "Should fail with wrong elements");
1468
1469        // Verify with out of range element
1470        let wrong_verification = multi_proof.verify_multi_inclusion(
1471            &mut hasher,
1472            &[
1473                (elements[0], Location::new_unchecked(0)),
1474                (elements[5], Location::new_unchecked(5)),
1475                (elements[10], Location::new_unchecked(1000)),
1476            ],
1477            root,
1478        );
1479        assert!(
1480            !wrong_verification,
1481            "Should fail with out of range elements"
1482        );
1483
1484        // Verify with wrong root should fail
1485        let wrong_root = test_digest(99);
1486        assert!(!multi_proof.verify_multi_inclusion(
1487            &mut hasher,
1488            &[
1489                (elements[0], Location::new_unchecked(0)),
1490                (elements[5], Location::new_unchecked(5)),
1491                (elements[10], Location::new_unchecked(10)),
1492            ],
1493            &wrong_root
1494        ));
1495
1496        // Empty multi-proof
1497        let mut hasher: Standard<Sha256> = Standard::new();
1498        let empty_mmr = CleanMmr::new(&mut hasher);
1499        let empty_root = empty_mmr.root();
1500        let empty_proof = Proof::default();
1501        assert!(empty_proof.verify_multi_inclusion(
1502            &mut hasher,
1503            &[] as &[(Digest, Location)],
1504            empty_root
1505        ));
1506    }
1507
1508    #[test]
1509    fn test_proving_multi_proof_deduplication() {
1510        let mut hasher: Standard<Sha256> = Standard::new();
1511        let mut dirty_mmr = DirtyMmr::new();
1512        let mut elements = Vec::new();
1513
1514        // Create an MMR with enough elements to have shared digests
1515        for i in 0..30 {
1516            elements.push(test_digest(i));
1517            dirty_mmr.add(&mut hasher, &elements[i as usize]);
1518        }
1519        let mmr = dirty_mmr.merkleize(&mut hasher, None);
1520
1521        // Get individual proofs that will share some digests (elements in same subtree)
1522        let proof1 = mmr.proof(Location::new_unchecked(0)).unwrap();
1523        let proof2 = mmr.proof(Location::new_unchecked(1)).unwrap();
1524        let total_digests_separate = proof1.digests.len() + proof2.digests.len();
1525
1526        // Generate multi-proof for the same positions
1527        let locations = &[Location::new_unchecked(0), Location::new_unchecked(1)];
1528        let multi_proof =
1529            nodes_required_for_multi_proof(mmr.leaves(), locations).expect("test locations valid");
1530        let digests = multi_proof
1531            .into_iter()
1532            .map(|pos| mmr.get_node(pos).unwrap())
1533            .collect();
1534        let multi_proof = Proof {
1535            leaves: mmr.leaves(),
1536            digests,
1537        };
1538
1539        // The combined proof should have fewer digests due to deduplication
1540        assert!(multi_proof.digests.len() < total_digests_separate);
1541
1542        // Verify it still works
1543        let root = mmr.root();
1544        assert!(multi_proof.verify_multi_inclusion(
1545            &mut hasher,
1546            &[
1547                (elements[0], Location::new_unchecked(0)),
1548                (elements[1], Location::new_unchecked(1))
1549            ],
1550            root
1551        ));
1552    }
1553
1554    #[test]
1555    fn test_max_location_is_provable() {
1556        // Test that the validation logic accepts MAX_LOCATION as a valid location
1557        // We use the maximum valid MMR size (2^63 - 1) which can hold up to 2^62 leaves
1558        let max_loc = Location::new_unchecked(MAX_LOCATION);
1559        let max_loc_plus_1 = Location::new_unchecked(MAX_LOCATION + 1);
1560
1561        // MAX_LOCATION should be accepted by the validation logic
1562        // (The range MAX_LOCATION..MAX_LOCATION+1 proves a single element at MAX_LOCATION)
1563        let result = nodes_required_for_range_proof(max_loc, max_loc - 1..max_loc);
1564
1565        // This should succeed - MAX_LOCATION is a valid location
1566        assert!(result.is_ok(), "Should be able to prove MAX_LOCATION");
1567
1568        // MAX_LOCATION + 1 should be rejected (exceeds MAX_LOCATION)
1569        let result_overflow =
1570            nodes_required_for_range_proof(max_loc_plus_1, max_loc..max_loc_plus_1);
1571        assert!(
1572            result_overflow.is_err(),
1573            "Should reject location > MAX_LOCATION"
1574        );
1575        matches!(result_overflow, Err(Error::LocationOverflow(_)));
1576    }
1577
1578    #[test]
1579    fn test_max_location_multi_proof() {
1580        // Test that multi_proof can handle MAX_LOCATION
1581        let max_loc = Location::new_unchecked(MAX_LOCATION);
1582
1583        // Should be able to generate multi-proof for MAX_LOCATION
1584        let result = nodes_required_for_multi_proof(max_loc, &[max_loc - 1]);
1585        assert!(
1586            result.is_ok(),
1587            "Should be able to generate multi-proof for MAX_LOCATION"
1588        );
1589
1590        // Should reject MAX_LOCATION + 1
1591        let invalid_loc = max_loc + 1;
1592        let result_overflow = nodes_required_for_multi_proof(invalid_loc, &[max_loc]);
1593        assert!(
1594            result_overflow.is_err(),
1595            "Should reject location > MAX_LOCATION in multi-proof"
1596        );
1597    }
1598
1599    #[test]
1600    fn test_max_proof_digests_per_element_sufficient() {
1601        // Verify that MAX_PROOF_DIGESTS_PER_ELEMENT (122) is sufficient for any single-element
1602        // proof in the largest valid MMR.
1603        //
1604        // MMR sizes follow: mmr_size(N) = 2*N - popcount(N) where N = leaf count.
1605        // The number of peaks equals popcount(N).
1606        //
1607        // To maximize peaks, we want N with maximum popcount. N = 2^62 - 1 has 62 one-bits:
1608        //   N = 0x3FFFFFFFFFFFFFFF = 2^0 + 2^1 + ... + 2^61
1609        //
1610        // This gives us 62 perfect binary trees with leaf counts 2^0, 2^1, ..., 2^61
1611        // and corresponding heights 0, 1, ..., 61.
1612        //
1613        // mmr_size(2^62 - 1) = 2*(2^62 - 1) - 62 = 2^63 - 2 - 62 = 2^63 - 64
1614        //
1615        // For a single-element proof in a tree of height h:
1616        //   - Path siblings from leaf to peak: h digests
1617        //   - Other peaks (not containing the element): (62 - 1) = 61 digests
1618        //   - Total: h + 61 digests
1619        //
1620        // Worst case: element in tallest tree (h = 61)
1621        //   - Path siblings: 61
1622        //   - Other peaks: 61
1623        //   - Total: 61 + 61 = 122 digests
1624
1625        const NUM_PEAKS: usize = 62;
1626        const MAX_TREE_HEIGHT: usize = 61;
1627        const EXPECTED_WORST_CASE: usize = MAX_TREE_HEIGHT + (NUM_PEAKS - 1);
1628
1629        let many_peaks_size = Position::new((1u64 << 63) - 64);
1630        assert!(
1631            many_peaks_size.is_mmr_size(),
1632            "Size {many_peaks_size} should be a valid MMR size",
1633        );
1634
1635        let peak_count = PeakIterator::new(many_peaks_size).count();
1636        assert_eq!(peak_count, NUM_PEAKS);
1637
1638        // Verify the peak heights are 61, 60, ..., 1, 0 (from left to right)
1639        let peaks: Vec<_> = PeakIterator::new(many_peaks_size).collect();
1640        for (i, &(_pos, height)) in peaks.iter().enumerate() {
1641            let expected_height = (NUM_PEAKS - 1 - i) as u32;
1642            assert_eq!(
1643                height, expected_height,
1644                "Peak {i} should have height {expected_height}, got {height}",
1645            );
1646        }
1647
1648        // Test location 0 (leftmost leaf, in tallest tree of height 61)
1649        // Expected: 61 path siblings + 61 other peaks = 122 digests
1650        let leaves = Location::try_from(many_peaks_size).unwrap();
1651        let loc = Location::new_unchecked(0);
1652        let positions = nodes_required_for_range_proof(leaves, loc..loc + 1)
1653            .expect("should compute positions for location 0");
1654
1655        assert_eq!(
1656            positions.len(),
1657            EXPECTED_WORST_CASE,
1658            "Location 0 proof should require exactly {EXPECTED_WORST_CASE} digests (61 path + 61 peaks)",
1659        );
1660
1661        // Test the rightmost leaf (in smallest tree of height 0, which is itself a peak)
1662        // Expected: 0 path siblings + 61 other peaks = 61 digests
1663        let last_leaf_loc = leaves - 1;
1664        let positions = nodes_required_for_range_proof(leaves, last_leaf_loc..last_leaf_loc + 1)
1665            .expect("should compute positions for last leaf");
1666
1667        let expected_last_leaf = NUM_PEAKS - 1;
1668        assert_eq!(
1669            positions.len(),
1670            expected_last_leaf,
1671            "Last leaf proof should require exactly {expected_last_leaf} digests (0 path + 61 peaks)",
1672        );
1673    }
1674
1675    #[test]
1676    fn test_max_proof_digests_per_element_is_maximum() {
1677        // For K peaks, the worst-case proof needs: (max_tree_height) + (K - 1) digests
1678        // With K peaks of heights K-1, K-2, ..., 0, this is (K-1) + (K-1) = 2*(K-1)
1679        //
1680        // To get K peaks, leaf count N must have exactly K bits set.
1681        // MMR size = 2*N - popcount(N) = 2*N - K
1682        //
1683        // For 63 peaks: N = 2^63 - 1 (63 bits set), size = 2*(2^63 - 1) - 63 = 2^64 - 65
1684        // This exceeds MAX_POSITION, so is_mmr_size() returns false.
1685
1686        let n_for_63_peaks = (1u128 << 63) - 1;
1687        let size_for_63_peaks = 2 * n_for_63_peaks - 63; // = 2^64 - 65
1688        assert!(
1689            size_for_63_peaks > *crate::mmr::MAX_POSITION as u128,
1690            "63 peaks requires size {size_for_63_peaks} > MAX_POSITION",
1691        );
1692
1693        let size_truncated = size_for_63_peaks as u64;
1694        assert!(
1695            !Position::new(size_truncated).is_mmr_size(),
1696            "Size for 63 peaks should fail is_mmr_size()"
1697        );
1698    }
1699
1700    #[cfg(feature = "arbitrary")]
1701    mod conformance {
1702        use super::*;
1703        use commonware_codec::conformance::CodecConformance;
1704        use commonware_cryptography::sha256::Digest as Sha256Digest;
1705
1706        commonware_conformance::conformance_tests! {
1707            CodecConformance<Proof<Sha256Digest>>,
1708        }
1709    }
1710}