Skip to main content

commonware_storage/merkle/
proof.rs

1//! Defines the generic inclusion [Proof] structure for Merkle-family data structures.
2//!
3//! The [Proof] struct is parameterized by a [`Family`] marker and a [`Digest`] type. Each Merkle
4//! family (MMR, MMB, etc.) reuses the shared verification and reconstruction logic in this module,
5//! while retaining any family-specific proof helpers in its submodule.
6
7use crate::merkle::{hasher::Hasher, Error, Family, Location, Position};
8use alloc::{
9    collections::{BTreeMap, BTreeSet},
10    vec,
11    vec::Vec,
12};
13use bytes::{Buf, BufMut};
14use commonware_codec::{EncodeSize, ReadExt, ReadRangeExt, Write};
15use commonware_cryptography::Digest;
16use core::ops::Range;
17
18/// Errors that can occur when reconstructing a digest from a proof due to invalid input.
19#[derive(thiserror::Error, Debug)]
20pub enum ReconstructionError {
21    #[error("missing digests in proof")]
22    MissingDigests,
23    #[error("extra digests in proof")]
24    ExtraDigests,
25    #[error("start location is out of bounds")]
26    InvalidStartLoc,
27    #[error("end location is out of bounds")]
28    InvalidEndLoc,
29    #[error("missing elements")]
30    MissingElements,
31    #[error("invalid size")]
32    InvalidSize,
33}
34
35/// Contains the information necessary for proving the inclusion of an element, or some range of
36/// elements, in a Merkle-family data structure from its root digest.
37///
38/// The `digests` vector uses a fold-based layout:
39///
40/// 1. If there are peaks entirely before the proven range (fold prefix), the first digest is
41///    a single accumulator produced by folding those peaks: `fold(fold(..., peak0), peak1)`.
42///    If there are no such peaks, this entry is absent.
43///
44/// 2. The digests of peaks entirely after the proven range, in peak iteration order.
45///
46/// 3. The sibling digests needed to reconstruct each range-peak digest from the proven elements,
47///    in depth-first (forward consumption) order for each range peak.
48#[derive(Clone, Debug, Eq)]
49pub struct Proof<F: Family, D: Digest> {
50    /// The total number of leaves in the data structure. For MMR proofs, this is the number of
51    /// leaves in the MMR, though other authenticated data structures may override the meaning of
52    /// this field. For example, the authenticated [crate::AuthenticatedBitMap] stores the number
53    /// of bits in the bitmap within this field.
54    pub leaves: Location<F>,
55    /// The digests necessary for proving inclusion.
56    pub digests: Vec<D>,
57}
58
59impl<F: Family, D: Digest> PartialEq for Proof<F, D> {
60    fn eq(&self, other: &Self) -> bool {
61        self.leaves == other.leaves && self.digests == other.digests
62    }
63}
64
65impl<F: Family, D: Digest> EncodeSize for Proof<F, D> {
66    fn encode_size(&self) -> usize {
67        self.leaves.encode_size() + self.digests.encode_size()
68    }
69}
70
71impl<F: Family, D: Digest> Write for Proof<F, D> {
72    fn write(&self, buf: &mut impl BufMut) {
73        self.leaves.write(buf);
74        self.digests.write(buf);
75    }
76}
77
78impl<F: Family, D: Digest> commonware_codec::Read for Proof<F, D> {
79    /// The maximum number of digests in the proof.
80    type Cfg = usize;
81
82    fn read_cfg(
83        buf: &mut impl Buf,
84        max_digests: &Self::Cfg,
85    ) -> Result<Self, commonware_codec::Error> {
86        let leaves = Location::<F>::read(buf)?;
87        let digests = Vec::<D>::read_range(buf, ..=*max_digests)?;
88        Ok(Self { leaves, digests })
89    }
90}
91
92impl<F: Family, D: Digest> Default for Proof<F, D> {
93    /// Create an empty proof. The empty proof will verify only against the root digest of an empty
94    /// (`leaves == 0`) data structure.
95    fn default() -> Self {
96        Self {
97            leaves: Location::new(0),
98            digests: vec![],
99        }
100    }
101}
102
103impl<F: Family, D: Digest> Proof<F, D> {
104    /// Return true if this proof proves that `element` appears at location `loc` within the
105    /// structure with root digest `root`.
106    pub fn verify_element_inclusion<H>(
107        &self,
108        hasher: &H,
109        element: &[u8],
110        loc: Location<F>,
111        root: &D,
112    ) -> bool
113    where
114        H: Hasher<F, Digest = D>,
115    {
116        self.verify_range_inclusion(hasher, &[element], loc, root)
117    }
118
119    /// Return true if this proof proves that the `elements` appear consecutively starting at
120    /// `start_loc` within the structure with root digest `root`.
121    pub fn verify_range_inclusion<H, E>(
122        &self,
123        hasher: &H,
124        elements: &[E],
125        start_loc: Location<F>,
126        root: &D,
127    ) -> bool
128    where
129        H: Hasher<F, Digest = D>,
130        E: AsRef<[u8]>,
131    {
132        match self.reconstruct_root(hasher, elements, start_loc) {
133            Ok(reconstructed_root) => *root == reconstructed_root,
134            Err(_error) => {
135                #[cfg(feature = "std")]
136                tracing::debug!(error = ?_error, "invalid proof input");
137                false
138            }
139        }
140    }
141
142    /// Return true if this proof proves that the elements at the specified locations are included
143    /// in the structure with root digest `root`. A malformed proof will return false.
144    ///
145    /// The order of the elements does not affect the output.
146    pub fn verify_multi_inclusion<H, E>(
147        &self,
148        hasher: &H,
149        elements: &[(E, Location<F>)],
150        root: &D,
151    ) -> bool
152    where
153        H: Hasher<F, Digest = D>,
154        E: AsRef<[u8]>,
155    {
156        // Empty proof is valid only for an empty tree with no extra digest data.
157        if elements.is_empty() {
158            return self.digests.is_empty()
159                && self.leaves == Location::new(0)
160                && *root == hasher.root(Location::new(0), core::iter::empty());
161        }
162
163        // Collect all required positions with deduplication, and blueprints per element.
164        let mut node_positions = BTreeSet::new();
165        let mut blueprints = BTreeMap::new();
166
167        for (_, loc) in elements {
168            if !loc.is_valid_index() {
169                return false;
170            }
171            // `loc` is valid so it won't overflow from +1
172            let Ok(bp) = Blueprint::new(self.leaves, *loc..*loc + 1) else {
173                return false;
174            };
175            node_positions.extend(&bp.fold_prefix);
176            node_positions.extend(&bp.fetch_nodes);
177            blueprints.insert(*loc, bp);
178        }
179
180        // Verify we have the exact number of digests needed
181        if node_positions.len() != self.digests.len() {
182            return false;
183        }
184
185        // Build position to digest mapping once
186        let node_digests: BTreeMap<Position<F>, D> = node_positions
187            .iter()
188            .zip(self.digests.iter())
189            .map(|(&pos, digest)| (pos, *digest))
190            .collect();
191
192        // Verify each element by constructing its sub-proof in fold-based format
193        for (element, loc) in elements {
194            let bp = &blueprints[loc];
195
196            let mut digests = Vec::with_capacity(
197                if bp.fold_prefix.is_empty() { 0 } else { 1 } + bp.fetch_nodes.len(),
198            );
199            if let Some((&first_pos, rest)) = bp.fold_prefix.split_first() {
200                let first = *node_digests
201                    .get(&first_pos)
202                    .expect("must exist by construction");
203                let acc = rest.iter().fold(first, |acc, &pos| {
204                    let d = node_digests.get(&pos).expect("must exist by construction");
205                    hasher.fold(&acc, d)
206                });
207                digests.push(acc);
208            }
209            for &pos in &bp.fetch_nodes {
210                let d = node_digests.get(&pos).expect("must exist by construction");
211                digests.push(*d);
212            }
213            let proof = Self {
214                leaves: self.leaves,
215                digests,
216            };
217
218            if !proof.verify_element_inclusion(hasher, element.as_ref(), *loc, root) {
219                return false;
220            }
221        }
222
223        true
224    }
225
226    /// Reconstruct the root digest from this proof and the given consecutive elements,
227    /// or return a `ReconstructionError` if the input data is invalid.
228    pub fn reconstruct_root<H, E>(
229        &self,
230        hasher: &H,
231        elements: &[E],
232        start_loc: Location<F>,
233    ) -> Result<D, ReconstructionError>
234    where
235        H: Hasher<F, Digest = D>,
236        E: AsRef<[u8]>,
237    {
238        self.reconstruct_root_collecting(hasher, elements, start_loc, None)
239    }
240
241    /// Reconstructs the root digest from the digests in the proof and the provided range
242    /// of elements, returning the (position,digest) of every node whose digest was required by the
243    /// process (including those from the proof itself). Returns [Error::InvalidProof] if the
244    /// input data is invalid and [Error::RootMismatch] if the root does not match the computed
245    /// root.
246    pub fn verify_range_inclusion_and_extract_digests<H, E>(
247        &self,
248        hasher: &H,
249        elements: &[E],
250        start_loc: Location<F>,
251        root: &D,
252    ) -> Result<Vec<(Position<F>, D)>, Error<F>>
253    where
254        H: Hasher<F, Digest = D>,
255        E: AsRef<[u8]>,
256    {
257        let mut collected_digests = Vec::new();
258        let Ok(reconstructed_root) = self.reconstruct_root_collecting(
259            hasher,
260            elements,
261            start_loc,
262            Some(&mut collected_digests),
263        ) else {
264            return Err(Error::InvalidProof);
265        };
266
267        if reconstructed_root != *root {
268            return Err(Error::RootMismatch);
269        }
270
271        Ok(collected_digests)
272    }
273
274    /// Verify that both the proof and the pinned nodes are valid with respect to `root`.
275    ///
276    /// The `pinned_nodes` are the peak digests of the sub-structure at `start_loc`, in the order
277    /// returned by `Family::nodes_to_pin`. Each pinned node is either:
278    ///
279    /// - A peak that precedes the proven range (fold-prefix peak). These are verified by
280    ///   refolding them and comparing against the proof's fold-prefix accumulator.
281    /// - A sibling node within a range peak's reconstruction. These are verified against the
282    ///   digests extracted during proof verification.
283    ///
284    /// Returns `true` only if the proof reconstructs to `root` and every pinned node digest is
285    /// accounted for. When `start_loc` is 0, `pinned_nodes` must be empty.
286    pub fn verify_proof_and_pinned_nodes<H, E>(
287        &self,
288        hasher: &H,
289        elements: &[E],
290        start_loc: Location<F>,
291        pinned_nodes: &[D],
292        root: &D,
293    ) -> bool
294    where
295        H: Hasher<F, Digest = D>,
296        E: AsRef<[u8]>,
297    {
298        let collected = match self
299            .verify_range_inclusion_and_extract_digests(hasher, elements, start_loc, root)
300        {
301            Ok(c) => c,
302            Err(_) => return false,
303        };
304
305        if elements.is_empty() {
306            return pinned_nodes.is_empty();
307        }
308
309        if !start_loc.is_valid() || start_loc > self.leaves {
310            return false;
311        }
312
313        let pinned_positions: alloc::vec::Vec<_> = F::nodes_to_pin(start_loc).collect();
314        if pinned_positions.len() != pinned_nodes.len() {
315            return false;
316        }
317
318        let Ok(fold_prefix) = Blueprint::fold_prefix(self.leaves, start_loc) else {
319            return false;
320        };
321
322        let mut pinned_map: alloc::collections::BTreeMap<Position<F>, D> = pinned_positions
323            .into_iter()
324            .zip(pinned_nodes.iter().copied())
325            .collect();
326
327        // Verify fold-prefix pinned nodes by recomputing the accumulator (without the leaf
328        // count, which is hashed into the final root independently).
329        if !fold_prefix.is_empty() {
330            if self.digests.is_empty() {
331                return false;
332            }
333            let Some(first) = pinned_map.remove(&fold_prefix[0]) else {
334                return false;
335            };
336            let mut acc = first;
337            for pos in &fold_prefix[1..] {
338                let Some(digest) = pinned_map.remove(pos) else {
339                    return false;
340                };
341                acc = hasher.fold(&acc, &digest);
342            }
343            if acc != self.digests[0] {
344                return false;
345            }
346        }
347
348        // Verify remaining pinned nodes (siblings) against the extracted digests.
349        let extracted: alloc::collections::BTreeMap<Position<F>, D> =
350            collected.into_iter().collect();
351        for (pos, digest) in pinned_map {
352            if extracted.get(&pos) != Some(&digest) {
353                return false;
354            }
355        }
356
357        true
358    }
359
360    /// Like [`reconstruct_root`](Self::reconstruct_root), but if `collected` is `Some`, every
361    /// `(position, digest)` pair encountered during reconstruction is appended.
362    pub(crate) fn reconstruct_root_collecting<H, E>(
363        &self,
364        hasher: &H,
365        elements: &[E],
366        start_loc: Location<F>,
367        mut collected: Option<&mut Vec<(Position<F>, D)>>,
368    ) -> Result<D, ReconstructionError>
369    where
370        H: Hasher<F, Digest = D>,
371        E: AsRef<[u8]>,
372    {
373        if elements.is_empty() {
374            if start_loc == 0 {
375                return if self.digests.is_empty() {
376                    Ok(hasher.digest(&self.leaves.to_be_bytes()))
377                } else {
378                    Err(ReconstructionError::ExtraDigests)
379                };
380            }
381            return Err(ReconstructionError::MissingElements);
382        }
383        if !start_loc.is_valid_index() {
384            return Err(ReconstructionError::InvalidStartLoc);
385        }
386        let end_loc = start_loc
387            .checked_add(elements.len() as u64)
388            .ok_or(ReconstructionError::InvalidEndLoc)?;
389        if end_loc > self.leaves {
390            return Err(ReconstructionError::InvalidEndLoc);
391        }
392        let range = start_loc..end_loc;
393
394        let bp =
395            Blueprint::new(self.leaves, range).map_err(|_| ReconstructionError::InvalidSize)?;
396
397        // Slice self.digests into [folded_prefix? | after_peaks... | siblings...]
398        let prefix_digests = usize::from(!bp.fold_prefix.is_empty());
399        let expected_min = prefix_digests + bp.fetch_nodes.len();
400        if self.digests.len() < expected_min {
401            return Err(ReconstructionError::MissingDigests);
402        }
403
404        // Blueprint's fetch_nodes contains after_peaks then the DFS sibling digests. We need to
405        // know how many after_peaks there are to skip over them.
406        let after_start = prefix_digests;
407        let after_peaks_count = bp.after_peaks.len();
408        let after_end = after_start + after_peaks_count;
409        let siblings = &self.digests[after_end..];
410
411        // Collect all peak digests to provide to hasher.root().
412        let mut peak_digests = Vec::new();
413        if !bp.fold_prefix.is_empty() {
414            peak_digests.push(self.digests[0]);
415        }
416
417        let mut sibling_cursor = 0usize;
418        let mut elements_iter = elements.iter();
419        for &peak in &bp.range_peaks {
420            let peak_digest = reconstruct_peak_from_range(
421                hasher,
422                peak,
423                &bp.range,
424                &mut elements_iter,
425                siblings,
426                &mut sibling_cursor,
427                collected.as_deref_mut(),
428            )?;
429            if let Some(ref mut cd) = collected {
430                cd.push((peak.pos, peak_digest));
431            }
432            peak_digests.push(peak_digest);
433        }
434
435        for (i, &after_peak_pos) in bp.after_peaks.iter().enumerate() {
436            let digest = self.digests[after_start + i];
437            if let Some(ref mut cd) = collected {
438                cd.push((after_peak_pos, digest));
439            }
440            peak_digests.push(digest);
441        }
442
443        // Verify all elements were consumed.
444        if elements_iter.next().is_some() {
445            return Err(ReconstructionError::ExtraDigests);
446        }
447
448        // Verify all siblings were consumed.
449        if sibling_cursor != siblings.len() {
450            return Err(ReconstructionError::ExtraDigests);
451        }
452
453        Ok(hasher.root(self.leaves, peak_digests.iter()))
454    }
455}
456
457/// A perfect binary subtree within a peak, identified by its root position, height,
458/// and the first leaf location it covers.
459#[derive(Copy, Clone)]
460pub(crate) struct Subtree<F: Family> {
461    /// Position of the subtree root node.
462    pub pos: Position<F>,
463    pub height: u32,
464    pub leaf_start: Location<F>,
465}
466
467impl<F: Family> Subtree<F> {
468    fn leaf_end(&self) -> Location<F> {
469        self.leaf_start + (1u64 << self.height)
470    }
471
472    fn children(&self) -> (Self, Self) {
473        let (left_pos, right_pos) = F::children(self.pos, self.height);
474        let child_height = self.height - 1;
475        let mid = self.leaf_start + (1u64 << child_height);
476        (
477            Self {
478                pos: left_pos,
479                height: child_height,
480                leaf_start: self.leaf_start,
481            },
482            Self {
483                pos: right_pos,
484                height: child_height,
485                leaf_start: mid,
486            },
487        )
488    }
489}
490
491/// Blueprint for a range proof, separating fold-prefix peaks from nodes that must be fetched.
492pub(crate) struct Blueprint<F: Family> {
493    /// Total number of leaves in the structure this blueprint was built for.
494    leaves: Location<F>,
495    /// The location range this blueprint was built for.
496    pub range: Range<Location<F>>,
497    /// Peak positions that precede the proven range (to be folded into a single accumulator).
498    pub fold_prefix: Vec<Position<F>>,
499    /// Peak positions entirely after the proven range.
500    pub after_peaks: Vec<Position<F>>,
501    /// The peaks that overlap the proven range.
502    pub range_peaks: Vec<Subtree<F>>,
503    /// Node positions to include in the proof: after-peaks followed by DFS siblings.
504    pub fetch_nodes: Vec<Position<F>>,
505}
506
507impl<F: Family> Blueprint<F> {
508    /// Efficiently compute just the fold prefix for a given starting location.
509    pub(crate) fn fold_prefix(
510        leaves: Location<F>,
511        start_loc: Location<F>,
512    ) -> Result<Vec<Position<F>>, super::Error<F>> {
513        let size = Position::<F>::try_from(leaves)?;
514        let mut fold_prefix = Vec::new();
515        let mut leaf_cursor = Location::new(0);
516
517        for (peak_pos, height) in F::peaks(size) {
518            let leaf_end = leaf_cursor + (1u64 << height);
519            if leaf_end <= start_loc {
520                fold_prefix.push(peak_pos);
521            } else {
522                break;
523            }
524            leaf_cursor = leaf_end;
525        }
526
527        Ok(fold_prefix)
528    }
529
530    /// Return a blueprint for building a range proof over the given leaf `range` in a
531    /// structure with `leaves` total leaves.
532    pub(crate) fn new(
533        leaves: Location<F>,
534        range: Range<Location<F>>,
535    ) -> Result<Self, super::Error<F>> {
536        if range.is_empty() {
537            return Err(super::Error::Empty);
538        }
539        let end_minus_one = range
540            .end
541            .checked_sub(1)
542            .expect("can't underflow because range is non-empty");
543        if end_minus_one >= leaves {
544            return Err(super::Error::RangeOutOfBounds(range.end));
545        }
546
547        let size = Position::try_from(leaves)?;
548
549        let mut fold_prefix = Vec::new();
550        let mut after_peaks = Vec::new();
551        let mut range_peaks = Vec::new();
552        let mut leaf_cursor = Location::new(0);
553
554        for (peak_pos, height) in F::peaks(size) {
555            let leaf_start = leaf_cursor;
556            let leaf_end = leaf_start + (1u64 << height);
557
558            if leaf_end <= range.start {
559                fold_prefix.push(peak_pos);
560            } else if leaf_start >= range.end {
561                after_peaks.push(peak_pos);
562            } else {
563                range_peaks.push(Subtree {
564                    pos: peak_pos,
565                    height,
566                    leaf_start,
567                });
568            }
569            leaf_cursor = leaf_end;
570        }
571
572        assert!(
573            !range_peaks.is_empty(),
574            "at least one peak must contain range elements"
575        );
576
577        let mut fetch_nodes = after_peaks.clone();
578        for &peak in &range_peaks {
579            collect_siblings_dfs(peak, &range, &mut fetch_nodes);
580        }
581
582        Ok(Self {
583            leaves,
584            range,
585            fold_prefix,
586            after_peaks,
587            range_peaks,
588            fetch_nodes,
589        })
590    }
591
592    /// Build a range proof from this blueprint and a node-fetching closure.
593    ///
594    /// The prover folds prefix peak digests into a single accumulator. The resulting proof
595    /// contains: `[fold_acc? | after_peaks... | siblings_dfs...]`.
596    ///
597    /// Returns an error via `element_pruned` if `get_node` returns `None` for any required
598    /// position.
599    pub(crate) fn build_proof<D, H, E>(
600        self,
601        hasher: &H,
602        get_node: impl Fn(Position<F>) -> Option<D>,
603        element_pruned: impl Fn(Position<F>) -> E,
604    ) -> Result<Proof<F, D>, E>
605    where
606        D: Digest,
607        H: Hasher<F, Digest = D>,
608    {
609        let mut digests = Vec::with_capacity(
610            if self.fold_prefix.is_empty() { 0 } else { 1 } + self.fetch_nodes.len(),
611        );
612
613        if let Some((&first_pos, rest)) = self.fold_prefix.split_first() {
614            let first = get_node(first_pos).ok_or_else(|| element_pruned(first_pos))?;
615            let acc = rest.iter().try_fold(first, |acc, &pos| {
616                let d = get_node(pos).ok_or_else(|| element_pruned(pos))?;
617                Ok(hasher.fold(&acc, &d))
618            })?;
619            digests.push(acc);
620        }
621
622        for &pos in &self.fetch_nodes {
623            digests.push(get_node(pos).ok_or_else(|| element_pruned(pos))?);
624        }
625
626        Ok(Proof {
627            leaves: self.leaves,
628            digests,
629        })
630    }
631}
632
633/// The maximum number of digests in a proof per element being proven.
634///
635/// This accounts for the worst case proof size, in an MMR/MMB with 62 peaks. The
636/// left-most leaf in such a tree requires 122 digests, for 61 path siblings
637/// and 61 peak digests.
638pub const MAX_PROOF_DIGESTS_PER_ELEMENT: usize = 122;
639
640/// Build a range proof from a node-fetching closure. This is the generic implementation
641/// shared by all Merkle families. The `element_pruned` closure is called when `get_node`
642/// returns `None` for a required position.
643pub(crate) fn build_range_proof<F, D, H, E>(
644    hasher: &H,
645    leaves: Location<F>,
646    range: Range<Location<F>>,
647    get_node: impl Fn(Position<F>) -> Option<D>,
648    element_pruned: impl Fn(Position<F>) -> E,
649) -> Result<Proof<F, D>, E>
650where
651    F: Family,
652    D: Digest,
653    H: Hasher<F, Digest = D>,
654    E: From<super::Error<F>>,
655{
656    Blueprint::new(leaves, range)?.build_proof(hasher, get_node, element_pruned)
657}
658
659/// Returns the positions of the minimal set of nodes whose digests are required to prove the
660/// inclusion of the elements at the specified `locations`. This is the generic implementation
661/// shared by all Merkle families.
662#[cfg(any(feature = "std", test))]
663pub(crate) fn nodes_required_for_multi_proof<F: Family>(
664    leaves: Location<F>,
665    locations: &[Location<F>],
666) -> Result<BTreeSet<Position<F>>, super::Error<F>> {
667    if locations.is_empty() {
668        return Err(super::Error::Empty);
669    }
670    locations.iter().try_fold(BTreeSet::new(), |mut acc, loc| {
671        if !loc.is_valid_index() {
672            return Err(super::Error::LocationOverflow(*loc));
673        }
674        let bp = Blueprint::new(leaves, *loc..*loc + 1)?;
675        acc.extend(bp.fold_prefix);
676        acc.extend(bp.fetch_nodes);
677        Ok(acc)
678    })
679}
680
681/// Collect sibling positions needed to reconstruct a peak digest from a range of elements, in
682/// left-first DFS order. This mirrors the traversal order of [`reconstruct_peak_from_range`].
683///
684/// At each node: if the subtree is entirely outside the range, its root position is emitted. If
685/// it's a leaf in the range, nothing is emitted. Otherwise, recurse into children via
686/// [`Family::children`].
687pub(crate) fn collect_siblings_dfs<F: Family>(
688    node: Subtree<F>,
689    range: &Range<Location<F>>,
690    out: &mut Vec<Position<F>>,
691) {
692    if node.leaf_end() <= range.start || node.leaf_start >= range.end {
693        out.push(node.pos);
694        return;
695    }
696
697    if node.height > 0 {
698        let (left, right) = node.children();
699        collect_siblings_dfs::<F>(left, range, out);
700        collect_siblings_dfs::<F>(right, range, out);
701    }
702}
703
704/// Reconstruct the digest of a peak subtree from a range of elements and sibling digests, consuming
705/// both in left-first DFS order matching [`collect_siblings_dfs`].
706///
707/// At each node:
708/// - If the subtree is entirely outside the range: consume a sibling digest.
709/// - If it's a leaf in the range: hash the next element.
710/// - Otherwise: recurse into children via [`Family::children`] and compute the node digest.
711///
712/// If `collected` is `Some`, every child `(position, digest)` pair encountered during
713/// reconstruction is appended to the vector.
714pub(crate) fn reconstruct_peak_from_range<F, D, H, E>(
715    hasher: &H,
716    node: Subtree<F>,
717    range: &Range<Location<F>>,
718    elements: &mut E,
719    siblings: &[D],
720    cursor: &mut usize,
721    mut collected: Option<&mut Vec<(Position<F>, D)>>,
722) -> Result<D, ReconstructionError>
723where
724    F: Family,
725    D: Digest,
726    H: Hasher<F, Digest = D>,
727    E: Iterator<Item: AsRef<[u8]>>,
728{
729    // Entirely outside the range: consume a sibling digest.
730    if node.leaf_end() <= range.start || node.leaf_start >= range.end {
731        let Some(digest) = siblings.get(*cursor).copied() else {
732            return Err(ReconstructionError::MissingDigests);
733        };
734        *cursor += 1;
735        return Ok(digest);
736    }
737
738    // Leaf in range: hash the next element.
739    if node.height == 0 {
740        let elem = elements
741            .next()
742            .ok_or(ReconstructionError::MissingElements)?;
743        return Ok(hasher.leaf_digest(node.pos, elem.as_ref()));
744    }
745
746    // Recurse into children.
747    let (left, right) = node.children();
748    let left_pos = left.pos;
749    let right_pos = right.pos;
750
751    let left_d = reconstruct_peak_from_range::<F, D, H, E>(
752        hasher,
753        left,
754        range,
755        elements,
756        siblings,
757        cursor,
758        collected.as_deref_mut(),
759    )?;
760    let right_d = reconstruct_peak_from_range::<F, D, H, E>(
761        hasher,
762        right,
763        range,
764        elements,
765        siblings,
766        cursor,
767        collected.as_deref_mut(),
768    )?;
769
770    if let Some(ref mut cd) = collected {
771        cd.push((left_pos, left_d));
772        cd.push((right_pos, right_d));
773    }
774
775    Ok(hasher.node_digest(node.pos, &left_d, &right_d))
776}
777
778#[cfg(feature = "arbitrary")]
779impl<F: Family, D: Digest> arbitrary::Arbitrary<'_> for Proof<F, D>
780where
781    D: for<'a> arbitrary::Arbitrary<'a>,
782{
783    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
784        Ok(Self {
785            leaves: u.arbitrary()?,
786            digests: u.arbitrary()?,
787        })
788    }
789}
790
791#[cfg(test)]
792mod tests {
793    use super::*;
794    use crate::merkle::{
795        hasher::Standard,
796        mem::Mem,
797        mmb, mmr,
798        proof::{nodes_required_for_multi_proof, Blueprint, Proof},
799        Family, Location, LocationRangeExt as _,
800    };
801    use alloc::vec;
802    use commonware_codec::{Decode, Encode, EncodeSize};
803    use commonware_cryptography::{sha256, Sha256};
804    use commonware_macros::test_traced;
805
806    type D = sha256::Digest;
807    type H = Standard<Sha256>;
808
809    fn test_digest(v: u8) -> D {
810        <Sha256 as commonware_cryptography::Hasher>::hash(&[v])
811    }
812
813    /// Build an in-memory Merkle structure with `n` elements (element i = i.to_be_bytes()).
814    fn build_raw<F: Family>(hasher: &H, n: u64) -> Mem<F, D> {
815        let mut mem = Mem::new(hasher);
816        let batch = {
817            let mut batch = mem.new_batch();
818            for i in 0..n {
819                batch = batch.add(hasher, &i.to_be_bytes());
820            }
821            batch.merkleize(&mem, hasher)
822        };
823        mem.apply_batch(&batch).unwrap();
824        mem
825    }
826
827    fn empty_proof<F: Family>() {
828        // Test that an empty proof authenticates an empty structure.
829        let hasher = H::new();
830        let mem = Mem::<F, D>::new(&hasher);
831        let root = mem.root();
832        let proof: Proof<F, D> = Proof::default();
833        assert!(proof.verify_range_inclusion(&hasher, &[] as &[D], Location::new(0), root));
834
835        // Any starting position other than 0 should fail to verify.
836        assert!(!proof.verify_range_inclusion(&hasher, &[] as &[D], Location::new(1), root));
837
838        // Invalid root should fail to verify.
839        let td = test_digest(0);
840        assert!(!proof.verify_range_inclusion(&hasher, &[] as &[D], Location::new(0), &td));
841
842        // Non-empty elements list should fail to verify.
843        assert!(!proof.verify_range_inclusion(&hasher, &[td], Location::new(0), root));
844    }
845
846    fn verify_element<F: Family>() {
847        // Create an 11 element structure and test single-element inclusion proofs.
848        let element = D::from(*b"01234567012345670123456701234567");
849        let hasher = H::new();
850        let mut mem = Mem::<F, D>::new(&hasher);
851        let batch = {
852            let mut batch = mem.new_batch();
853            for _ in 0..11 {
854                batch = batch.add(&hasher, &element);
855            }
856            batch.merkleize(&mem, &hasher)
857        };
858        mem.apply_batch(&batch).unwrap();
859        let root = mem.root();
860
861        // Confirm the proof of inclusion for each leaf verifies.
862        for leaf in 0u64..11 {
863            let leaf = Location::new(leaf);
864            let proof: Proof<F, D> = mem.proof(&hasher, leaf).unwrap();
865            assert!(
866                proof.verify_element_inclusion(&hasher, &element, leaf, root),
867                "valid proof should verify successfully"
868            );
869        }
870
871        // Create a valid proof, then confirm various mangling of the proof or proof args results in
872        // verification failure.
873        let leaf = Location::<F>::new(10);
874        let proof = mem.proof(&hasher, leaf).unwrap();
875        assert!(
876            proof.verify_element_inclusion(&hasher, &element, leaf, root),
877            "proof verification should be successful"
878        );
879        assert!(
880            !proof.verify_element_inclusion(&hasher, &element, leaf + 1, root),
881            "proof verification should fail with incorrect element position"
882        );
883        assert!(
884            !proof.verify_element_inclusion(&hasher, &element, leaf - 1, root),
885            "proof verification should fail with incorrect element position 2"
886        );
887        assert!(
888            !proof.verify_element_inclusion(&hasher, &test_digest(0), leaf, root),
889            "proof verification should fail with mangled element"
890        );
891        let root2 = test_digest(0);
892        assert!(
893            !proof.verify_element_inclusion(&hasher, &element, leaf, &root2),
894            "proof verification should fail with mangled root"
895        );
896        let mut proof2 = proof.clone();
897        proof2.digests[0] = test_digest(0);
898        assert!(
899            !proof2.verify_element_inclusion(&hasher, &element, leaf, root),
900            "proof verification should fail with mangled proof hash"
901        );
902        proof2 = proof.clone();
903        proof2.leaves = Location::new(10);
904        assert!(
905            !proof2.verify_element_inclusion(&hasher, &element, leaf, root),
906            "proof verification should fail with incorrect leaves"
907        );
908        proof2 = proof.clone();
909        proof2.digests.push(test_digest(0));
910        assert!(
911            !proof2.verify_element_inclusion(&hasher, &element, leaf, root),
912            "proof verification should fail with extra hash"
913        );
914        proof2 = proof.clone();
915        while !proof2.digests.is_empty() {
916            proof2.digests.pop();
917            assert!(
918                !proof2.verify_element_inclusion(&hasher, &element, leaf, root),
919                "proof verification should fail with missing digests"
920            );
921        }
922        // Inserting an extra digest in the middle should cause verification failure.
923        if proof.digests.len() >= 2 {
924            proof2 = proof.clone();
925            proof2.digests.clear();
926            proof2.digests.extend(proof.digests[0..1].iter().cloned());
927            proof2.digests.push(test_digest(0));
928            proof2.digests.extend(proof.digests[1..].iter().cloned());
929            assert!(
930                !proof2.verify_element_inclusion(&hasher, &element, leaf, root),
931                "proof verification should fail with extra hash even if it's unused by the computation"
932            );
933        }
934    }
935
936    fn verify_range<F: Family>() {
937        // Create a structure and add 49 elements.
938        let hasher = H::new();
939        let mut mem = Mem::<F, D>::new(&hasher);
940        let elements: Vec<_> = (0..49).map(test_digest).collect();
941        let batch = {
942            let mut batch = mem.new_batch();
943            for element in &elements {
944                batch = batch.add(&hasher, element);
945            }
946            batch.merkleize(&mem, &hasher)
947        };
948        mem.apply_batch(&batch).unwrap();
949        let root = mem.root();
950
951        // Test range proofs over all possible ranges of at least 2 elements.
952        for i in 0..elements.len() {
953            for j in i + 1..elements.len() {
954                let range = Location::new(i as u64)..Location::new(j as u64);
955                let range_proof = mem.range_proof(&hasher, range.clone()).unwrap();
956                assert!(
957                    range_proof.verify_range_inclusion(
958                        &hasher,
959                        &elements[range.to_usize_range()],
960                        range.start,
961                        root,
962                    ),
963                    "valid range proof should verify successfully {i}:{j}",
964                );
965            }
966        }
967
968        // Create a proof over a range, confirm it verifies, then mangle it in various ways.
969        let range = Location::new(33)..Location::new(40);
970        let range_proof = mem.range_proof(&hasher, range.clone()).unwrap();
971        let valid_elements = &elements[range.to_usize_range()];
972        assert!(
973            range_proof.verify_range_inclusion(&hasher, valid_elements, range.start, root),
974            "valid range proof should verify successfully"
975        );
976        // Remove digests from the proof until it's empty.
977        let mut invalid_proof = range_proof.clone();
978        for _i in 0..range_proof.digests.len() {
979            invalid_proof.digests.remove(0);
980            assert!(
981                !invalid_proof.verify_range_inclusion(&hasher, valid_elements, range.start, root,),
982                "range proof with removed elements should fail"
983            );
984        }
985        // Confirm proof verification fails when providing an element range different than the one
986        // used to generate the proof.
987        for i in 0..elements.len() {
988            for j in i + 1..elements.len() {
989                if Location::<F>::from(i) == range.start && Location::<F>::from(j) == range.end {
990                    continue;
991                }
992                assert!(
993                    !range_proof.verify_range_inclusion(
994                        &hasher,
995                        &elements[i..j],
996                        range.start,
997                        root,
998                    ),
999                    "range proof with invalid element range should fail {i}:{j}",
1000                );
1001            }
1002        }
1003        // Confirm proof fails to verify with an invalid root.
1004        let invalid_root = test_digest(1);
1005        assert!(
1006            !range_proof.verify_range_inclusion(
1007                &hasher,
1008                valid_elements,
1009                range.start,
1010                &invalid_root,
1011            ),
1012            "range proof with invalid root should fail"
1013        );
1014        // Mangle each element of the proof and confirm it fails to verify.
1015        for i in 0..range_proof.digests.len() {
1016            let mut invalid_proof = range_proof.clone();
1017            invalid_proof.digests[i] = test_digest(0);
1018            assert!(
1019                !invalid_proof.verify_range_inclusion(&hasher, valid_elements, range.start, root,),
1020                "mangled range proof should fail verification"
1021            );
1022        }
1023        // Inserting elements into the proof should also cause it to fail (malleability check)
1024        for i in 0..range_proof.digests.len() {
1025            let mut invalid_proof = range_proof.clone();
1026            invalid_proof.digests.insert(i, test_digest(0));
1027            assert!(
1028                !invalid_proof.verify_range_inclusion(&hasher, valid_elements, range.start, root,),
1029                "mangled range proof should fail verification. inserted element at: {i}",
1030            );
1031        }
1032        // Bad start_loc should cause verification to fail.
1033        for loc in 0..elements.len() {
1034            let loc = Location::new(loc as u64);
1035            if loc == range.start {
1036                continue;
1037            }
1038            assert!(
1039                !range_proof.verify_range_inclusion(&hasher, valid_elements, loc, root),
1040                "bad start_loc should fail verification {loc}",
1041            );
1042        }
1043    }
1044
1045    fn retained_nodes_provable_after_pruning<F: Family>() {
1046        // Create a structure and add 49 elements.
1047        let hasher = H::new();
1048        let mut mem = Mem::<F, D>::new(&hasher);
1049        let elements: Vec<_> = (0..49).map(test_digest).collect();
1050        let batch = {
1051            let mut batch = mem.new_batch();
1052            for element in &elements {
1053                batch = batch.add(&hasher, element);
1054            }
1055            batch.merkleize(&mem, &hasher)
1056        };
1057        mem.apply_batch(&batch).unwrap();
1058
1059        // Confirm we can successfully prove all retained elements after pruning.
1060        let root = *mem.root();
1061        for prune_leaf in 1..*mem.leaves() {
1062            let prune_loc = Location::new(prune_leaf);
1063            mem.prune(prune_loc).unwrap();
1064            let pruned_root = mem.root();
1065            assert_eq!(root, *pruned_root);
1066            for loc in 0..elements.len() {
1067                let loc = Location::new(loc as u64);
1068                let proof = mem.proof(&hasher, loc);
1069                if loc < prune_loc {
1070                    continue;
1071                }
1072                assert!(proof.is_ok());
1073                assert!(proof.unwrap().verify_element_inclusion(
1074                    &hasher,
1075                    &elements[*loc as usize],
1076                    loc,
1077                    &root
1078                ));
1079            }
1080        }
1081    }
1082
1083    fn ranges_provable_after_pruning<F: Family>() {
1084        // Create a structure and add 49 elements.
1085        let hasher = H::new();
1086        let mut mem = Mem::<F, D>::new(&hasher);
1087        let mut elements: Vec<_> = (0..49).map(test_digest).collect();
1088        let batch = {
1089            let mut batch = mem.new_batch();
1090            for element in &elements {
1091                batch = batch.add(&hasher, element);
1092            }
1093            batch.merkleize(&mem, &hasher)
1094        };
1095        mem.apply_batch(&batch).unwrap();
1096
1097        // Prune up to the first peak.
1098        let prune_loc = Location::<F>::new(32);
1099        mem.prune(prune_loc).unwrap();
1100        assert_eq!(mem.bounds().start, prune_loc);
1101
1102        // Test range proofs over all possible ranges of at least 2 elements
1103        let root = mem.root();
1104        for i in 0..elements.len() - 1 {
1105            if Location::<F>::new(i as u64) < prune_loc {
1106                continue;
1107            }
1108            for j in (i + 2)..elements.len() {
1109                let range = Location::new(i as u64)..Location::new(j as u64);
1110                let range_proof = mem.range_proof(&hasher, range.clone()).unwrap();
1111                assert!(
1112                    range_proof.verify_range_inclusion(
1113                        &hasher,
1114                        &elements[range.to_usize_range()],
1115                        range.start,
1116                        root,
1117                    ),
1118                    "valid range proof over remaining elements should verify successfully",
1119                );
1120            }
1121        }
1122
1123        // Add more nodes, prune again, and test again.
1124        let new_elements: Vec<_> = (0..37).map(test_digest).collect();
1125        let batch = {
1126            let mut batch = mem.new_batch();
1127            for element in &new_elements {
1128                batch = batch.add(&hasher, element);
1129            }
1130            batch.merkleize(&mem, &hasher)
1131        };
1132        mem.apply_batch(&batch).unwrap();
1133        elements.extend(new_elements);
1134        mem.prune(Location::new(66)).unwrap();
1135        assert_eq!(mem.bounds().start, Location::new(66));
1136
1137        let updated_root = mem.root();
1138        let range = Location::new(elements.len() as u64 - 10)..Location::new(elements.len() as u64);
1139        let range_proof = mem.range_proof(&hasher, range.clone()).unwrap();
1140        assert!(
1141            range_proof.verify_range_inclusion(
1142                &hasher,
1143                &elements[range.to_usize_range()],
1144                range.start,
1145                updated_root,
1146            ),
1147            "valid range proof over remaining elements after 2 pruning rounds should verify",
1148        );
1149    }
1150
1151    fn proof_serialization<F: Family>() {
1152        // Create a structure and add 25 elements.
1153        let hasher = H::new();
1154        let mut mem = Mem::<F, D>::new(&hasher);
1155        let elements: Vec<_> = (0..25).map(test_digest).collect();
1156        let batch = {
1157            let mut batch = mem.new_batch();
1158            for element in &elements {
1159                batch = batch.add(&hasher, element);
1160            }
1161            batch.merkleize(&mem, &hasher)
1162        };
1163        mem.apply_batch(&batch).unwrap();
1164
1165        // Generate proofs over all possible ranges of elements and confirm each
1166        // serializes=>deserializes correctly.
1167        for i in 0..elements.len() {
1168            for j in i + 1..elements.len() {
1169                let range = Location::new(i as u64)..Location::new(j as u64);
1170                let proof = mem.range_proof(&hasher, range).unwrap();
1171
1172                let expected_size = proof.encode_size();
1173                let serialized_proof = proof.encode();
1174                assert_eq!(
1175                    serialized_proof.len(),
1176                    expected_size,
1177                    "serialized proof should have expected size"
1178                );
1179                let max_digests = proof.digests.len();
1180                let deserialized_proof =
1181                    Proof::<F, D>::decode_cfg(serialized_proof, &max_digests).unwrap();
1182                assert_eq!(
1183                    proof, deserialized_proof,
1184                    "deserialized proof should match source proof"
1185                );
1186
1187                // Remove one byte from the end and confirm it fails to deserialize.
1188                let serialized_proof = proof.encode();
1189                let serialized_proof = serialized_proof.slice(0..serialized_proof.len() - 1);
1190                assert!(
1191                    Proof::<F, D>::decode_cfg(serialized_proof, &max_digests).is_err(),
1192                    "proof should not deserialize with truncated data"
1193                );
1194
1195                // Add extra data and confirm it fails to deserialize.
1196                let mut serialized_proof = proof.encode_mut();
1197                serialized_proof.extend_from_slice(&[0; 10]);
1198                let serialized_proof = serialized_proof;
1199                assert!(
1200                    Proof::<F, D>::decode_cfg(serialized_proof, &max_digests).is_err(),
1201                    "proof should not deserialize with extra data"
1202                );
1203
1204                // Confirm deserialization fails when max_digests is too small.
1205                let actual_digests = proof.digests.len();
1206                if actual_digests > 0 {
1207                    let too_small = actual_digests - 1;
1208                    let serialized_proof = proof.encode();
1209                    assert!(
1210                        Proof::<F, D>::decode_cfg(serialized_proof, &too_small).is_err(),
1211                        "proof should not deserialize with max_digests too small"
1212                    );
1213                }
1214            }
1215        }
1216    }
1217
1218    fn multi_proof_generation_and_verify<F: Family>() {
1219        // Create a structure with 20 elements.
1220        let hasher = H::new();
1221        let mut mem = Mem::<F, D>::new(&hasher);
1222        let elements: Vec<_> = (0..20).map(test_digest).collect();
1223        let batch = {
1224            let mut batch = mem.new_batch();
1225            for element in &elements {
1226                batch = batch.add(&hasher, element);
1227            }
1228            batch.merkleize(&mem, &hasher)
1229        };
1230        mem.apply_batch(&batch).unwrap();
1231
1232        let root = mem.root();
1233
1234        // Generate proof for non-contiguous single elements.
1235        let locations = &[Location::new(0), Location::new(5), Location::new(10)];
1236        let nodes_for_multi_proof =
1237            nodes_required_for_multi_proof(mem.leaves(), locations).expect("test locations valid");
1238        let digests = nodes_for_multi_proof
1239            .into_iter()
1240            .map(|pos| mem.get_node(pos).unwrap())
1241            .collect();
1242        let multi_proof = Proof {
1243            leaves: mem.leaves(),
1244            digests,
1245        };
1246
1247        assert_eq!(multi_proof.leaves, mem.leaves());
1248
1249        // Verify the proof.
1250        assert!(multi_proof.verify_multi_inclusion(
1251            &hasher,
1252            &[
1253                (elements[0], Location::new(0)),
1254                (elements[5], Location::new(5)),
1255                (elements[10], Location::new(10)),
1256            ],
1257            root
1258        ));
1259
1260        // Verify in different order.
1261        assert!(multi_proof.verify_multi_inclusion(
1262            &hasher,
1263            &[
1264                (elements[10], Location::new(10)),
1265                (elements[5], Location::new(5)),
1266                (elements[0], Location::new(0)),
1267            ],
1268            root
1269        ));
1270
1271        // Verify with duplicate items.
1272        assert!(multi_proof.verify_multi_inclusion(
1273            &hasher,
1274            &[
1275                (elements[0], Location::new(0)),
1276                (elements[0], Location::new(0)),
1277                (elements[10], Location::new(10)),
1278                (elements[5], Location::new(5)),
1279            ],
1280            root
1281        ));
1282
1283        // Verify mangling the location to something invalid should fail.
1284        let mut wrong_size_proof = multi_proof.clone();
1285        wrong_size_proof.leaves = Location::new(*F::MAX_LEAVES + 2);
1286        assert!(!wrong_size_proof.verify_multi_inclusion(
1287            &hasher,
1288            &[
1289                (elements[0], Location::new(0)),
1290                (elements[5], Location::new(5)),
1291                (elements[10], Location::new(10)),
1292            ],
1293            root,
1294        ));
1295
1296        // Verify with wrong positions.
1297        assert!(!multi_proof.verify_multi_inclusion(
1298            &hasher,
1299            &[
1300                (elements[0], Location::new(1)),
1301                (elements[5], Location::new(6)),
1302                (elements[10], Location::new(11)),
1303            ],
1304            root,
1305        ));
1306
1307        // Verify with wrong elements.
1308        let wrong_elements = [
1309            vec![255u8, 254u8, 253u8],
1310            vec![252u8, 251u8, 250u8],
1311            vec![249u8, 248u8, 247u8],
1312        ];
1313        let wrong_verification = multi_proof.verify_multi_inclusion(
1314            &hasher,
1315            &[
1316                (wrong_elements[0].as_slice(), Location::new(0)),
1317                (wrong_elements[1].as_slice(), Location::new(5)),
1318                (wrong_elements[2].as_slice(), Location::new(10)),
1319            ],
1320            root,
1321        );
1322        assert!(!wrong_verification, "Should fail with wrong elements");
1323
1324        // Verify with out of range element.
1325        let wrong_verification = multi_proof.verify_multi_inclusion(
1326            &hasher,
1327            &[
1328                (elements[0], Location::new(0)),
1329                (elements[5], Location::new(5)),
1330                (elements[10], Location::new(1000)),
1331            ],
1332            root,
1333        );
1334        assert!(
1335            !wrong_verification,
1336            "Should fail with out of range elements"
1337        );
1338
1339        // Verify with wrong root should fail.
1340        let wrong_root = test_digest(99);
1341        assert!(!multi_proof.verify_multi_inclusion(
1342            &hasher,
1343            &[
1344                (elements[0], Location::new(0)),
1345                (elements[5], Location::new(5)),
1346                (elements[10], Location::new(10)),
1347            ],
1348            &wrong_root
1349        ));
1350
1351        // Empty multi-proof.
1352        let hasher = H::new();
1353        let empty_mem = Mem::<F, D>::new(&hasher);
1354        let empty_root = empty_mem.root();
1355        let empty_proof: Proof<F, D> = Proof::default();
1356        assert!(empty_proof.verify_multi_inclusion(
1357            &hasher,
1358            &[] as &[(D, Location<F>)],
1359            empty_root
1360        ));
1361
1362        // Malformed empty proof with extra digests must be rejected.
1363        let malformed_proof: Proof<F, D> = Proof {
1364            leaves: Location::new(0),
1365            digests: vec![test_digest(0)],
1366        };
1367        assert!(!malformed_proof.verify_multi_inclusion(
1368            &hasher,
1369            &[] as &[(D, Location<F>)],
1370            empty_root
1371        ));
1372    }
1373
1374    fn multi_proof_deduplication<F: Family>() {
1375        let hasher = H::new();
1376        let mut mem = Mem::<F, D>::new(&hasher);
1377        let elements: Vec<_> = (0..30).map(test_digest).collect();
1378        let batch = {
1379            let mut batch = mem.new_batch();
1380            for element in &elements {
1381                batch = batch.add(&hasher, element);
1382            }
1383            batch.merkleize(&mem, &hasher)
1384        };
1385        mem.apply_batch(&batch).unwrap();
1386
1387        // Get individual proofs that will share some digests (elements in same subtree).
1388        let proof1 = mem.proof(&hasher, Location::new(0)).unwrap();
1389        let proof2 = mem.proof(&hasher, Location::new(1)).unwrap();
1390        let total_digests_separate = proof1.digests.len() + proof2.digests.len();
1391
1392        // Generate multi-proof for the same positions.
1393        let locations = &[Location::new(0), Location::new(1)];
1394        let multi_proof_nodes =
1395            nodes_required_for_multi_proof(mem.leaves(), locations).expect("test locations valid");
1396        let digests = multi_proof_nodes
1397            .into_iter()
1398            .map(|pos| mem.get_node(pos).unwrap())
1399            .collect();
1400        let multi_proof = Proof {
1401            leaves: mem.leaves(),
1402            digests,
1403        };
1404
1405        // The combined proof should have fewer digests due to deduplication.
1406        assert!(multi_proof.digests.len() < total_digests_separate);
1407
1408        // Verify it still works.
1409        let root = mem.root();
1410        assert!(multi_proof.verify_multi_inclusion(
1411            &hasher,
1412            &[
1413                (elements[0], Location::new(0)),
1414                (elements[1], Location::new(1))
1415            ],
1416            root
1417        ));
1418    }
1419
1420    fn proof_leaves_malleability<F: Family>() {
1421        let hasher = H::new();
1422        let mut mem = Mem::<F, D>::new(&hasher);
1423
1424        // 252 leaves. Leaf 240 sits in a peak preceded by prefix peaks.
1425        let elements: Vec<D> = (0..252u16)
1426            .map(|i| <Sha256 as commonware_cryptography::Hasher>::hash(&i.to_be_bytes()))
1427            .collect();
1428        let batch = {
1429            let mut batch = mem.new_batch();
1430            for e in &elements {
1431                batch = batch.add(&hasher, e);
1432            }
1433            batch.merkleize(&mem, &hasher)
1434        };
1435        mem.apply_batch(&batch).unwrap();
1436        let root = mem.root();
1437
1438        let loc = Location::new(240);
1439        let proof = mem.proof(&hasher, loc).unwrap();
1440        assert!(proof.verify_element_inclusion(&hasher, &elements[240], loc, root));
1441
1442        // Tamper with the leaves field (249 has the same peak layout for leaf 240).
1443        let mut tampered = proof.clone();
1444        tampered.leaves = Location::new(249);
1445        assert_ne!(tampered, proof);
1446        assert!(
1447            !tampered.verify_element_inclusion(&hasher, &elements[240], loc, root),
1448            "proof with tampered leaves field must not verify"
1449        );
1450    }
1451
1452    fn blueprint_errors<F: Family>() {
1453        let leaves = Location::<F>::new(10);
1454
1455        // Empty range.
1456        assert!(matches!(
1457            Blueprint::<F>::new(leaves, Location::new(3)..Location::new(3)),
1458            Err(crate::merkle::Error::Empty)
1459        ));
1460
1461        // Out of bounds.
1462        assert!(matches!(
1463            Blueprint::<F>::new(leaves, Location::new(0)..Location::new(11)),
1464            Err(crate::merkle::Error::RangeOutOfBounds(_))
1465        ));
1466
1467        // Empty locations for multi-proof.
1468        assert!(matches!(
1469            nodes_required_for_multi_proof::<F>(leaves, &[]),
1470            Err(crate::merkle::Error::Empty)
1471        ));
1472    }
1473
1474    fn single_element_proof_reconstruction<F: Family>() {
1475        for n in 1u64..=64 {
1476            let hasher = H::new();
1477            let mem = build_raw::<F>(&hasher, n);
1478            let root = *mem.root();
1479
1480            for loc_idx in 0..n {
1481                let proof = mem
1482                    .proof(&hasher, Location::new(loc_idx))
1483                    .unwrap_or_else(|e| panic!("n={n}, loc={loc_idx}: build failed: {e:?}"));
1484
1485                let elements = [loc_idx.to_be_bytes()];
1486                let start_loc = Location::new(loc_idx);
1487
1488                let reconstructed = proof
1489                    .reconstruct_root(&hasher, &elements, start_loc)
1490                    .unwrap_or_else(|e| panic!("n={n}, loc={loc_idx}: reconstruct failed: {e:?}"));
1491                assert_eq!(reconstructed, root, "n={n}, loc={loc_idx}: root mismatch");
1492            }
1493        }
1494    }
1495
1496    fn range_proof_reconstruction<F: Family>() {
1497        for n in 2u64..=32 {
1498            let hasher = H::new();
1499            let mem = build_raw::<F>(&hasher, n);
1500            let root = *mem.root();
1501
1502            let ranges: Vec<(u64, u64)> = vec![
1503                (0, n),
1504                (0, 1),
1505                (n - 1, n),
1506                (0, n.min(3)),
1507                (n.saturating_sub(3), n),
1508            ];
1509
1510            for (start, end) in ranges {
1511                if start >= end || end > n {
1512                    continue;
1513                }
1514                let proof = mem
1515                    .range_proof(&hasher, Location::new(start)..Location::new(end))
1516                    .unwrap_or_else(|e| panic!("n={n}, range={start}..{end}: build failed: {e:?}"));
1517                let elements: Vec<_> = (start..end).map(|i| i.to_be_bytes()).collect();
1518                let start_loc = Location::new(start);
1519
1520                let reconstructed = proof
1521                    .reconstruct_root(&hasher, &elements, start_loc)
1522                    .unwrap_or_else(|e| {
1523                        panic!("n={n}, range={start}..{end}: reconstruct failed: {e}")
1524                    });
1525                assert_eq!(
1526                    reconstructed, root,
1527                    "n={n}, range={start}..{end}: root mismatch"
1528                );
1529            }
1530        }
1531    }
1532
1533    fn verify_element_inclusion<F: Family>() {
1534        for n in 1u64..=32 {
1535            let hasher = H::new();
1536            let mem = build_raw::<F>(&hasher, n);
1537            let root = *mem.root();
1538
1539            for loc_idx in 0..n {
1540                let proof = mem.proof(&hasher, Location::new(loc_idx)).unwrap();
1541                let loc = Location::new(loc_idx);
1542
1543                assert!(
1544                    proof.verify_element_inclusion(&hasher, &loc_idx.to_be_bytes(), loc, &root),
1545                    "n={n}, loc={loc_idx}: verification failed"
1546                );
1547
1548                // Wrong element should fail.
1549                assert!(
1550                    !proof.verify_element_inclusion(
1551                        &hasher,
1552                        &(loc_idx + 1000).to_be_bytes(),
1553                        loc,
1554                        &root,
1555                    ),
1556                    "n={n}, loc={loc_idx}: wrong element should not verify"
1557                );
1558            }
1559        }
1560    }
1561
1562    fn full_range<F: Family>() {
1563        for n in 1u64..=32 {
1564            let hasher = H::new();
1565            let mem = build_raw::<F>(&hasher, n);
1566            let root = *mem.root();
1567
1568            let proof = mem
1569                .range_proof(&hasher, Location::new(0)..Location::new(n))
1570                .unwrap();
1571            let elements: Vec<_> = (0..n).map(|i| i.to_be_bytes()).collect();
1572            let reconstructed = proof
1573                .reconstruct_root(&hasher, &elements, Location::new(0))
1574                .unwrap();
1575            assert_eq!(reconstructed, root, "n={n}: full range failed");
1576
1577            // Full range should have 0 digests.
1578            assert_eq!(
1579                proof.digests.len(),
1580                0,
1581                "n={n}: full range proof should have 0 digests"
1582            );
1583        }
1584    }
1585
1586    fn empty_proof_verifies_empty_tree<F: Family>() {
1587        let hasher = H::new();
1588        let mem = Mem::<F, D>::new(&hasher);
1589        let root = *mem.root();
1590        let proof = Proof::<F, D>::default();
1591
1592        // Empty proof should verify against the empty root.
1593        assert!(proof.verify_range_inclusion(&hasher, &[] as &[&[u8]], Location::new(0), &root,));
1594
1595        // Non-zero start_loc with empty elements should fail.
1596        assert!(!proof.verify_range_inclusion(&hasher, &[] as &[&[u8]], Location::new(1), &root,));
1597    }
1598
1599    fn every_element_contributes_to_root<F: Family>() {
1600        for n in [8u64, 13, 20, 32] {
1601            let hasher = H::new();
1602            let mem = build_raw::<F>(&hasher, n);
1603            let root = *mem.root();
1604
1605            let start = 1;
1606            let end = n - 1;
1607            let proof = mem
1608                .range_proof(&hasher, Location::new(start)..Location::new(end))
1609                .unwrap();
1610            let elements: Vec<_> = (start..end).map(|i| i.to_be_bytes()).collect();
1611
1612            // Valid elements verify.
1613            assert!(
1614                proof.verify_range_inclusion(&hasher, &elements, Location::new(start), &root),
1615                "n={n}: valid range should verify"
1616            );
1617
1618            // Flipping one byte in each element must cause failure.
1619            for flip_idx in 0..elements.len() {
1620                let mut tampered = elements.clone();
1621                tampered[flip_idx][0] ^= 0xFF;
1622                assert!(
1623                    !proof.verify_range_inclusion(&hasher, &tampered, Location::new(start), &root,),
1624                    "n={n}: tampered element at index {flip_idx} should not verify"
1625                );
1626            }
1627        }
1628    }
1629
1630    fn multi_proof_generation_and_verify_raw<F: Family>() {
1631        let hasher = H::new();
1632        let mem = build_raw::<F>(&hasher, 20);
1633        let root = *mem.root();
1634
1635        let locations = &[Location::new(0), Location::new(5), Location::new(10)];
1636        let nodes =
1637            nodes_required_for_multi_proof(mem.leaves(), locations).expect("valid locations");
1638        let digests = nodes
1639            .into_iter()
1640            .map(|pos| mem.get_node(pos).unwrap())
1641            .collect();
1642        let multi_proof = Proof {
1643            leaves: mem.leaves(),
1644            digests,
1645        };
1646
1647        // Verify the proof.
1648        assert!(multi_proof.verify_multi_inclusion(
1649            &hasher,
1650            &[
1651                (0u64.to_be_bytes(), Location::new(0)),
1652                (5u64.to_be_bytes(), Location::new(5)),
1653                (10u64.to_be_bytes(), Location::new(10)),
1654            ],
1655            &root
1656        ));
1657
1658        // Different order should also verify.
1659        assert!(multi_proof.verify_multi_inclusion(
1660            &hasher,
1661            &[
1662                (10u64.to_be_bytes(), Location::new(10)),
1663                (5u64.to_be_bytes(), Location::new(5)),
1664                (0u64.to_be_bytes(), Location::new(0)),
1665            ],
1666            &root
1667        ));
1668
1669        // Wrong elements should fail.
1670        assert!(!multi_proof.verify_multi_inclusion(
1671            &hasher,
1672            &[
1673                (99u64.to_be_bytes(), Location::new(0)),
1674                (5u64.to_be_bytes(), Location::new(5)),
1675                (10u64.to_be_bytes(), Location::new(10)),
1676            ],
1677            &root
1678        ));
1679
1680        // Wrong root should fail.
1681        let wrong_root = hasher.digest(b"wrong");
1682        assert!(!multi_proof.verify_multi_inclusion(
1683            &hasher,
1684            &[
1685                (0u64.to_be_bytes(), Location::new(0)),
1686                (5u64.to_be_bytes(), Location::new(5)),
1687                (10u64.to_be_bytes(), Location::new(10)),
1688            ],
1689            &wrong_root
1690        ));
1691
1692        // Empty multi-proof on empty tree.
1693        let hasher2 = H::new();
1694        let empty_mem = Mem::<F, D>::new(&hasher2);
1695        let empty_proof: Proof<F, D> = Proof::default();
1696        assert!(empty_proof.verify_multi_inclusion(
1697            &hasher2,
1698            &[] as &[([u8; 8], Location<F>)],
1699            empty_mem.root()
1700        ));
1701
1702        // Malformed empty proof with extra digests must be rejected.
1703        let malformed_proof: Proof<F, D> = Proof {
1704            leaves: Location::new(0),
1705            digests: vec![test_digest(0)],
1706        };
1707        assert!(!malformed_proof.verify_multi_inclusion(
1708            &hasher2,
1709            &[] as &[([u8; 8], Location<F>)],
1710            empty_mem.root()
1711        ));
1712    }
1713
1714    fn tampered_proof_digests_rejected<F: Family>() {
1715        for n in [8u64, 13, 20, 32] {
1716            let hasher = H::new();
1717            let mem = build_raw::<F>(&hasher, n);
1718            let root = *mem.root();
1719
1720            for loc_idx in [0, n / 2, n - 1] {
1721                let proof = mem.proof(&hasher, Location::new(loc_idx)).unwrap();
1722                let element = loc_idx.to_be_bytes();
1723                let loc = Location::new(loc_idx);
1724
1725                assert!(proof.verify_element_inclusion(&hasher, &element, loc, &root));
1726
1727                for digest_idx in 0..proof.digests.len() {
1728                    let mut tampered = proof.clone();
1729                    tampered.digests[digest_idx].0[0] ^= 1;
1730                    assert!(
1731                        !tampered.verify_element_inclusion(&hasher, &element, loc, &root),
1732                        "n={n}, loc={loc_idx}: tampered digest[{digest_idx}] should not verify"
1733                    );
1734                }
1735            }
1736        }
1737    }
1738
1739    fn no_duplicate_positions<F: Family>() {
1740        use alloc::collections::BTreeSet;
1741        for n in 1u64..=64 {
1742            let hasher = H::new();
1743            let mem = build_raw::<F>(&hasher, n);
1744            let leaves = mem.leaves();
1745            for loc in 0..n {
1746                let loc = Location::new(loc);
1747                let bp = Blueprint::<F>::new(leaves, loc..loc + 1).unwrap();
1748                let mut positions: Vec<Position<F>> = Vec::new();
1749                positions.extend(&bp.fold_prefix);
1750                positions.extend(&bp.fetch_nodes);
1751                let set: BTreeSet<_> = positions.iter().copied().collect();
1752                assert_eq!(
1753                    positions.len(),
1754                    set.len(),
1755                    "n={n}, loc={loc}: duplicate positions"
1756                );
1757            }
1758        }
1759    }
1760
1761    // ---------------------------------------------------------------------------
1762    // MMR tests
1763    // ---------------------------------------------------------------------------
1764
1765    #[test]
1766    fn mmr_empty_proof() {
1767        empty_proof::<mmr::Family>();
1768    }
1769    #[test]
1770    fn mmr_verify_element() {
1771        verify_element::<mmr::Family>();
1772    }
1773    #[test]
1774    fn mmr_verify_range() {
1775        verify_range::<mmr::Family>();
1776    }
1777    #[test_traced]
1778    fn mmr_retained_nodes_provable_after_pruning() {
1779        retained_nodes_provable_after_pruning::<mmr::Family>();
1780    }
1781    #[test]
1782    fn mmr_ranges_provable_after_pruning() {
1783        ranges_provable_after_pruning::<mmr::Family>();
1784    }
1785    #[test]
1786    fn mmr_proof_serialization() {
1787        proof_serialization::<mmr::Family>();
1788    }
1789    #[test]
1790    fn mmr_multi_proof_generation_and_verify() {
1791        multi_proof_generation_and_verify::<mmr::Family>();
1792    }
1793    #[test]
1794    fn mmr_multi_proof_deduplication() {
1795        multi_proof_deduplication::<mmr::Family>();
1796    }
1797    #[test]
1798    fn mmr_proof_leaves_malleability() {
1799        proof_leaves_malleability::<mmr::Family>();
1800    }
1801    #[test]
1802    fn mmr_blueprint_errors() {
1803        blueprint_errors::<mmr::Family>();
1804    }
1805    #[test]
1806    fn mmr_single_element_proof_reconstruction() {
1807        single_element_proof_reconstruction::<mmr::Family>();
1808    }
1809    #[test]
1810    fn mmr_range_proof_reconstruction() {
1811        range_proof_reconstruction::<mmr::Family>();
1812    }
1813    #[test]
1814    fn mmr_verify_element_inclusion() {
1815        verify_element_inclusion::<mmr::Family>();
1816    }
1817    #[test]
1818    fn mmr_full_range() {
1819        full_range::<mmr::Family>();
1820    }
1821    #[test]
1822    fn mmr_empty_proof_verifies_empty_tree() {
1823        empty_proof_verifies_empty_tree::<mmr::Family>();
1824    }
1825    #[test]
1826    fn mmr_every_element_contributes_to_root() {
1827        every_element_contributes_to_root::<mmr::Family>();
1828    }
1829    #[test]
1830    fn mmr_multi_proof_generation_and_verify_raw() {
1831        multi_proof_generation_and_verify_raw::<mmr::Family>();
1832    }
1833    #[test]
1834    fn mmr_tampered_proof_digests_rejected() {
1835        tampered_proof_digests_rejected::<mmr::Family>();
1836    }
1837    #[test]
1838    fn mmr_no_duplicate_positions() {
1839        no_duplicate_positions::<mmr::Family>();
1840    }
1841
1842    // ---------------------------------------------------------------------------
1843    // MMB tests
1844    // ---------------------------------------------------------------------------
1845
1846    #[test]
1847    fn mmb_empty_proof() {
1848        empty_proof::<mmb::Family>();
1849    }
1850    #[test]
1851    fn mmb_verify_element() {
1852        verify_element::<mmb::Family>();
1853    }
1854    #[test]
1855    fn mmb_verify_range() {
1856        verify_range::<mmb::Family>();
1857    }
1858    #[test_traced]
1859    fn mmb_retained_nodes_provable_after_pruning() {
1860        retained_nodes_provable_after_pruning::<mmb::Family>();
1861    }
1862    #[test]
1863    fn mmb_ranges_provable_after_pruning() {
1864        ranges_provable_after_pruning::<mmb::Family>();
1865    }
1866    #[test]
1867    fn mmb_proof_serialization() {
1868        proof_serialization::<mmb::Family>();
1869    }
1870    #[test]
1871    fn mmb_multi_proof_generation_and_verify() {
1872        multi_proof_generation_and_verify::<mmb::Family>();
1873    }
1874    #[test]
1875    fn mmb_multi_proof_deduplication() {
1876        multi_proof_deduplication::<mmb::Family>();
1877    }
1878    #[test]
1879    fn mmb_proof_leaves_malleability() {
1880        proof_leaves_malleability::<mmb::Family>();
1881    }
1882    #[test]
1883    fn mmb_blueprint_errors() {
1884        blueprint_errors::<mmb::Family>();
1885    }
1886    #[test]
1887    fn mmb_single_element_proof_reconstruction() {
1888        single_element_proof_reconstruction::<mmb::Family>();
1889    }
1890    #[test]
1891    fn mmb_range_proof_reconstruction() {
1892        range_proof_reconstruction::<mmb::Family>();
1893    }
1894    #[test]
1895    fn mmb_verify_element_inclusion() {
1896        verify_element_inclusion::<mmb::Family>();
1897    }
1898    #[test]
1899    fn mmb_full_range() {
1900        full_range::<mmb::Family>();
1901    }
1902    #[test]
1903    fn mmb_empty_proof_verifies_empty_tree() {
1904        empty_proof_verifies_empty_tree::<mmb::Family>();
1905    }
1906    #[test]
1907    fn mmb_every_element_contributes_to_root() {
1908        every_element_contributes_to_root::<mmb::Family>();
1909    }
1910    #[test]
1911    fn mmb_multi_proof_generation_and_verify_raw() {
1912        multi_proof_generation_and_verify_raw::<mmb::Family>();
1913    }
1914    #[test]
1915    fn mmb_tampered_proof_digests_rejected() {
1916        tampered_proof_digests_rejected::<mmb::Family>();
1917    }
1918    #[test]
1919    fn mmb_no_duplicate_positions() {
1920        no_duplicate_positions::<mmb::Family>();
1921    }
1922}