commonware_storage/mmr/
verification.rs

1//! Defines the inclusion [Proof] structure, functions for generating them from any MMR implementing
2//! the [Storage] trait, and functions for verifying them against a root digest.
3//!
4//! ## Historical Proof Generation
5//!
6//! This module provides both current and historical proof generation capabilities:
7//! - [Proof::range_proof] generates proofs against the current MMR state
8//! - [Proof::historical_range_proof] generates proofs against historical MMR states
9//!
10//! Historical proofs are essential for sync operations where we need to prove elements
11//! against a past state of the MMR rather than its current state.
12
13use crate::mmr::{
14    iterator::{leaf_num_to_pos, leaf_pos_to_num, PathIterator, PeakIterator},
15    storage::Storage,
16    Error, Hasher,
17};
18use bytes::{Buf, BufMut};
19use commonware_codec::{varint::UInt, EncodeSize, Read, ReadExt, ReadRangeExt, Write};
20use commonware_cryptography::{Digest, Hasher as CHasher};
21use futures::future::try_join_all;
22use std::collections::{BTreeSet, HashMap};
23use tracing::debug;
24
25/// Errors that can occur when reconstructing a digest from a proof due to invalid input.
26#[derive(Error, Debug)]
27pub(crate) enum ReconstructionError {
28    #[error("missing digests in proof")]
29    MissingDigests,
30    #[error("extra digests in proof")]
31    ExtraDigests,
32    #[error("start position is not a leaf")]
33    InvalidStartPos,
34    #[error("end position exceeds MMR size")]
35    InvalidEndPos,
36    #[error("missing elements")]
37    MissingElements,
38}
39
40/// A store derived from a [Proof] that can be used to generate proofs over any sub-range of the
41/// original range.
42pub struct ProofStore<D> {
43    digests: HashMap<u64, D>,
44    size: u64,
45}
46
47impl<D: Digest> ProofStore<D> {
48    /// Create a new [ProofStore] from a valid [Proof] over the given range of elements. The
49    /// resulting store can be used to generate proofs over any sub-range of the original range.
50    /// Returns an error if the proof is invalid or could not be verified against the given root.
51    pub fn new<I, H, E>(
52        hasher: &mut H,
53        proof: &Proof<D>,
54        elements: &[E],
55        start_element_pos: u64,
56        root: &D,
57    ) -> Result<Self, Error>
58    where
59        I: CHasher<Digest = D>,
60        H: Hasher<I>,
61        E: AsRef<[u8]>,
62    {
63        let digests = proof.verify_range_inclusion_and_extract_digests(
64            hasher,
65            elements,
66            start_element_pos,
67            root,
68        )?;
69
70        Ok(ProofStore::new_from_digests(proof.size, digests))
71    }
72
73    /// Create a new [ProofStore] from the result of calling
74    /// [Proof::verify_range_inclusion_and_extract_digests]. The resulting store can be used to
75    /// generate proofs over any sub-range of the original range.
76    pub fn new_from_digests(size: u64, digests: Vec<(u64, D)>) -> Self {
77        Self {
78            size,
79            digests: digests.into_iter().collect(),
80        }
81    }
82}
83
84impl<D: Digest> Storage<D> for ProofStore<D> {
85    async fn get_node(&self, pos: u64) -> Result<Option<D>, Error> {
86        Ok(self.digests.get(&pos).cloned())
87    }
88
89    fn size(&self) -> u64 {
90        self.size
91    }
92}
93
94/// Contains the information necessary for proving the inclusion of an element, or some range of
95/// elements, in the MMR from its root digest.
96///
97/// The `digests` vector contains:
98///
99/// 1: the digests of each peak corresponding to a mountain containing no elements from the element
100/// range being proven in decreasing order of height, followed by:
101///
102/// 2: the nodes in the remaining mountains necessary for reconstructing their peak digests from the
103/// elements within the range, ordered by the position of their parent.
104#[derive(Clone, Debug, Eq)]
105pub struct Proof<D: Digest> {
106    /// The total number of nodes in the MMR.
107    pub size: u64,
108    /// The digests necessary for proving the inclusion of an element, or range of elements, in the
109    /// MMR.
110    pub digests: Vec<D>,
111}
112
113impl<D: Digest> PartialEq for Proof<D> {
114    fn eq(&self, other: &Self) -> bool {
115        self.size == other.size && self.digests == other.digests
116    }
117}
118
119impl<D: Digest> EncodeSize for Proof<D> {
120    fn encode_size(&self) -> usize {
121        UInt(self.size).encode_size() + self.digests.encode_size()
122    }
123}
124
125impl<D: Digest> Write for Proof<D> {
126    fn write(&self, buf: &mut impl BufMut) {
127        // Write the number of nodes in the MMR as a varint
128        UInt(self.size).write(buf);
129
130        // Write the digests
131        self.digests.write(buf);
132    }
133}
134
135impl<D: Digest> Read for Proof<D> {
136    /// The maximum number of digests in the proof.
137    type Cfg = usize;
138
139    fn read_cfg(buf: &mut impl Buf, max_len: &Self::Cfg) -> Result<Self, commonware_codec::Error> {
140        // Read the number of nodes in the MMR
141        let size = UInt::<u64>::read(buf)?.into();
142
143        // Read the digests
144        let range = ..=max_len;
145        let digests = Vec::<D>::read_range(buf, range)?;
146        Ok(Proof { size, digests })
147    }
148}
149
150impl<D: Digest> Default for Proof<D> {
151    /// Create an empty proof. The empty proof will verify only against the root digest of an empty
152    /// (`size == 0`) MMR.
153    fn default() -> Self {
154        Self {
155            size: 0,
156            digests: vec![],
157        }
158    }
159}
160
161impl<D: Digest> Proof<D> {
162    /// Return true if `proof` proves that `element` appears at position `element_pos` within the
163    /// MMR with root digest `root`.
164    pub fn verify_element_inclusion<I, H>(
165        &self,
166        hasher: &mut H,
167        element: &[u8],
168        element_pos: u64,
169        root: &D,
170    ) -> bool
171    where
172        I: CHasher<Digest = D>,
173        H: Hasher<I>,
174    {
175        self.verify_range_inclusion(hasher, &[element], element_pos, root)
176    }
177
178    /// Return true if `proof` proves that the `elements` appear consecutively starting at position
179    /// `start_element_pos` within the MMR with root digest `root`.
180    pub fn verify_range_inclusion<I, H, E>(
181        &self,
182        hasher: &mut H,
183        elements: &[E],
184        start_element_pos: u64,
185        root: &D,
186    ) -> bool
187    where
188        I: CHasher<Digest = D>,
189        H: Hasher<I>,
190        E: AsRef<[u8]>,
191    {
192        match self.reconstruct_root(hasher, elements, start_element_pos) {
193            Ok(reconstructed_root) => *root == reconstructed_root,
194            Err(error) => {
195                debug!(error = ?error, "invalid proof input");
196                false
197            }
198        }
199    }
200
201    /// Reconstructs the root digest of the MMR from the digests in the proof and the provided range
202    /// of elements, returning the (position,digest) of every node whose digest was required by the
203    /// process (including those from the proof itself). Returns a [Error::InvalidProof] if the
204    /// input data is invalid and [Error::RootMismatch] if the root does not match the computed
205    /// root.
206    pub fn verify_range_inclusion_and_extract_digests<I, H, E>(
207        &self,
208        hasher: &mut H,
209        elements: &[E],
210        start_element_pos: u64,
211        root: &D,
212    ) -> Result<Vec<(u64, D)>, Error>
213    where
214        I: CHasher<Digest = D>,
215        H: Hasher<I>,
216        E: AsRef<[u8]>,
217    {
218        let mut collected_digests = Vec::new();
219        let Ok(peak_digests) = self.reconstruct_peak_digests(
220            hasher,
221            elements,
222            start_element_pos,
223            Some(&mut collected_digests),
224        ) else {
225            return Err(Error::InvalidProof);
226        };
227
228        if hasher.root(self.size, peak_digests.iter()) != *root {
229            return Err(Error::RootMismatch);
230        }
231
232        Ok(collected_digests)
233    }
234
235    /// Reconstructs the root digest of the MMR from the digests in the proof and the provided range
236    /// of elements, or returns a [ReconstructionError] if the input data is invalid.
237    pub(crate) fn reconstruct_root<I, H, E>(
238        &self,
239        hasher: &mut H,
240        elements: &[E],
241        start_element_pos: u64,
242    ) -> Result<D, ReconstructionError>
243    where
244        I: CHasher<Digest = D>,
245        H: Hasher<I>,
246        E: AsRef<[u8]>,
247    {
248        let peak_digests =
249            self.reconstruct_peak_digests(hasher, elements, start_element_pos, None)?;
250
251        Ok(hasher.root(self.size, peak_digests.iter()))
252    }
253
254    /// Reconstruct the peak digests of the MMR that produced this proof, returning
255    /// [ReconstructionError] if the input data is invalid.  If collected_digests is Some, then all
256    /// node digests used in the process will be added to the wrapped vector.
257    fn reconstruct_peak_digests<I, H, E>(
258        &self,
259        hasher: &mut H,
260        elements: &[E],
261        start_element_pos: u64,
262        mut collected_digests: Option<&mut Vec<(u64, D)>>,
263    ) -> Result<Vec<D>, ReconstructionError>
264    where
265        I: CHasher<Digest = D>,
266        H: Hasher<I>,
267        E: AsRef<[u8]>,
268    {
269        if elements.is_empty() {
270            if start_element_pos == 0 {
271                return Ok(vec![]);
272            }
273            return Err(ReconstructionError::MissingElements);
274        }
275        let Some(start_leaf) = leaf_pos_to_num(start_element_pos) else {
276            debug!(pos = start_element_pos, "start pos is not a leaf");
277            return Err(ReconstructionError::InvalidStartPos);
278        };
279        let end_element_pos = if elements.len() == 1 {
280            start_element_pos
281        } else {
282            leaf_num_to_pos(start_leaf + elements.len() as u64 - 1)
283        };
284        if end_element_pos >= self.size {
285            return Err(ReconstructionError::InvalidEndPos);
286        }
287
288        let mut proof_digests_iter = self.digests.iter();
289        let mut siblings_iter = self.digests.iter().rev();
290
291        // Include peak digests only for trees that have no elements from the range, and keep track
292        // of the starting and ending trees of those that do contain some.
293        let mut peak_digests: Vec<D> = Vec::new();
294        let mut proof_digests_used = 0;
295        let mut elements_iter = elements.iter();
296        for (peak_pos, height) in PeakIterator::new(self.size) {
297            let leftmost_pos = peak_pos + 2 - (1 << (height + 1));
298            if peak_pos >= start_element_pos && leftmost_pos <= end_element_pos {
299                let hash = peak_digest_from_range(
300                    hasher,
301                    RangeInfo {
302                        pos: peak_pos,
303                        two_h: 1 << height,
304                        leftmost_pos: start_element_pos,
305                        rightmost_pos: end_element_pos,
306                    },
307                    &mut elements_iter,
308                    &mut siblings_iter,
309                    collected_digests.as_deref_mut(),
310                )?;
311                peak_digests.push(hash);
312                if let Some(ref mut collected_digests) = collected_digests {
313                    collected_digests.push((peak_pos, hash));
314                }
315            } else if let Some(hash) = proof_digests_iter.next() {
316                proof_digests_used += 1;
317                peak_digests.push(*hash);
318                if let Some(ref mut collected_digests) = collected_digests {
319                    collected_digests.push((peak_pos, *hash));
320                }
321            } else {
322                return Err(ReconstructionError::MissingDigests);
323            }
324        }
325
326        if elements_iter.next().is_some() {
327            return Err(ReconstructionError::ExtraDigests);
328        }
329        if let Some(next_sibling) = siblings_iter.next() {
330            if proof_digests_used == 0 || *next_sibling != self.digests[proof_digests_used - 1] {
331                return Err(ReconstructionError::ExtraDigests);
332            }
333        }
334
335        Ok(peak_digests)
336    }
337
338    /// Return the list of pruned (pos < `start_pos`) node positions that are still required for
339    /// proving any retained node.
340    ///
341    /// This set consists of every pruned node that is either (1) a peak, or (2) has no descendent
342    /// in the retained section, but its immediate parent does. (A node meeting condition (2) can be
343    /// shown to always be the left-child of its parent.)
344    ///
345    /// This set of nodes does not change with the MMR's size, only the pruning boundary. For a
346    /// given pruning boundary that happens to be a valid MMR size, one can prove that this set is
347    /// exactly the set of peaks for an MMR whose size equals the pruning boundary. If the pruning
348    /// boundary is not a valid MMR size, then the set corresponds to the peaks of the largest MMR
349    /// whose size is less than the pruning boundary.
350    pub fn nodes_to_pin(start_pos: u64) -> impl Iterator<Item = u64> {
351        PeakIterator::new(PeakIterator::to_nearest_size(start_pos)).map(|(pos, _)| pos)
352    }
353
354    /// Return the list of node positions required by the range proof for the specified range of
355    /// elements, inclusive of both endpoints.
356    pub fn nodes_required_for_range_proof(
357        size: u64,
358        start_element_pos: u64,
359        end_element_pos: u64,
360    ) -> Vec<u64> {
361        let mut positions = Vec::new();
362
363        // Find the mountains that contain no elements from the range. The peaks of these mountains
364        // are required to prove the range, so they are added to the result.
365        let mut start_tree_with_element = (u64::MAX, 0);
366        let mut end_tree_with_element = (u64::MAX, 0);
367        let mut peak_iterator = PeakIterator::new(size);
368        while let Some(peak) = peak_iterator.next() {
369            if start_tree_with_element.0 == u64::MAX && peak.0 >= start_element_pos {
370                // Found the first tree to contain an element in the range
371                start_tree_with_element = peak;
372                if peak.0 >= end_element_pos {
373                    // Start and end tree are the same
374                    end_tree_with_element = peak;
375                    continue;
376                }
377                for peak in peak_iterator.by_ref() {
378                    if peak.0 >= end_element_pos {
379                        // Found the last tree to contain an element in the range
380                        end_tree_with_element = peak;
381                        break;
382                    }
383                }
384            } else {
385                // Tree is outside the range, its peak is thus required.
386                positions.push(peak.0);
387            }
388        }
389        assert!(start_tree_with_element.0 != u64::MAX);
390        assert!(end_tree_with_element.0 != u64::MAX);
391
392        // Include the positions of any left-siblings of each node on the path from peak to
393        // leftmost-leaf, and right-siblings for the path from peak to rightmost-leaf. These are
394        // added in order of decreasing parent position.
395        let left_path_iter = PathIterator::new(
396            start_element_pos,
397            start_tree_with_element.0,
398            start_tree_with_element.1,
399        );
400
401        let mut siblings = Vec::new();
402        if start_element_pos == end_element_pos {
403            // For the (common) case of a single element range, the right and left path are the
404            // same so no need to process each independently.
405            siblings.extend(left_path_iter);
406        } else {
407            let right_path_iter = PathIterator::new(
408                end_element_pos,
409                end_tree_with_element.0,
410                end_tree_with_element.1,
411            );
412            // filter the right path for right siblings only
413            siblings.extend(right_path_iter.filter(|(parent_pos, pos)| *parent_pos == *pos + 1));
414            // filter the left path for left siblings only
415            siblings.extend(left_path_iter.filter(|(parent_pos, pos)| *parent_pos != *pos + 1));
416
417            // If the range spans more than one tree, then the digests must already be in the correct
418            // order. Otherwise, we enforce the desired order through sorting.
419            if start_tree_with_element.0 == end_tree_with_element.0 {
420                siblings.sort_by(|a, b| b.0.cmp(&a.0));
421            }
422        }
423        positions.extend(siblings.into_iter().map(|(_, pos)| pos));
424        positions
425    }
426
427    /// Return an inclusion proof for the specified range of elements, inclusive of both endpoints.
428    /// Returns ElementPruned error if some element needed to generate the proof has been pruned.
429    pub async fn range_proof<S: Storage<D>>(
430        mmr: &S,
431        start_element_pos: u64,
432        end_element_pos: u64,
433    ) -> Result<Proof<D>, Error> {
434        Self::historical_range_proof(mmr, mmr.size(), start_element_pos, end_element_pos).await
435    }
436
437    /// Analogous to range_proof but for a previous database state.
438    /// Specifically, the state when the MMR had `size` elements.
439    pub async fn historical_range_proof<S: Storage<D>>(
440        mmr: &S,
441        size: u64,
442        start_element_pos: u64,
443        end_element_pos: u64,
444    ) -> Result<Proof<D>, Error> {
445        assert!(start_element_pos <= end_element_pos);
446        assert!(start_element_pos < mmr.size());
447        assert!(end_element_pos < mmr.size());
448
449        let mut digests: Vec<D> = Vec::new();
450        let positions =
451            Self::nodes_required_for_range_proof(size, start_element_pos, end_element_pos);
452
453        let node_futures = positions.iter().map(|pos| mmr.get_node(*pos));
454        let hash_results = try_join_all(node_futures).await?;
455
456        for (i, hash_result) in hash_results.into_iter().enumerate() {
457            match hash_result {
458                Some(hash) => digests.push(hash),
459                None => return Err(Error::ElementPruned(positions[i])),
460            };
461        }
462
463        Ok(Proof { size, digests })
464    }
465
466    /// Return an inclusion proof for the specified positions. This is analogous to range_proof
467    /// but supports non-contiguous positions.
468    ///
469    /// The order of positions does not affect the output (sorted internally).
470    pub async fn multi_proof<S: Storage<D>>(mmr: &S, positions: &[u64]) -> Result<Proof<D>, Error> {
471        // If there are no positions, return an empty proof
472        if positions.is_empty() {
473            return Ok(Proof {
474                size: mmr.size(),
475                digests: vec![],
476            });
477        }
478
479        // Collect all required node positions
480        //
481        // TODO(#1472): Optimize this loop
482        let size = mmr.size();
483        let node_positions: BTreeSet<_> = positions
484            .iter()
485            .flat_map(|pos| Self::nodes_required_for_range_proof(size, *pos, *pos))
486            .collect();
487
488        // Fetch all required digests in parallel and collect with positions
489        let node_futures: Vec<_> = node_positions
490            .iter()
491            .map(|&pos| async move { mmr.get_node(pos).await.map(|digest| (pos, digest)) })
492            .collect();
493        let results = try_join_all(node_futures).await?;
494
495        // Build the proof, returning error with correct position on pruned nodes
496        let mut digests = Vec::with_capacity(results.len());
497        for (pos, digest) in results {
498            match digest {
499                Some(digest) => digests.push(digest),
500                None => return Err(Error::ElementPruned(pos)),
501            }
502        }
503
504        Ok(Proof { size, digests })
505    }
506
507    /// Return true if `proof` proves that the elements at the specified positions are included in the MMR
508    /// with the root digest `root`.
509    ///
510    /// The order of the elements does not affect the output.
511    pub fn verify_multi_inclusion<I, H, E>(
512        &self,
513        hasher: &mut H,
514        elements: &[(E, u64)],
515        root: &D,
516    ) -> bool
517    where
518        I: CHasher<Digest = D>,
519        H: Hasher<I>,
520        E: AsRef<[u8]>,
521    {
522        // Empty proof is valid for an empty MMR
523        if elements.is_empty() {
524            return self.size == 0 && *root == hasher.root(0, std::iter::empty());
525        }
526
527        // Single pass to collect all required positions with deduplication
528        let mut node_positions = BTreeSet::new();
529        let mut nodes_required = HashMap::new();
530        for (_, pos) in elements {
531            let required = Self::nodes_required_for_range_proof(self.size, *pos, *pos);
532            for req_pos in &required {
533                node_positions.insert(*req_pos);
534            }
535            nodes_required.insert(*pos, required);
536        }
537
538        // Verify we have the exact number of digests needed
539        if node_positions.len() != self.digests.len() {
540            return false;
541        }
542
543        // Build position to digest mapping once
544        let node_digests: HashMap<u64, D> = node_positions
545            .iter()
546            .zip(self.digests.iter())
547            .map(|(&pos, digest)| (pos, *digest))
548            .collect();
549
550        // Verify each element by reconstructing its path
551        for (element, pos) in elements {
552            // Get required positions for this element
553            let required = &nodes_required[pos];
554
555            // Build proof with required digests
556            let mut digests = Vec::with_capacity(required.len());
557            for req_pos in required {
558                // There must exist a digest for each required position (by
559                // construction of `node_digests`)
560                let digest = node_digests
561                    .get(req_pos)
562                    .expect("missing digest for required position");
563                digests.push(*digest);
564            }
565            let proof = Proof {
566                size: self.size,
567                digests,
568            };
569
570            // Verify the proof
571            if !proof.verify_element_inclusion(hasher, element.as_ref(), *pos, root) {
572                return false;
573            }
574        }
575
576        true
577    }
578
579    /// Extract the hashes of all nodes that should be pinned at the given pruning boundary
580    /// from a proof that proves a range starting at that boundary.
581    ///
582    /// # Arguments
583    /// * `start_element_pos` - Start of the proven range (must equal pruning_boundary)
584    /// * `end_element_pos` - End of the proven range
585    ///
586    /// # Returns
587    /// A Vec of digest values for all nodes in `nodes_to_pin(pruning_boundary)`,
588    /// in the same order as returned by `nodes_to_pin` (decreasing height order)
589    pub fn extract_pinned_nodes(
590        &self,
591        start_element_pos: u64,
592        end_element_pos: u64,
593    ) -> Result<Vec<D>, Error> {
594        // Get the positions of all nodes that should be pinned.
595        let pinned_positions: Vec<u64> = Self::nodes_to_pin(start_element_pos).collect();
596
597        // Get all positions required for the proof.
598        let required_positions =
599            Self::nodes_required_for_range_proof(self.size, start_element_pos, end_element_pos);
600        if required_positions.len() != self.digests.len() {
601            debug!(
602                digests_len = self.digests.len(),
603                required_positions_len = required_positions.len(),
604                "Proof digest count doesn't match required positions",
605            );
606            return Err(Error::InvalidProofLength);
607        }
608
609        // Happy path: we can extract the pinned nodes directly from the proof.
610        // This happens when the `end_element_pos` is the last element in the MMR.
611        if pinned_positions
612            == required_positions[required_positions.len() - pinned_positions.len()..]
613        {
614            return Ok(self.digests[required_positions.len() - pinned_positions.len()..].to_vec());
615        }
616
617        // Create a mapping from position to digest.
618        let position_to_digest: HashMap<u64, D> = required_positions
619            .iter()
620            .zip(self.digests.iter())
621            .map(|(&pos, &digest)| (pos, digest))
622            .collect();
623
624        // Extract the pinned nodes in the same order as nodes_to_pin.
625        let mut result = Vec::with_capacity(pinned_positions.len());
626        for pinned_pos in pinned_positions {
627            let Some(&digest) = position_to_digest.get(&pinned_pos) else {
628                debug!(pinned_pos, "Pinned node not found in proof");
629                return Err(Error::MissingDigest(pinned_pos));
630            };
631            result.push(digest);
632        }
633        Ok(result)
634    }
635}
636
637/// Information about the current range of nodes being traversed.
638struct RangeInfo {
639    pos: u64,           // current node position in the tree
640    two_h: u64,         // 2^height of the current node
641    leftmost_pos: u64,  // leftmost leaf in the tree to be traversed
642    rightmost_pos: u64, // rightmost leaf in the tree to be traversed
643}
644
645fn peak_digest_from_range<'a, I, H, E, S>(
646    hasher: &mut H,
647    range_info: RangeInfo,
648    elements: &mut E,
649    sibling_digests: &mut S,
650    mut collected_digests: Option<&mut Vec<(u64, I::Digest)>>,
651) -> Result<I::Digest, ReconstructionError>
652where
653    I: CHasher,
654    H: Hasher<I>,
655    E: Iterator<Item: AsRef<[u8]>>,
656    S: Iterator<Item = &'a I::Digest>,
657{
658    assert_ne!(range_info.two_h, 0);
659    if range_info.two_h == 1 {
660        match elements.next() {
661            Some(element) => return Ok(hasher.leaf_digest(range_info.pos, element.as_ref())),
662            None => return Err(ReconstructionError::MissingDigests),
663        }
664    }
665
666    let mut left_digest: Option<I::Digest> = None;
667    let mut right_digest: Option<I::Digest> = None;
668
669    let left_pos = range_info.pos - range_info.two_h;
670    let right_pos = left_pos + range_info.two_h - 1;
671    if left_pos >= range_info.leftmost_pos {
672        // Descend left
673        let digest = peak_digest_from_range(
674            hasher,
675            RangeInfo {
676                pos: left_pos,
677                two_h: range_info.two_h >> 1,
678                leftmost_pos: range_info.leftmost_pos,
679                rightmost_pos: range_info.rightmost_pos,
680            },
681            elements,
682            sibling_digests,
683            collected_digests.as_deref_mut(),
684        )?;
685        left_digest = Some(digest);
686    }
687    if left_pos < range_info.rightmost_pos {
688        // Descend right
689        let digest = peak_digest_from_range(
690            hasher,
691            RangeInfo {
692                pos: right_pos,
693                two_h: range_info.two_h >> 1,
694                leftmost_pos: range_info.leftmost_pos,
695                rightmost_pos: range_info.rightmost_pos,
696            },
697            elements,
698            sibling_digests,
699            collected_digests.as_deref_mut(),
700        )?;
701        right_digest = Some(digest);
702    }
703
704    if left_digest.is_none() {
705        match sibling_digests.next() {
706            Some(hash) => left_digest = Some(*hash),
707            None => return Err(ReconstructionError::MissingDigests),
708        }
709    }
710    if right_digest.is_none() {
711        match sibling_digests.next() {
712            Some(hash) => right_digest = Some(*hash),
713            None => return Err(ReconstructionError::MissingDigests),
714        }
715    }
716
717    if let Some(ref mut collected_digests) = collected_digests {
718        collected_digests.push((left_pos, left_digest.unwrap()));
719        collected_digests.push((right_pos, right_digest.unwrap()));
720    }
721
722    Ok(hasher.node_digest(
723        range_info.pos,
724        &left_digest.unwrap(),
725        &right_digest.unwrap(),
726    ))
727}
728
729#[cfg(test)]
730mod tests {
731    use super::*;
732    use crate::mmr::{hasher::Standard, mem::Mmr};
733    use bytes::Bytes;
734    use commonware_codec::{Decode, Encode};
735    use commonware_cryptography::{sha256::Digest, Sha256};
736    use commonware_macros::test_traced;
737    use commonware_runtime::{deterministic, Runner};
738
739    fn test_digest(v: u8) -> Digest {
740        Sha256::hash(&[v])
741    }
742
743    #[test_traced]
744    fn test_verification_empty_proof() {
745        // Test that an empty proof authenticates an empty MMR.
746        let mmr = Mmr::new();
747        let mut hasher: Standard<Sha256> = Standard::new();
748        let root = mmr.root(&mut hasher);
749        let proof = Proof::default();
750        assert!(proof.verify_range_inclusion(&mut hasher, &[] as &[Digest], 0, &root));
751
752        // Any starting position other than 0 should fail to verify.
753        assert!(!proof.verify_range_inclusion(&mut hasher, &[] as &[Digest], 1, &root));
754
755        // Invalid root should fail to verify.
756        let test_digest = test_digest(0);
757        assert!(!proof.verify_range_inclusion(&mut hasher, &[] as &[Digest], 0, &test_digest));
758
759        // Non-empty elements list should fail to verify.
760        assert!(!proof.verify_range_inclusion(&mut hasher, &[test_digest], 0, &root));
761    }
762
763    #[test_traced]
764    fn test_verification_verify_element() {
765        let executor = deterministic::Runner::default();
766        executor.start(|_| async move {
767            // create an 11 element MMR over which we'll test single-element inclusion proofs
768            let mut mmr = Mmr::new();
769            let element = Digest::from(*b"01234567012345670123456701234567");
770            let mut leaves: Vec<u64> = Vec::new();
771            let mut hasher: Standard<Sha256> = Standard::new();
772            for _ in 0..11 {
773                leaves.push(mmr.add(&mut hasher, &element));
774            }
775
776            let root = mmr.root(&mut hasher);
777
778            // confirm the proof of inclusion for each leaf successfully verifies
779            for leaf in leaves.iter().by_ref() {
780                let proof: Proof<Digest> = mmr.proof(*leaf).await.unwrap();
781                assert!(
782                    proof.verify_element_inclusion(&mut hasher, &element, *leaf, &root),
783                    "valid proof should verify successfully"
784                );
785            }
786
787            // confirm mangling the proof or proof args results in failed validation
788            const POS: u64 = 18;
789            let proof = mmr.proof(POS).await.unwrap();
790            assert!(
791                proof.verify_element_inclusion(&mut hasher, &element, POS, &root),
792                "proof verification should be successful"
793            );
794            assert!(
795                !proof.verify_element_inclusion(&mut hasher, &element, POS + 1, &root),
796                "proof verification should fail with incorrect element position"
797            );
798            assert!(
799                !proof.verify_element_inclusion(&mut hasher, &element, POS - 1, &root),
800                "proof verification should fail with incorrect element position 2"
801            );
802            assert!(
803                !proof.verify_element_inclusion(&mut hasher, &test_digest(0), POS, &root),
804                "proof verification should fail with mangled element"
805            );
806            let root2 = test_digest(0);
807            assert!(
808                !proof.verify_element_inclusion(&mut hasher, &element, POS, &root2),
809                "proof verification should fail with mangled root"
810            );
811            let mut proof2 = proof.clone();
812            proof2.digests[0] = test_digest(0);
813            assert!(
814                !proof2.verify_element_inclusion(&mut hasher, &element, POS, &root),
815                "proof verification should fail with mangled proof hash"
816            );
817            proof2 = proof.clone();
818            proof2.size = 10;
819            assert!(
820                !proof2.verify_element_inclusion(&mut hasher, &element, POS, &root),
821                "proof verification should fail with incorrect size"
822            );
823            proof2 = proof.clone();
824            proof2.digests.push(test_digest(0));
825            assert!(
826                !proof2.verify_element_inclusion(&mut hasher, &element, POS, &root),
827                "proof verification should fail with extra hash"
828            );
829            proof2 = proof.clone();
830            while !proof2.digests.is_empty() {
831                proof2.digests.pop();
832                assert!(
833                    !proof2.verify_element_inclusion(&mut hasher, &element, 7, &root),
834                    "proof verification should fail with missing digests"
835                );
836            }
837            proof2 = proof.clone();
838            proof2.digests.clear();
839            const PEAK_COUNT: usize = 3;
840            proof2
841                .digests
842                .extend(proof.digests[0..PEAK_COUNT - 1].iter().cloned());
843            // sneak in an extra hash that won't be used in the computation and make sure it's
844            // detected
845            proof2.digests.push(test_digest(0));
846            proof2
847                .digests
848                .extend(proof.digests[PEAK_COUNT - 1..].iter().cloned());
849            assert!(
850                !proof2.verify_element_inclusion(&mut hasher, &element, POS, &root),
851                "proof verification should fail with extra hash even if it's unused by the computation"
852            );
853        });
854    }
855
856    #[test_traced]
857    fn test_verification_verify_range() {
858        let executor = deterministic::Runner::default();
859
860        executor.start(|_| async move {
861            // create a new MMR and add a non-trivial amount (49) of elements
862            let mut mmr = Mmr::default();
863            let mut elements = Vec::new();
864            let mut element_positions = Vec::new();
865            let mut hasher: Standard<Sha256> = Standard::new();
866            for i in 0..49 {
867                elements.push(test_digest(i));
868                element_positions.push(mmr.add(&mut hasher, elements.last().unwrap()));
869            }
870            // test range proofs over all possible ranges of at least 2 elements
871            let root = mmr.root(&mut hasher);
872
873            for i in 0..elements.len() {
874                for j in i + 1..elements.len() {
875                    let start_pos = element_positions[i];
876                    let end_pos = element_positions[j];
877                    let range_proof = mmr.range_proof(start_pos, end_pos).await.unwrap();
878                    assert!(
879                        range_proof.verify_range_inclusion(
880                            &mut hasher,
881                            &elements[i..j + 1],
882                            start_pos,
883                            &root,
884                        ),
885                        "valid range proof should verify successfully {i}:{j}",
886                    );
887                }
888            }
889
890            // create a test range for which we will mangle data and confirm the proof fails
891            let start_index = 33;
892            let end_index = 39;
893            let start_pos = element_positions[start_index];
894            let end_pos = element_positions[end_index];
895            let range_proof = mmr.range_proof(start_pos, end_pos).await.unwrap();
896            let valid_elements = &elements[start_index..end_index + 1];
897            assert!(
898                range_proof.verify_range_inclusion(&mut hasher, valid_elements, start_pos, &root),
899                "valid range proof should verify successfully"
900            );
901            let mut invalid_proof = range_proof.clone();
902            for _i in 0..range_proof.digests.len() {
903                invalid_proof.digests.remove(0);
904                assert!(
905                    !range_proof.verify_range_inclusion(
906                        &mut hasher,
907                        &[] as &[Digest],
908                        start_pos,
909                        &root,
910                    ),
911                    "range proof with removed elements should fail"
912                );
913            }
914            // confirm proof fails with invalid element digests
915            for i in 0..elements.len() {
916                for j in i..elements.len() {
917                    if i == start_index && j == end_index {
918                        // skip the valid range
919                        continue;
920                    }
921                    assert!(
922                        !range_proof.verify_range_inclusion(
923                            &mut hasher,
924                            &elements[i..j + 1],
925                            start_pos,
926                            &root,
927                        ),
928                        "range proof with invalid elements should fail {i}:{j}",
929                    );
930                }
931            }
932            // confirm proof fails with invalid root
933            let invalid_root = test_digest(1);
934            assert!(
935                !range_proof.verify_range_inclusion(
936                    &mut hasher,
937                    valid_elements,
938                    start_pos,
939                    &invalid_root,
940                ),
941                "range proof with invalid root should fail"
942            );
943            // mangle the proof and confirm it fails
944            let mut invalid_proof = range_proof.clone();
945            invalid_proof.digests[1] = test_digest(0);
946            assert!(
947                !invalid_proof.verify_range_inclusion(
948                    &mut hasher,
949                    valid_elements,
950                    start_pos,
951                    &root,
952                ),
953                "mangled range proof should fail verification"
954            );
955            // inserting elements into the proof should also cause it to fail (malleability check)
956            for i in 0..range_proof.digests.len() {
957                let mut invalid_proof = range_proof.clone();
958                invalid_proof.digests.insert(i, test_digest(0));
959                assert!(
960                    !invalid_proof.verify_range_inclusion(
961                        &mut hasher,
962                        valid_elements,
963                        start_pos,
964                        &root,
965                    ),
966                    "mangled range proof should fail verification. inserted element at: {i}",
967                );
968            }
969            // removing proof elements should cause verification to fail
970            let mut invalid_proof = range_proof.clone();
971            for _ in 0..range_proof.digests.len() {
972                invalid_proof.digests.remove(0);
973                assert!(
974                    !invalid_proof.verify_range_inclusion(
975                        &mut hasher,
976                        valid_elements,
977                        start_pos,
978                        &root,
979                    ),
980                    "shortened range proof should fail verification"
981                );
982            }
983            // bad element range should cause verification to fail
984            for (i, _) in elements.iter().enumerate() {
985                for (j, _) in elements.iter().enumerate() {
986                    let start_pos2 = element_positions[i];
987                    if start_pos2 == start_pos {
988                        continue;
989                    }
990                    assert!(
991                        !range_proof.verify_range_inclusion(
992                            &mut hasher,
993                            valid_elements,
994                            start_pos2,
995                            &root,
996                        ),
997                        "bad element range should fail verification {i}:{j}",
998                    );
999                }
1000            }
1001        });
1002    }
1003
1004    #[test_traced]
1005    fn test_verification_retained_nodes_provable_after_pruning() {
1006        let executor = deterministic::Runner::default();
1007        executor.start(|_| async move {
1008            // create a new MMR and add a non-trivial amount (49) of elements
1009            let mut mmr = Mmr::default();
1010            let mut elements = Vec::new();
1011            let mut element_positions = Vec::new();
1012            let mut hasher: Standard<Sha256> = Standard::new();
1013            for i in 0..49 {
1014                elements.push(test_digest(i));
1015                element_positions.push(mmr.add(&mut hasher, elements.last().unwrap()));
1016            }
1017
1018            // Confirm we can successfully prove all retained elements in the MMR after pruning.
1019            let root = mmr.root(&mut hasher);
1020            for i in 1..mmr.size() {
1021                mmr.prune_to_pos(i);
1022                let pruned_root = mmr.root(&mut hasher);
1023                assert_eq!(root, pruned_root);
1024                for (j, pos) in element_positions.iter().enumerate() {
1025                    let proof = mmr.proof(*pos).await;
1026                    if *pos < i {
1027                        assert!(proof.is_err());
1028                    } else {
1029                        assert!(proof.is_ok());
1030                        assert!(proof.unwrap().verify_element_inclusion(
1031                            &mut hasher,
1032                            &elements[j],
1033                            *pos,
1034                            &root
1035                        ));
1036                    }
1037                }
1038            }
1039        });
1040    }
1041
1042    #[test_traced]
1043    fn test_verification_ranges_provable_after_pruning() {
1044        let executor = deterministic::Runner::default();
1045        executor.start(|_| async move {
1046            // create a new MMR and add a non-trivial amount (49) of elements
1047            let mut mmr = Mmr::default();
1048            let mut elements = Vec::new();
1049            let mut element_positions = Vec::new();
1050            let mut hasher: Standard<Sha256> = Standard::new();
1051            for i in 0..49 {
1052                elements.push(test_digest(i));
1053                element_positions.push(mmr.add(&mut hasher, elements.last().unwrap()));
1054            }
1055
1056            // prune up to the first peak
1057            mmr.prune_to_pos(62);
1058            assert_eq!(mmr.oldest_retained_pos().unwrap(), 62);
1059            for i in 0..elements.len() {
1060                if element_positions[i] > 62 {
1061                    elements = elements[i..elements.len()].to_vec();
1062                    element_positions = element_positions[i..element_positions.len()].to_vec();
1063                    break;
1064                }
1065            }
1066
1067            // test range proofs over all possible ranges of at least 2 elements
1068            let root = mmr.root(&mut hasher);
1069            for i in 0..elements.len() {
1070                for j in i + 1..elements.len() {
1071                    let start_pos = element_positions[i];
1072                    let end_pos = element_positions[j];
1073                    let range_proof = mmr.range_proof(start_pos, end_pos).await.unwrap();
1074                    assert!(
1075                        range_proof.verify_range_inclusion(
1076                            &mut hasher,
1077                            &elements[i..j + 1],
1078                            start_pos,
1079                            &root,
1080                        ),
1081                        "valid range proof over remaining elements should verify successfully",
1082                    );
1083                }
1084            }
1085
1086            // add a few more nodes, prune again, and test again to make sure repeated pruning works
1087            for i in 0..37 {
1088                elements.push(test_digest(i));
1089                element_positions.push(mmr.add(&mut hasher, elements.last().unwrap()));
1090            }
1091            mmr.prune_to_pos(130); // a bit after the new highest peak
1092            assert_eq!(mmr.oldest_retained_pos().unwrap(), 130);
1093
1094            let updated_root = mmr.root(&mut hasher);
1095            for i in 0..elements.len() {
1096                if element_positions[i] >= 130 {
1097                    elements = elements[i..elements.len()].to_vec();
1098                    element_positions = element_positions[i..element_positions.len()].to_vec();
1099                    break;
1100                }
1101            }
1102            let start_pos = element_positions[0];
1103            let end_pos = *element_positions.last().unwrap();
1104            let range_proof = mmr.range_proof(start_pos, end_pos).await.unwrap();
1105            assert!(
1106                range_proof.verify_range_inclusion(
1107                    &mut hasher,
1108                    &elements,
1109                        start_pos,
1110                    &updated_root,
1111                ),
1112                "valid range proof over remaining elements after 2 pruning rounds should verify successfully",
1113            );
1114        });
1115    }
1116
1117    #[test_traced]
1118    fn test_verification_proof_serialization() {
1119        let executor = deterministic::Runner::default();
1120        executor.start(|_| async move {
1121            // create a new MMR and add a non-trivial amount of elements
1122            let mut mmr = Mmr::default();
1123            let mut elements = Vec::new();
1124            let mut element_positions = Vec::new();
1125            let mut hasher: Standard<Sha256> = Standard::new();
1126            for i in 0..25 {
1127                elements.push(test_digest(i));
1128                element_positions.push(mmr.add(&mut hasher, elements.last().unwrap()));
1129            }
1130
1131            // Generate proofs over all possible ranges of elements and confirm each
1132            // serializes=>deserializes correctly.
1133            for i in 0..elements.len() {
1134                for j in i..elements.len() {
1135                    let start_pos = element_positions[i];
1136                    let end_pos = element_positions[j];
1137                    let proof = mmr.range_proof(start_pos, end_pos).await.unwrap();
1138
1139                    let expected_size = proof.encode_size();
1140                    let serialized_proof = proof.encode().freeze();
1141                    assert_eq!(
1142                        serialized_proof.len(),
1143                        expected_size,
1144                        "serialized proof should have expected size"
1145                    );
1146                    let max_digests = proof.digests.len();
1147                    let deserialized_proof =
1148                        Proof::decode_cfg(serialized_proof, &max_digests).unwrap();
1149                    assert_eq!(
1150                        proof, deserialized_proof,
1151                        "deserialized proof should match source proof"
1152                    );
1153
1154                    // Remove one byte from the end of the serialized
1155                    // proof and confirm it fails to deserialize.
1156                    let serialized_proof = proof.encode().freeze();
1157                    let serialized_proof: Bytes =
1158                        serialized_proof.slice(0..serialized_proof.len() - 1);
1159                    assert!(
1160                        Proof::<Digest>::decode_cfg(serialized_proof, &max_digests).is_err(),
1161                        "proof should not deserialize with truncated data"
1162                    );
1163
1164                    // Add 1 byte of extra data to the end of the serialized
1165                    // proof and confirm it fails to deserialize.
1166                    let mut serialized_proof = proof.encode();
1167                    serialized_proof.extend_from_slice(&[0; 10]);
1168                    let serialized_proof = serialized_proof.freeze();
1169
1170                    assert!(
1171                        Proof::<Digest>::decode_cfg(serialized_proof, &max_digests,).is_err(),
1172                        "proof should not deserialize with extra data"
1173                    );
1174
1175                    // Confirm deserialization fails when max length is exceeded.
1176                    if max_digests > 0 {
1177                        let serialized_proof = proof.encode().freeze();
1178                        assert!(
1179                            Proof::<Digest>::decode_cfg(serialized_proof, &(max_digests - 1),)
1180                                .is_err(),
1181                            "proof should not deserialize with max length exceeded"
1182                        );
1183                    }
1184                }
1185            }
1186        });
1187    }
1188
1189    #[test]
1190    fn test_extract_pinned_nodes() {
1191        let executor = deterministic::Runner::default();
1192        executor.start(|_| async move {
1193            // Test for every number of elements from 1 to 255
1194            for num_elements in 1u8..255 {
1195                // Build MMR with the specified number of elements
1196                let mut mmr = Mmr::new();
1197                let mut hasher: Standard<Sha256> = Standard::new();
1198                let mut element_positions = Vec::new();
1199                for i in 0..num_elements {
1200                    let digest = test_digest(i);
1201                    element_positions.push(mmr.add(&mut hasher, &digest));
1202                }
1203
1204                // Test every valid pruning boundary (each element position)
1205                for &start_pos in &element_positions {
1206                    // Test with a few different end positions to get good coverage
1207                    let test_end_positions = if element_positions.len() == 1 {
1208                        // Single element case
1209                        vec![start_pos]
1210                    } else {
1211                        // Multi-element case: test with various end positions
1212                        let mut ends = vec![start_pos]; // Single element proof
1213
1214                        // Add a few more end positions if available
1215                        let start_idx = element_positions.iter().position(|&pos| pos == start_pos).unwrap();
1216                        if start_idx + 1 < element_positions.len() {
1217                            ends.push(element_positions[start_idx + 1]);
1218                        }
1219                        if start_idx + 2 < element_positions.len() {
1220                            ends.push(element_positions[start_idx + 2]);
1221                        }
1222                        // Always test with the last element if different
1223                        if *element_positions.last().unwrap() != start_pos {
1224                            ends.push(*element_positions.last().unwrap());
1225                        }
1226
1227                        ends.into_iter().collect::<std::collections::HashSet<_>>().into_iter().collect()
1228                    };
1229
1230                    for end_pos in test_end_positions {
1231                        // Generate proof for the range
1232                        let proof_result = mmr.range_proof(start_pos, end_pos).await;
1233                        let proof = proof_result.unwrap();
1234
1235                        // Extract pinned nodes
1236                        let extract_result = proof.extract_pinned_nodes(start_pos, end_pos);
1237                        assert!(
1238                            extract_result.is_ok(),
1239                            "Failed to extract pinned nodes for {num_elements} elements, boundary={start_pos}, range=[{start_pos}, {end_pos}]"
1240                        );
1241
1242                        let pinned_nodes = extract_result.unwrap();
1243                        let expected_pinned: Vec<u64> = Proof::<Digest>::nodes_to_pin(start_pos).collect();
1244
1245                        // Verify count matches expected
1246                        assert_eq!(
1247                            pinned_nodes.len(),
1248                            expected_pinned.len(),
1249                            "Pinned node count mismatch for {num_elements} elements, boundary={start_pos}, range=[{start_pos}, {end_pos}]"
1250                        );
1251
1252                        // Verify extracted hashes match actual node values
1253                        // The pinned_nodes Vec is in the same order as expected_pinned
1254                        for (i, &expected_pos) in expected_pinned.iter().enumerate() {
1255                            let extracted_hash = pinned_nodes[i];
1256                            let actual_hash = mmr.get_node(expected_pos).unwrap();
1257                            assert_eq!(
1258                                extracted_hash, actual_hash,
1259                                "Hash mismatch at position {expected_pos} (index {i}) for {num_elements} elements, boundary={start_pos}, range=[{start_pos}, {end_pos}]"
1260                            );
1261                        }
1262                    }
1263                }
1264            }
1265        });
1266    }
1267
1268    #[test_traced]
1269    fn test_historical_range_proof_basic() {
1270        let executor = deterministic::Runner::default();
1271        executor.start(|_| async move {
1272            // Create an MMR with 5 elements
1273            let mut mmr = Mmr::new();
1274            let mut hasher: Standard<Sha256> = Standard::new();
1275            let mut elements = Vec::new();
1276            let mut element_positions = Vec::new();
1277            for i in 0..5 {
1278                elements.push(test_digest(i));
1279                element_positions.push(mmr.add(&mut hasher, elements.last().unwrap()));
1280            }
1281            let current_size = mmr.size();
1282            let current_root = mmr.root(&mut hasher);
1283
1284            {
1285                // Historical proof should match regular proof at current size
1286                let historical_proof = Proof::historical_range_proof(
1287                    &mmr,
1288                    current_size,
1289                    element_positions[0],
1290                    element_positions[2],
1291                )
1292                .await
1293                .unwrap();
1294                let regular_proof = mmr
1295                    .range_proof(element_positions[0], element_positions[2])
1296                    .await
1297                    .unwrap();
1298
1299                assert_eq!(historical_proof.size, current_size);
1300                assert_eq!(historical_proof.digests, regular_proof.digests);
1301                assert!(historical_proof.verify_range_inclusion(
1302                    &mut hasher,
1303                    &elements[0..3],
1304                    element_positions[0],
1305                    &current_root
1306                ));
1307            }
1308
1309            {
1310                // Historical proof should match regular proof at historical size
1311                let mut ref_mmr = Mmr::new();
1312                for elt in elements.iter().take(3) {
1313                    ref_mmr.add(&mut hasher, elt);
1314                }
1315                let ref_size = ref_mmr.size();
1316                let ref_root = ref_mmr.root(&mut hasher);
1317                let ref_proof = Proof::historical_range_proof(
1318                    &mmr,
1319                    ref_size,
1320                    element_positions[0],
1321                    element_positions[2],
1322                )
1323                .await
1324                .unwrap();
1325                let historical_proof = Proof::historical_range_proof(
1326                    &mmr,
1327                    ref_size,
1328                    element_positions[0],
1329                    element_positions[2],
1330                )
1331                .await
1332                .unwrap();
1333
1334                assert_eq!(ref_proof.size, ref_size);
1335                assert_eq!(ref_proof.digests, historical_proof.digests);
1336                assert!(ref_proof.verify_range_inclusion(
1337                    &mut hasher,
1338                    &elements[0..3],
1339                    element_positions[0],
1340                    &ref_root
1341                ));
1342            }
1343        });
1344    }
1345
1346    #[test_traced]
1347    fn test_historical_range_proof_large() {
1348        let executor = deterministic::Runner::default();
1349        executor.start(|_| async move {
1350            // Simulate a sync scenario: server has 1000 operations, client syncs 600-799
1351            let mut server_mmr = Mmr::new();
1352            let mut hasher: Standard<Sha256> = Standard::new();
1353            let mut elements = Vec::new();
1354            let mut element_positions = Vec::new();
1355
1356            // Add 1000 elements to server
1357            for i in 0..1000 {
1358                elements.push(test_digest((i % 256) as u8));
1359                element_positions.push(server_mmr.add(&mut hasher, elements.last().unwrap()));
1360            }
1361
1362            // Want operations 600-799
1363            let start_loc = 600;
1364            let end_loc = 799;
1365
1366            // Create historical MMR state as it would have been after end_loc elements
1367            let mut ref_mmr = Mmr::new();
1368            for elt in elements.iter().take(end_loc + 1) {
1369                ref_mmr.add(&mut hasher, elt);
1370            }
1371            let ref_size = ref_mmr.size();
1372            let ref_root = ref_mmr.root(&mut hasher);
1373
1374            // Generate proof at historical position
1375            let historical_proof = Proof::historical_range_proof(
1376                &server_mmr,
1377                ref_size,
1378                element_positions[start_loc],
1379                element_positions[end_loc],
1380            )
1381            .await
1382            .unwrap();
1383
1384            assert_eq!(historical_proof.size, ref_size);
1385
1386            // Verify the sync proof
1387            assert!(historical_proof.verify_range_inclusion(
1388                &mut hasher,
1389                &elements[start_loc..=end_loc],
1390                element_positions[start_loc],
1391                &ref_root // Compare to historical root
1392            ));
1393        });
1394    }
1395
1396    #[test_traced]
1397    fn test_historical_range_proof_singleton() {
1398        let executor = deterministic::Runner::default();
1399        executor.start(|_| async move {
1400            let mut mmr = Mmr::new();
1401            let mut single_hasher: Standard<Sha256> = Standard::new();
1402            let element = test_digest(0);
1403            mmr.add(&mut single_hasher, &element);
1404            let single_historical_size = mmr.size();
1405            let single_root = mmr.root(&mut single_hasher);
1406
1407            let single_element_proof =
1408                Proof::historical_range_proof(&mmr, single_historical_size, 0, 0)
1409                    .await
1410                    .unwrap();
1411
1412            assert!(single_element_proof.verify_range_inclusion(
1413                &mut single_hasher,
1414                &[element],
1415                0,
1416                &single_root
1417            ));
1418        });
1419    }
1420
1421    #[test_traced]
1422    fn test_verification_digests_from_range() {
1423        let executor = deterministic::Runner::default();
1424        executor.start(|_| async move {
1425            // create a new MMR and add a non-trivial amount (49) of elements
1426            let mut mmr = Mmr::default();
1427            let mut elements = Vec::new();
1428            let mut element_positions = Vec::new();
1429            let mut hasher: Standard<Sha256> = Standard::new();
1430            for i in 0..49 {
1431                elements.push(test_digest(i));
1432                element_positions.push(mmr.add(&mut hasher, elements.last().unwrap()));
1433            }
1434            let root = mmr.root(&mut hasher);
1435
1436            // Test 1: compute_digests over the entire range should contain a digest for every node
1437            // in the tree, plus one extra for the root.
1438            let proof = mmr.range_proof(0, mmr.size() - 1).await.unwrap();
1439            let mut node_digests = proof
1440                .verify_range_inclusion_and_extract_digests(&mut hasher, &elements, 0, &root)
1441                .unwrap();
1442            assert_eq!(node_digests.len(), mmr.size() as usize);
1443            node_digests.sort_by_key(|(pos, _)| *pos);
1444            for (i, (pos, d)) in node_digests.into_iter().enumerate() {
1445                assert_eq!(pos, i as u64);
1446                assert_eq!(mmr.get_node(pos).unwrap(), d);
1447            }
1448            // Make sure the wrong root fails.
1449            let wrong_root = elements[0]; // any other digest will do
1450            assert!(matches!(
1451                proof.verify_range_inclusion_and_extract_digests(
1452                    &mut hasher,
1453                    &elements,
1454                    0,
1455                    &wrong_root
1456                ),
1457                Err(Error::RootMismatch)
1458            ));
1459
1460            // Test 2: Single element range (first element)
1461            let single_proof = mmr
1462                .range_proof(element_positions[0], element_positions[0])
1463                .await
1464                .unwrap();
1465            let single_digests = single_proof
1466                .verify_range_inclusion_and_extract_digests(
1467                    &mut hasher,
1468                    &elements[0..1],
1469                    element_positions[0],
1470                    &root,
1471                )
1472                .unwrap();
1473            assert!(single_digests.len() > 1);
1474
1475            // Test 3: Single element range (middle element)
1476            let mid_idx = 24;
1477            let mid_proof = mmr
1478                .range_proof(element_positions[mid_idx], element_positions[mid_idx])
1479                .await
1480                .unwrap();
1481            let mid_digests = mid_proof
1482                .verify_range_inclusion_and_extract_digests(
1483                    &mut hasher,
1484                    &elements[mid_idx..mid_idx + 1],
1485                    element_positions[mid_idx],
1486                    &root,
1487                )
1488                .unwrap();
1489            assert!(mid_digests.len() > 1);
1490
1491            // Test 4: Single element range (last element)
1492            let last_idx = elements.len() - 1;
1493            let last_proof = mmr
1494                .range_proof(element_positions[last_idx], element_positions[last_idx])
1495                .await
1496                .unwrap();
1497            let last_digests = last_proof
1498                .verify_range_inclusion_and_extract_digests(
1499                    &mut hasher,
1500                    &elements[last_idx..],
1501                    element_positions[last_idx],
1502                    &root,
1503                )
1504                .unwrap();
1505            assert!(last_digests.len() > 1);
1506
1507            // Test 5: Small range at the beginning
1508            let small_proof = mmr
1509                .range_proof(element_positions[0], element_positions[4])
1510                .await
1511                .unwrap();
1512            let small_digests = small_proof
1513                .verify_range_inclusion_and_extract_digests(
1514                    &mut hasher,
1515                    &elements[0..5],
1516                    element_positions[0],
1517                    &root,
1518                )
1519                .unwrap();
1520            // Verify that we get digests for the range elements and their ancestors
1521            assert!(small_digests.len() > 5);
1522
1523            // Test 6: Medium range in the middle
1524            let mid_start = 10;
1525            let mid_end = 30;
1526            let mid_range_proof = mmr
1527                .range_proof(element_positions[mid_start], element_positions[mid_end])
1528                .await
1529                .unwrap();
1530            let mid_range_digests = mid_range_proof
1531                .verify_range_inclusion_and_extract_digests(
1532                    &mut hasher,
1533                    &elements[mid_start..mid_end + 1],
1534                    element_positions[mid_start],
1535                    &root,
1536                )
1537                .unwrap();
1538            let num_elements = mid_end - mid_start + 1;
1539            assert!(mid_range_digests.len() > num_elements);
1540        });
1541    }
1542
1543    #[test_traced]
1544    fn test_multi_proof_generation_and_verify() {
1545        let executor = deterministic::Runner::default();
1546        executor.start(|_| async move {
1547            // Create an MMR with multiple elements
1548            let mut mmr = Mmr::new();
1549            let mut hasher: Standard<Sha256> = Standard::new();
1550            let mut elements = Vec::new();
1551            let mut positions = Vec::new();
1552
1553            for i in 0..20 {
1554                elements.push(test_digest(i));
1555                positions.push(mmr.add(&mut hasher, &elements[i as usize]));
1556            }
1557
1558            let root = mmr.root(&mut hasher);
1559
1560            // Generate proof for non-contiguous single elements
1561            let multi_proof =
1562                Proof::multi_proof(&mmr, &[positions[0], positions[5], positions[10]])
1563                    .await
1564                    .unwrap();
1565
1566            assert_eq!(multi_proof.size, mmr.size());
1567
1568            // Verify the proof
1569            assert!(multi_proof.verify_multi_inclusion(
1570                &mut hasher,
1571                &[
1572                    (elements[0], positions[0]),
1573                    (elements[5], positions[5]),
1574                    (elements[10], positions[10]),
1575                ],
1576                &root
1577            ));
1578
1579            // Verify in different order
1580            assert!(multi_proof.verify_multi_inclusion(
1581                &mut hasher,
1582                &[
1583                    (elements[10], positions[10]),
1584                    (elements[5], positions[5]),
1585                    (elements[0], positions[0]),
1586                ],
1587                &root
1588            ));
1589
1590            // Verify with duplicate items
1591            assert!(multi_proof.verify_multi_inclusion(
1592                &mut hasher,
1593                &[
1594                    (elements[0], positions[0]),
1595                    (elements[0], positions[0]),
1596                    (elements[10], positions[10]),
1597                    (elements[5], positions[5]),
1598                ],
1599                &root
1600            ));
1601
1602            // Verify with wrong positions
1603            assert!(!multi_proof.verify_multi_inclusion(
1604                &mut hasher,
1605                &[
1606                    (elements[0], positions[1]),
1607                    (elements[5], positions[6]),
1608                    (elements[10], positions[11]),
1609                ],
1610                &root,
1611            ));
1612
1613            // Verify with wrong elements
1614            let wrong_elements = [
1615                vec![255u8, 254u8, 253u8],
1616                vec![252u8, 251u8, 250u8],
1617                vec![249u8, 248u8, 247u8],
1618            ];
1619            let wrong_verification = multi_proof.verify_multi_inclusion(
1620                &mut hasher,
1621                &[
1622                    (wrong_elements[0].as_slice(), positions[0]),
1623                    (wrong_elements[1].as_slice(), positions[5]),
1624                    (wrong_elements[2].as_slice(), positions[10]),
1625                ],
1626                &root,
1627            );
1628            assert!(!wrong_verification, "Should fail with wrong elements");
1629
1630            // Verify with wrong root should fail
1631            let wrong_root = test_digest(99);
1632            assert!(!multi_proof.verify_multi_inclusion(
1633                &mut hasher,
1634                &[
1635                    (elements[0], positions[0]),
1636                    (elements[5], positions[5]),
1637                    (elements[10], positions[10]),
1638                ],
1639                &wrong_root
1640            ));
1641
1642            // Empty multi-proof
1643            let empty_multi = Proof::multi_proof(&mmr, &[]).await.unwrap();
1644            assert_eq!(empty_multi.size, mmr.size());
1645            assert!(empty_multi.digests.is_empty());
1646
1647            let empty_mmr = Mmr::new();
1648            let empty_root = empty_mmr.root(&mut hasher);
1649            let empty_proof = Proof::multi_proof(&empty_mmr, &[]).await.unwrap();
1650            assert!(empty_proof.verify_multi_inclusion(
1651                &mut hasher,
1652                &[] as &[(Digest, u64)],
1653                &empty_root
1654            ));
1655        });
1656    }
1657
1658    #[test_traced]
1659    fn test_multi_proof_deduplication() {
1660        let executor = deterministic::Runner::default();
1661        executor.start(|_| async move {
1662            let mut mmr = Mmr::new();
1663            let mut hasher: Standard<Sha256> = Standard::new();
1664            let mut elements = Vec::new();
1665            let mut positions = Vec::new();
1666
1667            // Create an MMR with enough elements to have shared digests
1668            for i in 0..30 {
1669                elements.push(test_digest(i));
1670                positions.push(mmr.add(&mut hasher, &elements[i as usize]));
1671            }
1672
1673            // Get individual proofs that will share some digests (elements in same subtree)
1674            let proof1 = mmr.proof(positions[0]).await.unwrap();
1675            let proof2 = mmr.proof(positions[1]).await.unwrap();
1676            let total_digests_separate = proof1.digests.len() + proof2.digests.len();
1677
1678            // Generate multi-proof for the same positions
1679            let multi_proof = Proof::multi_proof(&mmr, &[positions[0], positions[1]])
1680                .await
1681                .unwrap();
1682
1683            // The combined proof should have fewer digests due to deduplication
1684            assert!(multi_proof.digests.len() < total_digests_separate);
1685
1686            // Verify it still works
1687            let root = mmr.root(&mut hasher);
1688            assert!(multi_proof.verify_multi_inclusion(
1689                &mut hasher,
1690                &[(elements[0], positions[0]), (elements[1], positions[1])],
1691                &root
1692            ));
1693        });
1694    }
1695
1696    #[test_traced]
1697    fn test_verification_proof_store() {
1698        let executor = deterministic::Runner::default();
1699        executor.start(|_| async move {
1700            // create a new MMR and add a non-trivial amount (49) of elements
1701            let mut mmr = Mmr::default();
1702            let mut elements = Vec::new();
1703            let mut element_positions = Vec::new();
1704            let mut hasher: Standard<Sha256> = Standard::new();
1705            for i in 0..49 {
1706                elements.push(test_digest(i));
1707                element_positions.push(mmr.add(&mut hasher, elements.last().unwrap()));
1708            }
1709            let root = mmr.root(&mut hasher);
1710
1711            // Extract a ProofStore from a proof over a variety of ranges, starting with the full
1712            // range and shrinking each endpoint with each iteration.
1713            let mut range_start = 0;
1714            let mut range_end = 48;
1715            while range_start <= range_end {
1716                let range_proof = mmr
1717                    .range_proof(element_positions[range_start], element_positions[range_end])
1718                    .await
1719                    .unwrap();
1720                let proof_store = ProofStore::new(
1721                    &mut hasher,
1722                    &range_proof,
1723                    &elements[range_start..range_end + 1],
1724                    element_positions[range_start],
1725                    &root,
1726                )
1727                .unwrap();
1728
1729                // Verify that the ProofStore can be used to generate proofs over a host of sub-ranges
1730                // starting with the full range down to a range containing a single element.
1731                let mut subrange_start = range_start;
1732                let mut subrange_end = range_end;
1733                while subrange_start <= subrange_end {
1734                    // Verify a proof over a sub-range of the original range.
1735                    let sub_range_proof = Proof::<Digest>::range_proof(
1736                        &proof_store,
1737                        element_positions[subrange_start],
1738                        element_positions[subrange_end],
1739                    )
1740                    .await
1741                    .unwrap();
1742                    assert!(sub_range_proof.verify_range_inclusion(
1743                        &mut hasher,
1744                        &elements[subrange_start..subrange_end + 1],
1745                        element_positions[subrange_start],
1746                        &root
1747                    ));
1748                    subrange_start += 1;
1749                    subrange_end -= 1;
1750                }
1751                range_start += 1;
1752                range_end -= 1;
1753            }
1754        });
1755    }
1756}