commonware_storage/mmr/
proof.rs

1//! Defines the inclusion [Proof] structure, and functions for verifying them against a root digest.
2//!
3//! Also provides lower-level functions for building verifiers against new or extended proof types.
4//! These lower level functions are kept outside of the [Proof] structure and not re-exported by the
5//! parent module.
6
7#[cfg(any(feature = "std", test))]
8use crate::mmr::iterator::nodes_to_pin;
9use crate::mmr::{
10    hasher::Hasher,
11    iterator::{PathIterator, PeakIterator},
12    Error, Location, Position,
13};
14use alloc::{
15    collections::{btree_map::BTreeMap, btree_set::BTreeSet},
16    vec,
17    vec::Vec,
18};
19use bytes::{Buf, BufMut};
20use commonware_codec::{varint::UInt, EncodeSize, Read, ReadExt, ReadRangeExt, Write};
21use commonware_cryptography::Digest;
22use core::ops::Range;
23#[cfg(feature = "std")]
24use tracing::debug;
25
26/// The maximum number of digests in a proof per element being proven.
27///
28/// This accounts for the worst case proof size, in an MMR with 62 peaks. The
29/// left-most leaf in such a tree requires 122 digests, for 61 path siblings
30/// and 61 peak digests.
31pub const MAX_PROOF_DIGESTS_PER_ELEMENT: usize = 122;
32
33/// Errors that can occur when reconstructing a digest from a proof due to invalid input.
34#[derive(Error, Debug)]
35pub enum ReconstructionError {
36    #[error("missing digests in proof")]
37    MissingDigests,
38    #[error("extra digests in proof")]
39    ExtraDigests,
40    #[error("start location is out of bounds")]
41    InvalidStartLoc,
42    #[error("end location is out of bounds")]
43    InvalidEndLoc,
44    #[error("missing elements")]
45    MissingElements,
46    #[error("invalid size")]
47    InvalidSize,
48}
49
50/// Contains the information necessary for proving the inclusion of an element, or some range of
51/// elements, in the MMR from its root digest.
52///
53/// The `digests` vector contains:
54///
55/// 1: the digests of each peak corresponding to a mountain containing no elements from the element
56/// range being proven in decreasing order of height, followed by:
57///
58/// 2: the nodes in the remaining mountains necessary for reconstructing their peak digests from the
59/// elements within the range, ordered by the position of their parent.
60#[derive(Clone, Debug, Eq)]
61pub struct Proof<D: Digest> {
62    /// The total number of nodes in the MMR for MMR proofs, though other authenticated data
63    /// structures may override the meaning of this field. For example, the authenticated
64    /// [crate::AuthenticatedBitMap] stores the number of bits in the bitmap within
65    /// this field.
66    pub size: Position,
67    /// The digests necessary for proving the inclusion of an element, or range of elements, in the
68    /// MMR.
69    pub digests: Vec<D>,
70}
71
72impl<D: Digest> PartialEq for Proof<D> {
73    fn eq(&self, other: &Self) -> bool {
74        self.size == other.size && self.digests == other.digests
75    }
76}
77
78impl<D: Digest> EncodeSize for Proof<D> {
79    fn encode_size(&self) -> usize {
80        UInt(*self.size).encode_size() + self.digests.encode_size()
81    }
82}
83
84impl<D: Digest> Write for Proof<D> {
85    fn write(&self, buf: &mut impl BufMut) {
86        // Write the number of nodes in the MMR as a varint
87        UInt(*self.size).write(buf);
88
89        // Write the digests
90        self.digests.write(buf);
91    }
92}
93
94impl<D: Digest> Read for Proof<D> {
95    /// The maximum number of items being proven.
96    ///
97    /// The upper bound on digests is derived as `max_items * MAX_PROOF_DIGESTS_PER_ELEMENT`.
98    type Cfg = usize;
99
100    fn read_cfg(
101        buf: &mut impl Buf,
102        max_items: &Self::Cfg,
103    ) -> Result<Self, commonware_codec::Error> {
104        // Read the number of nodes in the MMR
105        let size = Position::new(UInt::<u64>::read(buf)?.into());
106
107        // Read the digests
108        let max_digests = max_items.saturating_mul(MAX_PROOF_DIGESTS_PER_ELEMENT);
109        let digests = Vec::<D>::read_range(buf, ..=max_digests)?;
110        Ok(Self { size, digests })
111    }
112}
113
114impl<D: Digest> Default for Proof<D> {
115    /// Create an empty proof. The empty proof will verify only against the root digest of an empty
116    /// (`size == 0`) MMR.
117    fn default() -> Self {
118        Self {
119            size: Position::new(0),
120            digests: vec![],
121        }
122    }
123}
124
125#[cfg(feature = "arbitrary")]
126impl<D: Digest> arbitrary::Arbitrary<'_> for Proof<D>
127where
128    D: for<'a> arbitrary::Arbitrary<'a>,
129{
130    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
131        Ok(Self {
132            size: u.arbitrary()?,
133            digests: u.arbitrary()?,
134        })
135    }
136}
137
138impl<D: Digest> Proof<D> {
139    /// Return true if this proof proves that `element` appears at location `loc` within the MMR
140    /// with root digest `root`.
141    pub fn verify_element_inclusion<H>(
142        &self,
143        hasher: &mut H,
144        element: &[u8],
145        loc: Location,
146        root: &D,
147    ) -> bool
148    where
149        H: Hasher<D>,
150    {
151        self.verify_range_inclusion(hasher, &[element], loc, root)
152    }
153
154    /// Return true if this proof proves that the `elements` appear consecutively starting at
155    /// position `start_loc` within the MMR with root digest `root`. A malformed proof will return
156    /// false.
157    pub fn verify_range_inclusion<H, E>(
158        &self,
159        hasher: &mut H,
160        elements: &[E],
161        start_loc: Location,
162        root: &D,
163    ) -> bool
164    where
165        H: Hasher<D>,
166        E: AsRef<[u8]>,
167    {
168        if !self.size.is_mmr_size() {
169            #[cfg(feature = "std")]
170            debug!(size = ?self.size, "invalid proof size");
171            return false;
172        }
173
174        match self.reconstruct_root(hasher, elements, start_loc) {
175            Ok(reconstructed_root) => *root == reconstructed_root,
176            Err(_error) => {
177                #[cfg(feature = "std")]
178                tracing::debug!(error = ?_error, "invalid proof input");
179                false
180            }
181        }
182    }
183
184    /// Return true if this proof proves that the elements at the specified locations are included
185    /// in the MMR with the root digest `root`. A malformed proof will return false.
186    ///
187    /// The order of the elements does not affect the output.
188    pub fn verify_multi_inclusion<H, E>(
189        &self,
190        hasher: &mut H,
191        elements: &[(E, Location)],
192        root: &D,
193    ) -> bool
194    where
195        H: Hasher<D>,
196        E: AsRef<[u8]>,
197    {
198        // Empty proof is valid for an empty MMR
199        if elements.is_empty() {
200            return self.size == Position::new(0)
201                && *root == hasher.root(Position::new(0), core::iter::empty());
202        }
203        if !self.size.is_mmr_size() {
204            return false;
205        }
206
207        // Single pass to collect all required positions with deduplication
208        let mut node_positions = BTreeSet::new();
209        let mut nodes_required = BTreeMap::new();
210
211        for (_, loc) in elements {
212            if !loc.is_valid() {
213                return false;
214            }
215            // `loc` is valid so it won't overflow from +1
216            let Ok(required) = nodes_required_for_range_proof(self.size, *loc..*loc + 1) else {
217                return false;
218            };
219            for req_pos in &required {
220                node_positions.insert(*req_pos);
221            }
222            nodes_required.insert(*loc, required);
223        }
224
225        // Verify we have the exact number of digests needed
226        if node_positions.len() != self.digests.len() {
227            return false;
228        }
229
230        // Build position to digest mapping once
231        let node_digests: BTreeMap<Position, D> = node_positions
232            .iter()
233            .zip(self.digests.iter())
234            .map(|(&pos, digest)| (pos, *digest))
235            .collect();
236
237        // Verify each element by reconstructing its path
238        for (element, loc) in elements {
239            // Get required positions for this element
240            let required = &nodes_required[loc];
241
242            // Build proof with required digests
243            let mut digests = Vec::with_capacity(required.len());
244            for req_pos in required {
245                // There must exist a digest for each required position (by
246                // construction of `node_digests`)
247                let digest = node_digests
248                    .get(req_pos)
249                    .expect("must exist by construction of node_digests");
250                digests.push(*digest);
251            }
252            let proof = Self {
253                size: self.size,
254                digests,
255            };
256
257            // Verify the proof
258            if !proof.verify_element_inclusion(hasher, element.as_ref(), *loc, root) {
259                return false;
260            }
261        }
262
263        true
264    }
265
266    // The functions below are lower level functions that are useful to building verification
267    // functions for new or extended proof types.
268
269    /// Computes the set of pinned nodes for the pruning boundary corresponding to the start of the
270    /// given range, returning the digest of each by extracting it from the proof.
271    ///
272    /// # Arguments
273    /// * `range` - The start and end locations of the proven range, where start is also used as the
274    ///   pruning boundary.
275    ///
276    /// # Returns
277    /// A Vec of digests for all nodes in `nodes_to_pin(pruning_boundary)`, in the same order as
278    /// returned by `nodes_to_pin` (decreasing height order)
279    ///
280    /// # Errors
281    ///
282    /// Returns [Error::InvalidSize] if the proof size is not a valid MMR size.
283    /// Returns [Error::LocationOverflow] if a location in `range` > [crate::mmr::MAX_LOCATION].
284    /// Returns [Error::InvalidProofLength] if the proof digest count doesn't match the required
285    /// positions count.
286    /// Returns [Error::MissingDigest] if a pinned node is not found in the proof.
287    #[cfg(any(feature = "std", test))]
288    pub(crate) fn extract_pinned_nodes(
289        &self,
290        range: std::ops::Range<Location>,
291    ) -> Result<Vec<D>, Error> {
292        // Get the positions of all nodes that should be pinned.
293        let start_pos = Position::try_from(range.start)?;
294        let pinned_positions: Vec<Position> = nodes_to_pin(start_pos).collect();
295
296        // Get all positions required for the proof.
297        let required_positions = nodes_required_for_range_proof(self.size, range)?;
298
299        if required_positions.len() != self.digests.len() {
300            #[cfg(feature = "std")]
301            debug!(
302                digests_len = self.digests.len(),
303                required_positions_len = required_positions.len(),
304                "Proof digest count doesn't match required positions",
305            );
306            return Err(Error::InvalidProofLength);
307        }
308
309        // Happy path: we can extract the pinned nodes directly from the proof.
310        // This happens when the `end_element_pos` is the last element in the MMR.
311        if pinned_positions
312            == required_positions[required_positions.len() - pinned_positions.len()..]
313        {
314            return Ok(self.digests[required_positions.len() - pinned_positions.len()..].to_vec());
315        }
316
317        // Create a mapping from position to digest.
318        let position_to_digest: BTreeMap<Position, D> = required_positions
319            .iter()
320            .zip(self.digests.iter())
321            .map(|(&pos, &digest)| (pos, digest))
322            .collect();
323
324        // Extract the pinned nodes in the same order as nodes_to_pin.
325        let mut result = Vec::with_capacity(pinned_positions.len());
326        for pinned_pos in pinned_positions {
327            let Some(&digest) = position_to_digest.get(&pinned_pos) else {
328                #[cfg(feature = "std")]
329                debug!(?pinned_pos, "Pinned node not found in proof");
330                return Err(Error::MissingDigest(pinned_pos));
331            };
332            result.push(digest);
333        }
334        Ok(result)
335    }
336
337    /// Reconstructs the root digest of the MMR from the digests in the proof and the provided range
338    /// of elements, returning the (position,digest) of every node whose digest was required by the
339    /// process (including those from the proof itself). Returns a [Error::InvalidProof] if the
340    /// input data is invalid and [Error::RootMismatch] if the root does not match the computed
341    /// root.
342    pub fn verify_range_inclusion_and_extract_digests<H, E>(
343        &self,
344        hasher: &mut H,
345        elements: &[E],
346        start_loc: Location,
347        root: &D,
348    ) -> Result<Vec<(Position, D)>, super::Error>
349    where
350        H: Hasher<D>,
351        E: AsRef<[u8]>,
352    {
353        let mut collected_digests = Vec::new();
354        let Ok(peak_digests) = self.reconstruct_peak_digests(
355            hasher,
356            elements,
357            start_loc,
358            Some(&mut collected_digests),
359        ) else {
360            return Err(Error::InvalidProof);
361        };
362
363        if hasher.root(self.size, peak_digests.iter()) != *root {
364            return Err(Error::RootMismatch);
365        }
366
367        Ok(collected_digests)
368    }
369
370    /// Reconstructs the root digest of the MMR from the digests in the proof and the provided range
371    /// of elements, or returns a [ReconstructionError] if the input data is invalid.
372    pub fn reconstruct_root<H, E>(
373        &self,
374        hasher: &mut H,
375        elements: &[E],
376        start_loc: Location,
377    ) -> Result<D, ReconstructionError>
378    where
379        H: Hasher<D>,
380        E: AsRef<[u8]>,
381    {
382        if !self.size.is_mmr_size() {
383            return Err(ReconstructionError::InvalidSize);
384        }
385
386        let peak_digests = self.reconstruct_peak_digests(hasher, elements, start_loc, None)?;
387
388        Ok(hasher.root(self.size, peak_digests.iter()))
389    }
390
391    /// Reconstruct the peak digests of the MMR that produced this proof, returning
392    /// [ReconstructionError] if the input data is invalid.  If collected_digests is Some, then all
393    /// node digests used in the process will be added to the wrapped vector.
394    pub fn reconstruct_peak_digests<H, E>(
395        &self,
396        hasher: &mut H,
397        elements: &[E],
398        start_loc: Location,
399        mut collected_digests: Option<&mut Vec<(Position, D)>>,
400    ) -> Result<Vec<D>, ReconstructionError>
401    where
402        H: Hasher<D>,
403        E: AsRef<[u8]>,
404    {
405        if elements.is_empty() {
406            if start_loc == 0 {
407                return Ok(vec![]);
408            }
409            return Err(ReconstructionError::MissingElements);
410        }
411        let start_element_pos =
412            Position::try_from(start_loc).map_err(|_| ReconstructionError::InvalidStartLoc)?;
413        let end_element_pos = if elements.len() == 1 {
414            start_element_pos
415        } else {
416            let end_loc = start_loc
417                .checked_add(elements.len() as u64 - 1)
418                .ok_or(ReconstructionError::InvalidEndLoc)?;
419            Position::try_from(end_loc).map_err(|_| ReconstructionError::InvalidEndLoc)?
420        };
421        if end_element_pos >= self.size {
422            return Err(ReconstructionError::InvalidEndLoc);
423        }
424
425        let mut proof_digests_iter = self.digests.iter();
426        let mut siblings_iter = self.digests.iter().rev();
427
428        // Include peak digests only for trees that have no elements from the range, and keep track
429        // of the starting and ending trees of those that do contain some.
430        let mut peak_digests: Vec<D> = Vec::new();
431        let mut proof_digests_used = 0;
432        let mut elements_iter = elements.iter();
433        if !Position::is_mmr_size(self.size) {
434            return Err(ReconstructionError::InvalidSize);
435        }
436        for (peak_pos, height) in PeakIterator::new(self.size) {
437            let leftmost_pos = peak_pos + 2 - (1 << (height + 1));
438            if peak_pos >= start_element_pos && leftmost_pos <= end_element_pos {
439                let hash = peak_digest_from_range(
440                    hasher,
441                    RangeInfo {
442                        pos: peak_pos,
443                        two_h: 1 << height,
444                        leftmost_pos: start_element_pos,
445                        rightmost_pos: end_element_pos,
446                    },
447                    &mut elements_iter,
448                    &mut siblings_iter,
449                    collected_digests.as_deref_mut(),
450                )?;
451                peak_digests.push(hash);
452                if let Some(ref mut collected_digests) = collected_digests {
453                    collected_digests.push((peak_pos, hash));
454                }
455            } else if let Some(hash) = proof_digests_iter.next() {
456                proof_digests_used += 1;
457                peak_digests.push(*hash);
458                if let Some(ref mut collected_digests) = collected_digests {
459                    collected_digests.push((peak_pos, *hash));
460                }
461            } else {
462                return Err(ReconstructionError::MissingDigests);
463            }
464        }
465
466        if elements_iter.next().is_some() {
467            return Err(ReconstructionError::ExtraDigests);
468        }
469        if let Some(next_sibling) = siblings_iter.next() {
470            if proof_digests_used == 0 || *next_sibling != self.digests[proof_digests_used - 1] {
471                return Err(ReconstructionError::ExtraDigests);
472            }
473        }
474
475        Ok(peak_digests)
476    }
477}
478
479/// Return the list of node positions required by the range proof for the specified range of
480/// elements.
481///
482/// # Errors
483///
484/// Returns [Error::InvalidSize] if `size` is not a valid MMR size.
485/// Returns [Error::Empty] if the range is empty.
486/// Returns [Error::LocationOverflow] if a location in `range` > [crate::mmr::MAX_LOCATION].
487/// Returns [Error::RangeOutOfBounds] if the last element position in `range` is out of bounds
488/// (>= `size`).
489pub(crate) fn nodes_required_for_range_proof(
490    size: Position,
491    range: Range<Location>,
492) -> Result<Vec<Position>, Error> {
493    if !size.is_mmr_size() {
494        return Err(Error::InvalidSize(*size));
495    }
496    if range.is_empty() {
497        return Err(Error::Empty);
498    }
499    let start_element_pos = Position::try_from(range.start)?;
500    let end_minus_one = range
501        .end
502        .checked_sub(1)
503        .expect("can't underflow because range is non-empty");
504    let end_element_pos = Position::try_from(end_minus_one)?;
505    if end_element_pos >= size {
506        return Err(Error::RangeOutOfBounds(range.end));
507    }
508
509    // Find the mountains that contain no elements from the range. The peaks of these mountains
510    // are required to prove the range, so they are added to the result.
511    let mut start_tree_with_element: Option<(Position, u32)> = None;
512    let mut end_tree_with_element: Option<(Position, u32)> = None;
513    let mut peak_iterator = PeakIterator::new(size);
514    let mut positions = Vec::new();
515    while let Some(peak) = peak_iterator.next() {
516        if start_tree_with_element.is_none() && peak.0 >= start_element_pos {
517            // Found the first tree to contain an element in the range
518            start_tree_with_element = Some(peak);
519            if peak.0 >= end_element_pos {
520                // Start and end tree are the same
521                end_tree_with_element = Some(peak);
522                continue;
523            }
524            for peak in peak_iterator.by_ref() {
525                if peak.0 >= end_element_pos {
526                    // Found the last tree to contain an element in the range
527                    end_tree_with_element = Some(peak);
528                    break;
529                }
530            }
531        } else {
532            // Tree is outside the range, its peak is thus required.
533            positions.push(peak.0);
534        }
535    }
536
537    // We checked above that all range elements are in this MMR, so some mountain must contain
538    // the first and last elements in the range.
539    let (start_tree_peak, start_tree_height) =
540        start_tree_with_element.expect("start_tree_with_element is Some");
541    let (end_tree_peak, end_tree_height) =
542        end_tree_with_element.expect("end_tree_with_element is Some");
543
544    // Include the positions of any left-siblings of each node on the path from peak to
545    // leftmost-leaf, and right-siblings for the path from peak to rightmost-leaf. These are
546    // added in order of decreasing parent position.
547    let left_path_iter = PathIterator::new(start_element_pos, start_tree_peak, start_tree_height);
548
549    let mut siblings = Vec::new();
550    if start_element_pos == end_element_pos {
551        // For the (common) case of a single element range, the right and left path are the
552        // same so no need to process each independently.
553        siblings.extend(left_path_iter);
554    } else {
555        let right_path_iter = PathIterator::new(end_element_pos, end_tree_peak, end_tree_height);
556        // filter the right path for right siblings only
557        siblings.extend(right_path_iter.filter(|(parent_pos, pos)| *parent_pos == *pos + 1));
558        // filter the left path for left siblings only
559        siblings.extend(left_path_iter.filter(|(parent_pos, pos)| *parent_pos != *pos + 1));
560
561        // If the range spans more than one tree, then the digests must already be in the correct
562        // order. Otherwise, we enforce the desired order through sorting.
563        if start_tree_peak == end_tree_peak {
564            siblings.sort_by(|a, b| b.0.cmp(&a.0));
565        }
566    }
567    positions.extend(siblings.into_iter().map(|(_, pos)| pos));
568    Ok(positions)
569}
570
571/// Returns the positions of the minimal set of nodes whose digests are required to prove the
572/// inclusion of the elements at the specified `locations`.
573///
574/// The order of positions does not affect the output (sorted internally).
575///
576/// # Errors
577///
578/// Returns [Error::InvalidSize] if `size` is not a valid MMR size.
579/// Returns [Error::Empty] if locations is empty.
580/// Returns [Error::LocationOverflow] if any location in `locations` > [crate::mmr::MAX_LOCATION].
581/// Returns [Error::RangeOutOfBounds] if any location is out of bounds for the given `size`.
582#[cfg(any(feature = "std", test))]
583pub(crate) fn nodes_required_for_multi_proof(
584    size: Position,
585    locations: &[Location],
586) -> Result<BTreeSet<Position>, Error> {
587    if !size.is_mmr_size() {
588        return Err(Error::InvalidSize(*size));
589    }
590    // Collect all required node positions
591    //
592    // TODO(#1472): Optimize this loop
593    if locations.is_empty() {
594        return Err(Error::Empty);
595    }
596    locations.iter().try_fold(BTreeSet::new(), |mut acc, loc| {
597        if !loc.is_valid() {
598            return Err(Error::LocationOverflow(*loc));
599        }
600        // `loc` is valid so it won't overflow from +1
601        let positions = nodes_required_for_range_proof(size, *loc..*loc + 1)?;
602        acc.extend(positions);
603        Ok(acc)
604    })
605}
606
607/// Information about the current range of nodes being traversed.
608struct RangeInfo {
609    pos: Position,           // current node position in the tree
610    two_h: u64,              // 2^height of the current node
611    leftmost_pos: Position,  // leftmost leaf in the tree to be traversed
612    rightmost_pos: Position, // rightmost leaf in the tree to be traversed
613}
614
615fn peak_digest_from_range<'a, D, H, E, S>(
616    hasher: &mut H,
617    range_info: RangeInfo,
618    elements: &mut E,
619    sibling_digests: &mut S,
620    mut collected_digests: Option<&mut Vec<(Position, D)>>,
621) -> Result<D, ReconstructionError>
622where
623    D: Digest,
624    H: Hasher<D>,
625    E: Iterator<Item: AsRef<[u8]>>,
626    S: Iterator<Item = &'a D>,
627{
628    assert_ne!(range_info.two_h, 0);
629    if range_info.two_h == 1 {
630        match elements.next() {
631            Some(element) => return Ok(hasher.leaf_digest(range_info.pos, element.as_ref())),
632            None => return Err(ReconstructionError::MissingDigests),
633        }
634    }
635
636    let mut left_digest: Option<D> = None;
637    let mut right_digest: Option<D> = None;
638
639    let left_pos = range_info.pos - range_info.two_h;
640    let right_pos = left_pos + range_info.two_h - 1;
641    if left_pos >= range_info.leftmost_pos {
642        // Descend left
643        let digest = peak_digest_from_range(
644            hasher,
645            RangeInfo {
646                pos: left_pos,
647                two_h: range_info.two_h >> 1,
648                leftmost_pos: range_info.leftmost_pos,
649                rightmost_pos: range_info.rightmost_pos,
650            },
651            elements,
652            sibling_digests,
653            collected_digests.as_deref_mut(),
654        )?;
655        left_digest = Some(digest);
656    }
657    if left_pos < range_info.rightmost_pos {
658        // Descend right
659        let digest = peak_digest_from_range(
660            hasher,
661            RangeInfo {
662                pos: right_pos,
663                two_h: range_info.two_h >> 1,
664                leftmost_pos: range_info.leftmost_pos,
665                rightmost_pos: range_info.rightmost_pos,
666            },
667            elements,
668            sibling_digests,
669            collected_digests.as_deref_mut(),
670        )?;
671        right_digest = Some(digest);
672    }
673
674    if left_digest.is_none() {
675        match sibling_digests.next() {
676            Some(hash) => left_digest = Some(*hash),
677            None => return Err(ReconstructionError::MissingDigests),
678        }
679    }
680    if right_digest.is_none() {
681        match sibling_digests.next() {
682            Some(hash) => right_digest = Some(*hash),
683            None => return Err(ReconstructionError::MissingDigests),
684        }
685    }
686
687    if let Some(ref mut collected_digests) = collected_digests {
688        collected_digests.push((
689            left_pos,
690            left_digest.expect("left_digest guaranteed to be Some after checks above"),
691        ));
692        collected_digests.push((
693            right_pos,
694            right_digest.expect("right_digest guaranteed to be Some after checks above"),
695        ));
696    }
697
698    Ok(hasher.node_digest(
699        range_info.pos,
700        &left_digest.expect("left_digest guaranteed to be Some after checks above"),
701        &right_digest.expect("right_digest guaranteed to be Some after checks above"),
702    ))
703}
704
705#[cfg(test)]
706mod tests {
707    use super::*;
708    use crate::mmr::{
709        hasher::Standard, location::LocationRangeExt as _, mem::CleanMmr, MAX_LOCATION,
710    };
711    use bytes::Bytes;
712    use commonware_codec::{Decode, Encode};
713    use commonware_cryptography::{sha256::Digest, Hasher, Sha256};
714    use commonware_macros::test_traced;
715
716    fn test_digest(v: u8) -> Digest {
717        Sha256::hash(&[v])
718    }
719
720    #[test]
721    fn test_proving_proof() {
722        // Test that an empty proof authenticates an empty MMR.
723        let mut hasher: Standard<Sha256> = Standard::new();
724        let mmr = CleanMmr::new(&mut hasher);
725        let root = mmr.root();
726        let proof = Proof::default();
727        assert!(proof.verify_range_inclusion(
728            &mut hasher,
729            &[] as &[Digest],
730            Location::new_unchecked(0),
731            root
732        ));
733
734        // Any starting position other than 0 should fail to verify.
735        assert!(!proof.verify_range_inclusion(
736            &mut hasher,
737            &[] as &[Digest],
738            Location::new_unchecked(1),
739            root
740        ));
741
742        // Invalid root should fail to verify.
743        let test_digest = test_digest(0);
744        assert!(!proof.verify_range_inclusion(
745            &mut hasher,
746            &[] as &[Digest],
747            Location::new_unchecked(0),
748            &test_digest
749        ));
750
751        // Non-empty elements list should fail to verify.
752        assert!(!proof.verify_range_inclusion(
753            &mut hasher,
754            &[test_digest],
755            Location::new_unchecked(0),
756            root
757        ));
758    }
759
760    #[test]
761    fn test_proving_verify_element() {
762        // create an 11 element MMR over which we'll test single-element inclusion proofs
763        let element = Digest::from(*b"01234567012345670123456701234567");
764        let mut hasher: Standard<Sha256> = Standard::new();
765        let mut mmr = CleanMmr::new(&mut hasher);
766        for _ in 0..11 {
767            mmr.add(&mut hasher, &element);
768        }
769        let root = mmr.root();
770
771        // confirm the proof of inclusion for each leaf successfully verifies
772        for leaf in 0u64..11 {
773            let leaf = Location::new_unchecked(leaf);
774            let proof: Proof<Digest> = mmr.proof(leaf).unwrap();
775            assert!(
776                proof.verify_element_inclusion(&mut hasher, &element, leaf, root),
777                "valid proof should verify successfully"
778            );
779        }
780
781        // Create a valid proof, then confirm various mangling of the proof or proof args results in
782        // verification failure.
783        const LEAF: Location = Location::new_unchecked(10);
784        let proof = mmr.proof(LEAF).unwrap();
785        assert!(
786            proof.verify_element_inclusion(&mut hasher, &element, LEAF, root),
787            "proof verification should be successful"
788        );
789        let wrong_sizes = [0, 16, 17, 18, 20, u64::MAX - 100];
790        for sz in wrong_sizes {
791            let mut wrong_size_proof = proof.clone();
792            wrong_size_proof.size = Position::new(sz);
793            assert!(
794                !wrong_size_proof.verify_element_inclusion(&mut hasher, &element, LEAF, root),
795                "proof with wrong size should fail verification"
796            );
797        }
798        assert!(
799            !proof.verify_element_inclusion(&mut hasher, &element, LEAF + 1, root),
800            "proof verification should fail with incorrect element position"
801        );
802        assert!(
803            !proof.verify_element_inclusion(&mut hasher, &element, LEAF - 1, root),
804            "proof verification should fail with incorrect element position 2"
805        );
806        assert!(
807            !proof.verify_element_inclusion(&mut hasher, &test_digest(0), LEAF, root),
808            "proof verification should fail with mangled element"
809        );
810        let root2 = test_digest(0);
811        assert!(
812            !proof.verify_element_inclusion(&mut hasher, &element, LEAF, &root2),
813            "proof verification should fail with mangled root"
814        );
815        let mut proof2 = proof.clone();
816        proof2.digests[0] = test_digest(0);
817        assert!(
818            !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, root),
819            "proof verification should fail with mangled proof hash"
820        );
821        proof2 = proof.clone();
822        proof2.size = Position::new(10);
823        assert!(
824            !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, root),
825            "proof verification should fail with incorrect size"
826        );
827        proof2 = proof.clone();
828        proof2.digests.push(test_digest(0));
829        assert!(
830            !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, root),
831            "proof verification should fail with extra hash"
832        );
833        proof2 = proof.clone();
834        while !proof2.digests.is_empty() {
835            proof2.digests.pop();
836            assert!(
837                !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, root),
838                "proof verification should fail with missing digests"
839            );
840        }
841        proof2 = proof.clone();
842        proof2.digests.clear();
843        const PEAK_COUNT: usize = 3;
844        proof2
845            .digests
846            .extend(proof.digests[0..PEAK_COUNT - 1].iter().cloned());
847        // sneak in an extra hash that won't be used in the computation and make sure it's
848        // detected
849        proof2.digests.push(test_digest(0));
850        proof2
851            .digests
852            .extend(proof.digests[PEAK_COUNT - 1..].iter().cloned());
853        assert!(
854            !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, root),
855            "proof verification should fail with extra hash even if it's unused by the computation"
856        );
857    }
858
859    #[test]
860    fn test_proving_verify_range() {
861        // create a new MMR and add a non-trivial amount (49) of elements
862        let mut hasher: Standard<Sha256> = Standard::new();
863        let mut mmr = CleanMmr::new(&mut hasher);
864        let mut elements = Vec::new();
865        for i in 0..49 {
866            elements.push(test_digest(i));
867            mmr.add(&mut hasher, elements.last().unwrap());
868        }
869        // test range proofs over all possible ranges of at least 2 elements
870        let root = mmr.root();
871
872        for i in 0..elements.len() {
873            for j in i + 1..elements.len() {
874                let range = Location::new_unchecked(i as u64)..Location::new_unchecked(j as u64);
875                let range_proof = mmr.range_proof(range.clone()).unwrap();
876                assert!(
877                    range_proof.verify_range_inclusion(
878                        &mut hasher,
879                        &elements[range.to_usize_range()],
880                        range.start,
881                        root,
882                    ),
883                    "valid range proof should verify successfully {i}:{j}",
884                );
885            }
886        }
887
888        // Create a proof over a range of elements, confirm it verifies successfully, then mangle
889        // the proof & proof input in various ways, confirming verification fails.
890        let range = Location::new_unchecked(33)..Location::new_unchecked(40);
891        let range_proof = mmr.range_proof(range.clone()).unwrap();
892        let valid_elements = &elements[range.to_usize_range()];
893        assert!(
894            range_proof.verify_range_inclusion(&mut hasher, valid_elements, range.start, root),
895            "valid range proof should verify successfully"
896        );
897        // Remove digests from the proof until it's empty, confirming proof verification fails for
898        // each.
899        let mut invalid_proof = range_proof.clone();
900        for _i in 0..range_proof.digests.len() {
901            invalid_proof.digests.remove(0);
902            assert!(
903                !invalid_proof.verify_range_inclusion(
904                    &mut hasher,
905                    valid_elements,
906                    range.start,
907                    root,
908                ),
909                "range proof with removed elements should fail"
910            );
911        }
912        // Confirm proof verification fails when providing an element range different than the one
913        // used to generate the proof.
914        for i in 0..elements.len() {
915            for j in i + 1..elements.len() {
916                if Location::from(i) == range.start && Location::from(j) == range.end {
917                    // skip the valid range
918                    continue;
919                }
920                assert!(
921                    !range_proof.verify_range_inclusion(
922                        &mut hasher,
923                        &elements[i..j],
924                        range.start,
925                        root,
926                    ),
927                    "range proof with invalid element range should fail {i}:{j}",
928                );
929            }
930        }
931        // Confirm proof fails to verify with an invalid root.
932        let invalid_root = test_digest(1);
933        assert!(
934            !range_proof.verify_range_inclusion(
935                &mut hasher,
936                valid_elements,
937                range.start,
938                &invalid_root,
939            ),
940            "range proof with invalid root should fail"
941        );
942        // Mangle each element of the proof and confirm it fails to verify.
943        for i in 0..range_proof.digests.len() {
944            let mut invalid_proof = range_proof.clone();
945            invalid_proof.digests[i] = test_digest(0);
946
947            assert!(
948                !invalid_proof.verify_range_inclusion(
949                    &mut hasher,
950                    valid_elements,
951                    range.start,
952                    root,
953                ),
954                "mangled range proof should fail verification"
955            );
956        }
957        // Inserting elements into the proof should also cause it to fail (malleability check)
958        for i in 0..range_proof.digests.len() {
959            let mut invalid_proof = range_proof.clone();
960            invalid_proof.digests.insert(i, test_digest(0));
961            assert!(
962                !invalid_proof.verify_range_inclusion(
963                    &mut hasher,
964                    valid_elements,
965                    range.start,
966                    root,
967                ),
968                "mangled range proof should fail verification. inserted element at: {i}",
969            );
970        }
971        // Bad start_loc should cause verification to fail.
972        for loc in 0..elements.len() {
973            let loc = Location::new_unchecked(loc as u64);
974            if loc == range.start {
975                continue;
976            }
977            assert!(
978                !range_proof.verify_range_inclusion(&mut hasher, valid_elements, loc, root),
979                "bad start_loc should fail verification {loc}",
980            );
981        }
982    }
983
984    #[test_traced]
985    fn test_proving_retained_nodes_provable_after_pruning() {
986        // create a new MMR and add a non-trivial amount (49) of elements
987        let mut hasher: Standard<Sha256> = Standard::new();
988        let mut mmr = CleanMmr::new(&mut hasher);
989        let mut elements = Vec::new();
990        for i in 0..49 {
991            elements.push(test_digest(i));
992            mmr.add(&mut hasher, elements.last().unwrap());
993        }
994
995        // Confirm we can successfully prove all retained elements in the MMR after pruning.
996        let root = *mmr.root();
997        for i in 1..*mmr.size() {
998            mmr.prune_to_pos(Position::new(i));
999            let pruned_root = mmr.root();
1000            assert_eq!(root, *pruned_root);
1001            for loc in 0..elements.len() {
1002                let loc = Location::new_unchecked(loc as u64);
1003                let proof = mmr.proof(loc);
1004                if Position::try_from(loc).unwrap() < Position::new(i) {
1005                    continue;
1006                }
1007                assert!(proof.is_ok());
1008                assert!(proof.unwrap().verify_element_inclusion(
1009                    &mut hasher,
1010                    &elements[*loc as usize],
1011                    loc,
1012                    &root
1013                ));
1014            }
1015        }
1016    }
1017
1018    #[test]
1019    fn test_proving_ranges_provable_after_pruning() {
1020        // create a new MMR and add a non-trivial amount (49) of elements
1021        let mut hasher: Standard<Sha256> = Standard::new();
1022        let mut mmr = CleanMmr::new(&mut hasher);
1023        let mut elements = Vec::new();
1024        for i in 0..49 {
1025            elements.push(test_digest(i));
1026            mmr.add(&mut hasher, elements.last().unwrap());
1027        }
1028
1029        // prune up to the first peak
1030        const PRUNE_POS: Position = Position::new(62);
1031        mmr.prune_to_pos(PRUNE_POS);
1032        assert_eq!(mmr.oldest_retained_pos().unwrap(), PRUNE_POS);
1033
1034        // Test range proofs over all possible ranges of at least 2 elements
1035        let root = mmr.root();
1036        for i in 0..elements.len() - 1 {
1037            if Position::try_from(Location::new_unchecked(i as u64)).unwrap() < PRUNE_POS {
1038                continue;
1039            }
1040            for j in (i + 2)..elements.len() {
1041                let range = Location::new_unchecked(i as u64)..Location::new_unchecked(j as u64);
1042                let range_proof = mmr.range_proof(range.clone()).unwrap();
1043                assert!(
1044                    range_proof.verify_range_inclusion(
1045                        &mut hasher,
1046                        &elements[range.to_usize_range()],
1047                        range.start,
1048                        root,
1049                    ),
1050                    "valid range proof over remaining elements should verify successfully",
1051                );
1052            }
1053        }
1054
1055        // Add a few more nodes, prune again, and test again to make sure repeated pruning doesn't
1056        // break proof verification.
1057        for i in 0..37 {
1058            elements.push(test_digest(i));
1059            mmr.add(&mut hasher, elements.last().unwrap());
1060        }
1061        mmr.prune_to_pos(Position::new(130)); // a bit after the new highest peak
1062        assert_eq!(mmr.oldest_retained_pos().unwrap(), 130);
1063
1064        let updated_root = mmr.root();
1065        let range = Location::new_unchecked(elements.len() as u64 - 10)
1066            ..Location::new_unchecked(elements.len() as u64);
1067        let range_proof = mmr.range_proof(range.clone()).unwrap();
1068        assert!(
1069                range_proof.verify_range_inclusion(
1070                    &mut hasher,
1071                    &elements[range.to_usize_range()],
1072                    range.start,
1073                    updated_root,
1074                ),
1075                "valid range proof over remaining elements after 2 pruning rounds should verify successfully",
1076            );
1077    }
1078
1079    #[test]
1080    fn test_proving_proof_serialization() {
1081        // create a new MMR and add a non-trivial amount of elements
1082        let mut hasher: Standard<Sha256> = Standard::new();
1083        let mut mmr = CleanMmr::new(&mut hasher);
1084        let mut elements = Vec::new();
1085        for i in 0..25 {
1086            elements.push(test_digest(i));
1087            mmr.add(&mut hasher, elements.last().unwrap());
1088        }
1089
1090        // Generate proofs over all possible ranges of elements and confirm each
1091        // serializes=>deserializes correctly.
1092        for i in 0..elements.len() {
1093            for j in i + 1..elements.len() {
1094                let range = Location::new_unchecked(i as u64)..Location::new_unchecked(j as u64);
1095                let proof = mmr.range_proof(range).unwrap();
1096
1097                let expected_size = proof.encode_size();
1098                let serialized_proof = proof.encode();
1099                assert_eq!(
1100                    serialized_proof.len(),
1101                    expected_size,
1102                    "serialized proof should have expected size"
1103                );
1104                // max_items is the number of elements in the range
1105                let max_items = j - i;
1106                let deserialized_proof = Proof::decode_cfg(serialized_proof, &max_items).unwrap();
1107                assert_eq!(
1108                    proof, deserialized_proof,
1109                    "deserialized proof should match source proof"
1110                );
1111
1112                // Remove one byte from the end of the serialized
1113                // proof and confirm it fails to deserialize.
1114                let serialized_proof = proof.encode();
1115                let serialized_proof: Bytes = serialized_proof.slice(0..serialized_proof.len() - 1);
1116                assert!(
1117                    Proof::<Digest>::decode_cfg(serialized_proof, &max_items).is_err(),
1118                    "proof should not deserialize with truncated data"
1119                );
1120
1121                // Add 1 byte of extra data to the end of the serialized
1122                // proof and confirm it fails to deserialize.
1123                let mut serialized_proof = proof.encode_mut();
1124                serialized_proof.extend_from_slice(&[0; 10]);
1125                let serialized_proof = serialized_proof;
1126
1127                assert!(
1128                    Proof::<Digest>::decode_cfg(serialized_proof, &max_items).is_err(),
1129                    "proof should not deserialize with extra data"
1130                );
1131
1132                // Confirm deserialization fails when max_items is too small.
1133                let actual_digests = proof.digests.len();
1134                if actual_digests > 0 {
1135                    // Find the minimum max_items that would allow this many digests
1136                    let min_max_items = actual_digests.div_ceil(MAX_PROOF_DIGESTS_PER_ELEMENT);
1137                    // Using one less should fail
1138                    let too_small = min_max_items - 1;
1139                    let serialized_proof = proof.encode();
1140                    assert!(
1141                        Proof::<Digest>::decode_cfg(serialized_proof, &too_small).is_err(),
1142                        "proof should not deserialize with max_items too small"
1143                    );
1144                }
1145            }
1146        }
1147    }
1148
1149    #[test_traced]
1150    fn test_proving_extract_pinned_nodes() {
1151        // Test for every number of elements from 1 to 255
1152        for num_elements in 1u64..255 {
1153            // Build MMR with the specified number of elements
1154            let mut hasher: Standard<Sha256> = Standard::new();
1155            let mut mmr = CleanMmr::new(&mut hasher);
1156
1157            for i in 0..num_elements {
1158                let digest = test_digest(i as u8);
1159                mmr.add(&mut hasher, &digest);
1160            }
1161
1162            // Test pruning to each leaf.
1163            for leaf in 0..num_elements {
1164                // Test with a few different end positions to get good coverage
1165                let test_end_locs = if num_elements == 1 {
1166                    // Single element case
1167                    vec![leaf + 1]
1168                } else {
1169                    // Multi-element case: test with various end positions
1170                    let mut ends = vec![leaf + 1]; // Single element proof
1171
1172                    // Add a few more end positions if available
1173                    if leaf + 2 <= num_elements {
1174                        ends.push(leaf + 2);
1175                    }
1176                    if leaf + 3 <= num_elements {
1177                        ends.push(leaf + 3);
1178                    }
1179                    // Always test with the last element if different
1180                    if ends.last().unwrap() != &num_elements {
1181                        ends.push(num_elements);
1182                    }
1183
1184                    ends.into_iter()
1185                        .collect::<BTreeSet<_>>()
1186                        .into_iter()
1187                        .collect()
1188                };
1189
1190                for end_loc in test_end_locs {
1191                    // Generate proof for the range
1192                    let range = Location::new_unchecked(leaf)..Location::new_unchecked(end_loc);
1193                    let proof_result = mmr.range_proof(range.clone());
1194                    let proof = proof_result.unwrap();
1195
1196                    // Extract pinned nodes
1197                    let extract_result = proof.extract_pinned_nodes(range.clone());
1198                    assert!(
1199                            extract_result.is_ok(),
1200                            "Failed to extract pinned nodes for {num_elements} elements, boundary={leaf}, range={}..{}", range.start, range.end
1201                        );
1202
1203                    let pinned_nodes = extract_result.unwrap();
1204                    let leaf_loc = Location::new_unchecked(leaf);
1205                    let leaf_pos = Position::try_from(leaf_loc).unwrap();
1206                    let expected_pinned: Vec<Position> = nodes_to_pin(leaf_pos).collect();
1207
1208                    // Verify count matches expected
1209                    assert_eq!(
1210                            pinned_nodes.len(),
1211                            expected_pinned.len(),
1212                            "Pinned node count mismatch for {num_elements} elements, boundary={leaf}, range=[{leaf}, {end_loc}]"
1213                        );
1214
1215                    // Verify extracted hashes match actual node values
1216                    // The pinned_nodes Vec is in the same order as expected_pinned
1217                    for (i, &expected_pos) in expected_pinned.iter().enumerate() {
1218                        let extracted_hash = pinned_nodes[i];
1219                        let actual_hash = mmr.get_node(expected_pos).unwrap();
1220                        assert_eq!(
1221                                extracted_hash, actual_hash,
1222                                "Hash mismatch at position {expected_pos} (index {i}) for {num_elements} elements, boundary={leaf}, range=[{leaf}, {end_loc}]"
1223                            );
1224                    }
1225                }
1226            }
1227        }
1228    }
1229
1230    #[test]
1231    fn test_proving_extract_pinned_nodes_invalid_size() {
1232        // Test that extract_pinned_nodes returns an error for invalid MMR size
1233        let mut hasher: Standard<Sha256> = Standard::new();
1234        let mut mmr = CleanMmr::new(&mut hasher);
1235
1236        // Build MMR with 10 elements
1237        for i in 0..10 {
1238            let digest = test_digest(i);
1239            mmr.add(&mut hasher, &digest);
1240        }
1241
1242        // Generate a valid proof
1243        let range = Location::new_unchecked(5)..Location::new_unchecked(8);
1244        let mut proof = mmr.range_proof(range.clone()).unwrap();
1245
1246        // Verify the proof works with valid size
1247        assert!(proof.extract_pinned_nodes(range.clone()).is_ok());
1248
1249        // Test various invalid sizes
1250        const INVALID_SIZES: [u64; 5] = [2, 5, 6, 9, u64::MAX];
1251        for invalid_size in INVALID_SIZES {
1252            proof.size = Position::new(invalid_size);
1253            let result = proof.extract_pinned_nodes(range.clone());
1254            assert!(matches!(result, Err(Error::InvalidSize(s)) if s == invalid_size));
1255        }
1256
1257        // Test with MSB set (invalid due to leading zero requirement)
1258        proof.size = Position::new(1u64 << 63);
1259        let result = proof.extract_pinned_nodes(range);
1260        assert!(matches!(result, Err(Error::InvalidSize(s)) if s == 1u64 << 63));
1261    }
1262
1263    #[test]
1264    fn test_proving_digests_from_range() {
1265        // create a new MMR and add a non-trivial amount (49) of elements
1266        let mut hasher: Standard<Sha256> = Standard::new();
1267        let mut mmr = CleanMmr::new(&mut hasher);
1268        let mut elements = Vec::new();
1269        let mut element_positions = Vec::new();
1270        for i in 0..49 {
1271            elements.push(test_digest(i));
1272            element_positions.push(mmr.add(&mut hasher, elements.last().unwrap()));
1273        }
1274        let root = mmr.root();
1275
1276        // Test 1: compute_digests over the entire range should contain a digest for every node
1277        // in the tree.
1278        let proof = mmr
1279            .range_proof(Location::new_unchecked(0)..mmr.leaves())
1280            .unwrap();
1281        let mut node_digests = proof
1282            .verify_range_inclusion_and_extract_digests(
1283                &mut hasher,
1284                &elements,
1285                Location::new_unchecked(0),
1286                root,
1287            )
1288            .unwrap();
1289        assert_eq!(node_digests.len() as u64, mmr.size());
1290        node_digests.sort_by_key(|(pos, _)| *pos);
1291        for (i, (pos, d)) in node_digests.into_iter().enumerate() {
1292            assert_eq!(pos, i as u64);
1293            assert_eq!(mmr.get_node(pos).unwrap(), d);
1294        }
1295        // Make sure the wrong root fails.
1296        let wrong_root = elements[0]; // any other digest will do
1297        assert!(matches!(
1298            proof.verify_range_inclusion_and_extract_digests(
1299                &mut hasher,
1300                &elements,
1301                Location::new_unchecked(0),
1302                &wrong_root
1303            ),
1304            Err(Error::RootMismatch)
1305        ));
1306
1307        // Test 2: Single element range (first element)
1308        let range = Location::new_unchecked(0)..Location::new_unchecked(1);
1309        let single_proof = mmr.range_proof(range.clone()).unwrap();
1310        let range_start = range.start;
1311        let single_digests = single_proof
1312            .verify_range_inclusion_and_extract_digests(
1313                &mut hasher,
1314                &elements[range.to_usize_range()],
1315                range_start,
1316                root,
1317            )
1318            .unwrap();
1319        assert!(single_digests.len() > 1);
1320
1321        // Test 3: Single element range (middle element)
1322        let mid_idx = 24;
1323        let range = Location::new_unchecked(mid_idx)..Location::new_unchecked(mid_idx + 1);
1324        let range_start = range.start;
1325        let mid_proof = mmr.range_proof(range.clone()).unwrap();
1326        let mid_digests = mid_proof
1327            .verify_range_inclusion_and_extract_digests(
1328                &mut hasher,
1329                &elements[range.to_usize_range()],
1330                range_start,
1331                root,
1332            )
1333            .unwrap();
1334        assert!(mid_digests.len() > 1);
1335
1336        // Test 4: Single element range (last element)
1337        let last_idx = elements.len() as u64 - 1;
1338        let range = Location::new_unchecked(last_idx)..Location::new_unchecked(last_idx + 1);
1339        let range_start = range.start;
1340        let last_proof = mmr.range_proof(range.clone()).unwrap();
1341        let last_digests = last_proof
1342            .verify_range_inclusion_and_extract_digests(
1343                &mut hasher,
1344                &elements[range.to_usize_range()],
1345                range_start,
1346                root,
1347            )
1348            .unwrap();
1349        assert!(last_digests.len() > 1);
1350
1351        // Test 5: Small range at the beginning
1352        let range = Location::new_unchecked(0)..Location::new_unchecked(5);
1353        let range_start = range.start;
1354        let small_proof = mmr.range_proof(range.clone()).unwrap();
1355        let small_digests = small_proof
1356            .verify_range_inclusion_and_extract_digests(
1357                &mut hasher,
1358                &elements[range.to_usize_range()],
1359                range_start,
1360                root,
1361            )
1362            .unwrap();
1363        // Verify that we get digests for the range elements and their ancestors
1364        assert!(small_digests.len() > 5);
1365
1366        // Test 6: Medium range in the middle
1367        let range = Location::new_unchecked(10)..Location::new_unchecked(31);
1368        let range_start = range.start;
1369        let mid_range_proof = mmr.range_proof(range.clone()).unwrap();
1370        let mid_range_digests = mid_range_proof
1371            .verify_range_inclusion_and_extract_digests(
1372                &mut hasher,
1373                &elements[range.to_usize_range()],
1374                range_start,
1375                root,
1376            )
1377            .unwrap();
1378        let num_elements = range.end - range.start;
1379        assert!(mid_range_digests.len() as u64 > num_elements);
1380    }
1381
1382    #[test]
1383    fn test_proving_multi_proof_generation_and_verify() {
1384        // Create an MMR with multiple elements
1385        let mut hasher: Standard<Sha256> = Standard::new();
1386        let mut mmr = CleanMmr::new(&mut hasher);
1387        let mut elements = Vec::new();
1388
1389        for i in 0..20 {
1390            elements.push(test_digest(i));
1391            mmr.add(&mut hasher, &elements[i as usize]);
1392        }
1393
1394        let root = mmr.root();
1395
1396        // Generate proof for non-contiguous single elements
1397        let locations = &[
1398            Location::new_unchecked(0),
1399            Location::new_unchecked(5),
1400            Location::new_unchecked(10),
1401        ];
1402        let nodes_for_multi_proof =
1403            nodes_required_for_multi_proof(mmr.size(), locations).expect("test locations valid");
1404        let digests = nodes_for_multi_proof
1405            .into_iter()
1406            .map(|pos| mmr.get_node(pos).unwrap())
1407            .collect();
1408        let multi_proof = Proof {
1409            size: mmr.size(),
1410            digests,
1411        };
1412
1413        assert_eq!(multi_proof.size, mmr.size());
1414
1415        // Verify the proof
1416        assert!(multi_proof.verify_multi_inclusion(
1417            &mut hasher,
1418            &[
1419                (elements[0], Location::new_unchecked(0)),
1420                (elements[5], Location::new_unchecked(5)),
1421                (elements[10], Location::new_unchecked(10)),
1422            ],
1423            root
1424        ));
1425
1426        // Verify in different order
1427        assert!(multi_proof.verify_multi_inclusion(
1428            &mut hasher,
1429            &[
1430                (elements[10], Location::new_unchecked(10)),
1431                (elements[5], Location::new_unchecked(5)),
1432                (elements[0], Location::new_unchecked(0)),
1433            ],
1434            root
1435        ));
1436
1437        // Verify with duplicate items
1438        assert!(multi_proof.verify_multi_inclusion(
1439            &mut hasher,
1440            &[
1441                (elements[0], Location::new_unchecked(0)),
1442                (elements[0], Location::new_unchecked(0)),
1443                (elements[10], Location::new_unchecked(10)),
1444                (elements[5], Location::new_unchecked(5)),
1445            ],
1446            root
1447        ));
1448
1449        // Verify mangling the size to something invalid should fail. Test three cases: valid MMR
1450        // size but wrong value, invalid MMR size, and overflowing MMR size.
1451        let wrong_sizes = [
1452            Position::new(16),
1453            Position::new(17),
1454            Position::new(u64::MAX - 100),
1455        ];
1456        for sz in wrong_sizes {
1457            let mut wrong_size_proof = multi_proof.clone();
1458            wrong_size_proof.size = sz;
1459            assert!(!wrong_size_proof.verify_multi_inclusion(
1460                &mut hasher,
1461                &[
1462                    (elements[0], Location::new_unchecked(0)),
1463                    (elements[5], Location::new_unchecked(5)),
1464                    (elements[10], Location::new_unchecked(10)),
1465                ],
1466                root,
1467            ));
1468        }
1469
1470        // Verify with wrong positions
1471        assert!(!multi_proof.verify_multi_inclusion(
1472            &mut hasher,
1473            &[
1474                (elements[0], Location::new_unchecked(1)),
1475                (elements[5], Location::new_unchecked(6)),
1476                (elements[10], Location::new_unchecked(11)),
1477            ],
1478            root,
1479        ));
1480
1481        // Verify with wrong elements
1482        let wrong_elements = [
1483            vec![255u8, 254u8, 253u8],
1484            vec![252u8, 251u8, 250u8],
1485            vec![249u8, 248u8, 247u8],
1486        ];
1487        let wrong_verification = multi_proof.verify_multi_inclusion(
1488            &mut hasher,
1489            &[
1490                (wrong_elements[0].as_slice(), Location::new_unchecked(0)),
1491                (wrong_elements[1].as_slice(), Location::new_unchecked(5)),
1492                (wrong_elements[2].as_slice(), Location::new_unchecked(10)),
1493            ],
1494            root,
1495        );
1496        assert!(!wrong_verification, "Should fail with wrong elements");
1497
1498        // Verify with out of range element
1499        let wrong_verification = multi_proof.verify_multi_inclusion(
1500            &mut hasher,
1501            &[
1502                (elements[0], Location::new_unchecked(0)),
1503                (elements[5], Location::new_unchecked(5)),
1504                (elements[10], Location::new_unchecked(1000)),
1505            ],
1506            root,
1507        );
1508        assert!(
1509            !wrong_verification,
1510            "Should fail with out of range elements"
1511        );
1512
1513        // Verify with wrong root should fail
1514        let wrong_root = test_digest(99);
1515        assert!(!multi_proof.verify_multi_inclusion(
1516            &mut hasher,
1517            &[
1518                (elements[0], Location::new_unchecked(0)),
1519                (elements[5], Location::new_unchecked(5)),
1520                (elements[10], Location::new_unchecked(10)),
1521            ],
1522            &wrong_root
1523        ));
1524
1525        // Empty multi-proof
1526        let mut hasher: Standard<Sha256> = Standard::new();
1527        let empty_mmr = CleanMmr::new(&mut hasher);
1528        let empty_root = empty_mmr.root();
1529        let empty_proof = Proof::default();
1530        assert!(empty_proof.verify_multi_inclusion(
1531            &mut hasher,
1532            &[] as &[(Digest, Location)],
1533            empty_root
1534        ));
1535    }
1536
1537    #[test]
1538    fn test_proving_multi_proof_deduplication() {
1539        let mut hasher: Standard<Sha256> = Standard::new();
1540        let mut mmr = CleanMmr::new(&mut hasher);
1541        let mut elements = Vec::new();
1542
1543        // Create an MMR with enough elements to have shared digests
1544        for i in 0..30 {
1545            elements.push(test_digest(i));
1546            mmr.add(&mut hasher, &elements[i as usize]);
1547        }
1548
1549        // Get individual proofs that will share some digests (elements in same subtree)
1550        let proof1 = mmr.proof(Location::new_unchecked(0)).unwrap();
1551        let proof2 = mmr.proof(Location::new_unchecked(1)).unwrap();
1552        let total_digests_separate = proof1.digests.len() + proof2.digests.len();
1553
1554        // Generate multi-proof for the same positions
1555        let locations = &[Location::new_unchecked(0), Location::new_unchecked(1)];
1556        let multi_proof =
1557            nodes_required_for_multi_proof(mmr.size(), locations).expect("test locations valid");
1558        let digests = multi_proof
1559            .into_iter()
1560            .map(|pos| mmr.get_node(pos).unwrap())
1561            .collect();
1562        let multi_proof = Proof {
1563            size: mmr.size(),
1564            digests,
1565        };
1566
1567        // The combined proof should have fewer digests due to deduplication
1568        assert!(multi_proof.digests.len() < total_digests_separate);
1569
1570        // Verify it still works
1571        let root = mmr.root();
1572        assert!(multi_proof.verify_multi_inclusion(
1573            &mut hasher,
1574            &[
1575                (elements[0], Location::new_unchecked(0)),
1576                (elements[1], Location::new_unchecked(1))
1577            ],
1578            root
1579        ));
1580    }
1581
1582    #[test]
1583    fn test_max_location_is_provable() {
1584        // Test that the validation logic accepts MAX_LOCATION as a valid location
1585        // We use the maximum valid MMR size (2^63 - 1) which can hold up to 2^62 leaves
1586        let max_mmr_size = Position::new((1u64 << 63) - 1);
1587
1588        let max_loc = Location::new_unchecked(MAX_LOCATION);
1589        let max_loc_plus_1 = Location::new_unchecked(MAX_LOCATION + 1);
1590        let max_loc_plus_2 = Location::new_unchecked(MAX_LOCATION + 2);
1591
1592        // MAX_LOCATION should be accepted by the validation logic
1593        // (The range MAX_LOCATION..MAX_LOCATION+1 proves a single element at MAX_LOCATION)
1594        let result = nodes_required_for_range_proof(max_mmr_size, max_loc..max_loc_plus_1);
1595
1596        // This should succeed - MAX_LOCATION is a valid location
1597        assert!(result.is_ok(), "Should be able to prove MAX_LOCATION");
1598
1599        // MAX_LOCATION + 1 should be rejected (exceeds MAX_LOCATION)
1600        let result_overflow =
1601            nodes_required_for_range_proof(max_mmr_size, max_loc_plus_1..max_loc_plus_2);
1602        assert!(
1603            result_overflow.is_err(),
1604            "Should reject location > MAX_LOCATION"
1605        );
1606        matches!(result_overflow, Err(Error::LocationOverflow(_)));
1607    }
1608
1609    #[test]
1610    fn test_max_location_multi_proof() {
1611        // Test that multi_proof can handle MAX_LOCATION
1612        let max_mmr_size = Position::new((1u64 << 63) - 1);
1613        let max_loc = Location::new_unchecked(MAX_LOCATION);
1614
1615        // Should be able to generate multi-proof for MAX_LOCATION
1616        let result = nodes_required_for_multi_proof(max_mmr_size, &[max_loc]);
1617        assert!(
1618            result.is_ok(),
1619            "Should be able to generate multi-proof for MAX_LOCATION"
1620        );
1621
1622        // Should reject MAX_LOCATION + 1
1623        let invalid_loc = Location::new_unchecked(MAX_LOCATION + 1);
1624        let result_overflow = nodes_required_for_multi_proof(max_mmr_size, &[invalid_loc]);
1625        assert!(
1626            result_overflow.is_err(),
1627            "Should reject location > MAX_LOCATION in multi-proof"
1628        );
1629    }
1630
1631    #[test]
1632    fn test_invalid_size_validation() {
1633        // Test that invalid sizes are rejected
1634        let loc = Location::new_unchecked(0);
1635        let range = loc..loc + 1;
1636
1637        const INVALID_SIZES: [u64; 5] = [2, 5, 6, 9, u64::MAX];
1638        for size in INVALID_SIZES {
1639            let result = nodes_required_for_range_proof(Position::new(size), range.clone());
1640            assert!(matches!(result, Err(Error::InvalidSize(s)) if s == size));
1641
1642            let result = nodes_required_for_multi_proof(Position::new(size), &[loc]);
1643            assert!(matches!(result, Err(Error::InvalidSize(s)) if s == size));
1644        }
1645    }
1646
1647    #[test]
1648    fn test_max_proof_digests_per_element_sufficient() {
1649        // Verify that MAX_PROOF_DIGESTS_PER_ELEMENT (122) is sufficient for any single-element
1650        // proof in the largest valid MMR.
1651        //
1652        // MMR sizes follow: mmr_size(N) = 2*N - popcount(N) where N = leaf count.
1653        // The number of peaks equals popcount(N).
1654        //
1655        // To maximize peaks, we want N with maximum popcount. N = 2^62 - 1 has 62 one-bits:
1656        //   N = 0x3FFFFFFFFFFFFFFF = 2^0 + 2^1 + ... + 2^61
1657        //
1658        // This gives us 62 perfect binary trees with leaf counts 2^0, 2^1, ..., 2^61
1659        // and corresponding heights 0, 1, ..., 61.
1660        //
1661        // mmr_size(2^62 - 1) = 2*(2^62 - 1) - 62 = 2^63 - 2 - 62 = 2^63 - 64
1662        //
1663        // For a single-element proof in a tree of height h:
1664        //   - Path siblings from leaf to peak: h digests
1665        //   - Other peaks (not containing the element): (62 - 1) = 61 digests
1666        //   - Total: h + 61 digests
1667        //
1668        // Worst case: element in tallest tree (h = 61)
1669        //   - Path siblings: 61
1670        //   - Other peaks: 61
1671        //   - Total: 61 + 61 = 122 digests
1672
1673        const NUM_PEAKS: usize = 62;
1674        const MAX_TREE_HEIGHT: usize = 61;
1675        const EXPECTED_WORST_CASE: usize = MAX_TREE_HEIGHT + (NUM_PEAKS - 1);
1676
1677        let many_peaks_size = Position::new((1u64 << 63) - 64);
1678        assert!(
1679            many_peaks_size.is_mmr_size(),
1680            "Size {many_peaks_size} should be a valid MMR size",
1681        );
1682
1683        let peak_count = PeakIterator::new(many_peaks_size).count();
1684        assert_eq!(peak_count, NUM_PEAKS);
1685
1686        // Verify the peak heights are 61, 60, ..., 1, 0 (from left to right)
1687        let peaks: Vec<_> = PeakIterator::new(many_peaks_size).collect();
1688        for (i, &(_pos, height)) in peaks.iter().enumerate() {
1689            let expected_height = (NUM_PEAKS - 1 - i) as u32;
1690            assert_eq!(
1691                height, expected_height,
1692                "Peak {i} should have height {expected_height}, got {height}",
1693            );
1694        }
1695
1696        // Test location 0 (leftmost leaf, in tallest tree of height 61)
1697        // Expected: 61 path siblings + 61 other peaks = 122 digests
1698        let loc = Location::new_unchecked(0);
1699        let positions = nodes_required_for_range_proof(many_peaks_size, loc..loc + 1)
1700            .expect("should compute positions for location 0");
1701
1702        assert_eq!(
1703            positions.len(),
1704            EXPECTED_WORST_CASE,
1705            "Location 0 proof should require exactly {EXPECTED_WORST_CASE} digests (61 path + 61 peaks)",
1706        );
1707
1708        // Test the rightmost leaf (in smallest tree of height 0, which is itself a peak)
1709        // Expected: 0 path siblings + 61 other peaks = 61 digests
1710        let last_leaf_loc = (1u64 << 62) - 2; // Last leaf location
1711        let positions = nodes_required_for_range_proof(
1712            many_peaks_size,
1713            Location::new_unchecked(last_leaf_loc)..Location::new_unchecked(last_leaf_loc + 1),
1714        )
1715        .expect("should compute positions for last leaf");
1716
1717        let expected_last_leaf = NUM_PEAKS - 1;
1718        assert_eq!(
1719            positions.len(),
1720            expected_last_leaf,
1721            "Last leaf proof should require exactly {expected_last_leaf} digests (0 path + 61 peaks)",
1722        );
1723    }
1724
1725    #[test]
1726    fn test_max_proof_digests_per_element_is_maximum() {
1727        // For K peaks, the worst-case proof needs: (max_tree_height) + (K - 1) digests
1728        // With K peaks of heights K-1, K-2, ..., 0, this is (K-1) + (K-1) = 2*(K-1)
1729        //
1730        // To get K peaks, leaf count N must have exactly K bits set.
1731        // MMR size = 2*N - popcount(N) = 2*N - K
1732        //
1733        // For 63 peaks: N = 2^63 - 1 (63 bits set), size = 2*(2^63 - 1) - 63 = 2^64 - 65
1734        // This exceeds MAX_POSITION, so is_mmr_size() returns false.
1735
1736        let n_for_63_peaks = (1u128 << 63) - 1;
1737        let size_for_63_peaks = 2 * n_for_63_peaks - 63; // = 2^64 - 65
1738        assert!(
1739            size_for_63_peaks > *crate::mmr::MAX_POSITION as u128,
1740            "63 peaks requires size {size_for_63_peaks} > MAX_POSITION",
1741        );
1742
1743        let size_truncated = size_for_63_peaks as u64;
1744        assert!(
1745            !Position::new(size_truncated).is_mmr_size(),
1746            "Size for 63 peaks should fail is_mmr_size()"
1747        );
1748    }
1749
1750    #[cfg(feature = "arbitrary")]
1751    mod conformance {
1752        use super::*;
1753        use commonware_codec::conformance::CodecConformance;
1754        use commonware_cryptography::sha256::Digest as Sha256Digest;
1755
1756        commonware_conformance::conformance_tests! {
1757            CodecConformance<Proof<Sha256Digest>>,
1758        }
1759    }
1760}