Skip to main content

commonware_storage/mmr/
proof.rs

1//! Defines the inclusion [Proof] structure, and functions for verifying them against a root digest.
2//!
3//! Also provides lower-level functions for building verifiers against new or extended proof types.
4//! These lower level functions are kept outside of the [Proof] structure and not re-exported by the
5//! parent module.
6
7#[cfg(any(feature = "std", test))]
8use crate::mmr::iterator::nodes_to_pin;
9use crate::mmr::{
10    hasher::Hasher,
11    iterator::{PathIterator, PeakIterator},
12    Error, Location, Position,
13};
14use alloc::{
15    collections::{btree_map::BTreeMap, btree_set::BTreeSet},
16    vec,
17    vec::Vec,
18};
19use bytes::{Buf, BufMut};
20use commonware_codec::{EncodeSize, Read, ReadExt, ReadRangeExt, Write};
21use commonware_cryptography::Digest;
22use core::{cmp::Reverse, ops::Range};
23#[cfg(feature = "std")]
24use tracing::debug;
25
26/// The maximum number of digests in a proof per element being proven.
27///
28/// This accounts for the worst case proof size, in an MMR with 62 peaks. The
29/// left-most leaf in such a tree requires 122 digests, for 61 path siblings
30/// and 61 peak digests.
31pub const MAX_PROOF_DIGESTS_PER_ELEMENT: usize = 122;
32
33/// Errors that can occur when reconstructing a digest from a proof due to invalid input.
34#[derive(Error, Debug)]
35pub enum ReconstructionError {
36    #[error("missing digests in proof")]
37    MissingDigests,
38    #[error("extra digests in proof")]
39    ExtraDigests,
40    #[error("start location is out of bounds")]
41    InvalidStartLoc,
42    #[error("end location is out of bounds")]
43    InvalidEndLoc,
44    #[error("missing elements")]
45    MissingElements,
46    #[error("invalid size")]
47    InvalidSize,
48}
49
50/// Contains the information necessary for proving the inclusion of an element, or some range of
51/// elements, in the MMR from its root digest.
52///
53/// The `digests` vector contains:
54///
55/// 1: the digests of each peak corresponding to a mountain containing no elements from the element
56/// range being proven in decreasing order of height, followed by:
57///
58/// 2: the nodes in the remaining mountains necessary for reconstructing their peak digests from the
59/// elements within the range, ordered by the position of their parent.
60#[derive(Clone, Debug, Eq)]
61pub struct Proof<D: Digest> {
62    /// The total number of leaves in the MMR for MMR proofs, though other authenticated data
63    /// structures may override the meaning of this field. For example, the authenticated
64    /// [crate::AuthenticatedBitMap] stores the number of bits in the bitmap within this field.
65    pub leaves: Location,
66    /// The digests necessary for proving the inclusion of an element, or range of elements, in the
67    /// MMR.
68    pub digests: Vec<D>,
69}
70
71impl<D: Digest> PartialEq for Proof<D> {
72    fn eq(&self, other: &Self) -> bool {
73        self.leaves == other.leaves && self.digests == other.digests
74    }
75}
76
77impl<D: Digest> EncodeSize for Proof<D> {
78    fn encode_size(&self) -> usize {
79        self.leaves.encode_size() + self.digests.encode_size()
80    }
81}
82
83impl<D: Digest> Write for Proof<D> {
84    fn write(&self, buf: &mut impl BufMut) {
85        // Write the number of leaves in the MMR
86        self.leaves.write(buf);
87
88        // Write the digests
89        self.digests.write(buf);
90    }
91}
92
93impl<D: Digest> Read for Proof<D> {
94    /// The maximum number of items being proven.
95    ///
96    /// The upper bound on digests is derived as `max_items * MAX_PROOF_DIGESTS_PER_ELEMENT`.
97    type Cfg = usize;
98
99    fn read_cfg(
100        buf: &mut impl Buf,
101        max_items: &Self::Cfg,
102    ) -> Result<Self, commonware_codec::Error> {
103        // Read the number of nodes in the MMR
104        let leaves = Location::read(buf)?;
105
106        // Read the digests
107        let max_digests = max_items.saturating_mul(MAX_PROOF_DIGESTS_PER_ELEMENT);
108        let digests = Vec::<D>::read_range(buf, ..=max_digests)?;
109
110        Ok(Self { leaves, digests })
111    }
112}
113
114impl<D: Digest> Default for Proof<D> {
115    /// Create an empty proof. The empty proof will verify only against the root digest of an empty
116    /// (`leaves == 0`) MMR.
117    fn default() -> Self {
118        Self {
119            leaves: Location::new(0),
120            digests: vec![],
121        }
122    }
123}
124
125#[cfg(feature = "arbitrary")]
126impl<D: Digest> arbitrary::Arbitrary<'_> for Proof<D>
127where
128    D: for<'a> arbitrary::Arbitrary<'a>,
129{
130    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
131        Ok(Self {
132            leaves: u.arbitrary()?,
133            digests: u.arbitrary()?,
134        })
135    }
136}
137
138impl<D: Digest> Proof<D> {
139    /// Return true if this proof proves that `element` appears at location `loc` within the MMR
140    /// with root digest `root`.
141    pub fn verify_element_inclusion<H>(
142        &self,
143        hasher: &mut H,
144        element: &[u8],
145        loc: Location,
146        root: &D,
147    ) -> bool
148    where
149        H: Hasher<Digest = D>,
150    {
151        self.verify_range_inclusion(hasher, &[element], loc, root)
152    }
153
154    /// Return true if this proof proves that the `elements` appear consecutively starting at
155    /// position `start_loc` within the MMR with root digest `root`. A malformed proof will return
156    /// false.
157    pub fn verify_range_inclusion<H, E>(
158        &self,
159        hasher: &mut H,
160        elements: &[E],
161        start_loc: Location,
162        root: &D,
163    ) -> bool
164    where
165        H: Hasher<Digest = D>,
166        E: AsRef<[u8]>,
167    {
168        match self.reconstruct_root(hasher, elements, start_loc) {
169            Ok(reconstructed_root) => *root == reconstructed_root,
170            Err(_error) => {
171                #[cfg(feature = "std")]
172                tracing::debug!(error = ?_error, "invalid proof input");
173                false
174            }
175        }
176    }
177
178    /// Return true if this proof proves that the elements at the specified locations are included
179    /// in the MMR with the root digest `root`. A malformed proof will return false.
180    ///
181    /// The order of the elements does not affect the output.
182    pub fn verify_multi_inclusion<H, E>(
183        &self,
184        hasher: &mut H,
185        elements: &[(E, Location)],
186        root: &D,
187    ) -> bool
188    where
189        H: Hasher<Digest = D>,
190        E: AsRef<[u8]>,
191    {
192        // Empty proof is valid for an empty MMR
193        if elements.is_empty() {
194            return self.leaves == Location::new(0)
195                && *root == hasher.root(Location::new(0), core::iter::empty());
196        }
197
198        // Single pass to collect all required positions with deduplication
199        let mut node_positions = BTreeSet::new();
200        let mut nodes_required = BTreeMap::new();
201
202        for (_, loc) in elements {
203            if !loc.is_valid() {
204                return false;
205            }
206            // `loc` is valid so it won't overflow from +1
207            let Ok(required) = nodes_required_for_range_proof(self.leaves, *loc..*loc + 1) else {
208                return false;
209            };
210            for req_pos in &required {
211                node_positions.insert(*req_pos);
212            }
213            nodes_required.insert(*loc, required);
214        }
215
216        // Verify we have the exact number of digests needed
217        if node_positions.len() != self.digests.len() {
218            return false;
219        }
220
221        // Build position to digest mapping once
222        let node_digests: BTreeMap<Position, D> = node_positions
223            .iter()
224            .zip(self.digests.iter())
225            .map(|(&pos, digest)| (pos, *digest))
226            .collect();
227
228        // Verify each element by reconstructing its path
229        for (element, loc) in elements {
230            // Get required positions for this element
231            let required = &nodes_required[loc];
232
233            // Build proof with required digests
234            let mut digests = Vec::with_capacity(required.len());
235            for req_pos in required {
236                // There must exist a digest for each required position (by
237                // construction of `node_digests`)
238                let digest = node_digests
239                    .get(req_pos)
240                    .expect("must exist by construction of node_digests");
241                digests.push(*digest);
242            }
243            let proof = Self {
244                leaves: self.leaves,
245                digests,
246            };
247
248            // Verify the proof
249            if !proof.verify_element_inclusion(hasher, element.as_ref(), *loc, root) {
250                return false;
251            }
252        }
253
254        true
255    }
256
257    // The functions below are lower level functions that are useful to building verification
258    // functions for new or extended proof types.
259
260    /// Computes the set of pinned nodes for the pruning boundary corresponding to the start of the
261    /// given range, returning the digest of each by extracting it from the proof.
262    ///
263    /// # Arguments
264    /// * `range` - The start and end locations of the proven range, where start is also used as the
265    ///   pruning boundary.
266    ///
267    /// # Returns
268    /// A Vec of digests for all nodes in `nodes_to_pin(pruning_boundary)`, in the same order as
269    /// returned by `nodes_to_pin` (decreasing height order)
270    ///
271    /// # Errors
272    ///
273    /// Returns [Error::InvalidSize] if the proof size is not a valid MMR size.
274    /// Returns [Error::LocationOverflow] if a location in `range` > [crate::mmr::MAX_LOCATION].
275    /// Returns [Error::InvalidProofLength] if the proof digest count doesn't match the required
276    /// positions count.
277    /// Returns [Error::MissingDigest] if a pinned node is not found in the proof.
278    #[cfg(any(feature = "std", test))]
279    pub(crate) fn extract_pinned_nodes(
280        &self,
281        range: std::ops::Range<Location>,
282    ) -> Result<Vec<D>, Error> {
283        // Get the positions of all nodes that should be pinned.
284        let start_pos = Position::try_from(range.start)?;
285        let pinned_positions: Vec<Position> = nodes_to_pin(start_pos).collect();
286
287        // Get all positions required for the proof.
288        let required_positions = nodes_required_for_range_proof(self.leaves, range)?;
289
290        if required_positions.len() != self.digests.len() {
291            #[cfg(feature = "std")]
292            debug!(
293                digests_len = self.digests.len(),
294                required_positions_len = required_positions.len(),
295                "Proof digest count doesn't match required positions",
296            );
297            return Err(Error::InvalidProofLength);
298        }
299
300        // Happy path: we can extract the pinned nodes directly from the proof.
301        // This happens when the `end_element_pos` is the last element in the MMR.
302        if pinned_positions
303            == required_positions[required_positions.len() - pinned_positions.len()..]
304        {
305            return Ok(self.digests[required_positions.len() - pinned_positions.len()..].to_vec());
306        }
307
308        // Create a mapping from position to digest.
309        let position_to_digest: BTreeMap<Position, D> = required_positions
310            .iter()
311            .zip(self.digests.iter())
312            .map(|(&pos, &digest)| (pos, digest))
313            .collect();
314
315        // Extract the pinned nodes in the same order as nodes_to_pin.
316        let mut result = Vec::with_capacity(pinned_positions.len());
317        for pinned_pos in pinned_positions {
318            let Some(&digest) = position_to_digest.get(&pinned_pos) else {
319                #[cfg(feature = "std")]
320                debug!(?pinned_pos, "Pinned node not found in proof");
321                return Err(Error::MissingDigest(pinned_pos));
322            };
323            result.push(digest);
324        }
325        Ok(result)
326    }
327
328    /// Reconstructs the root digest of the MMR from the digests in the proof and the provided range
329    /// of elements, returning the (position,digest) of every node whose digest was required by the
330    /// process (including those from the proof itself). Returns a [Error::InvalidProof] if the
331    /// input data is invalid and [Error::RootMismatch] if the root does not match the computed
332    /// root.
333    pub fn verify_range_inclusion_and_extract_digests<H, E>(
334        &self,
335        hasher: &mut H,
336        elements: &[E],
337        start_loc: Location,
338        root: &D,
339    ) -> Result<Vec<(Position, D)>, super::Error>
340    where
341        H: Hasher<Digest = D>,
342        E: AsRef<[u8]>,
343    {
344        let mut collected_digests = Vec::new();
345        let Ok(peak_digests) = self.reconstruct_peak_digests(
346            hasher,
347            elements,
348            start_loc,
349            Some(&mut collected_digests),
350        ) else {
351            return Err(Error::InvalidProof);
352        };
353
354        if hasher.root(self.leaves, peak_digests.iter()) != *root {
355            return Err(Error::RootMismatch);
356        }
357
358        Ok(collected_digests)
359    }
360
361    /// Reconstructs the root digest of the MMR from the digests in the proof and the provided range
362    /// of elements, or returns a [ReconstructionError] if the input data is invalid.
363    pub fn reconstruct_root<H, E>(
364        &self,
365        hasher: &mut H,
366        elements: &[E],
367        start_loc: Location,
368    ) -> Result<D, ReconstructionError>
369    where
370        H: Hasher<Digest = D>,
371        E: AsRef<[u8]>,
372    {
373        let peak_digests = self.reconstruct_peak_digests(hasher, elements, start_loc, None)?;
374
375        Ok(hasher.root(self.leaves, peak_digests.iter()))
376    }
377
378    /// Reconstruct the peak digests of the MMR that produced this proof, returning
379    /// [ReconstructionError] if the input data is invalid.  If collected_digests is Some, then all
380    /// node digests used in the process will be added to the wrapped vector.
381    pub fn reconstruct_peak_digests<H, E>(
382        &self,
383        hasher: &mut H,
384        elements: &[E],
385        start_loc: Location,
386        mut collected_digests: Option<&mut Vec<(Position, D)>>,
387    ) -> Result<Vec<D>, ReconstructionError>
388    where
389        H: Hasher<Digest = D>,
390        E: AsRef<[u8]>,
391    {
392        if elements.is_empty() {
393            if start_loc == 0 {
394                return Ok(vec![]);
395            }
396            return Err(ReconstructionError::MissingElements);
397        }
398        let size = Position::try_from(self.leaves).map_err(|_| ReconstructionError::InvalidSize)?;
399        let start_element_pos =
400            Position::try_from(start_loc).map_err(|_| ReconstructionError::InvalidStartLoc)?;
401        let end_element_pos = if elements.len() == 1 {
402            start_element_pos
403        } else {
404            let end_loc = start_loc
405                .checked_add(elements.len() as u64 - 1)
406                .ok_or(ReconstructionError::InvalidEndLoc)?;
407            Position::try_from(end_loc).map_err(|_| ReconstructionError::InvalidEndLoc)?
408        };
409        if end_element_pos >= size {
410            return Err(ReconstructionError::InvalidEndLoc);
411        }
412
413        let mut proof_digests_iter = self.digests.iter();
414        let mut siblings_iter = self.digests.iter().rev();
415
416        // Include peak digests only for trees that have no elements from the range, and keep track
417        // of the starting and ending trees of those that do contain some.
418        let mut peak_digests: Vec<D> = Vec::new();
419        let mut proof_digests_used = 0;
420        let mut elements_iter = elements.iter();
421        for (peak_pos, height) in PeakIterator::new(size) {
422            let leftmost_pos = peak_pos + 2 - (1 << (height + 1));
423            if peak_pos >= start_element_pos && leftmost_pos <= end_element_pos {
424                let hash = peak_digest_from_range(
425                    hasher,
426                    RangeInfo {
427                        pos: peak_pos,
428                        two_h: 1 << height,
429                        leftmost_pos: start_element_pos,
430                        rightmost_pos: end_element_pos,
431                    },
432                    &mut elements_iter,
433                    &mut siblings_iter,
434                    collected_digests.as_deref_mut(),
435                )?;
436                peak_digests.push(hash);
437                if let Some(ref mut collected_digests) = collected_digests {
438                    collected_digests.push((peak_pos, hash));
439                }
440            } else if let Some(hash) = proof_digests_iter.next() {
441                proof_digests_used += 1;
442                peak_digests.push(*hash);
443                if let Some(ref mut collected_digests) = collected_digests {
444                    collected_digests.push((peak_pos, *hash));
445                }
446            } else {
447                return Err(ReconstructionError::MissingDigests);
448            }
449        }
450
451        if elements_iter.next().is_some() {
452            return Err(ReconstructionError::ExtraDigests);
453        }
454        if let Some(next_sibling) = siblings_iter.next() {
455            if proof_digests_used == 0 || *next_sibling != self.digests[proof_digests_used - 1] {
456                return Err(ReconstructionError::ExtraDigests);
457            }
458        }
459
460        Ok(peak_digests)
461    }
462}
463
464/// Return the list of node positions required by the range proof for the specified range of
465/// elements.
466///
467/// # Errors
468///
469/// Returns [Error::InvalidSize] if `size` is not a valid MMR size.
470/// Returns [Error::Empty] if the range is empty.
471/// Returns [Error::LocationOverflow] if a location in `range` > [crate::mmr::MAX_LOCATION].
472/// Returns [Error::RangeOutOfBounds] if the last element position in `range` is out of bounds
473/// (>= `size`).
474pub(crate) fn nodes_required_for_range_proof(
475    leaves: Location,
476    range: Range<Location>,
477) -> Result<Vec<Position>, Error> {
478    if range.is_empty() {
479        return Err(Error::Empty);
480    }
481    let end_minus_one = range
482        .end
483        .checked_sub(1)
484        .expect("can't underflow because range is non-empty");
485    if end_minus_one >= leaves {
486        return Err(Error::RangeOutOfBounds(range.end));
487    }
488
489    // Find the mountains that contain no elements from the range. The peaks of these mountains
490    // are required to prove the range, so they are added to the result.
491    let mut start_tree_with_element: Option<(Position, u32)> = None;
492    let mut end_tree_with_element: Option<(Position, u32)> = None;
493    let mut positions = Vec::new();
494    let size = Position::try_from(leaves)?;
495    let start_element_pos = Position::try_from(range.start)?;
496    let end_element_pos = Position::try_from(end_minus_one)?;
497
498    let mut peak_iterator = PeakIterator::new(size);
499    while let Some(peak) = peak_iterator.next() {
500        if start_tree_with_element.is_none() && peak.0 >= start_element_pos {
501            // Found the first tree to contain an element in the range
502            start_tree_with_element = Some(peak);
503            if peak.0 >= end_element_pos {
504                // Start and end tree are the same
505                end_tree_with_element = Some(peak);
506                continue;
507            }
508            for peak in peak_iterator.by_ref() {
509                if peak.0 >= end_element_pos {
510                    // Found the last tree to contain an element in the range
511                    end_tree_with_element = Some(peak);
512                    break;
513                }
514            }
515        } else {
516            // Tree is outside the range, its peak is thus required.
517            positions.push(peak.0);
518        }
519    }
520
521    // We checked above that all range elements are in this MMR, so some mountain must contain
522    // the first and last elements in the range.
523    let (start_tree_peak, start_tree_height) =
524        start_tree_with_element.expect("start_tree_with_element is Some");
525    let (end_tree_peak, end_tree_height) =
526        end_tree_with_element.expect("end_tree_with_element is Some");
527
528    // Include the positions of any left-siblings of each node on the path from peak to
529    // leftmost-leaf, and right-siblings for the path from peak to rightmost-leaf. These are
530    // added in order of decreasing parent position.
531    let left_path_iter = PathIterator::new(start_element_pos, start_tree_peak, start_tree_height);
532
533    let mut siblings = Vec::new();
534    if start_element_pos == end_element_pos {
535        // For the (common) case of a single element range, the right and left path are the
536        // same so no need to process each independently.
537        siblings.extend(left_path_iter);
538    } else {
539        let right_path_iter = PathIterator::new(end_element_pos, end_tree_peak, end_tree_height);
540        // filter the right path for right siblings only
541        siblings.extend(right_path_iter.filter(|(parent_pos, pos)| *parent_pos == *pos + 1));
542        // filter the left path for left siblings only
543        siblings.extend(left_path_iter.filter(|(parent_pos, pos)| *parent_pos != *pos + 1));
544
545        // If the range spans more than one tree, then the digests must already be in the correct
546        // order. Otherwise, we enforce the desired order through sorting.
547        if start_tree_peak == end_tree_peak {
548            siblings.sort_by_key(|a| Reverse(a.0));
549        }
550    }
551    positions.extend(siblings.into_iter().map(|(_, pos)| pos));
552
553    Ok(positions)
554}
555
556/// Returns the positions of the minimal set of nodes whose digests are required to prove the
557/// inclusion of the elements at the specified `locations`.
558///
559/// The order of positions does not affect the output (sorted internally).
560///
561/// # Errors
562///
563/// Returns [Error::InvalidSize] if `size` is not a valid MMR size.
564/// Returns [Error::Empty] if locations is empty.
565/// Returns [Error::LocationOverflow] if any location in `locations` > [crate::mmr::MAX_LOCATION].
566/// Returns [Error::RangeOutOfBounds] if any location is out of bounds for the given `size`.
567#[cfg(any(feature = "std", test))]
568pub(crate) fn nodes_required_for_multi_proof(
569    leaves: Location,
570    locations: &[Location],
571) -> Result<BTreeSet<Position>, Error> {
572    // Collect all required node positions
573    //
574    // TODO(#1472): Optimize this loop
575    if locations.is_empty() {
576        return Err(Error::Empty);
577    }
578    locations.iter().try_fold(BTreeSet::new(), |mut acc, loc| {
579        if !loc.is_valid() {
580            return Err(Error::LocationOverflow(*loc));
581        }
582        // `loc` is valid so it won't overflow from +1
583        let positions = nodes_required_for_range_proof(leaves, *loc..*loc + 1)?;
584        acc.extend(positions);
585
586        Ok(acc)
587    })
588}
589
590/// Information about the current range of nodes being traversed.
591struct RangeInfo {
592    pos: Position,           // current node position in the tree
593    two_h: u64,              // 2^height of the current node
594    leftmost_pos: Position,  // leftmost leaf in the tree to be traversed
595    rightmost_pos: Position, // rightmost leaf in the tree to be traversed
596}
597
598fn peak_digest_from_range<'a, D, H, E, S>(
599    hasher: &mut H,
600    range_info: RangeInfo,
601    elements: &mut E,
602    sibling_digests: &mut S,
603    mut collected_digests: Option<&mut Vec<(Position, D)>>,
604) -> Result<D, ReconstructionError>
605where
606    D: Digest,
607    H: Hasher<Digest = D>,
608    E: Iterator<Item: AsRef<[u8]>>,
609    S: Iterator<Item = &'a D>,
610{
611    assert_ne!(range_info.two_h, 0);
612    if range_info.two_h == 1 {
613        match elements.next() {
614            Some(element) => return Ok(hasher.leaf_digest(range_info.pos, element.as_ref())),
615            None => return Err(ReconstructionError::MissingDigests),
616        }
617    }
618
619    let mut left_digest: Option<D> = None;
620    let mut right_digest: Option<D> = None;
621
622    let left_pos = range_info.pos - range_info.two_h;
623    let right_pos = left_pos + range_info.two_h - 1;
624    if left_pos >= range_info.leftmost_pos {
625        // Descend left
626        let digest = peak_digest_from_range(
627            hasher,
628            RangeInfo {
629                pos: left_pos,
630                two_h: range_info.two_h >> 1,
631                leftmost_pos: range_info.leftmost_pos,
632                rightmost_pos: range_info.rightmost_pos,
633            },
634            elements,
635            sibling_digests,
636            collected_digests.as_deref_mut(),
637        )?;
638        left_digest = Some(digest);
639    }
640    if left_pos < range_info.rightmost_pos {
641        // Descend right
642        let digest = peak_digest_from_range(
643            hasher,
644            RangeInfo {
645                pos: right_pos,
646                two_h: range_info.two_h >> 1,
647                leftmost_pos: range_info.leftmost_pos,
648                rightmost_pos: range_info.rightmost_pos,
649            },
650            elements,
651            sibling_digests,
652            collected_digests.as_deref_mut(),
653        )?;
654        right_digest = Some(digest);
655    }
656
657    if left_digest.is_none() {
658        match sibling_digests.next() {
659            Some(hash) => left_digest = Some(*hash),
660            None => return Err(ReconstructionError::MissingDigests),
661        }
662    }
663    if right_digest.is_none() {
664        match sibling_digests.next() {
665            Some(hash) => right_digest = Some(*hash),
666            None => return Err(ReconstructionError::MissingDigests),
667        }
668    }
669
670    if let Some(ref mut collected_digests) = collected_digests {
671        collected_digests.push((
672            left_pos,
673            left_digest.expect("left_digest guaranteed to be Some after checks above"),
674        ));
675        collected_digests.push((
676            right_pos,
677            right_digest.expect("right_digest guaranteed to be Some after checks above"),
678        ));
679    }
680
681    Ok(hasher.node_digest(
682        range_info.pos,
683        &left_digest.expect("left_digest guaranteed to be Some after checks above"),
684        &right_digest.expect("right_digest guaranteed to be Some after checks above"),
685    ))
686}
687
688#[cfg(test)]
689mod tests {
690    use super::*;
691    use crate::mmr::{hasher::Standard, location::LocationRangeExt as _, mem::Mmr, MAX_LOCATION};
692    use commonware_codec::{Decode, Encode};
693    use commonware_cryptography::{sha256::Digest, Hasher, Sha256};
694    use commonware_macros::test_traced;
695
696    fn test_digest(v: u8) -> Digest {
697        Sha256::hash(&[v])
698    }
699
700    #[test]
701    fn test_proving_proof() {
702        // Test that an empty proof authenticates an empty MMR.
703        let mut hasher: Standard<Sha256> = Standard::new();
704        let mmr = Mmr::new(&mut hasher);
705        let root = mmr.root();
706        let proof = Proof::default();
707        assert!(proof.verify_range_inclusion(
708            &mut hasher,
709            &[] as &[Digest],
710            Location::new(0),
711            root
712        ));
713
714        // Any starting position other than 0 should fail to verify.
715        assert!(!proof.verify_range_inclusion(
716            &mut hasher,
717            &[] as &[Digest],
718            Location::new(1),
719            root
720        ));
721
722        // Invalid root should fail to verify.
723        let test_digest = test_digest(0);
724        assert!(!proof.verify_range_inclusion(
725            &mut hasher,
726            &[] as &[Digest],
727            Location::new(0),
728            &test_digest
729        ));
730
731        // Non-empty elements list should fail to verify.
732        assert!(!proof.verify_range_inclusion(&mut hasher, &[test_digest], Location::new(0), root));
733    }
734
735    #[test]
736    fn test_proving_verify_element() {
737        // create an 11 element MMR over which we'll test single-element inclusion proofs
738        let element = Digest::from(*b"01234567012345670123456701234567");
739        let mut hasher: Standard<Sha256> = Standard::new();
740        let mut mmr = Mmr::new(&mut hasher);
741        let changeset = {
742            let mut batch = mmr.new_batch();
743            for _ in 0..11 {
744                batch.add(&mut hasher, &element);
745            }
746            batch.merkleize(&mut hasher).finalize()
747        };
748        mmr.apply(changeset).unwrap();
749        let root = mmr.root();
750
751        // confirm the proof of inclusion for each leaf successfully verifies
752        for leaf in 0u64..11 {
753            let leaf = Location::new(leaf);
754            let proof: Proof<Digest> = mmr.proof(leaf).unwrap();
755            assert!(
756                proof.verify_element_inclusion(&mut hasher, &element, leaf, root),
757                "valid proof should verify successfully"
758            );
759        }
760
761        // Create a valid proof, then confirm various mangling of the proof or proof args results in
762        // verification failure.
763        const LEAF: Location = Location::new(10);
764        let proof = mmr.proof(LEAF).unwrap();
765        assert!(
766            proof.verify_element_inclusion(&mut hasher, &element, LEAF, root),
767            "proof verification should be successful"
768        );
769        assert!(
770            !proof.verify_element_inclusion(&mut hasher, &element, LEAF + 1, root),
771            "proof verification should fail with incorrect element position"
772        );
773        assert!(
774            !proof.verify_element_inclusion(&mut hasher, &element, LEAF - 1, root),
775            "proof verification should fail with incorrect element position 2"
776        );
777        assert!(
778            !proof.verify_element_inclusion(&mut hasher, &test_digest(0), LEAF, root),
779            "proof verification should fail with mangled element"
780        );
781        let root2 = test_digest(0);
782        assert!(
783            !proof.verify_element_inclusion(&mut hasher, &element, LEAF, &root2),
784            "proof verification should fail with mangled root"
785        );
786        let mut proof2 = proof.clone();
787        proof2.digests[0] = test_digest(0);
788        assert!(
789            !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, root),
790            "proof verification should fail with mangled proof hash"
791        );
792        proof2 = proof.clone();
793        proof2.leaves = Location::new(10);
794        assert!(
795            !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, root),
796            "proof verification should fail with incorrect leaves"
797        );
798        proof2 = proof.clone();
799        proof2.digests.push(test_digest(0));
800        assert!(
801            !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, root),
802            "proof verification should fail with extra hash"
803        );
804        proof2 = proof.clone();
805        while !proof2.digests.is_empty() {
806            proof2.digests.pop();
807            assert!(
808                !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, root),
809                "proof verification should fail with missing digests"
810            );
811        }
812        proof2 = proof.clone();
813        proof2.digests.clear();
814        const PEAK_COUNT: usize = 3;
815        proof2
816            .digests
817            .extend(proof.digests[0..PEAK_COUNT - 1].iter().cloned());
818        // sneak in an extra hash that won't be used in the computation and make sure it's
819        // detected
820        proof2.digests.push(test_digest(0));
821        proof2
822            .digests
823            .extend(proof.digests[PEAK_COUNT - 1..].iter().cloned());
824        assert!(
825            !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, root),
826            "proof verification should fail with extra hash even if it's unused by the computation"
827        );
828    }
829
830    #[test]
831    fn test_proving_verify_range() {
832        // create a new MMR and add a non-trivial amount (49) of elements
833        let mut hasher: Standard<Sha256> = Standard::new();
834        let mut mmr = Mmr::new(&mut hasher);
835        let elements: Vec<_> = (0..49).map(test_digest).collect();
836        let changeset = {
837            let mut batch = mmr.new_batch();
838            for element in &elements {
839                batch.add(&mut hasher, element);
840            }
841            batch.merkleize(&mut hasher).finalize()
842        };
843        mmr.apply(changeset).unwrap();
844        // test range proofs over all possible ranges of at least 2 elements
845        let root = mmr.root();
846
847        for i in 0..elements.len() {
848            for j in i + 1..elements.len() {
849                let range = Location::new(i as u64)..Location::new(j as u64);
850                let range_proof = mmr.range_proof(range.clone()).unwrap();
851                assert!(
852                    range_proof.verify_range_inclusion(
853                        &mut hasher,
854                        &elements[range.to_usize_range()],
855                        range.start,
856                        root,
857                    ),
858                    "valid range proof should verify successfully {i}:{j}",
859                );
860            }
861        }
862
863        // Create a proof over a range of elements, confirm it verifies successfully, then mangle
864        // the proof & proof input in various ways, confirming verification fails.
865        let range = Location::new(33)..Location::new(40);
866        let range_proof = mmr.range_proof(range.clone()).unwrap();
867        let valid_elements = &elements[range.to_usize_range()];
868        assert!(
869            range_proof.verify_range_inclusion(&mut hasher, valid_elements, range.start, root),
870            "valid range proof should verify successfully"
871        );
872        // Remove digests from the proof until it's empty, confirming proof verification fails for
873        // each.
874        let mut invalid_proof = range_proof.clone();
875        for _i in 0..range_proof.digests.len() {
876            invalid_proof.digests.remove(0);
877            assert!(
878                !invalid_proof.verify_range_inclusion(
879                    &mut hasher,
880                    valid_elements,
881                    range.start,
882                    root,
883                ),
884                "range proof with removed elements should fail"
885            );
886        }
887        // Confirm proof verification fails when providing an element range different than the one
888        // used to generate the proof.
889        for i in 0..elements.len() {
890            for j in i + 1..elements.len() {
891                if Location::from(i) == range.start && Location::from(j) == range.end {
892                    // skip the valid range
893                    continue;
894                }
895                assert!(
896                    !range_proof.verify_range_inclusion(
897                        &mut hasher,
898                        &elements[i..j],
899                        range.start,
900                        root,
901                    ),
902                    "range proof with invalid element range should fail {i}:{j}",
903                );
904            }
905        }
906        // Confirm proof fails to verify with an invalid root.
907        let invalid_root = test_digest(1);
908        assert!(
909            !range_proof.verify_range_inclusion(
910                &mut hasher,
911                valid_elements,
912                range.start,
913                &invalid_root,
914            ),
915            "range proof with invalid root should fail"
916        );
917        // Mangle each element of the proof and confirm it fails to verify.
918        for i in 0..range_proof.digests.len() {
919            let mut invalid_proof = range_proof.clone();
920            invalid_proof.digests[i] = test_digest(0);
921
922            assert!(
923                !invalid_proof.verify_range_inclusion(
924                    &mut hasher,
925                    valid_elements,
926                    range.start,
927                    root,
928                ),
929                "mangled range proof should fail verification"
930            );
931        }
932        // Inserting elements into the proof should also cause it to fail (malleability check)
933        for i in 0..range_proof.digests.len() {
934            let mut invalid_proof = range_proof.clone();
935            invalid_proof.digests.insert(i, test_digest(0));
936            assert!(
937                !invalid_proof.verify_range_inclusion(
938                    &mut hasher,
939                    valid_elements,
940                    range.start,
941                    root,
942                ),
943                "mangled range proof should fail verification. inserted element at: {i}",
944            );
945        }
946        // Bad start_loc should cause verification to fail.
947        for loc in 0..elements.len() {
948            let loc = Location::new(loc as u64);
949            if loc == range.start {
950                continue;
951            }
952            assert!(
953                !range_proof.verify_range_inclusion(&mut hasher, valid_elements, loc, root),
954                "bad start_loc should fail verification {loc}",
955            );
956        }
957    }
958
959    #[test_traced]
960    fn test_proving_retained_nodes_provable_after_pruning() {
961        // create a new MMR and add a non-trivial amount (49) of elements
962        let mut hasher: Standard<Sha256> = Standard::new();
963        let mut mmr = Mmr::new(&mut hasher);
964        let elements: Vec<_> = (0..49).map(test_digest).collect();
965        let changeset = {
966            let mut batch = mmr.new_batch();
967            for element in &elements {
968                batch.add(&mut hasher, element);
969            }
970            batch.merkleize(&mut hasher).finalize()
971        };
972        mmr.apply(changeset).unwrap();
973
974        // Confirm we can successfully prove all retained elements in the MMR after pruning.
975        let root = *mmr.root();
976        for prune_leaf in 1..*mmr.leaves() {
977            let prune_loc = Location::new(prune_leaf);
978            mmr.prune(prune_loc).unwrap();
979            let pruned_root = mmr.root();
980            assert_eq!(root, *pruned_root);
981            for loc in 0..elements.len() {
982                let loc = Location::new(loc as u64);
983                let proof = mmr.proof(loc);
984                if loc < prune_loc {
985                    continue;
986                }
987                assert!(proof.is_ok());
988                assert!(proof.unwrap().verify_element_inclusion(
989                    &mut hasher,
990                    &elements[*loc as usize],
991                    loc,
992                    &root
993                ));
994            }
995        }
996    }
997
998    #[test]
999    fn test_proving_ranges_provable_after_pruning() {
1000        // create a new MMR and add a non-trivial amount (49) of elements
1001        let mut hasher: Standard<Sha256> = Standard::new();
1002        let mut mmr = Mmr::new(&mut hasher);
1003        let mut elements: Vec<_> = (0..49).map(test_digest).collect();
1004        let changeset = {
1005            let mut batch = mmr.new_batch();
1006            for element in &elements {
1007                batch.add(&mut hasher, element);
1008            }
1009            batch.merkleize(&mut hasher).finalize()
1010        };
1011        mmr.apply(changeset).unwrap();
1012
1013        // prune up to the first peak
1014        const PRUNE_LOC: Location = Location::new(32);
1015        mmr.prune(PRUNE_LOC).unwrap();
1016        assert_eq!(mmr.bounds().start, PRUNE_LOC);
1017
1018        // Test range proofs over all possible ranges of at least 2 elements
1019        let root = mmr.root();
1020        for i in 0..elements.len() - 1 {
1021            if Location::new(i as u64) < PRUNE_LOC {
1022                continue;
1023            }
1024            for j in (i + 2)..elements.len() {
1025                let range = Location::new(i as u64)..Location::new(j as u64);
1026                let range_proof = mmr.range_proof(range.clone()).unwrap();
1027                assert!(
1028                    range_proof.verify_range_inclusion(
1029                        &mut hasher,
1030                        &elements[range.to_usize_range()],
1031                        range.start,
1032                        root,
1033                    ),
1034                    "valid range proof over remaining elements should verify successfully",
1035                );
1036            }
1037        }
1038
1039        // Add a few more nodes, prune again, and test again to make sure repeated pruning doesn't
1040        // break proof verification.
1041        let new_elements: Vec<_> = (0..37).map(test_digest).collect();
1042        let changeset = {
1043            let mut batch = mmr.new_batch();
1044            for element in &new_elements {
1045                batch.add(&mut hasher, element);
1046            }
1047            batch.merkleize(&mut hasher).finalize()
1048        };
1049        mmr.apply(changeset).unwrap();
1050        elements.extend(new_elements);
1051        mmr.prune(Location::new(66)).unwrap(); // a bit after the new highest peak
1052        assert_eq!(mmr.bounds().start, Location::new(66));
1053
1054        let updated_root = mmr.root();
1055        let range = Location::new(elements.len() as u64 - 10)..Location::new(elements.len() as u64);
1056        let range_proof = mmr.range_proof(range.clone()).unwrap();
1057        assert!(
1058                range_proof.verify_range_inclusion(
1059                    &mut hasher,
1060                    &elements[range.to_usize_range()],
1061                    range.start,
1062                    updated_root,
1063                ),
1064                "valid range proof over remaining elements after 2 pruning rounds should verify successfully",
1065            );
1066    }
1067
1068    #[test]
1069    fn test_proving_proof_serialization() {
1070        // create a new MMR and add a non-trivial amount of elements
1071        let mut hasher: Standard<Sha256> = Standard::new();
1072        let mut mmr = Mmr::new(&mut hasher);
1073        let elements: Vec<_> = (0..25).map(test_digest).collect();
1074        let changeset = {
1075            let mut batch = mmr.new_batch();
1076            for element in &elements {
1077                batch.add(&mut hasher, element);
1078            }
1079            batch.merkleize(&mut hasher).finalize()
1080        };
1081        mmr.apply(changeset).unwrap();
1082
1083        // Generate proofs over all possible ranges of elements and confirm each
1084        // serializes=>deserializes correctly.
1085        for i in 0..elements.len() {
1086            for j in i + 1..elements.len() {
1087                let range = Location::new(i as u64)..Location::new(j as u64);
1088                let proof = mmr.range_proof(range).unwrap();
1089
1090                let expected_size = proof.encode_size();
1091                let serialized_proof = proof.encode();
1092                assert_eq!(
1093                    serialized_proof.len(),
1094                    expected_size,
1095                    "serialized proof should have expected size"
1096                );
1097                // max_items is the number of elements in the range
1098                let max_items = j - i;
1099                let deserialized_proof = Proof::decode_cfg(serialized_proof, &max_items).unwrap();
1100                assert_eq!(
1101                    proof, deserialized_proof,
1102                    "deserialized proof should match source proof"
1103                );
1104
1105                // Remove one byte from the end of the serialized
1106                // proof and confirm it fails to deserialize.
1107                let serialized_proof = proof.encode();
1108                let serialized_proof = serialized_proof.slice(0..serialized_proof.len() - 1);
1109                assert!(
1110                    Proof::<Digest>::decode_cfg(serialized_proof, &max_items).is_err(),
1111                    "proof should not deserialize with truncated data"
1112                );
1113
1114                // Add 1 byte of extra data to the end of the serialized
1115                // proof and confirm it fails to deserialize.
1116                let mut serialized_proof = proof.encode_mut();
1117                serialized_proof.extend_from_slice(&[0; 10]);
1118                let serialized_proof = serialized_proof;
1119
1120                assert!(
1121                    Proof::<Digest>::decode_cfg(serialized_proof, &max_items).is_err(),
1122                    "proof should not deserialize with extra data"
1123                );
1124
1125                // Confirm deserialization fails when max_items is too small.
1126                let actual_digests = proof.digests.len();
1127                if actual_digests > 0 {
1128                    // Find the minimum max_items that would allow this many digests
1129                    let min_max_items = actual_digests.div_ceil(MAX_PROOF_DIGESTS_PER_ELEMENT);
1130                    // Using one less should fail
1131                    let too_small = min_max_items - 1;
1132                    let serialized_proof = proof.encode();
1133                    assert!(
1134                        Proof::<Digest>::decode_cfg(serialized_proof, &too_small).is_err(),
1135                        "proof should not deserialize with max_items too small"
1136                    );
1137                }
1138            }
1139        }
1140    }
1141
1142    #[test_traced]
1143    fn test_proving_extract_pinned_nodes() {
1144        // Test for every number of elements from 1 to 255
1145        for num_elements in 1u64..255 {
1146            // Build MMR with the specified number of elements
1147            let mut hasher: Standard<Sha256> = Standard::new();
1148            let mut mmr = Mmr::new(&mut hasher);
1149
1150            let changeset = {
1151                let mut batch = mmr.new_batch();
1152                for i in 0..num_elements {
1153                    batch.add(&mut hasher, &test_digest(i as u8));
1154                }
1155                batch.merkleize(&mut hasher).finalize()
1156            };
1157            mmr.apply(changeset).unwrap();
1158
1159            // Test pruning to each leaf.
1160            for leaf in 0..num_elements {
1161                // Test with a few different end positions to get good coverage
1162                let test_end_locs = if num_elements == 1 {
1163                    // Single element case
1164                    vec![leaf + 1]
1165                } else {
1166                    // Multi-element case: test with various end positions
1167                    let mut ends = vec![leaf + 1]; // Single element proof
1168
1169                    // Add a few more end positions if available
1170                    if leaf + 2 <= num_elements {
1171                        ends.push(leaf + 2);
1172                    }
1173                    if leaf + 3 <= num_elements {
1174                        ends.push(leaf + 3);
1175                    }
1176                    // Always test with the last element if different
1177                    if ends.last().unwrap() != &num_elements {
1178                        ends.push(num_elements);
1179                    }
1180
1181                    ends.into_iter()
1182                        .collect::<BTreeSet<_>>()
1183                        .into_iter()
1184                        .collect()
1185                };
1186
1187                for end_loc in test_end_locs {
1188                    // Generate proof for the range
1189                    let range = Location::new(leaf)..Location::new(end_loc);
1190                    let proof_result = mmr.range_proof(range.clone());
1191                    let proof = proof_result.unwrap();
1192
1193                    // Extract pinned nodes
1194                    let extract_result = proof.extract_pinned_nodes(range.clone());
1195                    assert!(
1196                            extract_result.is_ok(),
1197                            "Failed to extract pinned nodes for {num_elements} elements, boundary={leaf}, range={}..{}", range.start, range.end
1198                        );
1199
1200                    let pinned_nodes = extract_result.unwrap();
1201                    let leaf_loc = Location::new(leaf);
1202                    let leaf_pos = Position::try_from(leaf_loc).unwrap();
1203                    let expected_pinned: Vec<Position> = nodes_to_pin(leaf_pos).collect();
1204
1205                    // Verify count matches expected
1206                    assert_eq!(
1207                            pinned_nodes.len(),
1208                            expected_pinned.len(),
1209                            "Pinned node count mismatch for {num_elements} elements, boundary={leaf}, range=[{leaf}, {end_loc}]"
1210                        );
1211
1212                    // Verify extracted hashes match actual node values
1213                    // The pinned_nodes Vec is in the same order as expected_pinned
1214                    for (i, &expected_pos) in expected_pinned.iter().enumerate() {
1215                        let extracted_hash = pinned_nodes[i];
1216                        let actual_hash = mmr.get_node(expected_pos).unwrap();
1217                        assert_eq!(
1218                                extracted_hash, actual_hash,
1219                                "Hash mismatch at position {expected_pos} (index {i}) for {num_elements} elements, boundary={leaf}, range=[{leaf}, {end_loc}]"
1220                            );
1221                    }
1222                }
1223            }
1224        }
1225    }
1226
1227    #[test]
1228    fn test_proving_extract_pinned_nodes_invalid_size() {
1229        // Test that extract_pinned_nodes returns an error for invalid MMR size
1230        let mut hasher: Standard<Sha256> = Standard::new();
1231        let mut mmr = Mmr::new(&mut hasher);
1232
1233        // Build MMR with 10 elements
1234        let changeset = {
1235            let mut batch = mmr.new_batch();
1236            for i in 0..10 {
1237                batch.add(&mut hasher, &test_digest(i));
1238            }
1239            batch.merkleize(&mut hasher).finalize()
1240        };
1241        mmr.apply(changeset).unwrap();
1242
1243        // Generate a valid proof
1244        let range = Location::new(5)..Location::new(8);
1245        let mut proof = mmr.range_proof(range.clone()).unwrap();
1246
1247        // Verify the proof works with valid size
1248        assert!(proof.extract_pinned_nodes(range.clone()).is_ok());
1249
1250        // Test with invalid location.
1251        proof.leaves = Location::new(*MAX_LOCATION + 2);
1252        let result = proof.extract_pinned_nodes(range);
1253        assert!(matches!(result, Err(Error::LocationOverflow(_))));
1254    }
1255
1256    #[test]
1257    fn test_proving_digests_from_range() {
1258        // create a new MMR and add a non-trivial amount (49) of elements
1259        let mut hasher: Standard<Sha256> = Standard::new();
1260        let mut mmr = Mmr::new(&mut hasher);
1261        let elements: Vec<_> = (0..49).map(test_digest).collect();
1262        let changeset = {
1263            let mut batch = mmr.new_batch();
1264            for element in &elements {
1265                batch.add(&mut hasher, element);
1266            }
1267            batch.merkleize(&mut hasher).finalize()
1268        };
1269        mmr.apply(changeset).unwrap();
1270        let root = mmr.root();
1271
1272        // Test 1: compute_digests over the entire range should contain a digest for every node
1273        // in the tree.
1274        let proof = mmr.range_proof(Location::new(0)..mmr.leaves()).unwrap();
1275        let mut node_digests = proof
1276            .verify_range_inclusion_and_extract_digests(
1277                &mut hasher,
1278                &elements,
1279                Location::new(0),
1280                root,
1281            )
1282            .unwrap();
1283        assert_eq!(node_digests.len() as u64, mmr.size());
1284        node_digests.sort_by_key(|(pos, _)| *pos);
1285        for (i, (pos, d)) in node_digests.into_iter().enumerate() {
1286            assert_eq!(pos, i as u64);
1287            assert_eq!(mmr.get_node(pos).unwrap(), d);
1288        }
1289        // Make sure the wrong root fails.
1290        let wrong_root = elements[0]; // any other digest will do
1291        assert!(matches!(
1292            proof.verify_range_inclusion_and_extract_digests(
1293                &mut hasher,
1294                &elements,
1295                Location::new(0),
1296                &wrong_root
1297            ),
1298            Err(Error::RootMismatch)
1299        ));
1300
1301        // Test 2: Single element range (first element)
1302        let range = Location::new(0)..Location::new(1);
1303        let single_proof = mmr.range_proof(range.clone()).unwrap();
1304        let range_start = range.start;
1305        let single_digests = single_proof
1306            .verify_range_inclusion_and_extract_digests(
1307                &mut hasher,
1308                &elements[range.to_usize_range()],
1309                range_start,
1310                root,
1311            )
1312            .unwrap();
1313        assert!(single_digests.len() > 1);
1314
1315        // Test 3: Single element range (middle element)
1316        let mid_idx = 24;
1317        let range = Location::new(mid_idx)..Location::new(mid_idx + 1);
1318        let range_start = range.start;
1319        let mid_proof = mmr.range_proof(range.clone()).unwrap();
1320        let mid_digests = mid_proof
1321            .verify_range_inclusion_and_extract_digests(
1322                &mut hasher,
1323                &elements[range.to_usize_range()],
1324                range_start,
1325                root,
1326            )
1327            .unwrap();
1328        assert!(mid_digests.len() > 1);
1329
1330        // Test 4: Single element range (last element)
1331        let last_idx = elements.len() as u64 - 1;
1332        let range = Location::new(last_idx)..Location::new(last_idx + 1);
1333        let range_start = range.start;
1334        let last_proof = mmr.range_proof(range.clone()).unwrap();
1335        let last_digests = last_proof
1336            .verify_range_inclusion_and_extract_digests(
1337                &mut hasher,
1338                &elements[range.to_usize_range()],
1339                range_start,
1340                root,
1341            )
1342            .unwrap();
1343        assert!(last_digests.len() > 1);
1344
1345        // Test 5: Small range at the beginning
1346        let range = Location::new(0)..Location::new(5);
1347        let range_start = range.start;
1348        let small_proof = mmr.range_proof(range.clone()).unwrap();
1349        let small_digests = small_proof
1350            .verify_range_inclusion_and_extract_digests(
1351                &mut hasher,
1352                &elements[range.to_usize_range()],
1353                range_start,
1354                root,
1355            )
1356            .unwrap();
1357        // Verify that we get digests for the range elements and their ancestors
1358        assert!(small_digests.len() > 5);
1359
1360        // Test 6: Medium range in the middle
1361        let range = Location::new(10)..Location::new(31);
1362        let range_start = range.start;
1363        let mid_range_proof = mmr.range_proof(range.clone()).unwrap();
1364        let mid_range_digests = mid_range_proof
1365            .verify_range_inclusion_and_extract_digests(
1366                &mut hasher,
1367                &elements[range.to_usize_range()],
1368                range_start,
1369                root,
1370            )
1371            .unwrap();
1372        let num_elements = range.end - range.start;
1373        assert!(mid_range_digests.len() as u64 > num_elements);
1374    }
1375
1376    #[test]
1377    fn test_proving_multi_proof_generation_and_verify() {
1378        // Create an MMR with multiple elements
1379        let mut hasher: Standard<Sha256> = Standard::new();
1380        let mut mmr = Mmr::new(&mut hasher);
1381        let elements: Vec<_> = (0..20).map(test_digest).collect();
1382        let changeset = {
1383            let mut batch = mmr.new_batch();
1384            for element in &elements {
1385                batch.add(&mut hasher, element);
1386            }
1387            batch.merkleize(&mut hasher).finalize()
1388        };
1389        mmr.apply(changeset).unwrap();
1390
1391        let root = mmr.root();
1392
1393        // Generate proof for non-contiguous single elements
1394        let locations = &[Location::new(0), Location::new(5), Location::new(10)];
1395        let nodes_for_multi_proof =
1396            nodes_required_for_multi_proof(mmr.leaves(), locations).expect("test locations valid");
1397        let digests = nodes_for_multi_proof
1398            .into_iter()
1399            .map(|pos| mmr.get_node(pos).unwrap())
1400            .collect();
1401        let multi_proof = Proof {
1402            leaves: mmr.leaves(),
1403            digests,
1404        };
1405
1406        assert_eq!(multi_proof.leaves, mmr.leaves());
1407
1408        // Verify the proof
1409        assert!(multi_proof.verify_multi_inclusion(
1410            &mut hasher,
1411            &[
1412                (elements[0], Location::new(0)),
1413                (elements[5], Location::new(5)),
1414                (elements[10], Location::new(10)),
1415            ],
1416            root
1417        ));
1418
1419        // Verify in different order
1420        assert!(multi_proof.verify_multi_inclusion(
1421            &mut hasher,
1422            &[
1423                (elements[10], Location::new(10)),
1424                (elements[5], Location::new(5)),
1425                (elements[0], Location::new(0)),
1426            ],
1427            root
1428        ));
1429
1430        // Verify with duplicate items
1431        assert!(multi_proof.verify_multi_inclusion(
1432            &mut hasher,
1433            &[
1434                (elements[0], Location::new(0)),
1435                (elements[0], Location::new(0)),
1436                (elements[10], Location::new(10)),
1437                (elements[5], Location::new(5)),
1438            ],
1439            root
1440        ));
1441
1442        // Verify mangling the location to something invalid should fail.
1443        let mut wrong_size_proof = multi_proof.clone();
1444        wrong_size_proof.leaves = Location::new(*MAX_LOCATION + 2);
1445        assert!(!wrong_size_proof.verify_multi_inclusion(
1446            &mut hasher,
1447            &[
1448                (elements[0], Location::new(0)),
1449                (elements[5], Location::new(5)),
1450                (elements[10], Location::new(10)),
1451            ],
1452            root,
1453        ));
1454
1455        // Verify with wrong positions
1456        assert!(!multi_proof.verify_multi_inclusion(
1457            &mut hasher,
1458            &[
1459                (elements[0], Location::new(1)),
1460                (elements[5], Location::new(6)),
1461                (elements[10], Location::new(11)),
1462            ],
1463            root,
1464        ));
1465
1466        // Verify with wrong elements
1467        let wrong_elements = [
1468            vec![255u8, 254u8, 253u8],
1469            vec![252u8, 251u8, 250u8],
1470            vec![249u8, 248u8, 247u8],
1471        ];
1472        let wrong_verification = multi_proof.verify_multi_inclusion(
1473            &mut hasher,
1474            &[
1475                (wrong_elements[0].as_slice(), Location::new(0)),
1476                (wrong_elements[1].as_slice(), Location::new(5)),
1477                (wrong_elements[2].as_slice(), Location::new(10)),
1478            ],
1479            root,
1480        );
1481        assert!(!wrong_verification, "Should fail with wrong elements");
1482
1483        // Verify with out of range element
1484        let wrong_verification = multi_proof.verify_multi_inclusion(
1485            &mut hasher,
1486            &[
1487                (elements[0], Location::new(0)),
1488                (elements[5], Location::new(5)),
1489                (elements[10], Location::new(1000)),
1490            ],
1491            root,
1492        );
1493        assert!(
1494            !wrong_verification,
1495            "Should fail with out of range elements"
1496        );
1497
1498        // Verify with wrong root should fail
1499        let wrong_root = test_digest(99);
1500        assert!(!multi_proof.verify_multi_inclusion(
1501            &mut hasher,
1502            &[
1503                (elements[0], Location::new(0)),
1504                (elements[5], Location::new(5)),
1505                (elements[10], Location::new(10)),
1506            ],
1507            &wrong_root
1508        ));
1509
1510        // Empty multi-proof
1511        let mut hasher: Standard<Sha256> = Standard::new();
1512        let empty_mmr = Mmr::new(&mut hasher);
1513        let empty_root = empty_mmr.root();
1514        let empty_proof = Proof::default();
1515        assert!(empty_proof.verify_multi_inclusion(
1516            &mut hasher,
1517            &[] as &[(Digest, Location)],
1518            empty_root
1519        ));
1520    }
1521
1522    #[test]
1523    fn test_proving_multi_proof_deduplication() {
1524        let mut hasher: Standard<Sha256> = Standard::new();
1525        let mut mmr = Mmr::new(&mut hasher);
1526        // Create an MMR with enough elements to have shared digests
1527        let elements: Vec<_> = (0..30).map(test_digest).collect();
1528        let changeset = {
1529            let mut batch = mmr.new_batch();
1530            for element in &elements {
1531                batch.add(&mut hasher, element);
1532            }
1533            batch.merkleize(&mut hasher).finalize()
1534        };
1535        mmr.apply(changeset).unwrap();
1536
1537        // Get individual proofs that will share some digests (elements in same subtree)
1538        let proof1 = mmr.proof(Location::new(0)).unwrap();
1539        let proof2 = mmr.proof(Location::new(1)).unwrap();
1540        let total_digests_separate = proof1.digests.len() + proof2.digests.len();
1541
1542        // Generate multi-proof for the same positions
1543        let locations = &[Location::new(0), Location::new(1)];
1544        let multi_proof =
1545            nodes_required_for_multi_proof(mmr.leaves(), locations).expect("test locations valid");
1546        let digests = multi_proof
1547            .into_iter()
1548            .map(|pos| mmr.get_node(pos).unwrap())
1549            .collect();
1550        let multi_proof = Proof {
1551            leaves: mmr.leaves(),
1552            digests,
1553        };
1554
1555        // The combined proof should have fewer digests due to deduplication
1556        assert!(multi_proof.digests.len() < total_digests_separate);
1557
1558        // Verify it still works
1559        let root = mmr.root();
1560        assert!(multi_proof.verify_multi_inclusion(
1561            &mut hasher,
1562            &[
1563                (elements[0], Location::new(0)),
1564                (elements[1], Location::new(1))
1565            ],
1566            root
1567        ));
1568    }
1569
1570    #[test]
1571    fn test_max_location_is_provable() {
1572        // Test that the validation logic accepts MAX_LOCATION as a valid leaf count.
1573        // With MAX_LOCATION leaves, valid locations are 0..MAX_LOCATION-1.
1574        // The range MAX_LOCATION-1..MAX_LOCATION proves the last element.
1575        let max_loc_plus_1 = Location::new(*MAX_LOCATION + 1);
1576
1577        let result = nodes_required_for_range_proof(MAX_LOCATION, MAX_LOCATION - 1..MAX_LOCATION);
1578        assert!(
1579            result.is_ok(),
1580            "Should be able to prove with MAX_LOCATION leaves"
1581        );
1582
1583        // MAX_LOCATION + 1 should be rejected (exceeds MAX_LOCATION)
1584        let result_overflow =
1585            nodes_required_for_range_proof(max_loc_plus_1, MAX_LOCATION..max_loc_plus_1);
1586        assert!(
1587            result_overflow.is_err(),
1588            "Should reject location > MAX_LOCATION"
1589        );
1590        matches!(result_overflow, Err(Error::LocationOverflow(_)));
1591    }
1592
1593    #[test]
1594    fn test_max_location_multi_proof() {
1595        // Test that multi_proof can handle MAX_LOCATION
1596        // Should be able to generate multi-proof for MAX_LOCATION
1597        let result = nodes_required_for_multi_proof(MAX_LOCATION, &[MAX_LOCATION - 1]);
1598        assert!(
1599            result.is_ok(),
1600            "Should be able to generate multi-proof for MAX_LOCATION"
1601        );
1602
1603        // Should reject MAX_LOCATION + 1
1604        let invalid_loc = MAX_LOCATION + 1;
1605        let result_overflow = nodes_required_for_multi_proof(invalid_loc, &[MAX_LOCATION]);
1606        assert!(
1607            result_overflow.is_err(),
1608            "Should reject location > MAX_LOCATION in multi-proof"
1609        );
1610    }
1611
1612    #[test]
1613    fn test_max_proof_digests_per_element_sufficient() {
1614        // Verify that MAX_PROOF_DIGESTS_PER_ELEMENT (122) is sufficient for any single-element
1615        // proof in the largest valid MMR.
1616        //
1617        // MMR sizes follow: mmr_size(N) = 2*N - popcount(N) where N = leaf count.
1618        // The number of peaks equals popcount(N).
1619        //
1620        // To maximize peaks, we want N with maximum popcount. N = 2^62 - 1 has 62 one-bits:
1621        //   N = 0x3FFFFFFFFFFFFFFF = 2^0 + 2^1 + ... + 2^61
1622        //
1623        // This gives us 62 perfect binary trees with leaf counts 2^0, 2^1, ..., 2^61
1624        // and corresponding heights 0, 1, ..., 61.
1625        //
1626        // mmr_size(2^62 - 1) = 2*(2^62 - 1) - 62 = 2^63 - 2 - 62 = 2^63 - 64
1627        //
1628        // For a single-element proof in a tree of height h:
1629        //   - Path siblings from leaf to peak: h digests
1630        //   - Other peaks (not containing the element): (62 - 1) = 61 digests
1631        //   - Total: h + 61 digests
1632        //
1633        // Worst case: element in tallest tree (h = 61)
1634        //   - Path siblings: 61
1635        //   - Other peaks: 61
1636        //   - Total: 61 + 61 = 122 digests
1637
1638        const NUM_PEAKS: usize = 62;
1639        const MAX_TREE_HEIGHT: usize = 61;
1640        const EXPECTED_WORST_CASE: usize = MAX_TREE_HEIGHT + (NUM_PEAKS - 1);
1641
1642        let many_peaks_size = Position::new((1u64 << 63) - 64);
1643        assert!(
1644            many_peaks_size.is_mmr_size(),
1645            "Size {many_peaks_size} should be a valid MMR size",
1646        );
1647
1648        let peak_count = PeakIterator::new(many_peaks_size).count();
1649        assert_eq!(peak_count, NUM_PEAKS);
1650
1651        // Verify the peak heights are 61, 60, ..., 1, 0 (from left to right)
1652        let peaks: Vec<_> = PeakIterator::new(many_peaks_size).collect();
1653        for (i, &(_pos, height)) in peaks.iter().enumerate() {
1654            let expected_height = (NUM_PEAKS - 1 - i) as u32;
1655            assert_eq!(
1656                height, expected_height,
1657                "Peak {i} should have height {expected_height}, got {height}",
1658            );
1659        }
1660
1661        // Test location 0 (leftmost leaf, in tallest tree of height 61)
1662        // Expected: 61 path siblings + 61 other peaks = 122 digests
1663        let leaves = Location::try_from(many_peaks_size).unwrap();
1664        let loc = Location::new(0);
1665        let positions = nodes_required_for_range_proof(leaves, loc..loc + 1)
1666            .expect("should compute positions for location 0");
1667
1668        assert_eq!(
1669            positions.len(),
1670            EXPECTED_WORST_CASE,
1671            "Location 0 proof should require exactly {EXPECTED_WORST_CASE} digests (61 path + 61 peaks)",
1672        );
1673
1674        // Test the rightmost leaf (in smallest tree of height 0, which is itself a peak)
1675        // Expected: 0 path siblings + 61 other peaks = 61 digests
1676        let last_leaf_loc = leaves - 1;
1677        let positions = nodes_required_for_range_proof(leaves, last_leaf_loc..last_leaf_loc + 1)
1678            .expect("should compute positions for last leaf");
1679
1680        let expected_last_leaf = NUM_PEAKS - 1;
1681        assert_eq!(
1682            positions.len(),
1683            expected_last_leaf,
1684            "Last leaf proof should require exactly {expected_last_leaf} digests (0 path + 61 peaks)",
1685        );
1686    }
1687
1688    #[test]
1689    fn test_max_proof_digests_per_element_is_maximum() {
1690        // For K peaks, the worst-case proof needs: (max_tree_height) + (K - 1) digests
1691        // With K peaks of heights K-1, K-2, ..., 0, this is (K-1) + (K-1) = 2*(K-1)
1692        //
1693        // To get K peaks, leaf count N must have exactly K bits set.
1694        // MMR size = 2*N - popcount(N) = 2*N - K
1695        //
1696        // For 63 peaks: N = 2^63 - 1 (63 bits set), size = 2*(2^63 - 1) - 63 = 2^64 - 65
1697        // This exceeds MAX_POSITION, so is_mmr_size() returns false.
1698
1699        let n_for_63_peaks = (1u128 << 63) - 1;
1700        let size_for_63_peaks = 2 * n_for_63_peaks - 63; // = 2^64 - 65
1701        assert!(
1702            size_for_63_peaks > *crate::mmr::MAX_POSITION as u128,
1703            "63 peaks requires size {size_for_63_peaks} > MAX_POSITION",
1704        );
1705
1706        let size_truncated = size_for_63_peaks as u64;
1707        assert!(
1708            !Position::new(size_truncated).is_mmr_size(),
1709            "Size for 63 peaks should fail is_mmr_size()"
1710        );
1711    }
1712
1713    #[cfg(feature = "arbitrary")]
1714    mod conformance {
1715        use super::*;
1716        use commonware_codec::conformance::CodecConformance;
1717        use commonware_cryptography::sha256::Digest as Sha256Digest;
1718
1719        commonware_conformance::conformance_tests! {
1720            CodecConformance<Proof<Sha256Digest>>,
1721        }
1722    }
1723}