commonware_storage/mmr/
proof.rs

1//! Defines the inclusion [Proof] structure, and functions for verifying them against a root digest.
2//!
3//! Also provides lower-level functions for building verifiers against new or extended proof types.
4//! These lower level functions are kept outside of the [Proof] structure and not re-exported by the
5//! parent module.
6
7#[cfg(any(feature = "std", test))]
8use crate::mmr::iterator::nodes_to_pin;
9use crate::mmr::{
10    hasher::Hasher,
11    iterator::{PathIterator, PeakIterator},
12    Error, Location, Position,
13};
14use alloc::{
15    collections::{btree_map::BTreeMap, btree_set::BTreeSet},
16    vec,
17    vec::Vec,
18};
19use bytes::{Buf, BufMut};
20use commonware_codec::{varint::UInt, EncodeSize, Read, ReadExt, ReadRangeExt, Write};
21use commonware_cryptography::{Digest, Hasher as CHasher};
22use core::ops::Range;
23#[cfg(feature = "std")]
24use tracing::debug;
25
26/// Errors that can occur when reconstructing a digest from a proof due to invalid input.
27#[derive(Error, Debug)]
28pub enum ReconstructionError {
29    #[error("missing digests in proof")]
30    MissingDigests,
31    #[error("extra digests in proof")]
32    ExtraDigests,
33    #[error("start location is out of bounds")]
34    InvalidStartLoc,
35    #[error("end location is out of bounds")]
36    InvalidEndLoc,
37    #[error("missing elements")]
38    MissingElements,
39    #[error("invalid size")]
40    InvalidSize,
41}
42
43/// Contains the information necessary for proving the inclusion of an element, or some range of
44/// elements, in the MMR from its root digest.
45///
46/// The `digests` vector contains:
47///
48/// 1: the digests of each peak corresponding to a mountain containing no elements from the element
49/// range being proven in decreasing order of height, followed by:
50///
51/// 2: the nodes in the remaining mountains necessary for reconstructing their peak digests from the
52/// elements within the range, ordered by the position of their parent.
53#[derive(Clone, Debug, Eq)]
54pub struct Proof<D: Digest> {
55    /// The total number of nodes in the MMR for MMR proofs, though other authenticated data
56    /// structures may override the meaning of this field. For example, the authenticated
57    /// [crate::mmr::bitmap::BitMap] stores the number of bits in the bitmap within
58    /// this field.
59    pub size: Position,
60    /// The digests necessary for proving the inclusion of an element, or range of elements, in the
61    /// MMR.
62    pub digests: Vec<D>,
63}
64
65impl<D: Digest> PartialEq for Proof<D> {
66    fn eq(&self, other: &Self) -> bool {
67        self.size == other.size && self.digests == other.digests
68    }
69}
70
71impl<D: Digest> EncodeSize for Proof<D> {
72    fn encode_size(&self) -> usize {
73        UInt(*self.size).encode_size() + self.digests.encode_size()
74    }
75}
76
77impl<D: Digest> Write for Proof<D> {
78    fn write(&self, buf: &mut impl BufMut) {
79        // Write the number of nodes in the MMR as a varint
80        UInt(*self.size).write(buf);
81
82        // Write the digests
83        self.digests.write(buf);
84    }
85}
86
87impl<D: Digest> Read for Proof<D> {
88    /// The maximum number of digests in the proof.
89    type Cfg = usize;
90
91    fn read_cfg(buf: &mut impl Buf, max_len: &Self::Cfg) -> Result<Self, commonware_codec::Error> {
92        // Read the number of nodes in the MMR
93        let size = Position::new(UInt::<u64>::read(buf)?.into());
94
95        // Read the digests
96        let range = ..=max_len;
97        let digests = Vec::<D>::read_range(buf, range)?;
98        Ok(Proof { size, digests })
99    }
100}
101
102impl<D: Digest> Default for Proof<D> {
103    /// Create an empty proof. The empty proof will verify only against the root digest of an empty
104    /// (`size == 0`) MMR.
105    fn default() -> Self {
106        Self {
107            size: Position::new(0),
108            digests: vec![],
109        }
110    }
111}
112
113impl<D: Digest> Proof<D> {
114    /// Return true if this proof proves that `element` appears at location `loc` within the MMR
115    /// with root digest `root`.
116    pub fn verify_element_inclusion<I, H>(
117        &self,
118        hasher: &mut H,
119        element: &[u8],
120        loc: Location,
121        root: &D,
122    ) -> bool
123    where
124        I: CHasher<Digest = D>,
125        H: Hasher<I>,
126    {
127        self.verify_range_inclusion(hasher, &[element], loc, root)
128    }
129
130    /// Return true if this proof proves that the `elements` appear consecutively starting at
131    /// position `start_loc` within the MMR with root digest `root`. A malformed proof will return
132    /// false.
133    pub fn verify_range_inclusion<I, H, E>(
134        &self,
135        hasher: &mut H,
136        elements: &[E],
137        start_loc: Location,
138        root: &D,
139    ) -> bool
140    where
141        I: CHasher<Digest = D>,
142        H: Hasher<I>,
143        E: AsRef<[u8]>,
144    {
145        if !self.size.is_mmr_size() {
146            #[cfg(feature = "std")]
147            debug!(size = ?self.size, "invalid proof size");
148            return false;
149        }
150
151        match self.reconstruct_root(hasher, elements, start_loc) {
152            Ok(reconstructed_root) => *root == reconstructed_root,
153            Err(_error) => {
154                #[cfg(feature = "std")]
155                tracing::debug!(error = ?_error, "invalid proof input");
156                false
157            }
158        }
159    }
160
161    /// Return true if this proof proves that the elements at the specified locations are included
162    /// in the MMR with the root digest `root`. A malformed proof will return false.
163    ///
164    /// The order of the elements does not affect the output.
165    pub fn verify_multi_inclusion<I, H, E>(
166        &self,
167        hasher: &mut H,
168        elements: &[(E, Location)],
169        root: &D,
170    ) -> bool
171    where
172        I: CHasher<Digest = D>,
173        H: Hasher<I>,
174        E: AsRef<[u8]>,
175    {
176        // Empty proof is valid for an empty MMR
177        if elements.is_empty() {
178            return self.size == Position::new(0)
179                && *root == hasher.root(Position::new(0), core::iter::empty());
180        }
181        if !self.size.is_mmr_size() {
182            return false;
183        }
184
185        // Single pass to collect all required positions with deduplication
186        let mut node_positions = BTreeSet::new();
187        let mut nodes_required = BTreeMap::new();
188
189        for (_, loc) in elements {
190            if !loc.is_valid() {
191                return false;
192            }
193            // `loc` is valid so it won't overflow from +1
194            let Ok(required) = nodes_required_for_range_proof(self.size, *loc..*loc + 1) else {
195                return false;
196            };
197            for req_pos in &required {
198                node_positions.insert(*req_pos);
199            }
200            nodes_required.insert(*loc, required);
201        }
202
203        // Verify we have the exact number of digests needed
204        if node_positions.len() != self.digests.len() {
205            return false;
206        }
207
208        // Build position to digest mapping once
209        let node_digests: BTreeMap<Position, D> = node_positions
210            .iter()
211            .zip(self.digests.iter())
212            .map(|(&pos, digest)| (pos, *digest))
213            .collect();
214
215        // Verify each element by reconstructing its path
216        for (element, loc) in elements {
217            // Get required positions for this element
218            let required = &nodes_required[loc];
219
220            // Build proof with required digests
221            let mut digests = Vec::with_capacity(required.len());
222            for req_pos in required {
223                // There must exist a digest for each required position (by
224                // construction of `node_digests`)
225                let digest = node_digests
226                    .get(req_pos)
227                    .expect("must exist by construction of node_digests");
228                digests.push(*digest);
229            }
230            let proof = Proof {
231                size: self.size,
232                digests,
233            };
234
235            // Verify the proof
236            if !proof.verify_element_inclusion(hasher, element.as_ref(), *loc, root) {
237                return false;
238            }
239        }
240
241        true
242    }
243
244    // The functions below are lower level functions that are useful to building verification
245    // functions for new or extended proof types.
246
247    /// Computes the set of pinned nodes for the pruning boundary corresponding to the start of the
248    /// given range, returning the digest of each by extracting it from the proof.
249    ///
250    /// # Arguments
251    /// * `range` - The start and end locations of the proven range, where start is also used as the
252    ///   pruning boundary.
253    ///
254    /// # Returns
255    /// A Vec of digests for all nodes in `nodes_to_pin(pruning_boundary)`, in the same order as
256    /// returned by `nodes_to_pin` (decreasing height order)
257    ///
258    /// # Errors
259    ///
260    /// Returns [Error::InvalidSize] if the proof size is not a valid MMR size.
261    /// Returns [Error::LocationOverflow] if a location in `range` > [crate::mmr::MAX_LOCATION].
262    /// Returns [Error::InvalidProofLength] if the proof digest count doesn't match the required
263    /// positions count.
264    /// Returns [Error::MissingDigest] if a pinned node is not found in the proof.
265    #[cfg(any(feature = "std", test))]
266    pub(crate) fn extract_pinned_nodes(
267        &self,
268        range: std::ops::Range<Location>,
269    ) -> Result<Vec<D>, Error> {
270        // Get the positions of all nodes that should be pinned.
271        let start_pos = Position::try_from(range.start)?;
272        let pinned_positions: Vec<Position> = nodes_to_pin(start_pos).collect();
273
274        // Get all positions required for the proof.
275        let required_positions = nodes_required_for_range_proof(self.size, range)?;
276
277        if required_positions.len() != self.digests.len() {
278            #[cfg(feature = "std")]
279            debug!(
280                digests_len = self.digests.len(),
281                required_positions_len = required_positions.len(),
282                "Proof digest count doesn't match required positions",
283            );
284            return Err(Error::InvalidProofLength);
285        }
286
287        // Happy path: we can extract the pinned nodes directly from the proof.
288        // This happens when the `end_element_pos` is the last element in the MMR.
289        if pinned_positions
290            == required_positions[required_positions.len() - pinned_positions.len()..]
291        {
292            return Ok(self.digests[required_positions.len() - pinned_positions.len()..].to_vec());
293        }
294
295        // Create a mapping from position to digest.
296        let position_to_digest: BTreeMap<Position, D> = required_positions
297            .iter()
298            .zip(self.digests.iter())
299            .map(|(&pos, &digest)| (pos, digest))
300            .collect();
301
302        // Extract the pinned nodes in the same order as nodes_to_pin.
303        let mut result = Vec::with_capacity(pinned_positions.len());
304        for pinned_pos in pinned_positions {
305            let Some(&digest) = position_to_digest.get(&pinned_pos) else {
306                #[cfg(feature = "std")]
307                debug!(?pinned_pos, "Pinned node not found in proof");
308                return Err(Error::MissingDigest(pinned_pos));
309            };
310            result.push(digest);
311        }
312        Ok(result)
313    }
314
315    /// Reconstructs the root digest of the MMR from the digests in the proof and the provided range
316    /// of elements, returning the (position,digest) of every node whose digest was required by the
317    /// process (including those from the proof itself). Returns a [Error::InvalidProof] if the
318    /// input data is invalid and [Error::RootMismatch] if the root does not match the computed
319    /// root.
320    pub fn verify_range_inclusion_and_extract_digests<I, H, E>(
321        &self,
322        hasher: &mut H,
323        elements: &[E],
324        start_loc: Location,
325        root: &I::Digest,
326    ) -> Result<Vec<(Position, D)>, super::Error>
327    where
328        I: CHasher<Digest = D>,
329        H: Hasher<I>,
330        E: AsRef<[u8]>,
331    {
332        let mut collected_digests = Vec::new();
333        let Ok(peak_digests) = self.reconstruct_peak_digests(
334            hasher,
335            elements,
336            start_loc,
337            Some(&mut collected_digests),
338        ) else {
339            return Err(Error::InvalidProof);
340        };
341
342        if hasher.root(self.size, peak_digests.iter()) != *root {
343            return Err(Error::RootMismatch);
344        }
345
346        Ok(collected_digests)
347    }
348
349    /// Reconstructs the root digest of the MMR from the digests in the proof and the provided range
350    /// of elements, or returns a [ReconstructionError] if the input data is invalid.
351    pub fn reconstruct_root<I, H, E>(
352        &self,
353        hasher: &mut H,
354        elements: &[E],
355        start_loc: Location,
356    ) -> Result<I::Digest, ReconstructionError>
357    where
358        I: CHasher<Digest = D>,
359        H: Hasher<I>,
360        E: AsRef<[u8]>,
361    {
362        if !self.size.is_mmr_size() {
363            return Err(ReconstructionError::InvalidSize);
364        }
365
366        let peak_digests = self.reconstruct_peak_digests(hasher, elements, start_loc, None)?;
367
368        Ok(hasher.root(self.size, peak_digests.iter()))
369    }
370
371    /// Reconstruct the peak digests of the MMR that produced this proof, returning
372    /// [ReconstructionError] if the input data is invalid.  If collected_digests is Some, then all
373    /// node digests used in the process will be added to the wrapped vector.
374    pub fn reconstruct_peak_digests<I, H, E>(
375        &self,
376        hasher: &mut H,
377        elements: &[E],
378        start_loc: Location,
379        mut collected_digests: Option<&mut Vec<(Position, I::Digest)>>,
380    ) -> Result<Vec<D>, ReconstructionError>
381    where
382        I: CHasher<Digest = D>,
383        H: Hasher<I>,
384        E: AsRef<[u8]>,
385    {
386        if elements.is_empty() {
387            if start_loc == 0 {
388                return Ok(vec![]);
389            }
390            return Err(ReconstructionError::MissingElements);
391        }
392        let start_element_pos =
393            Position::try_from(start_loc).map_err(|_| ReconstructionError::InvalidStartLoc)?;
394        let end_element_pos = if elements.len() == 1 {
395            start_element_pos
396        } else {
397            let end_loc = start_loc
398                .checked_add(elements.len() as u64 - 1)
399                .ok_or(ReconstructionError::InvalidEndLoc)?;
400            Position::try_from(end_loc).map_err(|_| ReconstructionError::InvalidEndLoc)?
401        };
402        if end_element_pos >= self.size {
403            return Err(ReconstructionError::InvalidEndLoc);
404        }
405
406        let mut proof_digests_iter = self.digests.iter();
407        let mut siblings_iter = self.digests.iter().rev();
408
409        // Include peak digests only for trees that have no elements from the range, and keep track
410        // of the starting and ending trees of those that do contain some.
411        let mut peak_digests: Vec<D> = Vec::new();
412        let mut proof_digests_used = 0;
413        let mut elements_iter = elements.iter();
414        if !Position::is_mmr_size(self.size) {
415            return Err(ReconstructionError::InvalidSize);
416        }
417        for (peak_pos, height) in PeakIterator::new(self.size) {
418            let leftmost_pos = peak_pos + 2 - (1 << (height + 1));
419            if peak_pos >= start_element_pos && leftmost_pos <= end_element_pos {
420                let hash = peak_digest_from_range(
421                    hasher,
422                    RangeInfo {
423                        pos: peak_pos,
424                        two_h: 1 << height,
425                        leftmost_pos: start_element_pos,
426                        rightmost_pos: end_element_pos,
427                    },
428                    &mut elements_iter,
429                    &mut siblings_iter,
430                    collected_digests.as_deref_mut(),
431                )?;
432                peak_digests.push(hash);
433                if let Some(ref mut collected_digests) = collected_digests {
434                    collected_digests.push((peak_pos, hash));
435                }
436            } else if let Some(hash) = proof_digests_iter.next() {
437                proof_digests_used += 1;
438                peak_digests.push(*hash);
439                if let Some(ref mut collected_digests) = collected_digests {
440                    collected_digests.push((peak_pos, *hash));
441                }
442            } else {
443                return Err(ReconstructionError::MissingDigests);
444            }
445        }
446
447        if elements_iter.next().is_some() {
448            return Err(ReconstructionError::ExtraDigests);
449        }
450        if let Some(next_sibling) = siblings_iter.next() {
451            if proof_digests_used == 0 || *next_sibling != self.digests[proof_digests_used - 1] {
452                return Err(ReconstructionError::ExtraDigests);
453            }
454        }
455
456        Ok(peak_digests)
457    }
458}
459
460/// Return the list of node positions required by the range proof for the specified range of
461/// elements.
462///
463/// # Errors
464///
465/// Returns [Error::InvalidSize] if `size` is not a valid MMR size.
466/// Returns [Error::Empty] if the range is empty.
467/// Returns [Error::LocationOverflow] if a location in `range` > [crate::mmr::MAX_LOCATION].
468/// Returns [Error::RangeOutOfBounds] if the last element position in `range` is out of bounds
469/// (>= `size`).
470pub(crate) fn nodes_required_for_range_proof(
471    size: Position,
472    range: Range<Location>,
473) -> Result<Vec<Position>, Error> {
474    if !size.is_mmr_size() {
475        return Err(Error::InvalidSize(*size));
476    }
477    if range.is_empty() {
478        return Err(Error::Empty);
479    }
480    let start_element_pos = Position::try_from(range.start)?;
481    let end_minus_one = range
482        .end
483        .checked_sub(1)
484        .expect("can't underflow because range is non-empty");
485    let end_element_pos = Position::try_from(end_minus_one)?;
486    if end_element_pos >= size {
487        return Err(Error::RangeOutOfBounds(range.end));
488    }
489
490    // Find the mountains that contain no elements from the range. The peaks of these mountains
491    // are required to prove the range, so they are added to the result.
492    let mut start_tree_with_element: Option<(Position, u32)> = None;
493    let mut end_tree_with_element: Option<(Position, u32)> = None;
494    let mut peak_iterator = PeakIterator::new(size);
495    let mut positions = Vec::new();
496    while let Some(peak) = peak_iterator.next() {
497        if start_tree_with_element.is_none() && peak.0 >= start_element_pos {
498            // Found the first tree to contain an element in the range
499            start_tree_with_element = Some(peak);
500            if peak.0 >= end_element_pos {
501                // Start and end tree are the same
502                end_tree_with_element = Some(peak);
503                continue;
504            }
505            for peak in peak_iterator.by_ref() {
506                if peak.0 >= end_element_pos {
507                    // Found the last tree to contain an element in the range
508                    end_tree_with_element = Some(peak);
509                    break;
510                }
511            }
512        } else {
513            // Tree is outside the range, its peak is thus required.
514            positions.push(peak.0);
515        }
516    }
517
518    // We checked above that all range elements are in this MMR, so some mountain must contain
519    // the first and last elements in the range.
520    let (start_tree_peak, start_tree_height) =
521        start_tree_with_element.expect("start_tree_with_element is Some");
522    let (end_tree_peak, end_tree_height) =
523        end_tree_with_element.expect("end_tree_with_element is Some");
524
525    // Include the positions of any left-siblings of each node on the path from peak to
526    // leftmost-leaf, and right-siblings for the path from peak to rightmost-leaf. These are
527    // added in order of decreasing parent position.
528    let left_path_iter = PathIterator::new(start_element_pos, start_tree_peak, start_tree_height);
529
530    let mut siblings = Vec::new();
531    if start_element_pos == end_element_pos {
532        // For the (common) case of a single element range, the right and left path are the
533        // same so no need to process each independently.
534        siblings.extend(left_path_iter);
535    } else {
536        let right_path_iter = PathIterator::new(end_element_pos, end_tree_peak, end_tree_height);
537        // filter the right path for right siblings only
538        siblings.extend(right_path_iter.filter(|(parent_pos, pos)| *parent_pos == *pos + 1));
539        // filter the left path for left siblings only
540        siblings.extend(left_path_iter.filter(|(parent_pos, pos)| *parent_pos != *pos + 1));
541
542        // If the range spans more than one tree, then the digests must already be in the correct
543        // order. Otherwise, we enforce the desired order through sorting.
544        if start_tree_peak == end_tree_peak {
545            siblings.sort_by(|a, b| b.0.cmp(&a.0));
546        }
547    }
548    positions.extend(siblings.into_iter().map(|(_, pos)| pos));
549    Ok(positions)
550}
551
552/// Returns the positions of the minimal set of nodes whose digests are required to prove the
553/// inclusion of the elements at the specified `locations`.
554///
555/// The order of positions does not affect the output (sorted internally).
556///
557/// # Errors
558///
559/// Returns [Error::InvalidSize] if `size` is not a valid MMR size.
560/// Returns [Error::Empty] if locations is empty.
561/// Returns [Error::LocationOverflow] if any location in `locations` > [crate::mmr::MAX_LOCATION].
562/// Returns [Error::RangeOutOfBounds] if any location is out of bounds for the given `size`.
563#[cfg(any(feature = "std", test))]
564pub(crate) fn nodes_required_for_multi_proof(
565    size: Position,
566    locations: &[Location],
567) -> Result<BTreeSet<Position>, Error> {
568    if !size.is_mmr_size() {
569        return Err(Error::InvalidSize(*size));
570    }
571    // Collect all required node positions
572    //
573    // TODO(#1472): Optimize this loop
574    if locations.is_empty() {
575        return Err(Error::Empty);
576    }
577    locations.iter().try_fold(BTreeSet::new(), |mut acc, loc| {
578        if !loc.is_valid() {
579            return Err(Error::LocationOverflow(*loc));
580        }
581        // `loc` is valid so it won't overflow from +1
582        let positions = nodes_required_for_range_proof(size, *loc..*loc + 1)?;
583        acc.extend(positions);
584        Ok(acc)
585    })
586}
587
588/// Information about the current range of nodes being traversed.
589struct RangeInfo {
590    pos: Position,           // current node position in the tree
591    two_h: u64,              // 2^height of the current node
592    leftmost_pos: Position,  // leftmost leaf in the tree to be traversed
593    rightmost_pos: Position, // rightmost leaf in the tree to be traversed
594}
595
596fn peak_digest_from_range<'a, I, H, E, S>(
597    hasher: &mut H,
598    range_info: RangeInfo,
599    elements: &mut E,
600    sibling_digests: &mut S,
601    mut collected_digests: Option<&mut Vec<(Position, I::Digest)>>,
602) -> Result<I::Digest, ReconstructionError>
603where
604    I: CHasher,
605    H: Hasher<I>,
606    E: Iterator<Item: AsRef<[u8]>>,
607    S: Iterator<Item = &'a I::Digest>,
608{
609    assert_ne!(range_info.two_h, 0);
610    if range_info.two_h == 1 {
611        match elements.next() {
612            Some(element) => return Ok(hasher.leaf_digest(range_info.pos, element.as_ref())),
613            None => return Err(ReconstructionError::MissingDigests),
614        }
615    }
616
617    let mut left_digest: Option<I::Digest> = None;
618    let mut right_digest: Option<I::Digest> = None;
619
620    let left_pos = range_info.pos - range_info.two_h;
621    let right_pos = left_pos + range_info.two_h - 1;
622    if left_pos >= range_info.leftmost_pos {
623        // Descend left
624        let digest = peak_digest_from_range(
625            hasher,
626            RangeInfo {
627                pos: left_pos,
628                two_h: range_info.two_h >> 1,
629                leftmost_pos: range_info.leftmost_pos,
630                rightmost_pos: range_info.rightmost_pos,
631            },
632            elements,
633            sibling_digests,
634            collected_digests.as_deref_mut(),
635        )?;
636        left_digest = Some(digest);
637    }
638    if left_pos < range_info.rightmost_pos {
639        // Descend right
640        let digest = peak_digest_from_range(
641            hasher,
642            RangeInfo {
643                pos: right_pos,
644                two_h: range_info.two_h >> 1,
645                leftmost_pos: range_info.leftmost_pos,
646                rightmost_pos: range_info.rightmost_pos,
647            },
648            elements,
649            sibling_digests,
650            collected_digests.as_deref_mut(),
651        )?;
652        right_digest = Some(digest);
653    }
654
655    if left_digest.is_none() {
656        match sibling_digests.next() {
657            Some(hash) => left_digest = Some(*hash),
658            None => return Err(ReconstructionError::MissingDigests),
659        }
660    }
661    if right_digest.is_none() {
662        match sibling_digests.next() {
663            Some(hash) => right_digest = Some(*hash),
664            None => return Err(ReconstructionError::MissingDigests),
665        }
666    }
667
668    if let Some(ref mut collected_digests) = collected_digests {
669        collected_digests.push((
670            left_pos,
671            left_digest.expect("left_digest guaranteed to be Some after checks above"),
672        ));
673        collected_digests.push((
674            right_pos,
675            right_digest.expect("right_digest guaranteed to be Some after checks above"),
676        ));
677    }
678
679    Ok(hasher.node_digest(
680        range_info.pos,
681        &left_digest.expect("left_digest guaranteed to be Some after checks above"),
682        &right_digest.expect("right_digest guaranteed to be Some after checks above"),
683    ))
684}
685
686#[cfg(test)]
687mod tests {
688    use super::*;
689    use crate::mmr::{hasher::Standard, location::LocationRangeExt as _, mem::Mmr, MAX_LOCATION};
690    use bytes::Bytes;
691    use commonware_codec::{Decode, Encode};
692    use commonware_cryptography::{sha256::Digest, Sha256};
693    use commonware_macros::test_traced;
694
695    fn test_digest(v: u8) -> Digest {
696        Sha256::hash(&[v])
697    }
698
699    #[test]
700    fn test_proving_proof() {
701        // Test that an empty proof authenticates an empty MMR.
702        let mmr = Mmr::new();
703        let mut hasher: Standard<Sha256> = Standard::new();
704        let root = mmr.root(&mut hasher);
705        let proof = Proof::default();
706        assert!(proof.verify_range_inclusion(
707            &mut hasher,
708            &[] as &[Digest],
709            Location::new_unchecked(0),
710            &root
711        ));
712
713        // Any starting position other than 0 should fail to verify.
714        assert!(!proof.verify_range_inclusion(
715            &mut hasher,
716            &[] as &[Digest],
717            Location::new_unchecked(1),
718            &root
719        ));
720
721        // Invalid root should fail to verify.
722        let test_digest = test_digest(0);
723        assert!(!proof.verify_range_inclusion(
724            &mut hasher,
725            &[] as &[Digest],
726            Location::new_unchecked(0),
727            &test_digest
728        ));
729
730        // Non-empty elements list should fail to verify.
731        assert!(!proof.verify_range_inclusion(
732            &mut hasher,
733            &[test_digest],
734            Location::new_unchecked(0),
735            &root
736        ));
737    }
738
739    #[test]
740    fn test_proving_verify_element() {
741        // create an 11 element MMR over which we'll test single-element inclusion proofs
742        let mut mmr = Mmr::new();
743        let element = Digest::from(*b"01234567012345670123456701234567");
744        let mut hasher: Standard<Sha256> = Standard::new();
745        for _ in 0..11 {
746            mmr.add(&mut hasher, &element);
747        }
748
749        let root = mmr.root(&mut hasher);
750
751        // confirm the proof of inclusion for each leaf successfully verifies
752        for leaf in 0u64..11 {
753            let leaf = Location::new_unchecked(leaf);
754            let proof: Proof<Digest> = mmr.proof(leaf).unwrap();
755            assert!(
756                proof.verify_element_inclusion(&mut hasher, &element, leaf, &root),
757                "valid proof should verify successfully"
758            );
759        }
760
761        // Create a valid proof, then confirm various mangling of the proof or proof args results in
762        // verification failure.
763        const LEAF: Location = Location::new_unchecked(10);
764        let proof = mmr.proof(LEAF).unwrap();
765        assert!(
766            proof.verify_element_inclusion(&mut hasher, &element, LEAF, &root),
767            "proof verification should be successful"
768        );
769        let wrong_sizes = [0, 16, 17, 18, 20, u64::MAX - 100];
770        for sz in wrong_sizes {
771            let mut wrong_size_proof = proof.clone();
772            wrong_size_proof.size = Position::new(sz);
773            assert!(
774                !wrong_size_proof.verify_element_inclusion(&mut hasher, &element, LEAF, &root),
775                "proof with wrong size should fail verification"
776            );
777        }
778        assert!(
779            !proof.verify_element_inclusion(&mut hasher, &element, LEAF + 1, &root),
780            "proof verification should fail with incorrect element position"
781        );
782        assert!(
783            !proof.verify_element_inclusion(&mut hasher, &element, LEAF - 1, &root),
784            "proof verification should fail with incorrect element position 2"
785        );
786        assert!(
787            !proof.verify_element_inclusion(&mut hasher, &test_digest(0), LEAF, &root),
788            "proof verification should fail with mangled element"
789        );
790        let root2 = test_digest(0);
791        assert!(
792            !proof.verify_element_inclusion(&mut hasher, &element, LEAF, &root2),
793            "proof verification should fail with mangled root"
794        );
795        let mut proof2 = proof.clone();
796        proof2.digests[0] = test_digest(0);
797        assert!(
798            !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, &root),
799            "proof verification should fail with mangled proof hash"
800        );
801        proof2 = proof.clone();
802        proof2.size = Position::new(10);
803        assert!(
804            !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, &root),
805            "proof verification should fail with incorrect size"
806        );
807        proof2 = proof.clone();
808        proof2.digests.push(test_digest(0));
809        assert!(
810            !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, &root),
811            "proof verification should fail with extra hash"
812        );
813        proof2 = proof.clone();
814        while !proof2.digests.is_empty() {
815            proof2.digests.pop();
816            assert!(
817                !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, &root),
818                "proof verification should fail with missing digests"
819            );
820        }
821        proof2 = proof.clone();
822        proof2.digests.clear();
823        const PEAK_COUNT: usize = 3;
824        proof2
825            .digests
826            .extend(proof.digests[0..PEAK_COUNT - 1].iter().cloned());
827        // sneak in an extra hash that won't be used in the computation and make sure it's
828        // detected
829        proof2.digests.push(test_digest(0));
830        proof2
831            .digests
832            .extend(proof.digests[PEAK_COUNT - 1..].iter().cloned());
833        assert!(
834            !proof2.verify_element_inclusion(&mut hasher, &element, LEAF, &root),
835            "proof verification should fail with extra hash even if it's unused by the computation"
836        );
837    }
838
839    #[test]
840    fn test_proving_verify_range() {
841        // create a new MMR and add a non-trivial amount (49) of elements
842        let mut mmr = Mmr::default();
843        let mut elements = Vec::new();
844        let mut hasher: Standard<Sha256> = Standard::new();
845        for i in 0..49 {
846            elements.push(test_digest(i));
847            mmr.add(&mut hasher, elements.last().unwrap());
848        }
849        // test range proofs over all possible ranges of at least 2 elements
850        let root = mmr.root(&mut hasher);
851
852        for i in 0..elements.len() {
853            for j in i + 1..elements.len() {
854                let range = Location::new_unchecked(i as u64)..Location::new_unchecked(j as u64);
855                let range_proof = mmr.range_proof(range.clone()).unwrap();
856                assert!(
857                    range_proof.verify_range_inclusion(
858                        &mut hasher,
859                        &elements[range.to_usize_range()],
860                        range.start,
861                        &root,
862                    ),
863                    "valid range proof should verify successfully {i}:{j}",
864                );
865            }
866        }
867
868        // Create a proof over a range of elements, confirm it verifies successfully, then mangle
869        // the proof & proof input in various ways, confirming verification fails.
870        let range = Location::new_unchecked(33)..Location::new_unchecked(40);
871        let range_proof = mmr.range_proof(range.clone()).unwrap();
872        let valid_elements = &elements[range.to_usize_range()];
873        assert!(
874            range_proof.verify_range_inclusion(&mut hasher, valid_elements, range.start, &root),
875            "valid range proof should verify successfully"
876        );
877        // Remove digests from the proof until it's empty, confirming proof verification fails for
878        // each.
879        let mut invalid_proof = range_proof.clone();
880        for _i in 0..range_proof.digests.len() {
881            invalid_proof.digests.remove(0);
882            assert!(
883                !invalid_proof.verify_range_inclusion(
884                    &mut hasher,
885                    valid_elements,
886                    range.start,
887                    &root,
888                ),
889                "range proof with removed elements should fail"
890            );
891        }
892        // Confirm proof verification fails when providing an element range different than the one
893        // used to generate the proof.
894        for i in 0..elements.len() {
895            for j in i + 1..elements.len() {
896                if Location::from(i) == range.start && Location::from(j) == range.end {
897                    // skip the valid range
898                    continue;
899                }
900                assert!(
901                    !range_proof.verify_range_inclusion(
902                        &mut hasher,
903                        &elements[i..j],
904                        range.start,
905                        &root,
906                    ),
907                    "range proof with invalid element range should fail {i}:{j}",
908                );
909            }
910        }
911        // Confirm proof fails to verify with an invalid root.
912        let invalid_root = test_digest(1);
913        assert!(
914            !range_proof.verify_range_inclusion(
915                &mut hasher,
916                valid_elements,
917                range.start,
918                &invalid_root,
919            ),
920            "range proof with invalid root should fail"
921        );
922        // Mangle each element of the proof and confirm it fails to verify.
923        for i in 0..range_proof.digests.len() {
924            let mut invalid_proof = range_proof.clone();
925            invalid_proof.digests[i] = test_digest(0);
926
927            assert!(
928                !invalid_proof.verify_range_inclusion(
929                    &mut hasher,
930                    valid_elements,
931                    range.start,
932                    &root,
933                ),
934                "mangled range proof should fail verification"
935            );
936        }
937        // Inserting elements into the proof should also cause it to fail (malleability check)
938        for i in 0..range_proof.digests.len() {
939            let mut invalid_proof = range_proof.clone();
940            invalid_proof.digests.insert(i, test_digest(0));
941            assert!(
942                !invalid_proof.verify_range_inclusion(
943                    &mut hasher,
944                    valid_elements,
945                    range.start,
946                    &root,
947                ),
948                "mangled range proof should fail verification. inserted element at: {i}",
949            );
950        }
951        // Bad start_loc should cause verification to fail.
952        for loc in 0..elements.len() {
953            let loc = Location::new_unchecked(loc as u64);
954            if loc == range.start {
955                continue;
956            }
957            assert!(
958                !range_proof.verify_range_inclusion(&mut hasher, valid_elements, loc, &root),
959                "bad start_loc should fail verification {loc}",
960            );
961        }
962    }
963
964    #[test_traced]
965    fn test_proving_retained_nodes_provable_after_pruning() {
966        // create a new MMR and add a non-trivial amount (49) of elements
967        let mut mmr = Mmr::default();
968        let mut elements = Vec::new();
969        let mut hasher: Standard<Sha256> = Standard::new();
970        for i in 0..49 {
971            elements.push(test_digest(i));
972            mmr.add(&mut hasher, elements.last().unwrap());
973        }
974
975        // Confirm we can successfully prove all retained elements in the MMR after pruning.
976        let root = mmr.root(&mut hasher);
977        for i in 1..*mmr.size() {
978            mmr.prune_to_pos(Position::new(i));
979            let pruned_root = mmr.root(&mut hasher);
980            assert_eq!(root, pruned_root);
981            for loc in 0..elements.len() {
982                let loc = Location::new_unchecked(loc as u64);
983                let proof = mmr.proof(loc);
984                if Position::try_from(loc).unwrap() < Position::new(i) {
985                    continue;
986                }
987                assert!(proof.is_ok());
988                assert!(proof.unwrap().verify_element_inclusion(
989                    &mut hasher,
990                    &elements[*loc as usize],
991                    loc,
992                    &root
993                ));
994            }
995        }
996    }
997
998    #[test]
999    fn test_proving_ranges_provable_after_pruning() {
1000        // create a new MMR and add a non-trivial amount (49) of elements
1001        let mut mmr = Mmr::default();
1002        let mut elements = Vec::new();
1003        let mut hasher: Standard<Sha256> = Standard::new();
1004        for i in 0..49 {
1005            elements.push(test_digest(i));
1006            mmr.add(&mut hasher, elements.last().unwrap());
1007        }
1008
1009        // prune up to the first peak
1010        const PRUNE_POS: Position = Position::new(62);
1011        mmr.prune_to_pos(PRUNE_POS);
1012        assert_eq!(mmr.oldest_retained_pos().unwrap(), PRUNE_POS);
1013
1014        // Test range proofs over all possible ranges of at least 2 elements
1015        let root = mmr.root(&mut hasher);
1016        for i in 0..elements.len() - 1 {
1017            if Position::try_from(Location::new_unchecked(i as u64)).unwrap() < PRUNE_POS {
1018                continue;
1019            }
1020            for j in (i + 2)..elements.len() {
1021                let range = Location::new_unchecked(i as u64)..Location::new_unchecked(j as u64);
1022                let range_proof = mmr.range_proof(range.clone()).unwrap();
1023                assert!(
1024                    range_proof.verify_range_inclusion(
1025                        &mut hasher,
1026                        &elements[range.to_usize_range()],
1027                        range.start,
1028                        &root,
1029                    ),
1030                    "valid range proof over remaining elements should verify successfully",
1031                );
1032            }
1033        }
1034
1035        // Add a few more nodes, prune again, and test again to make sure repeated pruning doesn't
1036        // break proof verification.
1037        for i in 0..37 {
1038            elements.push(test_digest(i));
1039            mmr.add(&mut hasher, elements.last().unwrap());
1040        }
1041        mmr.prune_to_pos(Position::new(130)); // a bit after the new highest peak
1042        assert_eq!(mmr.oldest_retained_pos().unwrap(), 130);
1043
1044        let updated_root = mmr.root(&mut hasher);
1045        let range = Location::new_unchecked(elements.len() as u64 - 10)
1046            ..Location::new_unchecked(elements.len() as u64);
1047        let range_proof = mmr.range_proof(range.clone()).unwrap();
1048        assert!(
1049                range_proof.verify_range_inclusion(
1050                    &mut hasher,
1051                    &elements[range.to_usize_range()],
1052                    range.start,
1053                    &updated_root,
1054                ),
1055                "valid range proof over remaining elements after 2 pruning rounds should verify successfully",
1056            );
1057    }
1058
1059    #[test]
1060    fn test_proving_proof_serialization() {
1061        // create a new MMR and add a non-trivial amount of elements
1062        let mut mmr = Mmr::default();
1063        let mut elements = Vec::new();
1064        let mut hasher: Standard<Sha256> = Standard::new();
1065        for i in 0..25 {
1066            elements.push(test_digest(i));
1067            mmr.add(&mut hasher, elements.last().unwrap());
1068        }
1069
1070        // Generate proofs over all possible ranges of elements and confirm each
1071        // serializes=>deserializes correctly.
1072        for i in 0..elements.len() {
1073            for j in i + 1..elements.len() {
1074                let range = Location::new_unchecked(i as u64)..Location::new_unchecked(j as u64);
1075                let proof = mmr.range_proof(range).unwrap();
1076
1077                let expected_size = proof.encode_size();
1078                let serialized_proof = proof.encode().freeze();
1079                assert_eq!(
1080                    serialized_proof.len(),
1081                    expected_size,
1082                    "serialized proof should have expected size"
1083                );
1084                let max_digests = proof.digests.len();
1085                let deserialized_proof = Proof::decode_cfg(serialized_proof, &max_digests).unwrap();
1086                assert_eq!(
1087                    proof, deserialized_proof,
1088                    "deserialized proof should match source proof"
1089                );
1090
1091                // Remove one byte from the end of the serialized
1092                // proof and confirm it fails to deserialize.
1093                let serialized_proof = proof.encode().freeze();
1094                let serialized_proof: Bytes = serialized_proof.slice(0..serialized_proof.len() - 1);
1095                assert!(
1096                    Proof::<Digest>::decode_cfg(serialized_proof, &max_digests).is_err(),
1097                    "proof should not deserialize with truncated data"
1098                );
1099
1100                // Add 1 byte of extra data to the end of the serialized
1101                // proof and confirm it fails to deserialize.
1102                let mut serialized_proof = proof.encode();
1103                serialized_proof.extend_from_slice(&[0; 10]);
1104                let serialized_proof = serialized_proof.freeze();
1105
1106                assert!(
1107                    Proof::<Digest>::decode_cfg(serialized_proof, &max_digests).is_err(),
1108                    "proof should not deserialize with extra data"
1109                );
1110
1111                // Confirm deserialization fails when max length is exceeded.
1112                if max_digests > 0 {
1113                    let serialized_proof = proof.encode().freeze();
1114                    assert!(
1115                        Proof::<Digest>::decode_cfg(serialized_proof, &(max_digests - 1)).is_err(),
1116                        "proof should not deserialize with max length exceeded"
1117                    );
1118                }
1119            }
1120        }
1121    }
1122
1123    #[test_traced]
1124    fn test_proving_extract_pinned_nodes() {
1125        // Test for every number of elements from 1 to 255
1126        for num_elements in 1u64..255 {
1127            // Build MMR with the specified number of elements
1128            let mut mmr = Mmr::new();
1129            let mut hasher: Standard<Sha256> = Standard::new();
1130
1131            for i in 0..num_elements {
1132                let digest = test_digest(i as u8);
1133                mmr.add(&mut hasher, &digest);
1134            }
1135
1136            // Test pruning to each leaf.
1137            for leaf in 0..num_elements {
1138                // Test with a few different end positions to get good coverage
1139                let test_end_locs = if num_elements == 1 {
1140                    // Single element case
1141                    vec![leaf + 1]
1142                } else {
1143                    // Multi-element case: test with various end positions
1144                    let mut ends = vec![leaf + 1]; // Single element proof
1145
1146                    // Add a few more end positions if available
1147                    if leaf + 2 <= num_elements {
1148                        ends.push(leaf + 2);
1149                    }
1150                    if leaf + 3 <= num_elements {
1151                        ends.push(leaf + 3);
1152                    }
1153                    // Always test with the last element if different
1154                    if ends.last().unwrap() != &num_elements {
1155                        ends.push(num_elements);
1156                    }
1157
1158                    ends.into_iter()
1159                        .collect::<BTreeSet<_>>()
1160                        .into_iter()
1161                        .collect()
1162                };
1163
1164                for end_loc in test_end_locs {
1165                    // Generate proof for the range
1166                    let range = Location::new_unchecked(leaf)..Location::new_unchecked(end_loc);
1167                    let proof_result = mmr.range_proof(range.clone());
1168                    let proof = proof_result.unwrap();
1169
1170                    // Extract pinned nodes
1171                    let extract_result = proof.extract_pinned_nodes(range.clone());
1172                    assert!(
1173                            extract_result.is_ok(),
1174                            "Failed to extract pinned nodes for {num_elements} elements, boundary={leaf}, range={}..{}", range.start, range.end
1175                        );
1176
1177                    let pinned_nodes = extract_result.unwrap();
1178                    let leaf_loc = Location::new_unchecked(leaf);
1179                    let leaf_pos = Position::try_from(leaf_loc).unwrap();
1180                    let expected_pinned: Vec<Position> = nodes_to_pin(leaf_pos).collect();
1181
1182                    // Verify count matches expected
1183                    assert_eq!(
1184                            pinned_nodes.len(),
1185                            expected_pinned.len(),
1186                            "Pinned node count mismatch for {num_elements} elements, boundary={leaf}, range=[{leaf}, {end_loc}]"
1187                        );
1188
1189                    // Verify extracted hashes match actual node values
1190                    // The pinned_nodes Vec is in the same order as expected_pinned
1191                    for (i, &expected_pos) in expected_pinned.iter().enumerate() {
1192                        let extracted_hash = pinned_nodes[i];
1193                        let actual_hash = mmr.get_node(expected_pos).unwrap();
1194                        assert_eq!(
1195                                extracted_hash, actual_hash,
1196                                "Hash mismatch at position {expected_pos} (index {i}) for {num_elements} elements, boundary={leaf}, range=[{leaf}, {end_loc}]"
1197                            );
1198                    }
1199                }
1200            }
1201        }
1202    }
1203
1204    #[test]
1205    fn test_proving_extract_pinned_nodes_invalid_size() {
1206        // Test that extract_pinned_nodes returns an error for invalid MMR size
1207        let mut mmr = Mmr::new();
1208        let mut hasher: Standard<Sha256> = Standard::new();
1209
1210        // Build MMR with 10 elements
1211        for i in 0..10 {
1212            let digest = test_digest(i);
1213            mmr.add(&mut hasher, &digest);
1214        }
1215
1216        // Generate a valid proof
1217        let range = Location::new_unchecked(5)..Location::new_unchecked(8);
1218        let mut proof = mmr.range_proof(range.clone()).unwrap();
1219
1220        // Verify the proof works with valid size
1221        assert!(proof.extract_pinned_nodes(range.clone()).is_ok());
1222
1223        // Test various invalid sizes
1224        const INVALID_SIZES: [u64; 5] = [2, 5, 6, 9, u64::MAX];
1225        for invalid_size in INVALID_SIZES {
1226            proof.size = Position::new(invalid_size);
1227            let result = proof.extract_pinned_nodes(range.clone());
1228            assert!(matches!(result, Err(Error::InvalidSize(s)) if s == invalid_size));
1229        }
1230
1231        // Test with MSB set (invalid due to leading zero requirement)
1232        proof.size = Position::new(1u64 << 63);
1233        let result = proof.extract_pinned_nodes(range);
1234        assert!(matches!(result, Err(Error::InvalidSize(s)) if s == 1u64 << 63));
1235    }
1236
1237    #[test]
1238    fn test_proving_digests_from_range() {
1239        // create a new MMR and add a non-trivial amount (49) of elements
1240        let mut mmr = Mmr::default();
1241        let mut elements = Vec::new();
1242        let mut element_positions = Vec::new();
1243        let mut hasher: Standard<Sha256> = Standard::new();
1244        for i in 0..49 {
1245            elements.push(test_digest(i));
1246            element_positions.push(mmr.add(&mut hasher, elements.last().unwrap()));
1247        }
1248        let root = mmr.root(&mut hasher);
1249
1250        // Test 1: compute_digests over the entire range should contain a digest for every node
1251        // in the tree.
1252        let proof = mmr
1253            .range_proof(Location::new_unchecked(0)..mmr.leaves())
1254            .unwrap();
1255        let mut node_digests = proof
1256            .verify_range_inclusion_and_extract_digests(
1257                &mut hasher,
1258                &elements,
1259                Location::new_unchecked(0),
1260                &root,
1261            )
1262            .unwrap();
1263        assert_eq!(node_digests.len() as u64, mmr.size());
1264        node_digests.sort_by_key(|(pos, _)| *pos);
1265        for (i, (pos, d)) in node_digests.into_iter().enumerate() {
1266            assert_eq!(pos, i as u64);
1267            assert_eq!(mmr.get_node(pos).unwrap(), d);
1268        }
1269        // Make sure the wrong root fails.
1270        let wrong_root = elements[0]; // any other digest will do
1271        assert!(matches!(
1272            proof.verify_range_inclusion_and_extract_digests(
1273                &mut hasher,
1274                &elements,
1275                Location::new_unchecked(0),
1276                &wrong_root
1277            ),
1278            Err(Error::RootMismatch)
1279        ));
1280
1281        // Test 2: Single element range (first element)
1282        let range = Location::new_unchecked(0)..Location::new_unchecked(1);
1283        let single_proof = mmr.range_proof(range.clone()).unwrap();
1284        let range_start = range.start;
1285        let single_digests = single_proof
1286            .verify_range_inclusion_and_extract_digests(
1287                &mut hasher,
1288                &elements[range.to_usize_range()],
1289                range_start,
1290                &root,
1291            )
1292            .unwrap();
1293        assert!(single_digests.len() > 1);
1294
1295        // Test 3: Single element range (middle element)
1296        let mid_idx = 24;
1297        let range = Location::new_unchecked(mid_idx)..Location::new_unchecked(mid_idx + 1);
1298        let range_start = range.start;
1299        let mid_proof = mmr.range_proof(range.clone()).unwrap();
1300        let mid_digests = mid_proof
1301            .verify_range_inclusion_and_extract_digests(
1302                &mut hasher,
1303                &elements[range.to_usize_range()],
1304                range_start,
1305                &root,
1306            )
1307            .unwrap();
1308        assert!(mid_digests.len() > 1);
1309
1310        // Test 4: Single element range (last element)
1311        let last_idx = elements.len() as u64 - 1;
1312        let range = Location::new_unchecked(last_idx)..Location::new_unchecked(last_idx + 1);
1313        let range_start = range.start;
1314        let last_proof = mmr.range_proof(range.clone()).unwrap();
1315        let last_digests = last_proof
1316            .verify_range_inclusion_and_extract_digests(
1317                &mut hasher,
1318                &elements[range.to_usize_range()],
1319                range_start,
1320                &root,
1321            )
1322            .unwrap();
1323        assert!(last_digests.len() > 1);
1324
1325        // Test 5: Small range at the beginning
1326        let range = Location::new_unchecked(0)..Location::new_unchecked(5);
1327        let range_start = range.start;
1328        let small_proof = mmr.range_proof(range.clone()).unwrap();
1329        let small_digests = small_proof
1330            .verify_range_inclusion_and_extract_digests(
1331                &mut hasher,
1332                &elements[range.to_usize_range()],
1333                range_start,
1334                &root,
1335            )
1336            .unwrap();
1337        // Verify that we get digests for the range elements and their ancestors
1338        assert!(small_digests.len() > 5);
1339
1340        // Test 6: Medium range in the middle
1341        let range = Location::new_unchecked(10)..Location::new_unchecked(31);
1342        let range_start = range.start;
1343        let mid_range_proof = mmr.range_proof(range.clone()).unwrap();
1344        let mid_range_digests = mid_range_proof
1345            .verify_range_inclusion_and_extract_digests(
1346                &mut hasher,
1347                &elements[range.to_usize_range()],
1348                range_start,
1349                &root,
1350            )
1351            .unwrap();
1352        let num_elements = range.end - range.start;
1353        assert!(mid_range_digests.len() as u64 > num_elements);
1354    }
1355
1356    #[test]
1357    fn test_proving_multi_proof_generation_and_verify() {
1358        // Create an MMR with multiple elements
1359        let mut mmr = Mmr::new();
1360        let mut hasher: Standard<Sha256> = Standard::new();
1361        let mut elements = Vec::new();
1362
1363        for i in 0..20 {
1364            elements.push(test_digest(i));
1365            mmr.add(&mut hasher, &elements[i as usize]);
1366        }
1367
1368        let root = mmr.root(&mut hasher);
1369
1370        // Generate proof for non-contiguous single elements
1371        let locations = &[
1372            Location::new_unchecked(0),
1373            Location::new_unchecked(5),
1374            Location::new_unchecked(10),
1375        ];
1376        let nodes_for_multi_proof =
1377            nodes_required_for_multi_proof(mmr.size(), locations).expect("test locations valid");
1378        let digests = nodes_for_multi_proof
1379            .into_iter()
1380            .map(|pos| mmr.get_node(pos).unwrap())
1381            .collect();
1382        let multi_proof = Proof {
1383            size: mmr.size(),
1384            digests,
1385        };
1386
1387        assert_eq!(multi_proof.size, mmr.size());
1388
1389        // Verify the proof
1390        assert!(multi_proof.verify_multi_inclusion(
1391            &mut hasher,
1392            &[
1393                (elements[0], Location::new_unchecked(0)),
1394                (elements[5], Location::new_unchecked(5)),
1395                (elements[10], Location::new_unchecked(10)),
1396            ],
1397            &root
1398        ));
1399
1400        // Verify in different order
1401        assert!(multi_proof.verify_multi_inclusion(
1402            &mut hasher,
1403            &[
1404                (elements[10], Location::new_unchecked(10)),
1405                (elements[5], Location::new_unchecked(5)),
1406                (elements[0], Location::new_unchecked(0)),
1407            ],
1408            &root
1409        ));
1410
1411        // Verify with duplicate items
1412        assert!(multi_proof.verify_multi_inclusion(
1413            &mut hasher,
1414            &[
1415                (elements[0], Location::new_unchecked(0)),
1416                (elements[0], Location::new_unchecked(0)),
1417                (elements[10], Location::new_unchecked(10)),
1418                (elements[5], Location::new_unchecked(5)),
1419            ],
1420            &root
1421        ));
1422
1423        // Verify mangling the size to something invalid should fail. Test three cases: valid MMR
1424        // size but wrong value, invalid MMR size, and overflowing MMR size.
1425        let wrong_sizes = [
1426            Position::new(16),
1427            Position::new(17),
1428            Position::new(u64::MAX - 100),
1429        ];
1430        for sz in wrong_sizes {
1431            let mut wrong_size_proof = multi_proof.clone();
1432            wrong_size_proof.size = sz;
1433            assert!(!wrong_size_proof.verify_multi_inclusion(
1434                &mut hasher,
1435                &[
1436                    (elements[0], Location::new_unchecked(0)),
1437                    (elements[5], Location::new_unchecked(5)),
1438                    (elements[10], Location::new_unchecked(10)),
1439                ],
1440                &root,
1441            ));
1442        }
1443
1444        // Verify with wrong positions
1445        assert!(!multi_proof.verify_multi_inclusion(
1446            &mut hasher,
1447            &[
1448                (elements[0], Location::new_unchecked(1)),
1449                (elements[5], Location::new_unchecked(6)),
1450                (elements[10], Location::new_unchecked(11)),
1451            ],
1452            &root,
1453        ));
1454
1455        // Verify with wrong elements
1456        let wrong_elements = [
1457            vec![255u8, 254u8, 253u8],
1458            vec![252u8, 251u8, 250u8],
1459            vec![249u8, 248u8, 247u8],
1460        ];
1461        let wrong_verification = multi_proof.verify_multi_inclusion(
1462            &mut hasher,
1463            &[
1464                (wrong_elements[0].as_slice(), Location::new_unchecked(0)),
1465                (wrong_elements[1].as_slice(), Location::new_unchecked(5)),
1466                (wrong_elements[2].as_slice(), Location::new_unchecked(10)),
1467            ],
1468            &root,
1469        );
1470        assert!(!wrong_verification, "Should fail with wrong elements");
1471
1472        // Verify with out of range element
1473        let wrong_verification = multi_proof.verify_multi_inclusion(
1474            &mut hasher,
1475            &[
1476                (elements[0], Location::new_unchecked(0)),
1477                (elements[5], Location::new_unchecked(5)),
1478                (elements[10], Location::new_unchecked(1000)),
1479            ],
1480            &root,
1481        );
1482        assert!(
1483            !wrong_verification,
1484            "Should fail with out of range elements"
1485        );
1486
1487        // Verify with wrong root should fail
1488        let wrong_root = test_digest(99);
1489        assert!(!multi_proof.verify_multi_inclusion(
1490            &mut hasher,
1491            &[
1492                (elements[0], Location::new_unchecked(0)),
1493                (elements[5], Location::new_unchecked(5)),
1494                (elements[10], Location::new_unchecked(10)),
1495            ],
1496            &wrong_root
1497        ));
1498
1499        // Empty multi-proof
1500        let empty_mmr = Mmr::new();
1501        let empty_root = empty_mmr.root(&mut hasher);
1502        let empty_proof = Proof::default();
1503        assert!(empty_proof.verify_multi_inclusion(
1504            &mut hasher,
1505            &[] as &[(Digest, Location)],
1506            &empty_root
1507        ));
1508    }
1509
1510    #[test]
1511    fn test_proving_multi_proof_deduplication() {
1512        let mut mmr = Mmr::new();
1513        let mut hasher: Standard<Sha256> = Standard::new();
1514        let mut elements = Vec::new();
1515
1516        // Create an MMR with enough elements to have shared digests
1517        for i in 0..30 {
1518            elements.push(test_digest(i));
1519            mmr.add(&mut hasher, &elements[i as usize]);
1520        }
1521
1522        // Get individual proofs that will share some digests (elements in same subtree)
1523        let proof1 = mmr.proof(Location::new_unchecked(0)).unwrap();
1524        let proof2 = mmr.proof(Location::new_unchecked(1)).unwrap();
1525        let total_digests_separate = proof1.digests.len() + proof2.digests.len();
1526
1527        // Generate multi-proof for the same positions
1528        let locations = &[Location::new_unchecked(0), Location::new_unchecked(1)];
1529        let multi_proof =
1530            nodes_required_for_multi_proof(mmr.size(), locations).expect("test locations valid");
1531        let digests = multi_proof
1532            .into_iter()
1533            .map(|pos| mmr.get_node(pos).unwrap())
1534            .collect();
1535        let multi_proof = Proof {
1536            size: mmr.size(),
1537            digests,
1538        };
1539
1540        // The combined proof should have fewer digests due to deduplication
1541        assert!(multi_proof.digests.len() < total_digests_separate);
1542
1543        // Verify it still works
1544        let root = mmr.root(&mut hasher);
1545        assert!(multi_proof.verify_multi_inclusion(
1546            &mut hasher,
1547            &[
1548                (elements[0], Location::new_unchecked(0)),
1549                (elements[1], Location::new_unchecked(1))
1550            ],
1551            &root
1552        ));
1553    }
1554
1555    #[test]
1556    fn test_max_location_is_provable() {
1557        // Test that the validation logic accepts MAX_LOCATION as a valid location
1558        // We use the maximum valid MMR size (2^63 - 1) which can hold up to 2^62 leaves
1559        let max_mmr_size = Position::new((1u64 << 63) - 1);
1560
1561        let max_loc = Location::new_unchecked(MAX_LOCATION);
1562        let max_loc_plus_1 = Location::new_unchecked(MAX_LOCATION + 1);
1563        let max_loc_plus_2 = Location::new_unchecked(MAX_LOCATION + 2);
1564
1565        // MAX_LOCATION should be accepted by the validation logic
1566        // (The range MAX_LOCATION..MAX_LOCATION+1 proves a single element at MAX_LOCATION)
1567        let result = nodes_required_for_range_proof(max_mmr_size, max_loc..max_loc_plus_1);
1568
1569        // This should succeed - MAX_LOCATION is a valid location
1570        assert!(result.is_ok(), "Should be able to prove MAX_LOCATION");
1571
1572        // MAX_LOCATION + 1 should be rejected (exceeds MAX_LOCATION)
1573        let result_overflow =
1574            nodes_required_for_range_proof(max_mmr_size, max_loc_plus_1..max_loc_plus_2);
1575        assert!(
1576            result_overflow.is_err(),
1577            "Should reject location > MAX_LOCATION"
1578        );
1579        matches!(result_overflow, Err(Error::LocationOverflow(_)));
1580    }
1581
1582    #[test]
1583    fn test_max_location_multi_proof() {
1584        // Test that multi_proof can handle MAX_LOCATION
1585        let max_mmr_size = Position::new((1u64 << 63) - 1);
1586        let max_loc = Location::new_unchecked(MAX_LOCATION);
1587
1588        // Should be able to generate multi-proof for MAX_LOCATION
1589        let result = nodes_required_for_multi_proof(max_mmr_size, &[max_loc]);
1590        assert!(
1591            result.is_ok(),
1592            "Should be able to generate multi-proof for MAX_LOCATION"
1593        );
1594
1595        // Should reject MAX_LOCATION + 1
1596        let invalid_loc = Location::new_unchecked(MAX_LOCATION + 1);
1597        let result_overflow = nodes_required_for_multi_proof(max_mmr_size, &[invalid_loc]);
1598        assert!(
1599            result_overflow.is_err(),
1600            "Should reject location > MAX_LOCATION in multi-proof"
1601        );
1602    }
1603
1604    #[test]
1605    fn test_invalid_size_validation() {
1606        // Test that invalid sizes are rejected
1607        let loc = Location::new_unchecked(0);
1608        let range = loc..loc + 1;
1609
1610        const INVALID_SIZES: [u64; 5] = [2, 5, 6, 9, u64::MAX];
1611        for size in INVALID_SIZES {
1612            let result = nodes_required_for_range_proof(Position::new(size), range.clone());
1613            assert!(matches!(result, Err(Error::InvalidSize(s)) if s == size));
1614
1615            let result = nodes_required_for_multi_proof(Position::new(size), &[loc]);
1616            assert!(matches!(result, Err(Error::InvalidSize(s)) if s == size));
1617        }
1618    }
1619}