commonware_storage/mmr/
proof.rs

1//! Defines the inclusion [Proof] structure, and functions for verifying them against a root digest.
2//!
3//! Also provides lower-level functions for building verifiers against new or extended proof types.
4//! These lower level functions are kept outside of the [Proof] structure and not re-exported by the
5//! parent module.
6
7#[cfg(any(feature = "std", test))]
8use crate::mmr::iterator::nodes_to_pin;
9use crate::mmr::{
10    hasher::Hasher,
11    iterator::{PathIterator, PeakIterator},
12    Error, Location, Position,
13};
14use alloc::{
15    collections::{btree_map::BTreeMap, btree_set::BTreeSet},
16    vec,
17    vec::Vec,
18};
19use bytes::{Buf, BufMut};
20use commonware_codec::{varint::UInt, EncodeSize, Read, ReadExt, ReadRangeExt, Write};
21use commonware_cryptography::Digest;
22use core::ops::Range;
23#[cfg(feature = "std")]
24use tracing::debug;
25
26/// Errors that can occur when reconstructing a digest from a proof due to invalid input.
27#[derive(Error, Debug)]
28pub enum ReconstructionError {
29    #[error("missing digests in proof")]
30    MissingDigests,
31    #[error("extra digests in proof")]
32    ExtraDigests,
33    #[error("start location is out of bounds")]
34    InvalidStartLoc,
35    #[error("end location is out of bounds")]
36    InvalidEndLoc,
37    #[error("missing elements")]
38    MissingElements,
39    #[error("invalid size")]
40    InvalidSize,
41}
42
43/// Contains the information necessary for proving the inclusion of an element, or some range of
44/// elements, in the MMR from its root digest.
45///
46/// The `digests` vector contains:
47///
48/// 1: the digests of each peak corresponding to a mountain containing no elements from the element
49/// range being proven in decreasing order of height, followed by:
50///
51/// 2: the nodes in the remaining mountains necessary for reconstructing their peak digests from the
52/// elements within the range, ordered by the position of their parent.
53#[derive(Clone, Debug, Eq)]
54pub struct Proof<D: Digest> {
55    /// The total number of nodes in the MMR for MMR proofs, though other authenticated data
56    /// structures may override the meaning of this field. For example, the authenticated
57    /// [crate::AuthenticatedBitMap] stores the number of bits in the bitmap within
58    /// this field.
59    pub size: Position,
60    /// The digests necessary for proving the inclusion of an element, or range of elements, in the
61    /// MMR.
62    pub digests: Vec<D>,
63}
64
65impl<D: Digest> PartialEq for Proof<D> {
66    fn eq(&self, other: &Self) -> bool {
67        self.size == other.size && self.digests == other.digests
68    }
69}
70
71impl<D: Digest> EncodeSize for Proof<D> {
72    fn encode_size(&self) -> usize {
73        UInt(*self.size).encode_size() + self.digests.encode_size()
74    }
75}
76
77impl<D: Digest> Write for Proof<D> {
78    fn write(&self, buf: &mut impl BufMut) {
79        // Write the number of nodes in the MMR as a varint
80        UInt(*self.size).write(buf);
81
82        // Write the digests
83        self.digests.write(buf);
84    }
85}
86
87impl<D: Digest> Read for Proof<D> {
88    /// The maximum number of digests in the proof.
89    type Cfg = usize;
90
91    fn read_cfg(buf: &mut impl Buf, max_len: &Self::Cfg) -> Result<Self, commonware_codec::Error> {
92        // Read the number of nodes in the MMR
93        let size = Position::new(UInt::<u64>::read(buf)?.into());
94
95        // Read the digests
96        let range = ..=max_len;
97        let digests = Vec::<D>::read_range(buf, range)?;
98        Ok(Self { size, digests })
99    }
100}
101
102impl<D: Digest> Default for Proof<D> {
103    /// Create an empty proof. The empty proof will verify only against the root digest of an empty
104    /// (`size == 0`) MMR.
105    fn default() -> Self {
106        Self {
107            size: Position::new(0),
108            digests: vec![],
109        }
110    }
111}
112
113#[cfg(feature = "arbitrary")]
114impl<D: Digest> arbitrary::Arbitrary<'_> for Proof<D>
115where
116    D: for<'a> arbitrary::Arbitrary<'a>,
117{
118    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
119        Ok(Self {
120            size: u.arbitrary()?,
121            digests: u.arbitrary()?,
122        })
123    }
124}
125
126impl<D: Digest> Proof<D> {
127    /// Return true if this proof proves that `element` appears at location `loc` within the MMR
128    /// with root digest `root`.
129    pub fn verify_element_inclusion<H>(
130        &self,
131        hasher: &mut H,
132        element: &[u8],
133        loc: Location,
134        root: &D,
135    ) -> bool
136    where
137        H: Hasher<D>,
138    {
139        self.verify_range_inclusion(hasher, &[element], loc, root)
140    }
141
142    /// Return true if this proof proves that the `elements` appear consecutively starting at
143    /// position `start_loc` within the MMR with root digest `root`. A malformed proof will return
144    /// false.
145    pub fn verify_range_inclusion<H, E>(
146        &self,
147        hasher: &mut H,
148        elements: &[E],
149        start_loc: Location,
150        root: &D,
151    ) -> bool
152    where
153        H: Hasher<D>,
154        E: AsRef<[u8]>,
155    {
156        if !self.size.is_mmr_size() {
157            #[cfg(feature = "std")]
158            debug!(size = ?self.size, "invalid proof size");
159            return false;
160        }
161
162        match self.reconstruct_root(hasher, elements, start_loc) {
163            Ok(reconstructed_root) => *root == reconstructed_root,
164            Err(_error) => {
165                #[cfg(feature = "std")]
166                tracing::debug!(error = ?_error, "invalid proof input");
167                false
168            }
169        }
170    }
171
172    /// Return true if this proof proves that the elements at the specified locations are included
173    /// in the MMR with the root digest `root`. A malformed proof will return false.
174    ///
175    /// The order of the elements does not affect the output.
176    pub fn verify_multi_inclusion<H, E>(
177        &self,
178        hasher: &mut H,
179        elements: &[(E, Location)],
180        root: &D,
181    ) -> bool
182    where
183        H: Hasher<D>,
184        E: AsRef<[u8]>,
185    {
186        // Empty proof is valid for an empty MMR
187        if elements.is_empty() {
188            return self.size == Position::new(0)
189                && *root == hasher.root(Position::new(0), core::iter::empty());
190        }
191        if !self.size.is_mmr_size() {
192            return false;
193        }
194
195        // Single pass to collect all required positions with deduplication
196        let mut node_positions = BTreeSet::new();
197        let mut nodes_required = BTreeMap::new();
198
199        for (_, loc) in elements {
200            if !loc.is_valid() {
201                return false;
202            }
203            // `loc` is valid so it won't overflow from +1
204            let Ok(required) = nodes_required_for_range_proof(self.size, *loc..*loc + 1) else {
205                return false;
206            };
207            for req_pos in &required {
208                node_positions.insert(*req_pos);
209            }
210            nodes_required.insert(*loc, required);
211        }
212
213        // Verify we have the exact number of digests needed
214        if node_positions.len() != self.digests.len() {
215            return false;
216        }
217
218        // Build position to digest mapping once
219        let node_digests: BTreeMap<Position, D> = node_positions
220            .iter()
221            .zip(self.digests.iter())
222            .map(|(&pos, digest)| (pos, *digest))
223            .collect();
224
225        // Verify each element by reconstructing its path
226        for (element, loc) in elements {
227            // Get required positions for this element
228            let required = &nodes_required[loc];
229
230            // Build proof with required digests
231            let mut digests = Vec::with_capacity(required.len());
232            for req_pos in required {
233                // There must exist a digest for each required position (by
234                // construction of `node_digests`)
235                let digest = node_digests
236                    .get(req_pos)
237                    .expect("must exist by construction of node_digests");
238                digests.push(*digest);
239            }
240            let proof = Self {
241                size: self.size,
242                digests,
243            };
244
245            // Verify the proof
246            if !proof.verify_element_inclusion(hasher, element.as_ref(), *loc, root) {
247                return false;
248            }
249        }
250
251        true
252    }
253
254    // The functions below are lower level functions that are useful to building verification
255    // functions for new or extended proof types.
256
257    /// Computes the set of pinned nodes for the pruning boundary corresponding to the start of the
258    /// given range, returning the digest of each by extracting it from the proof.
259    ///
260    /// # Arguments
261    /// * `range` - The start and end locations of the proven range, where start is also used as the
262    ///   pruning boundary.
263    ///
264    /// # Returns
265    /// A Vec of digests for all nodes in `nodes_to_pin(pruning_boundary)`, in the same order as
266    /// returned by `nodes_to_pin` (decreasing height order)
267    ///
268    /// # Errors
269    ///
270    /// Returns [Error::InvalidSize] if the proof size is not a valid MMR size.
271    /// Returns [Error::LocationOverflow] if a location in `range` > [crate::mmr::MAX_LOCATION].
272    /// Returns [Error::InvalidProofLength] if the proof digest count doesn't match the required
273    /// positions count.
274    /// Returns [Error::MissingDigest] if a pinned node is not found in the proof.
275    #[cfg(any(feature = "std", test))]
276    pub(crate) fn extract_pinned_nodes(
277        &self,
278        range: std::ops::Range<Location>,
279    ) -> Result<Vec<D>, Error> {
280        // Get the positions of all nodes that should be pinned.
281        let start_pos = Position::try_from(range.start)?;
282        let pinned_positions: Vec<Position> = nodes_to_pin(start_pos).collect();
283
284        // Get all positions required for the proof.
285        let required_positions = nodes_required_for_range_proof(self.size, range)?;
286
287        if required_positions.len() != self.digests.len() {
288            #[cfg(feature = "std")]
289            debug!(
290                digests_len = self.digests.len(),
291                required_positions_len = required_positions.len(),
292                "Proof digest count doesn't match required positions",
293            );
294            return Err(Error::InvalidProofLength);
295        }
296
297        // Happy path: we can extract the pinned nodes directly from the proof.
298        // This happens when the `end_element_pos` is the last element in the MMR.
299        if pinned_positions
300            == required_positions[required_positions.len() - pinned_positions.len()..]
301        {
302            return Ok(self.digests[required_positions.len() - pinned_positions.len()..].to_vec());
303        }
304
305        // Create a mapping from position to digest.
306        let position_to_digest: BTreeMap<Position, D> = required_positions
307            .iter()
308            .zip(self.digests.iter())
309            .map(|(&pos, &digest)| (pos, digest))
310            .collect();
311
312        // Extract the pinned nodes in the same order as nodes_to_pin.
313        let mut result = Vec::with_capacity(pinned_positions.len());
314        for pinned_pos in pinned_positions {
315            let Some(&digest) = position_to_digest.get(&pinned_pos) else {
316                #[cfg(feature = "std")]
317                debug!(?pinned_pos, "Pinned node not found in proof");
318                return Err(Error::MissingDigest(pinned_pos));
319            };
320            result.push(digest);
321        }
322        Ok(result)
323    }
324
325    /// Reconstructs the root digest of the MMR from the digests in the proof and the provided range
326    /// of elements, returning the (position,digest) of every node whose digest was required by the
327    /// process (including those from the proof itself). Returns a [Error::InvalidProof] if the
328    /// input data is invalid and [Error::RootMismatch] if the root does not match the computed
329    /// root.
330    pub fn verify_range_inclusion_and_extract_digests<H, E>(
331        &self,
332        hasher: &mut H,
333        elements: &[E],
334        start_loc: Location,
335        root: &D,
336    ) -> Result<Vec<(Position, D)>, super::Error>
337    where
338        H: Hasher<D>,
339        E: AsRef<[u8]>,
340    {
341        let mut collected_digests = Vec::new();
342        let Ok(peak_digests) = self.reconstruct_peak_digests(
343            hasher,
344            elements,
345            start_loc,
346            Some(&mut collected_digests),
347        ) else {
348            return Err(Error::InvalidProof);
349        };
350
351        if hasher.root(self.size, peak_digests.iter()) != *root {
352            return Err(Error::RootMismatch);
353        }
354
355        Ok(collected_digests)
356    }
357
358    /// Reconstructs the root digest of the MMR from the digests in the proof and the provided range
359    /// of elements, or returns a [ReconstructionError] if the input data is invalid.
360    pub fn reconstruct_root<H, E>(
361        &self,
362        hasher: &mut H,
363        elements: &[E],
364        start_loc: Location,
365    ) -> Result<D, ReconstructionError>
366    where
367        H: Hasher<D>,
368        E: AsRef<[u8]>,
369    {
370        if !self.size.is_mmr_size() {
371            return Err(ReconstructionError::InvalidSize);
372        }
373
374        let peak_digests = self.reconstruct_peak_digests(hasher, elements, start_loc, None)?;
375
376        Ok(hasher.root(self.size, peak_digests.iter()))
377    }
378
379    /// Reconstruct the peak digests of the MMR that produced this proof, returning
380    /// [ReconstructionError] if the input data is invalid.  If collected_digests is Some, then all
381    /// node digests used in the process will be added to the wrapped vector.
382    pub fn reconstruct_peak_digests<H, E>(
383        &self,
384        hasher: &mut H,
385        elements: &[E],
386        start_loc: Location,
387        mut collected_digests: Option<&mut Vec<(Position, D)>>,
388    ) -> Result<Vec<D>, ReconstructionError>
389    where
390        H: Hasher<D>,
391        E: AsRef<[u8]>,
392    {
393        if elements.is_empty() {
394            if start_loc == 0 {
395                return Ok(vec![]);
396            }
397            return Err(ReconstructionError::MissingElements);
398        }
399        let start_element_pos =
400            Position::try_from(start_loc).map_err(|_| ReconstructionError::InvalidStartLoc)?;
401        let end_element_pos = if elements.len() == 1 {
402            start_element_pos
403        } else {
404            let end_loc = start_loc
405                .checked_add(elements.len() as u64 - 1)
406                .ok_or(ReconstructionError::InvalidEndLoc)?;
407            Position::try_from(end_loc).map_err(|_| ReconstructionError::InvalidEndLoc)?
408        };
409        if end_element_pos >= self.size {
410            return Err(ReconstructionError::InvalidEndLoc);
411        }
412
413        let mut proof_digests_iter = self.digests.iter();
414        let mut siblings_iter = self.digests.iter().rev();
415
416        // Include peak digests only for trees that have no elements from the range, and keep track
417        // of the starting and ending trees of those that do contain some.
418        let mut peak_digests: Vec<D> = Vec::new();
419        let mut proof_digests_used = 0;
420        let mut elements_iter = elements.iter();
421        if !Position::is_mmr_size(self.size) {
422            return Err(ReconstructionError::InvalidSize);
423        }
424        for (peak_pos, height) in PeakIterator::new(self.size) {
425            let leftmost_pos = peak_pos + 2 - (1 << (height + 1));
426            if peak_pos >= start_element_pos && leftmost_pos <= end_element_pos {
427                let hash = peak_digest_from_range(
428                    hasher,
429                    RangeInfo {
430                        pos: peak_pos,
431                        two_h: 1 << height,
432                        leftmost_pos: start_element_pos,
433                        rightmost_pos: end_element_pos,
434                    },
435                    &mut elements_iter,
436                    &mut siblings_iter,
437                    collected_digests.as_deref_mut(),
438                )?;
439                peak_digests.push(hash);
440                if let Some(ref mut collected_digests) = collected_digests {
441                    collected_digests.push((peak_pos, hash));
442                }
443            } else if let Some(hash) = proof_digests_iter.next() {
444                proof_digests_used += 1;
445                peak_digests.push(*hash);
446                if let Some(ref mut collected_digests) = collected_digests {
447                    collected_digests.push((peak_pos, *hash));
448                }
449            } else {
450                return Err(ReconstructionError::MissingDigests);
451            }
452        }
453
454        if elements_iter.next().is_some() {
455            return Err(ReconstructionError::ExtraDigests);
456        }
457        if let Some(next_sibling) = siblings_iter.next() {
458            if proof_digests_used == 0 || *next_sibling != self.digests[proof_digests_used - 1] {
459                return Err(ReconstructionError::ExtraDigests);
460            }
461        }
462
463        Ok(peak_digests)
464    }
465}
466
467/// Return the list of node positions required by the range proof for the specified range of
468/// elements.
469///
470/// # Errors
471///
472/// Returns [Error::InvalidSize] if `size` is not a valid MMR size.
473/// Returns [Error::Empty] if the range is empty.
474/// Returns [Error::LocationOverflow] if a location in `range` > [crate::mmr::MAX_LOCATION].
475/// Returns [Error::RangeOutOfBounds] if the last element position in `range` is out of bounds
476/// (>= `size`).
477pub(crate) fn nodes_required_for_range_proof(
478    size: Position,
479    range: Range<Location>,
480) -> Result<Vec<Position>, Error> {
481    if !size.is_mmr_size() {
482        return Err(Error::InvalidSize(*size));
483    }
484    if range.is_empty() {
485        return Err(Error::Empty);
486    }
487    let start_element_pos = Position::try_from(range.start)?;
488    let end_minus_one = range
489        .end
490        .checked_sub(1)
491        .expect("can't underflow because range is non-empty");
492    let end_element_pos = Position::try_from(end_minus_one)?;
493    if end_element_pos >= size {
494        return Err(Error::RangeOutOfBounds(range.end));
495    }
496
497    // Find the mountains that contain no elements from the range. The peaks of these mountains
498    // are required to prove the range, so they are added to the result.
499    let mut start_tree_with_element: Option<(Position, u32)> = None;
500    let mut end_tree_with_element: Option<(Position, u32)> = None;
501    let mut peak_iterator = PeakIterator::new(size);
502    let mut positions = Vec::new();
503    while let Some(peak) = peak_iterator.next() {
504        if start_tree_with_element.is_none() && peak.0 >= start_element_pos {
505            // Found the first tree to contain an element in the range
506            start_tree_with_element = Some(peak);
507            if peak.0 >= end_element_pos {
508                // Start and end tree are the same
509                end_tree_with_element = Some(peak);
510                continue;
511            }
512            for peak in peak_iterator.by_ref() {
513                if peak.0 >= end_element_pos {
514                    // Found the last tree to contain an element in the range
515                    end_tree_with_element = Some(peak);
516                    break;
517                }
518            }
519        } else {
520            // Tree is outside the range, its peak is thus required.
521            positions.push(peak.0);
522        }
523    }
524
525    // We checked above that all range elements are in this MMR, so some mountain must contain
526    // the first and last elements in the range.
527    let (start_tree_peak, start_tree_height) =
528        start_tree_with_element.expect("start_tree_with_element is Some");
529    let (end_tree_peak, end_tree_height) =
530        end_tree_with_element.expect("end_tree_with_element is Some");
531
532    // Include the positions of any left-siblings of each node on the path from peak to
533    // leftmost-leaf, and right-siblings for the path from peak to rightmost-leaf. These are
534    // added in order of decreasing parent position.
535    let left_path_iter = PathIterator::new(start_element_pos, start_tree_peak, start_tree_height);
536
537    let mut siblings = Vec::new();
538    if start_element_pos == end_element_pos {
539        // For the (common) case of a single element range, the right and left path are the
540        // same so no need to process each independently.
541        siblings.extend(left_path_iter);
542    } else {
543        let right_path_iter = PathIterator::new(end_element_pos, end_tree_peak, end_tree_height);
544        // filter the right path for right siblings only
545        siblings.extend(right_path_iter.filter(|(parent_pos, pos)| *parent_pos == *pos + 1));
546        // filter the left path for left siblings only
547        siblings.extend(left_path_iter.filter(|(parent_pos, pos)| *parent_pos != *pos + 1));
548
549        // If the range spans more than one tree, then the digests must already be in the correct
550        // order. Otherwise, we enforce the desired order through sorting.
551        if start_tree_peak == end_tree_peak {
552            siblings.sort_by(|a, b| b.0.cmp(&a.0));
553        }
554    }
555    positions.extend(siblings.into_iter().map(|(_, pos)| pos));
556    Ok(positions)
557}
558
559/// Returns the positions of the minimal set of nodes whose digests are required to prove the
560/// inclusion of the elements at the specified `locations`.
561///
562/// The order of positions does not affect the output (sorted internally).
563///
564/// # Errors
565///
566/// Returns [Error::InvalidSize] if `size` is not a valid MMR size.
567/// Returns [Error::Empty] if locations is empty.
568/// Returns [Error::LocationOverflow] if any location in `locations` > [crate::mmr::MAX_LOCATION].
569/// Returns [Error::RangeOutOfBounds] if any location is out of bounds for the given `size`.
570#[cfg(any(feature = "std", test))]
571pub(crate) fn nodes_required_for_multi_proof(
572    size: Position,
573    locations: &[Location],
574) -> Result<BTreeSet<Position>, Error> {
575    if !size.is_mmr_size() {
576        return Err(Error::InvalidSize(*size));
577    }
578    // Collect all required node positions
579    //
580    // TODO(#1472): Optimize this loop
581    if locations.is_empty() {
582        return Err(Error::Empty);
583    }
584    locations.iter().try_fold(BTreeSet::new(), |mut acc, loc| {
585        if !loc.is_valid() {
586            return Err(Error::LocationOverflow(*loc));
587        }
588        // `loc` is valid so it won't overflow from +1
589        let positions = nodes_required_for_range_proof(size, *loc..*loc + 1)?;
590        acc.extend(positions);
591        Ok(acc)
592    })
593}
594
595/// Information about the current range of nodes being traversed.
596struct RangeInfo {
597    pos: Position,           // current node position in the tree
598    two_h: u64,              // 2^height of the current node
599    leftmost_pos: Position,  // leftmost leaf in the tree to be traversed
600    rightmost_pos: Position, // rightmost leaf in the tree to be traversed
601}
602
603fn peak_digest_from_range<'a, D, H, E, S>(
604    hasher: &mut H,
605    range_info: RangeInfo,
606    elements: &mut E,
607    sibling_digests: &mut S,
608    mut collected_digests: Option<&mut Vec<(Position, D)>>,
609) -> Result<D, ReconstructionError>
610where
611    D: Digest,
612    H: Hasher<D>,
613    E: Iterator<Item: AsRef<[u8]>>,
614    S: Iterator<Item = &'a D>,
615{
616    assert_ne!(range_info.two_h, 0);
617    if range_info.two_h == 1 {
618        match elements.next() {
619            Some(element) => return Ok(hasher.leaf_digest(range_info.pos, element.as_ref())),
620            None => return Err(ReconstructionError::MissingDigests),
621        }
622    }
623
624    let mut left_digest: Option<D> = None;
625    let mut right_digest: Option<D> = None;
626
627    let left_pos = range_info.pos - range_info.two_h;
628    let right_pos = left_pos + range_info.two_h - 1;
629    if left_pos >= range_info.leftmost_pos {
630        // Descend left
631        let digest = peak_digest_from_range(
632            hasher,
633            RangeInfo {
634                pos: left_pos,
635                two_h: range_info.two_h >> 1,
636                leftmost_pos: range_info.leftmost_pos,
637                rightmost_pos: range_info.rightmost_pos,
638            },
639            elements,
640            sibling_digests,
641            collected_digests.as_deref_mut(),
642        )?;
643        left_digest = Some(digest);
644    }
645    if left_pos < range_info.rightmost_pos {
646        // Descend right
647        let digest = peak_digest_from_range(
648            hasher,
649            RangeInfo {
650                pos: right_pos,
651                two_h: range_info.two_h >> 1,
652                leftmost_pos: range_info.leftmost_pos,
653                rightmost_pos: range_info.rightmost_pos,
654            },
655            elements,
656            sibling_digests,
657            collected_digests.as_deref_mut(),
658        )?;
659        right_digest = Some(digest);
660    }
661
662    if left_digest.is_none() {
663        match sibling_digests.next() {
664            Some(hash) => left_digest = Some(*hash),
665            None => return Err(ReconstructionError::MissingDigests),
666        }
667    }
668    if right_digest.is_none() {
669        match sibling_digests.next() {
670            Some(hash) => right_digest = Some(*hash),
671            None => return Err(ReconstructionError::MissingDigests),
672        }
673    }
674
675    if let Some(ref mut collected_digests) = collected_digests {
676        collected_digests.push((
677            left_pos,
678            left_digest.expect("left_digest guaranteed to be Some after checks above"),
679        ));
680        collected_digests.push((
681            right_pos,
682            right_digest.expect("right_digest guaranteed to be Some after checks above"),
683        ));
684    }
685
686    Ok(hasher.node_digest(
687        range_info.pos,
688        &left_digest.expect("left_digest guaranteed to be Some after checks above"),
689        &right_digest.expect("right_digest guaranteed to be Some after checks above"),
690    ))
691}
692
693#[cfg(test)]
694mod tests {
695    use super::*;
696    use crate::mmr::{
697        hasher::Standard, location::LocationRangeExt as _, mem::CleanMmr, MAX_LOCATION,
698    };
699    use bytes::Bytes;
700    use commonware_codec::{Decode, Encode};
701    use commonware_cryptography::{sha256::Digest, Hasher, Sha256};
702    use commonware_macros::test_traced;
703
704    fn test_digest(v: u8) -> Digest {
705        Sha256::hash(&[v])
706    }
707
708    #[test]
709    fn test_proving_proof() {
710        // Test that an empty proof authenticates an empty MMR.
711        let mut hasher: Standard<Sha256> = Standard::new();
712        let mmr = CleanMmr::new(&mut hasher);
713        let root = mmr.root();
714        let proof = Proof::default();
715        assert!(proof.verify_range_inclusion(
716            &mut hasher,
717            &[] as &[Digest],
718            Location::new_unchecked(0),
719            root
720        ));
721
722        // Any starting position other than 0 should fail to verify.
723        assert!(!proof.verify_range_inclusion(
724            &mut hasher,
725            &[] as &[Digest],
726            Location::new_unchecked(1),
727            root
728        ));
729
730        // Invalid root should fail to verify.
731        let test_digest = test_digest(0);
732        assert!(!proof.verify_range_inclusion(
733            &mut hasher,
734            &[] as &[Digest],
735            Location::new_unchecked(0),
736            &test_digest
737        ));
738
739        // Non-empty elements list should fail to verify.
740        assert!(!proof.verify_range_inclusion(
741            &mut hasher,
742            &[test_digest],
743            Location::new_unchecked(0),
744            root
745        ));
746    }
747
748    #[test]
749    fn test_proving_verify_element() {
750        // create an 11 element MMR over which we'll test single-element inclusion proofs
751        let element = Digest::from(*b"01234567012345670123456701234567");
752        let mut hasher: Standard<Sha256> = Standard::new();
753        let mut mmr = CleanMmr::new(&mut hasher);
754        for _ in 0..11 {
755            mmr.add(&mut hasher, &element);
756        }
757        let root = mmr.root();
758
759        // confirm the proof of inclusion for each leaf successfully verifies
760        for leaf in 0u64..11 {
761            let leaf = Location::new_unchecked(leaf);
762            let proof: Proof<Digest> = mmr.proof(leaf).unwrap();
763            assert!(
764                proof.verify_element_inclusion(&mut hasher, &element, leaf, root),
765                "valid proof should verify successfully"
766            );
767        }
768
769        // Create a valid proof, then confirm various mangling of the proof or proof args results in
770        // verification failure.
771        const LEAF: Location = Location::new_unchecked(10);
772        let proof = mmr.proof(LEAF).unwrap();
773        assert!(
774            proof.verify_element_inclusion(&mut hasher, &element, LEAF, root),
775            "proof verification should be successful"
776        );
777        let wrong_sizes = [0, 16, 17, 18, 20, u64::MAX - 100];
778        for sz in wrong_sizes {
779            let mut wrong_size_proof = proof.clone();
780            wrong_size_proof.size = Position::new(sz);
781            assert!(
782                !wrong_size_proof.verify_element_inclusion(&mut hasher, &element, LEAF, root),
783                "proof with wrong size should fail verification"
784            );
785        }
786        assert!(
787            !proof.verify_element_inclusion(&mut hasher, &element, LEAF + 1, root),
788            "proof verification should fail with incorrect element position"
789        );
790        assert!(
791            !proof.verify_element_inclusion(&mut hasher, &element, LEAF - 1, root),
792            "proof verification should fail with incorrect element position 2"
793        );
794        assert!(
795            !proof.verify_element_inclusion(&mut hasher, &test_digest(0), LEAF, root),
796            "proof verification should fail with mangled element"
797        );
798        let root2 = test_digest(0);
799        assert!(
800            !proof.verify_element_inclusion(&mut hasher, &element, LEAF, &root2),
801            "proof verification should fail with mangled root"
802        );
803        let mut proof2 = proof.clone();
804        proof2.digests[0] = test_digest(0);
805        assert!(
806            !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, root),
807            "proof verification should fail with mangled proof hash"
808        );
809        proof2 = proof.clone();
810        proof2.size = Position::new(10);
811        assert!(
812            !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, root),
813            "proof verification should fail with incorrect size"
814        );
815        proof2 = proof.clone();
816        proof2.digests.push(test_digest(0));
817        assert!(
818            !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, root),
819            "proof verification should fail with extra hash"
820        );
821        proof2 = proof.clone();
822        while !proof2.digests.is_empty() {
823            proof2.digests.pop();
824            assert!(
825                !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, root),
826                "proof verification should fail with missing digests"
827            );
828        }
829        proof2 = proof.clone();
830        proof2.digests.clear();
831        const PEAK_COUNT: usize = 3;
832        proof2
833            .digests
834            .extend(proof.digests[0..PEAK_COUNT - 1].iter().cloned());
835        // sneak in an extra hash that won't be used in the computation and make sure it's
836        // detected
837        proof2.digests.push(test_digest(0));
838        proof2
839            .digests
840            .extend(proof.digests[PEAK_COUNT - 1..].iter().cloned());
841        assert!(
842            !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, root),
843            "proof verification should fail with extra hash even if it's unused by the computation"
844        );
845    }
846
847    #[test]
848    fn test_proving_verify_range() {
849        // create a new MMR and add a non-trivial amount (49) of elements
850        let mut hasher: Standard<Sha256> = Standard::new();
851        let mut mmr = CleanMmr::new(&mut hasher);
852        let mut elements = Vec::new();
853        for i in 0..49 {
854            elements.push(test_digest(i));
855            mmr.add(&mut hasher, elements.last().unwrap());
856        }
857        // test range proofs over all possible ranges of at least 2 elements
858        let root = mmr.root();
859
860        for i in 0..elements.len() {
861            for j in i + 1..elements.len() {
862                let range = Location::new_unchecked(i as u64)..Location::new_unchecked(j as u64);
863                let range_proof = mmr.range_proof(range.clone()).unwrap();
864                assert!(
865                    range_proof.verify_range_inclusion(
866                        &mut hasher,
867                        &elements[range.to_usize_range()],
868                        range.start,
869                        root,
870                    ),
871                    "valid range proof should verify successfully {i}:{j}",
872                );
873            }
874        }
875
876        // Create a proof over a range of elements, confirm it verifies successfully, then mangle
877        // the proof & proof input in various ways, confirming verification fails.
878        let range = Location::new_unchecked(33)..Location::new_unchecked(40);
879        let range_proof = mmr.range_proof(range.clone()).unwrap();
880        let valid_elements = &elements[range.to_usize_range()];
881        assert!(
882            range_proof.verify_range_inclusion(&mut hasher, valid_elements, range.start, root),
883            "valid range proof should verify successfully"
884        );
885        // Remove digests from the proof until it's empty, confirming proof verification fails for
886        // each.
887        let mut invalid_proof = range_proof.clone();
888        for _i in 0..range_proof.digests.len() {
889            invalid_proof.digests.remove(0);
890            assert!(
891                !invalid_proof.verify_range_inclusion(
892                    &mut hasher,
893                    valid_elements,
894                    range.start,
895                    root,
896                ),
897                "range proof with removed elements should fail"
898            );
899        }
900        // Confirm proof verification fails when providing an element range different than the one
901        // used to generate the proof.
902        for i in 0..elements.len() {
903            for j in i + 1..elements.len() {
904                if Location::from(i) == range.start && Location::from(j) == range.end {
905                    // skip the valid range
906                    continue;
907                }
908                assert!(
909                    !range_proof.verify_range_inclusion(
910                        &mut hasher,
911                        &elements[i..j],
912                        range.start,
913                        root,
914                    ),
915                    "range proof with invalid element range should fail {i}:{j}",
916                );
917            }
918        }
919        // Confirm proof fails to verify with an invalid root.
920        let invalid_root = test_digest(1);
921        assert!(
922            !range_proof.verify_range_inclusion(
923                &mut hasher,
924                valid_elements,
925                range.start,
926                &invalid_root,
927            ),
928            "range proof with invalid root should fail"
929        );
930        // Mangle each element of the proof and confirm it fails to verify.
931        for i in 0..range_proof.digests.len() {
932            let mut invalid_proof = range_proof.clone();
933            invalid_proof.digests[i] = test_digest(0);
934
935            assert!(
936                !invalid_proof.verify_range_inclusion(
937                    &mut hasher,
938                    valid_elements,
939                    range.start,
940                    root,
941                ),
942                "mangled range proof should fail verification"
943            );
944        }
945        // Inserting elements into the proof should also cause it to fail (malleability check)
946        for i in 0..range_proof.digests.len() {
947            let mut invalid_proof = range_proof.clone();
948            invalid_proof.digests.insert(i, test_digest(0));
949            assert!(
950                !invalid_proof.verify_range_inclusion(
951                    &mut hasher,
952                    valid_elements,
953                    range.start,
954                    root,
955                ),
956                "mangled range proof should fail verification. inserted element at: {i}",
957            );
958        }
959        // Bad start_loc should cause verification to fail.
960        for loc in 0..elements.len() {
961            let loc = Location::new_unchecked(loc as u64);
962            if loc == range.start {
963                continue;
964            }
965            assert!(
966                !range_proof.verify_range_inclusion(&mut hasher, valid_elements, loc, root),
967                "bad start_loc should fail verification {loc}",
968            );
969        }
970    }
971
972    #[test_traced]
973    fn test_proving_retained_nodes_provable_after_pruning() {
974        // create a new MMR and add a non-trivial amount (49) of elements
975        let mut hasher: Standard<Sha256> = Standard::new();
976        let mut mmr = CleanMmr::new(&mut hasher);
977        let mut elements = Vec::new();
978        for i in 0..49 {
979            elements.push(test_digest(i));
980            mmr.add(&mut hasher, elements.last().unwrap());
981        }
982
983        // Confirm we can successfully prove all retained elements in the MMR after pruning.
984        let root = *mmr.root();
985        for i in 1..*mmr.size() {
986            mmr.prune_to_pos(Position::new(i));
987            let pruned_root = mmr.root();
988            assert_eq!(root, *pruned_root);
989            for loc in 0..elements.len() {
990                let loc = Location::new_unchecked(loc as u64);
991                let proof = mmr.proof(loc);
992                if Position::try_from(loc).unwrap() < Position::new(i) {
993                    continue;
994                }
995                assert!(proof.is_ok());
996                assert!(proof.unwrap().verify_element_inclusion(
997                    &mut hasher,
998                    &elements[*loc as usize],
999                    loc,
1000                    &root
1001                ));
1002            }
1003        }
1004    }
1005
1006    #[test]
1007    fn test_proving_ranges_provable_after_pruning() {
1008        // create a new MMR and add a non-trivial amount (49) of elements
1009        let mut hasher: Standard<Sha256> = Standard::new();
1010        let mut mmr = CleanMmr::new(&mut hasher);
1011        let mut elements = Vec::new();
1012        for i in 0..49 {
1013            elements.push(test_digest(i));
1014            mmr.add(&mut hasher, elements.last().unwrap());
1015        }
1016
1017        // prune up to the first peak
1018        const PRUNE_POS: Position = Position::new(62);
1019        mmr.prune_to_pos(PRUNE_POS);
1020        assert_eq!(mmr.oldest_retained_pos().unwrap(), PRUNE_POS);
1021
1022        // Test range proofs over all possible ranges of at least 2 elements
1023        let root = mmr.root();
1024        for i in 0..elements.len() - 1 {
1025            if Position::try_from(Location::new_unchecked(i as u64)).unwrap() < PRUNE_POS {
1026                continue;
1027            }
1028            for j in (i + 2)..elements.len() {
1029                let range = Location::new_unchecked(i as u64)..Location::new_unchecked(j as u64);
1030                let range_proof = mmr.range_proof(range.clone()).unwrap();
1031                assert!(
1032                    range_proof.verify_range_inclusion(
1033                        &mut hasher,
1034                        &elements[range.to_usize_range()],
1035                        range.start,
1036                        root,
1037                    ),
1038                    "valid range proof over remaining elements should verify successfully",
1039                );
1040            }
1041        }
1042
1043        // Add a few more nodes, prune again, and test again to make sure repeated pruning doesn't
1044        // break proof verification.
1045        for i in 0..37 {
1046            elements.push(test_digest(i));
1047            mmr.add(&mut hasher, elements.last().unwrap());
1048        }
1049        mmr.prune_to_pos(Position::new(130)); // a bit after the new highest peak
1050        assert_eq!(mmr.oldest_retained_pos().unwrap(), 130);
1051
1052        let updated_root = mmr.root();
1053        let range = Location::new_unchecked(elements.len() as u64 - 10)
1054            ..Location::new_unchecked(elements.len() as u64);
1055        let range_proof = mmr.range_proof(range.clone()).unwrap();
1056        assert!(
1057                range_proof.verify_range_inclusion(
1058                    &mut hasher,
1059                    &elements[range.to_usize_range()],
1060                    range.start,
1061                    updated_root,
1062                ),
1063                "valid range proof over remaining elements after 2 pruning rounds should verify successfully",
1064            );
1065    }
1066
1067    #[test]
1068    fn test_proving_proof_serialization() {
1069        // create a new MMR and add a non-trivial amount of elements
1070        let mut hasher: Standard<Sha256> = Standard::new();
1071        let mut mmr = CleanMmr::new(&mut hasher);
1072        let mut elements = Vec::new();
1073        for i in 0..25 {
1074            elements.push(test_digest(i));
1075            mmr.add(&mut hasher, elements.last().unwrap());
1076        }
1077
1078        // Generate proofs over all possible ranges of elements and confirm each
1079        // serializes=>deserializes correctly.
1080        for i in 0..elements.len() {
1081            for j in i + 1..elements.len() {
1082                let range = Location::new_unchecked(i as u64)..Location::new_unchecked(j as u64);
1083                let proof = mmr.range_proof(range).unwrap();
1084
1085                let expected_size = proof.encode_size();
1086                let serialized_proof = proof.encode().freeze();
1087                assert_eq!(
1088                    serialized_proof.len(),
1089                    expected_size,
1090                    "serialized proof should have expected size"
1091                );
1092                let max_digests = proof.digests.len();
1093                let deserialized_proof = Proof::decode_cfg(serialized_proof, &max_digests).unwrap();
1094                assert_eq!(
1095                    proof, deserialized_proof,
1096                    "deserialized proof should match source proof"
1097                );
1098
1099                // Remove one byte from the end of the serialized
1100                // proof and confirm it fails to deserialize.
1101                let serialized_proof = proof.encode().freeze();
1102                let serialized_proof: Bytes = serialized_proof.slice(0..serialized_proof.len() - 1);
1103                assert!(
1104                    Proof::<Digest>::decode_cfg(serialized_proof, &max_digests).is_err(),
1105                    "proof should not deserialize with truncated data"
1106                );
1107
1108                // Add 1 byte of extra data to the end of the serialized
1109                // proof and confirm it fails to deserialize.
1110                let mut serialized_proof = proof.encode();
1111                serialized_proof.extend_from_slice(&[0; 10]);
1112                let serialized_proof = serialized_proof.freeze();
1113
1114                assert!(
1115                    Proof::<Digest>::decode_cfg(serialized_proof, &max_digests).is_err(),
1116                    "proof should not deserialize with extra data"
1117                );
1118
1119                // Confirm deserialization fails when max length is exceeded.
1120                if max_digests > 0 {
1121                    let serialized_proof = proof.encode().freeze();
1122                    assert!(
1123                        Proof::<Digest>::decode_cfg(serialized_proof, &(max_digests - 1)).is_err(),
1124                        "proof should not deserialize with max length exceeded"
1125                    );
1126                }
1127            }
1128        }
1129    }
1130
1131    #[test_traced]
1132    fn test_proving_extract_pinned_nodes() {
1133        // Test for every number of elements from 1 to 255
1134        for num_elements in 1u64..255 {
1135            // Build MMR with the specified number of elements
1136            let mut hasher: Standard<Sha256> = Standard::new();
1137            let mut mmr = CleanMmr::new(&mut hasher);
1138
1139            for i in 0..num_elements {
1140                let digest = test_digest(i as u8);
1141                mmr.add(&mut hasher, &digest);
1142            }
1143
1144            // Test pruning to each leaf.
1145            for leaf in 0..num_elements {
1146                // Test with a few different end positions to get good coverage
1147                let test_end_locs = if num_elements == 1 {
1148                    // Single element case
1149                    vec![leaf + 1]
1150                } else {
1151                    // Multi-element case: test with various end positions
1152                    let mut ends = vec![leaf + 1]; // Single element proof
1153
1154                    // Add a few more end positions if available
1155                    if leaf + 2 <= num_elements {
1156                        ends.push(leaf + 2);
1157                    }
1158                    if leaf + 3 <= num_elements {
1159                        ends.push(leaf + 3);
1160                    }
1161                    // Always test with the last element if different
1162                    if ends.last().unwrap() != &num_elements {
1163                        ends.push(num_elements);
1164                    }
1165
1166                    ends.into_iter()
1167                        .collect::<BTreeSet<_>>()
1168                        .into_iter()
1169                        .collect()
1170                };
1171
1172                for end_loc in test_end_locs {
1173                    // Generate proof for the range
1174                    let range = Location::new_unchecked(leaf)..Location::new_unchecked(end_loc);
1175                    let proof_result = mmr.range_proof(range.clone());
1176                    let proof = proof_result.unwrap();
1177
1178                    // Extract pinned nodes
1179                    let extract_result = proof.extract_pinned_nodes(range.clone());
1180                    assert!(
1181                            extract_result.is_ok(),
1182                            "Failed to extract pinned nodes for {num_elements} elements, boundary={leaf}, range={}..{}", range.start, range.end
1183                        );
1184
1185                    let pinned_nodes = extract_result.unwrap();
1186                    let leaf_loc = Location::new_unchecked(leaf);
1187                    let leaf_pos = Position::try_from(leaf_loc).unwrap();
1188                    let expected_pinned: Vec<Position> = nodes_to_pin(leaf_pos).collect();
1189
1190                    // Verify count matches expected
1191                    assert_eq!(
1192                            pinned_nodes.len(),
1193                            expected_pinned.len(),
1194                            "Pinned node count mismatch for {num_elements} elements, boundary={leaf}, range=[{leaf}, {end_loc}]"
1195                        );
1196
1197                    // Verify extracted hashes match actual node values
1198                    // The pinned_nodes Vec is in the same order as expected_pinned
1199                    for (i, &expected_pos) in expected_pinned.iter().enumerate() {
1200                        let extracted_hash = pinned_nodes[i];
1201                        let actual_hash = mmr.get_node(expected_pos).unwrap();
1202                        assert_eq!(
1203                                extracted_hash, actual_hash,
1204                                "Hash mismatch at position {expected_pos} (index {i}) for {num_elements} elements, boundary={leaf}, range=[{leaf}, {end_loc}]"
1205                            );
1206                    }
1207                }
1208            }
1209        }
1210    }
1211
1212    #[test]
1213    fn test_proving_extract_pinned_nodes_invalid_size() {
1214        // Test that extract_pinned_nodes returns an error for invalid MMR size
1215        let mut hasher: Standard<Sha256> = Standard::new();
1216        let mut mmr = CleanMmr::new(&mut hasher);
1217
1218        // Build MMR with 10 elements
1219        for i in 0..10 {
1220            let digest = test_digest(i);
1221            mmr.add(&mut hasher, &digest);
1222        }
1223
1224        // Generate a valid proof
1225        let range = Location::new_unchecked(5)..Location::new_unchecked(8);
1226        let mut proof = mmr.range_proof(range.clone()).unwrap();
1227
1228        // Verify the proof works with valid size
1229        assert!(proof.extract_pinned_nodes(range.clone()).is_ok());
1230
1231        // Test various invalid sizes
1232        const INVALID_SIZES: [u64; 5] = [2, 5, 6, 9, u64::MAX];
1233        for invalid_size in INVALID_SIZES {
1234            proof.size = Position::new(invalid_size);
1235            let result = proof.extract_pinned_nodes(range.clone());
1236            assert!(matches!(result, Err(Error::InvalidSize(s)) if s == invalid_size));
1237        }
1238
1239        // Test with MSB set (invalid due to leading zero requirement)
1240        proof.size = Position::new(1u64 << 63);
1241        let result = proof.extract_pinned_nodes(range);
1242        assert!(matches!(result, Err(Error::InvalidSize(s)) if s == 1u64 << 63));
1243    }
1244
1245    #[test]
1246    fn test_proving_digests_from_range() {
1247        // create a new MMR and add a non-trivial amount (49) of elements
1248        let mut hasher: Standard<Sha256> = Standard::new();
1249        let mut mmr = CleanMmr::new(&mut hasher);
1250        let mut elements = Vec::new();
1251        let mut element_positions = Vec::new();
1252        for i in 0..49 {
1253            elements.push(test_digest(i));
1254            element_positions.push(mmr.add(&mut hasher, elements.last().unwrap()));
1255        }
1256        let root = mmr.root();
1257
1258        // Test 1: compute_digests over the entire range should contain a digest for every node
1259        // in the tree.
1260        let proof = mmr
1261            .range_proof(Location::new_unchecked(0)..mmr.leaves())
1262            .unwrap();
1263        let mut node_digests = proof
1264            .verify_range_inclusion_and_extract_digests(
1265                &mut hasher,
1266                &elements,
1267                Location::new_unchecked(0),
1268                root,
1269            )
1270            .unwrap();
1271        assert_eq!(node_digests.len() as u64, mmr.size());
1272        node_digests.sort_by_key(|(pos, _)| *pos);
1273        for (i, (pos, d)) in node_digests.into_iter().enumerate() {
1274            assert_eq!(pos, i as u64);
1275            assert_eq!(mmr.get_node(pos).unwrap(), d);
1276        }
1277        // Make sure the wrong root fails.
1278        let wrong_root = elements[0]; // any other digest will do
1279        assert!(matches!(
1280            proof.verify_range_inclusion_and_extract_digests(
1281                &mut hasher,
1282                &elements,
1283                Location::new_unchecked(0),
1284                &wrong_root
1285            ),
1286            Err(Error::RootMismatch)
1287        ));
1288
1289        // Test 2: Single element range (first element)
1290        let range = Location::new_unchecked(0)..Location::new_unchecked(1);
1291        let single_proof = mmr.range_proof(range.clone()).unwrap();
1292        let range_start = range.start;
1293        let single_digests = single_proof
1294            .verify_range_inclusion_and_extract_digests(
1295                &mut hasher,
1296                &elements[range.to_usize_range()],
1297                range_start,
1298                root,
1299            )
1300            .unwrap();
1301        assert!(single_digests.len() > 1);
1302
1303        // Test 3: Single element range (middle element)
1304        let mid_idx = 24;
1305        let range = Location::new_unchecked(mid_idx)..Location::new_unchecked(mid_idx + 1);
1306        let range_start = range.start;
1307        let mid_proof = mmr.range_proof(range.clone()).unwrap();
1308        let mid_digests = mid_proof
1309            .verify_range_inclusion_and_extract_digests(
1310                &mut hasher,
1311                &elements[range.to_usize_range()],
1312                range_start,
1313                root,
1314            )
1315            .unwrap();
1316        assert!(mid_digests.len() > 1);
1317
1318        // Test 4: Single element range (last element)
1319        let last_idx = elements.len() as u64 - 1;
1320        let range = Location::new_unchecked(last_idx)..Location::new_unchecked(last_idx + 1);
1321        let range_start = range.start;
1322        let last_proof = mmr.range_proof(range.clone()).unwrap();
1323        let last_digests = last_proof
1324            .verify_range_inclusion_and_extract_digests(
1325                &mut hasher,
1326                &elements[range.to_usize_range()],
1327                range_start,
1328                root,
1329            )
1330            .unwrap();
1331        assert!(last_digests.len() > 1);
1332
1333        // Test 5: Small range at the beginning
1334        let range = Location::new_unchecked(0)..Location::new_unchecked(5);
1335        let range_start = range.start;
1336        let small_proof = mmr.range_proof(range.clone()).unwrap();
1337        let small_digests = small_proof
1338            .verify_range_inclusion_and_extract_digests(
1339                &mut hasher,
1340                &elements[range.to_usize_range()],
1341                range_start,
1342                root,
1343            )
1344            .unwrap();
1345        // Verify that we get digests for the range elements and their ancestors
1346        assert!(small_digests.len() > 5);
1347
1348        // Test 6: Medium range in the middle
1349        let range = Location::new_unchecked(10)..Location::new_unchecked(31);
1350        let range_start = range.start;
1351        let mid_range_proof = mmr.range_proof(range.clone()).unwrap();
1352        let mid_range_digests = mid_range_proof
1353            .verify_range_inclusion_and_extract_digests(
1354                &mut hasher,
1355                &elements[range.to_usize_range()],
1356                range_start,
1357                root,
1358            )
1359            .unwrap();
1360        let num_elements = range.end - range.start;
1361        assert!(mid_range_digests.len() as u64 > num_elements);
1362    }
1363
1364    #[test]
1365    fn test_proving_multi_proof_generation_and_verify() {
1366        // Create an MMR with multiple elements
1367        let mut hasher: Standard<Sha256> = Standard::new();
1368        let mut mmr = CleanMmr::new(&mut hasher);
1369        let mut elements = Vec::new();
1370
1371        for i in 0..20 {
1372            elements.push(test_digest(i));
1373            mmr.add(&mut hasher, &elements[i as usize]);
1374        }
1375
1376        let root = mmr.root();
1377
1378        // Generate proof for non-contiguous single elements
1379        let locations = &[
1380            Location::new_unchecked(0),
1381            Location::new_unchecked(5),
1382            Location::new_unchecked(10),
1383        ];
1384        let nodes_for_multi_proof =
1385            nodes_required_for_multi_proof(mmr.size(), locations).expect("test locations valid");
1386        let digests = nodes_for_multi_proof
1387            .into_iter()
1388            .map(|pos| mmr.get_node(pos).unwrap())
1389            .collect();
1390        let multi_proof = Proof {
1391            size: mmr.size(),
1392            digests,
1393        };
1394
1395        assert_eq!(multi_proof.size, mmr.size());
1396
1397        // Verify the proof
1398        assert!(multi_proof.verify_multi_inclusion(
1399            &mut hasher,
1400            &[
1401                (elements[0], Location::new_unchecked(0)),
1402                (elements[5], Location::new_unchecked(5)),
1403                (elements[10], Location::new_unchecked(10)),
1404            ],
1405            root
1406        ));
1407
1408        // Verify in different order
1409        assert!(multi_proof.verify_multi_inclusion(
1410            &mut hasher,
1411            &[
1412                (elements[10], Location::new_unchecked(10)),
1413                (elements[5], Location::new_unchecked(5)),
1414                (elements[0], Location::new_unchecked(0)),
1415            ],
1416            root
1417        ));
1418
1419        // Verify with duplicate items
1420        assert!(multi_proof.verify_multi_inclusion(
1421            &mut hasher,
1422            &[
1423                (elements[0], Location::new_unchecked(0)),
1424                (elements[0], Location::new_unchecked(0)),
1425                (elements[10], Location::new_unchecked(10)),
1426                (elements[5], Location::new_unchecked(5)),
1427            ],
1428            root
1429        ));
1430
1431        // Verify mangling the size to something invalid should fail. Test three cases: valid MMR
1432        // size but wrong value, invalid MMR size, and overflowing MMR size.
1433        let wrong_sizes = [
1434            Position::new(16),
1435            Position::new(17),
1436            Position::new(u64::MAX - 100),
1437        ];
1438        for sz in wrong_sizes {
1439            let mut wrong_size_proof = multi_proof.clone();
1440            wrong_size_proof.size = sz;
1441            assert!(!wrong_size_proof.verify_multi_inclusion(
1442                &mut hasher,
1443                &[
1444                    (elements[0], Location::new_unchecked(0)),
1445                    (elements[5], Location::new_unchecked(5)),
1446                    (elements[10], Location::new_unchecked(10)),
1447                ],
1448                root,
1449            ));
1450        }
1451
1452        // Verify with wrong positions
1453        assert!(!multi_proof.verify_multi_inclusion(
1454            &mut hasher,
1455            &[
1456                (elements[0], Location::new_unchecked(1)),
1457                (elements[5], Location::new_unchecked(6)),
1458                (elements[10], Location::new_unchecked(11)),
1459            ],
1460            root,
1461        ));
1462
1463        // Verify with wrong elements
1464        let wrong_elements = [
1465            vec![255u8, 254u8, 253u8],
1466            vec![252u8, 251u8, 250u8],
1467            vec![249u8, 248u8, 247u8],
1468        ];
1469        let wrong_verification = multi_proof.verify_multi_inclusion(
1470            &mut hasher,
1471            &[
1472                (wrong_elements[0].as_slice(), Location::new_unchecked(0)),
1473                (wrong_elements[1].as_slice(), Location::new_unchecked(5)),
1474                (wrong_elements[2].as_slice(), Location::new_unchecked(10)),
1475            ],
1476            root,
1477        );
1478        assert!(!wrong_verification, "Should fail with wrong elements");
1479
1480        // Verify with out of range element
1481        let wrong_verification = multi_proof.verify_multi_inclusion(
1482            &mut hasher,
1483            &[
1484                (elements[0], Location::new_unchecked(0)),
1485                (elements[5], Location::new_unchecked(5)),
1486                (elements[10], Location::new_unchecked(1000)),
1487            ],
1488            root,
1489        );
1490        assert!(
1491            !wrong_verification,
1492            "Should fail with out of range elements"
1493        );
1494
1495        // Verify with wrong root should fail
1496        let wrong_root = test_digest(99);
1497        assert!(!multi_proof.verify_multi_inclusion(
1498            &mut hasher,
1499            &[
1500                (elements[0], Location::new_unchecked(0)),
1501                (elements[5], Location::new_unchecked(5)),
1502                (elements[10], Location::new_unchecked(10)),
1503            ],
1504            &wrong_root
1505        ));
1506
1507        // Empty multi-proof
1508        let mut hasher: Standard<Sha256> = Standard::new();
1509        let empty_mmr = CleanMmr::new(&mut hasher);
1510        let empty_root = empty_mmr.root();
1511        let empty_proof = Proof::default();
1512        assert!(empty_proof.verify_multi_inclusion(
1513            &mut hasher,
1514            &[] as &[(Digest, Location)],
1515            empty_root
1516        ));
1517    }
1518
1519    #[test]
1520    fn test_proving_multi_proof_deduplication() {
1521        let mut hasher: Standard<Sha256> = Standard::new();
1522        let mut mmr = CleanMmr::new(&mut hasher);
1523        let mut elements = Vec::new();
1524
1525        // Create an MMR with enough elements to have shared digests
1526        for i in 0..30 {
1527            elements.push(test_digest(i));
1528            mmr.add(&mut hasher, &elements[i as usize]);
1529        }
1530
1531        // Get individual proofs that will share some digests (elements in same subtree)
1532        let proof1 = mmr.proof(Location::new_unchecked(0)).unwrap();
1533        let proof2 = mmr.proof(Location::new_unchecked(1)).unwrap();
1534        let total_digests_separate = proof1.digests.len() + proof2.digests.len();
1535
1536        // Generate multi-proof for the same positions
1537        let locations = &[Location::new_unchecked(0), Location::new_unchecked(1)];
1538        let multi_proof =
1539            nodes_required_for_multi_proof(mmr.size(), locations).expect("test locations valid");
1540        let digests = multi_proof
1541            .into_iter()
1542            .map(|pos| mmr.get_node(pos).unwrap())
1543            .collect();
1544        let multi_proof = Proof {
1545            size: mmr.size(),
1546            digests,
1547        };
1548
1549        // The combined proof should have fewer digests due to deduplication
1550        assert!(multi_proof.digests.len() < total_digests_separate);
1551
1552        // Verify it still works
1553        let root = mmr.root();
1554        assert!(multi_proof.verify_multi_inclusion(
1555            &mut hasher,
1556            &[
1557                (elements[0], Location::new_unchecked(0)),
1558                (elements[1], Location::new_unchecked(1))
1559            ],
1560            root
1561        ));
1562    }
1563
1564    #[test]
1565    fn test_max_location_is_provable() {
1566        // Test that the validation logic accepts MAX_LOCATION as a valid location
1567        // We use the maximum valid MMR size (2^63 - 1) which can hold up to 2^62 leaves
1568        let max_mmr_size = Position::new((1u64 << 63) - 1);
1569
1570        let max_loc = Location::new_unchecked(MAX_LOCATION);
1571        let max_loc_plus_1 = Location::new_unchecked(MAX_LOCATION + 1);
1572        let max_loc_plus_2 = Location::new_unchecked(MAX_LOCATION + 2);
1573
1574        // MAX_LOCATION should be accepted by the validation logic
1575        // (The range MAX_LOCATION..MAX_LOCATION+1 proves a single element at MAX_LOCATION)
1576        let result = nodes_required_for_range_proof(max_mmr_size, max_loc..max_loc_plus_1);
1577
1578        // This should succeed - MAX_LOCATION is a valid location
1579        assert!(result.is_ok(), "Should be able to prove MAX_LOCATION");
1580
1581        // MAX_LOCATION + 1 should be rejected (exceeds MAX_LOCATION)
1582        let result_overflow =
1583            nodes_required_for_range_proof(max_mmr_size, max_loc_plus_1..max_loc_plus_2);
1584        assert!(
1585            result_overflow.is_err(),
1586            "Should reject location > MAX_LOCATION"
1587        );
1588        matches!(result_overflow, Err(Error::LocationOverflow(_)));
1589    }
1590
1591    #[test]
1592    fn test_max_location_multi_proof() {
1593        // Test that multi_proof can handle MAX_LOCATION
1594        let max_mmr_size = Position::new((1u64 << 63) - 1);
1595        let max_loc = Location::new_unchecked(MAX_LOCATION);
1596
1597        // Should be able to generate multi-proof for MAX_LOCATION
1598        let result = nodes_required_for_multi_proof(max_mmr_size, &[max_loc]);
1599        assert!(
1600            result.is_ok(),
1601            "Should be able to generate multi-proof for MAX_LOCATION"
1602        );
1603
1604        // Should reject MAX_LOCATION + 1
1605        let invalid_loc = Location::new_unchecked(MAX_LOCATION + 1);
1606        let result_overflow = nodes_required_for_multi_proof(max_mmr_size, &[invalid_loc]);
1607        assert!(
1608            result_overflow.is_err(),
1609            "Should reject location > MAX_LOCATION in multi-proof"
1610        );
1611    }
1612
1613    #[test]
1614    fn test_invalid_size_validation() {
1615        // Test that invalid sizes are rejected
1616        let loc = Location::new_unchecked(0);
1617        let range = loc..loc + 1;
1618
1619        const INVALID_SIZES: [u64; 5] = [2, 5, 6, 9, u64::MAX];
1620        for size in INVALID_SIZES {
1621            let result = nodes_required_for_range_proof(Position::new(size), range.clone());
1622            assert!(matches!(result, Err(Error::InvalidSize(s)) if s == size));
1623
1624            let result = nodes_required_for_multi_proof(Position::new(size), &[loc]);
1625            assert!(matches!(result, Err(Error::InvalidSize(s)) if s == size));
1626        }
1627    }
1628
1629    #[cfg(feature = "arbitrary")]
1630    mod conformance {
1631        use super::*;
1632        use commonware_codec::conformance::CodecConformance;
1633        use commonware_cryptography::sha256::Digest as Sha256Digest;
1634
1635        commonware_conformance::conformance_tests! {
1636            CodecConformance<Proof<Sha256Digest>>,
1637        }
1638    }
1639}