Skip to main content

miden_crypto/merkle/
sparse_path.rs

1use alloc::{borrow::Cow, vec::Vec};
2use core::{
3    iter::{self, FusedIterator},
4    num::NonZero,
5};
6
7use super::{
8    EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Word, smt::SMT_MAX_DEPTH,
9};
10use crate::{
11    hash::poseidon2::Poseidon2,
12    utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
13};
14
15/// A different representation of [`MerklePath`] designed for memory efficiency for Merkle paths
16/// with empty nodes.
17///
18/// Empty nodes in the path are stored only as their position, represented with a bitmask. A
19/// maximum of 64 nodes (`SMT_MAX_DEPTH`) can be stored (empty and non-empty). The more nodes in a
20/// path are empty, the less memory this struct will use. This type calculates empty nodes on-demand
21/// when iterated through, converted to a [MerklePath], or an empty node is retrieved with
22/// [`SparseMerklePath::at_depth()`], which will incur overhead.
23///
24/// NOTE: This type assumes that Merkle paths always span from the root of the tree to a leaf.
25/// Partial paths are not supported.
26#[derive(Clone, Debug, Default, PartialEq, Eq)]
27#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
28pub struct SparseMerklePath {
29    /// A bitmask representing empty nodes. The set bit corresponds to the depth of an empty node.
30    /// The least significant bit (bit 0) describes depth 1 node (root's children).
31    /// The `bit index + 1` is equal to node's depth.
32    empty_nodes_mask: u64,
33    /// The non-empty nodes, stored in depth-order, but not contiguous across depth.
34    nodes: Vec<Word>,
35}
36
37impl SparseMerklePath {
38    /// Constructs a new sparse Merkle path from a bitmask of empty nodes and a vector of non-empty
39    /// nodes.
40    ///
41    /// The `empty_nodes_mask` is a bitmask where each set bit indicates that the node at that
42    /// depth is empty. The least significant bit (bit 0) describes depth 1 node (root's children).
43    /// The `bit index + 1` is equal to node's depth.
44    /// The `nodes` vector must contain the non-empty nodes in depth order.
45    ///
46    /// # Errors
47    /// - [MerkleError::InvalidPathLength] if the provided `nodes` vector is shorter than the
48    ///   minimum length required by the `empty_nodes_mask`.
49    /// - [MerkleError::DepthTooBig] if the total depth of the path (calculated from the
50    ///   `empty_nodes_mask` and `nodes`) is greater than [SMT_MAX_DEPTH].
51    pub fn from_parts(empty_nodes_mask: u64, nodes: Vec<Word>) -> Result<Self, MerkleError> {
52        // The most significant set bit in the mask marks the minimum length of the path.
53        // For every zero bit before the first set bit, there must be a corresponding node in
54        // `nodes`.
55        // For example, if the mask is `0b1100`, this means that the first two nodes
56        // (depths 1 and 2) are non-empty, and the next two nodes (depths 3 and 4) are empty.
57        // The minimum length of the path is 4, and the `nodes` vector must contain at least 2
58        // nodes to account for the first two zeroes in the mask (depths 1 and 2).
59        let min_path_len = u64::BITS - empty_nodes_mask.leading_zeros();
60        let empty_nodes_count = empty_nodes_mask.count_ones();
61        let min_non_empty_nodes = (min_path_len - empty_nodes_count) as usize;
62
63        if nodes.len() < min_non_empty_nodes {
64            return Err(MerkleError::InvalidPathLength(min_non_empty_nodes));
65        }
66
67        let depth = Self::depth_from_parts(empty_nodes_mask, &nodes) as u8;
68        if depth > SMT_MAX_DEPTH {
69            return Err(MerkleError::DepthTooBig(depth as u64));
70        }
71
72        Ok(Self { empty_nodes_mask, nodes })
73    }
74
75    /// Constructs a sparse Merkle path from an iterator over Merkle nodes that also knows its
76    /// exact size (such as iterators created with [Vec::into_iter]). The iterator must be in order
77    /// of deepest to shallowest.
78    ///
79    /// Knowing the size is necessary to calculate the depth of the tree, which is needed to detect
80    /// which nodes are empty nodes.
81    ///
82    /// # Errors
83    /// Returns [MerkleError::DepthTooBig] if `tree_depth` is greater than [SMT_MAX_DEPTH].
84    pub fn from_sized_iter<I>(iterator: I) -> Result<Self, MerkleError>
85    where
86        I: IntoIterator<IntoIter: ExactSizeIterator, Item = Word>,
87    {
88        let iterator = iterator.into_iter();
89        let tree_depth = iterator.len() as u8;
90
91        if tree_depth > SMT_MAX_DEPTH {
92            return Err(MerkleError::DepthTooBig(tree_depth as u64));
93        }
94
95        let mut empty_nodes_mask: u64 = 0;
96        let mut nodes: Vec<Word> = Default::default();
97
98        for (depth, node) in iter::zip(path_depth_iter(tree_depth), iterator) {
99            let &equivalent_empty_node = EmptySubtreeRoots::entry(tree_depth, depth.get());
100            let is_empty = node == equivalent_empty_node;
101            let node = if is_empty { None } else { Some(node) };
102
103            match node {
104                Some(node) => nodes.push(node),
105                None => empty_nodes_mask |= Self::bitmask_for_depth(depth),
106            }
107        }
108
109        Ok(SparseMerklePath { nodes, empty_nodes_mask })
110    }
111
112    /// Returns the total depth of this path, i.e., the number of nodes this path represents.
113    pub fn depth(&self) -> u8 {
114        Self::depth_from_parts(self.empty_nodes_mask, &self.nodes) as u8
115    }
116
117    /// Get a specific node in this path at a given depth.
118    ///
119    /// The `depth` parameter is defined in terms of `self.depth()`. Merkle paths conventionally do
120    /// not include the root, so the shallowest depth is `1`, and the deepest depth is
121    /// `self.depth()`.
122    ///
123    /// # Errors
124    /// Returns [MerkleError::DepthTooBig] if `node_depth` is greater than the total depth of this
125    /// path.
126    pub fn at_depth(&self, node_depth: NonZero<u8>) -> Result<Word, MerkleError> {
127        if node_depth.get() > self.depth() {
128            return Err(MerkleError::DepthTooBig(node_depth.get().into()));
129        }
130
131        let node = if let Some(nonempty_index) = self.get_nonempty_index(node_depth) {
132            self.nodes[nonempty_index]
133        } else {
134            *EmptySubtreeRoots::entry(self.depth(), node_depth.get())
135        };
136
137        Ok(node)
138    }
139
140    /// Deconstructs this path into its component parts.
141    ///
142    /// Returns a tuple containing:
143    /// - a bitmask where each set bit indicates that the node at that depth is empty. The least
144    ///   significant bit (bit 0) describes depth 1 node (root's children).
145    /// - a vector of non-empty nodes in depth order.
146    pub fn into_parts(self) -> (u64, Vec<Word>) {
147        (self.empty_nodes_mask, self.nodes)
148    }
149
150    // PROVIDERS
151    // ============================================================================================
152
153    /// Constructs a borrowing iterator over the nodes in this path.
154    /// Starts from the leaf and iterates toward the root (excluding the root).
155    pub fn iter(&self) -> impl ExactSizeIterator<Item = Word> {
156        self.into_iter()
157    }
158
159    /// Computes the Merkle root for this opening.
160    pub fn compute_root(&self, index: u64, node_to_prove: Word) -> Result<Word, MerkleError> {
161        let mut index = NodeIndex::new(self.depth(), index)?;
162        let root = self.iter().fold(node_to_prove, |node, sibling| {
163            // Compute the node and move to the next iteration.
164            let children = index.build_node(node, sibling);
165            index.move_up();
166            Poseidon2::merge(&children)
167        });
168
169        Ok(root)
170    }
171
172    /// Verifies the Merkle opening proof towards the provided root.
173    ///
174    /// # Errors
175    /// Returns an error if:
176    /// - provided node index is invalid.
177    /// - root calculated during the verification differs from the provided one.
178    pub fn verify(&self, index: u64, node: Word, &expected_root: &Word) -> Result<(), MerkleError> {
179        let computed_root = self.compute_root(index, node)?;
180        if computed_root != expected_root {
181            return Err(MerkleError::ConflictingRoots {
182                expected_root,
183                actual_root: computed_root,
184            });
185        }
186
187        Ok(())
188    }
189
190    /// Given the node this path opens to, return an iterator of all the nodes that are known via
191    /// this path.
192    ///
193    /// Each item in the iterator is an [InnerNodeInfo], containing the hash of a node as `.value`,
194    /// and its two children as `.left` and `.right`. The very first item in that iterator will be
195    /// the parent of `node_to_prove` as stored in this [SparseMerklePath].
196    ///
197    /// From there, the iterator will continue to yield every further parent and both of its
198    /// children, up to and including the root node.
199    ///
200    /// If `node_to_prove` is not the node this path is an opening to, or `index` is not the
201    /// correct index for that node, the returned nodes will be meaningless.
202    ///
203    /// # Errors
204    /// Returns an error if the specified index is not valid for this path.
205    pub fn authenticated_nodes(
206        &self,
207        index: u64,
208        node_to_prove: Word,
209    ) -> Result<InnerNodeIterator<'_>, MerkleError> {
210        let index = NodeIndex::new(self.depth(), index)?;
211        Ok(InnerNodeIterator { path: self, index, value: node_to_prove })
212    }
213
214    // PRIVATE HELPERS
215    // ============================================================================================
216
217    const fn bitmask_for_depth(node_depth: NonZero<u8>) -> u64 {
218        // - 1 because paths do not include the root.
219        1 << (node_depth.get() - 1)
220    }
221
222    const fn is_depth_empty(&self, node_depth: NonZero<u8>) -> bool {
223        (self.empty_nodes_mask & Self::bitmask_for_depth(node_depth)) != 0
224    }
225
226    /// Index of the non-empty node in the `self.nodes` vector. If the specified depth is
227    /// empty, None is returned.
228    fn get_nonempty_index(&self, node_depth: NonZero<u8>) -> Option<usize> {
229        if self.is_depth_empty(node_depth) {
230            return None;
231        }
232
233        let bit_index = node_depth.get() - 1;
234        let without_shallower = self.empty_nodes_mask >> bit_index;
235        let empty_deeper = without_shallower.count_ones() as usize;
236        // The vec index we would use if we didn't have any empty nodes to account for...
237        let normal_index = (self.depth() - node_depth.get()) as usize;
238        // subtracted by the number of empty nodes that are deeper than us.
239        Some(normal_index - empty_deeper)
240    }
241
242    /// Returns the total depth of this path from its parts.
243    fn depth_from_parts(empty_nodes_mask: u64, nodes: &[Word]) -> usize {
244        nodes.len() + empty_nodes_mask.count_ones() as usize
245    }
246}
247
248// SERIALIZATION
249// ================================================================================================
250
251impl Serializable for SparseMerklePath {
252    fn write_into<W: ByteWriter>(&self, target: &mut W) {
253        target.write_u8(self.depth());
254        target.write_u64(self.empty_nodes_mask);
255        target.write_many(&self.nodes);
256    }
257}
258
259impl Deserializable for SparseMerklePath {
260    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
261        let depth = source.read_u8()?;
262        if depth > SMT_MAX_DEPTH {
263            return Err(DeserializationError::InvalidValue(format!(
264                "SparseMerklePath max depth exceeded ({depth} > {SMT_MAX_DEPTH})",
265            )));
266        }
267        let empty_nodes_mask = source.read_u64()?;
268        let empty_nodes_count = empty_nodes_mask.count_ones();
269        if empty_nodes_count > depth as u32 {
270            return Err(DeserializationError::InvalidValue(format!(
271                "SparseMerklePath has more empty nodes ({empty_nodes_count}) than its full length ({depth})",
272            )));
273        }
274        let count = depth as u32 - empty_nodes_count;
275        let nodes: Vec<Word> = source.read_many_iter(count as usize)?.collect::<Result<_, _>>()?;
276        Ok(Self { empty_nodes_mask, nodes })
277    }
278}
279
280// CONVERSIONS
281// ================================================================================================
282
283impl From<SparseMerklePath> for MerklePath {
284    fn from(sparse_path: SparseMerklePath) -> Self {
285        MerklePath::from_iter(sparse_path)
286    }
287}
288
289impl TryFrom<MerklePath> for SparseMerklePath {
290    type Error = MerkleError;
291
292    /// # Errors
293    ///
294    /// This conversion returns [MerkleError::DepthTooBig] if the path length is greater than
295    /// [`SMT_MAX_DEPTH`].
296    fn try_from(path: MerklePath) -> Result<Self, MerkleError> {
297        SparseMerklePath::from_sized_iter(path)
298    }
299}
300
301impl From<SparseMerklePath> for Vec<Word> {
302    fn from(path: SparseMerklePath) -> Self {
303        Vec::from_iter(path)
304    }
305}
306
307// ITERATORS
308// ================================================================================================
309
310/// Iterator for [`SparseMerklePath`]. Starts from the leaf and iterates toward the root (excluding
311/// the root).
312pub struct SparseMerklePathIter<'p> {
313    /// The "inner" value we're iterating over.
314    path: Cow<'p, SparseMerklePath>,
315
316    /// The depth a `next()` call will get. `next_depth == 0` indicates that the iterator has been
317    /// exhausted.
318    next_depth: u8,
319}
320
321impl Iterator for SparseMerklePathIter<'_> {
322    type Item = Word;
323
324    fn next(&mut self) -> Option<Word> {
325        let this_depth = self.next_depth;
326        // Paths don't include the root, so if `this_depth` is 0 then we keep returning `None`.
327        let this_depth = NonZero::new(this_depth)?;
328        self.next_depth = this_depth.get() - 1;
329
330        // `this_depth` is only ever decreasing, so it can't ever exceed `self.path.depth()`.
331        let node = self
332            .path
333            .at_depth(this_depth)
334            .expect("current depth should never exceed the path depth");
335        Some(node)
336    }
337
338    // SparseMerkleIter always knows its exact size.
339    fn size_hint(&self) -> (usize, Option<usize>) {
340        let remaining = ExactSizeIterator::len(self);
341        (remaining, Some(remaining))
342    }
343}
344
345impl ExactSizeIterator for SparseMerklePathIter<'_> {
346    fn len(&self) -> usize {
347        self.next_depth as usize
348    }
349}
350
351impl FusedIterator for SparseMerklePathIter<'_> {}
352
353// TODO: impl DoubleEndedIterator.
354
355impl IntoIterator for SparseMerklePath {
356    type IntoIter = SparseMerklePathIter<'static>;
357    type Item = <Self::IntoIter as Iterator>::Item;
358
359    fn into_iter(self) -> SparseMerklePathIter<'static> {
360        let tree_depth = self.depth();
361        SparseMerklePathIter {
362            path: Cow::Owned(self),
363            next_depth: tree_depth,
364        }
365    }
366}
367
368impl<'p> IntoIterator for &'p SparseMerklePath {
369    type Item = <SparseMerklePathIter<'p> as Iterator>::Item;
370    type IntoIter = SparseMerklePathIter<'p>;
371
372    fn into_iter(self) -> SparseMerklePathIter<'p> {
373        let tree_depth = self.depth();
374        SparseMerklePathIter {
375            path: Cow::Borrowed(self),
376            next_depth: tree_depth,
377        }
378    }
379}
380
381/// An iterator over nodes known by a [SparseMerklePath]. See
382/// [`SparseMerklePath::authenticated_nodes()`].
383pub struct InnerNodeIterator<'p> {
384    path: &'p SparseMerklePath,
385    index: NodeIndex,
386    value: Word,
387}
388
389impl Iterator for InnerNodeIterator<'_> {
390    type Item = InnerNodeInfo;
391
392    fn next(&mut self) -> Option<Self::Item> {
393        if self.index.is_root() {
394            return None;
395        }
396
397        let index_depth = NonZero::new(self.index.depth()).expect("non-root depth cannot be 0");
398        let path_node = self.path.at_depth(index_depth).unwrap();
399
400        let children = self.index.build_node(self.value, path_node);
401        self.value = Poseidon2::merge(&children);
402        self.index.move_up();
403
404        Some(InnerNodeInfo {
405            value: self.value,
406            left: children[0],
407            right: children[1],
408        })
409    }
410}
411
412// COMPARISONS
413// ================================================================================================
414impl PartialEq<MerklePath> for SparseMerklePath {
415    fn eq(&self, rhs: &MerklePath) -> bool {
416        if self.depth() != rhs.depth() {
417            return false;
418        }
419
420        for (node, &rhs_node) in iter::zip(self, rhs.iter()) {
421            if node != rhs_node {
422                return false;
423            }
424        }
425
426        true
427    }
428}
429
430impl PartialEq<SparseMerklePath> for MerklePath {
431    fn eq(&self, rhs: &SparseMerklePath) -> bool {
432        rhs == self
433    }
434}
435
436// HELPERS
437// ================================================================================================
438
439/// Iterator for path depths, which start at the deepest part of the tree and go the shallowest
440/// depth before the root (depth 1).
441fn path_depth_iter(tree_depth: u8) -> impl ExactSizeIterator<Item = NonZero<u8>> {
442    let top_down_iter = (1..=tree_depth).map(|depth| {
443        // SAFETY: `RangeInclusive<1, _>` cannot ever yield 0. Even if `tree_depth` is 0, then the
444        // range is `RangeInclusive<1, 0>` will simply not yield any values, and this block won't
445        // even be reached.
446        unsafe { NonZero::new_unchecked(depth) }
447    });
448
449    // Reverse the top-down iterator to get a bottom-up iterator.
450    top_down_iter.rev()
451}
452
453// TESTS
454// ================================================================================================
455#[cfg(test)]
456mod tests {
457    use alloc::vec::Vec;
458    use core::num::NonZero;
459
460    use assert_matches::assert_matches;
461
462    use super::SparseMerklePath;
463    use crate::{
464        Felt, ONE, Word,
465        merkle::{
466            EmptySubtreeRoots, MerkleError, MerklePath, MerkleTree, NodeIndex,
467            smt::{LeafIndex, SMT_MAX_DEPTH, SimpleSmt, Smt, SparseMerkleTree},
468            sparse_path::path_depth_iter,
469        },
470    };
471
472    fn make_smt(pair_count: u64) -> Smt {
473        let entries: Vec<(Word, Word)> = (0..pair_count)
474            .map(|n| {
475                let leaf_index = ((n as f64 / pair_count as f64) * 255.0) as u64;
476                let key = Word::new([ONE, ONE, Felt::new(n), Felt::new(leaf_index)]);
477                let value = Word::new([ONE, ONE, ONE, ONE]);
478                (key, value)
479            })
480            .collect();
481
482        Smt::with_entries(entries).unwrap()
483    }
484
485    /// Manually test the exact bit patterns for a sample path of 8 nodes, including both empty and
486    /// non-empty nodes.
487    ///
488    /// This also offers an overview of what each part of the bit-math involved means and
489    /// represents.
490    #[test]
491    fn test_sparse_bits() {
492        const DEPTH: u8 = 8;
493        let raw_nodes: [Word; DEPTH as usize] = [
494            // Depth 8.
495            ([8u8, 8, 8, 8].into()),
496            // Depth 7.
497            *EmptySubtreeRoots::entry(DEPTH, 7),
498            // Depth 6.
499            *EmptySubtreeRoots::entry(DEPTH, 6),
500            // Depth 5.
501            [5u8, 5, 5, 5].into(),
502            // Depth 4.
503            [4u8, 4, 4, 4].into(),
504            // Depth 3.
505            *EmptySubtreeRoots::entry(DEPTH, 3),
506            // Depth 2.
507            *EmptySubtreeRoots::entry(DEPTH, 2),
508            // Depth 1.
509            *EmptySubtreeRoots::entry(DEPTH, 1),
510            // Root is not included.
511        ];
512
513        let sparse_nodes: [Option<Word>; DEPTH as usize] = [
514            // Depth 8.
515            Some([8u8, 8, 8, 8].into()),
516            // Depth 7.
517            None,
518            // Depth 6.
519            None,
520            // Depth 5.
521            Some([5u8, 5, 5, 5].into()),
522            // Depth 4.
523            Some([4u8, 4, 4, 4].into()),
524            // Depth 3.
525            None,
526            // Depth 2.
527            None,
528            // Depth 1.
529            None,
530            // Root is not included.
531        ];
532
533        const EMPTY_BITS: u64 = 0b0110_0111;
534
535        let sparse_path = SparseMerklePath::from_sized_iter(raw_nodes).unwrap();
536
537        assert_eq!(sparse_path.empty_nodes_mask, EMPTY_BITS);
538
539        // Keep track of how many non-empty nodes we have seen
540        let mut nonempty_idx = 0;
541
542        // Test starting from the deepest nodes (depth 8)
543        for depth in (1..=8).rev() {
544            let idx = (sparse_path.depth() - depth) as usize;
545            let bit = 1 << (depth - 1);
546
547            // Check that the depth bit is set correctly...
548            let is_set = (sparse_path.empty_nodes_mask & bit) != 0;
549            assert_eq!(is_set, sparse_nodes.get(idx).unwrap().is_none());
550
551            if is_set {
552                // Check that we don't return digests for empty nodes
553                let &test_node = sparse_nodes.get(idx).unwrap();
554                assert_eq!(test_node, None);
555            } else {
556                // Check that we can calculate non-empty indices correctly.
557                let control_node = raw_nodes.get(idx).unwrap();
558                assert_eq!(
559                    sparse_path.get_nonempty_index(NonZero::new(depth).unwrap()).unwrap(),
560                    nonempty_idx
561                );
562                let test_node = sparse_path.nodes.get(nonempty_idx).unwrap();
563                assert_eq!(test_node, control_node);
564
565                nonempty_idx += 1;
566            }
567        }
568    }
569
570    #[test]
571    fn from_parts() {
572        const DEPTH: u8 = 8;
573        let raw_nodes: [Word; DEPTH as usize] = [
574            // Depth 8.
575            ([8u8, 8, 8, 8].into()),
576            // Depth 7.
577            *EmptySubtreeRoots::entry(DEPTH, 7),
578            // Depth 6.
579            *EmptySubtreeRoots::entry(DEPTH, 6),
580            // Depth 5.
581            [5u8, 5, 5, 5].into(),
582            // Depth 4.
583            [4u8, 4, 4, 4].into(),
584            // Depth 3.
585            *EmptySubtreeRoots::entry(DEPTH, 3),
586            // Depth 2.
587            *EmptySubtreeRoots::entry(DEPTH, 2),
588            // Depth 1.
589            *EmptySubtreeRoots::entry(DEPTH, 1),
590            // Root is not included.
591        ];
592
593        let empty_nodes_mask = 0b0110_0111;
594        let nodes = vec![[8u8, 8, 8, 8].into(), [5u8, 5, 5, 5].into(), [4u8, 4, 4, 4].into()];
595        let insufficient_nodes = vec![[4u8, 4, 4, 4].into()];
596
597        let error = SparseMerklePath::from_parts(empty_nodes_mask, insufficient_nodes).unwrap_err();
598        assert_matches!(error, MerkleError::InvalidPathLength(2));
599
600        let iter_sparse_path = SparseMerklePath::from_sized_iter(raw_nodes).unwrap();
601        let sparse_path = SparseMerklePath::from_parts(empty_nodes_mask, nodes).unwrap();
602
603        assert_eq!(sparse_path, iter_sparse_path);
604    }
605
606    #[test]
607    fn from_sized_iter() {
608        let tree = make_smt(8192);
609
610        for (key, _value) in tree.entries() {
611            let index = NodeIndex::from(Smt::key_to_leaf_index(key));
612            let sparse_path = tree.get_path(key);
613            for (sparse_node, proof_idx) in
614                itertools::zip_eq(sparse_path.clone(), index.proof_indices())
615            {
616                let proof_node = tree.get_node_hash(proof_idx);
617                assert_eq!(sparse_node, proof_node);
618            }
619        }
620    }
621
622    #[test]
623    fn test_zero_sized() {
624        let nodes: Vec<Word> = Default::default();
625
626        // Sparse paths that don't actually contain any nodes should still be well behaved.
627        let sparse_path = SparseMerklePath::from_sized_iter(nodes).unwrap();
628        assert_eq!(sparse_path.depth(), 0);
629        assert_matches!(
630            sparse_path.at_depth(NonZero::new(1).unwrap()),
631            Err(MerkleError::DepthTooBig(1))
632        );
633        assert_eq!(sparse_path.iter().next(), None);
634        assert_eq!(sparse_path.into_iter().next(), None);
635    }
636
637    use proptest::prelude::*;
638
639    // Arbitrary instance for MerklePath
640    impl Arbitrary for MerklePath {
641        type Parameters = ();
642        type Strategy = BoxedStrategy<Self>;
643
644        fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
645            prop::collection::vec(any::<Word>(), 0..=SMT_MAX_DEPTH as usize)
646                .prop_map(MerklePath::new)
647                .boxed()
648        }
649    }
650
651    // Arbitrary instance for SparseMerklePath
652    impl Arbitrary for SparseMerklePath {
653        type Parameters = ();
654        type Strategy = BoxedStrategy<Self>;
655
656        fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
657            (0..=SMT_MAX_DEPTH as usize)
658                .prop_flat_map(|depth| {
659                    // Generate a bitmask for empty nodes - avoid overflow
660                    let max_mask = if depth > 0 && depth < 64 {
661                        (1u64 << depth) - 1
662                    } else if depth == 64 {
663                        u64::MAX
664                    } else {
665                        0
666                    };
667                    let empty_nodes_mask =
668                        prop::num::u64::ANY.prop_map(move |mask| mask & max_mask);
669
670                    // Generate non-empty nodes based on the mask
671                    empty_nodes_mask.prop_flat_map(move |mask| {
672                        let empty_count = mask.count_ones() as usize;
673                        let non_empty_count = depth.saturating_sub(empty_count);
674
675                        prop::collection::vec(any::<Word>(), non_empty_count).prop_map(
676                            move |nodes| SparseMerklePath::from_parts(mask, nodes).unwrap(),
677                        )
678                    })
679                })
680                .boxed()
681        }
682    }
683
684    proptest! {
685        #[test]
686        fn sparse_merkle_path_roundtrip_equivalence(path in any::<MerklePath>()) {
687            // Convert MerklePath to SparseMerklePath and back
688            let sparse_result = SparseMerklePath::try_from(path.clone());
689            if path.depth() <= SMT_MAX_DEPTH {
690                let sparse = sparse_result.unwrap();
691                let reconstructed = MerklePath::from(sparse);
692                prop_assert_eq!(path, reconstructed);
693            } else {
694                prop_assert!(sparse_result.is_err());
695            }
696        }
697    }
698    proptest! {
699
700        #[test]
701        fn merkle_path_roundtrip_equivalence(sparse in any::<SparseMerklePath>()) {
702            // Convert SparseMerklePath to MerklePath and back
703            let merkle = MerklePath::from(sparse.clone());
704            let reconstructed = SparseMerklePath::try_from(merkle.clone()).unwrap();
705            prop_assert_eq!(sparse, reconstructed);
706        }
707    }
708    proptest! {
709
710        #[test]
711        fn path_equivalence_tests(path in any::<MerklePath>(), path2 in any::<MerklePath>()) {
712            if path.depth() > SMT_MAX_DEPTH {
713                return Ok(());
714            }
715
716            let sparse = SparseMerklePath::try_from(path.clone()).unwrap();
717
718            // Depth consistency
719            prop_assert_eq!(path.depth(), sparse.depth());
720
721            // Node access consistency including path_depth_iter
722            if path.depth() > 0 {
723                for depth in path_depth_iter(path.depth()) {
724                    let merkle_node = path.at_depth(depth);
725                    let sparse_node = sparse.at_depth(depth);
726
727                    match (merkle_node, sparse_node) {
728                        (Some(m), Ok(s)) => prop_assert_eq!(m, s),
729                        (None, Err(_)) => {},
730                        _ => prop_assert!(false, "Inconsistent node access at depth {}", depth.get()),
731                    }
732                }
733            }
734
735            // Iterator consistency
736            if path.depth() > 0 {
737                let merkle_nodes: Vec<_> = path.iter().collect();
738                let sparse_nodes: Vec<_> = sparse.iter().collect();
739
740                prop_assert_eq!(merkle_nodes.len(), sparse_nodes.len());
741                for (m, s) in merkle_nodes.iter().zip(sparse_nodes.iter()) {
742                    prop_assert_eq!(*m, s);
743                }
744            }
745
746            // Test equality between different representations
747            if path2.depth() <= SMT_MAX_DEPTH {
748                let sparse2 = SparseMerklePath::try_from(path2.clone()).unwrap();
749                prop_assert_eq!(path == path2, sparse == sparse2);
750                prop_assert_eq!(path == sparse2, sparse == path2);
751            }
752        }
753    }
754    // rather heavy tests
755    proptest! {
756        #![proptest_config(ProptestConfig::with_cases(100))]
757
758        #[test]
759        fn compute_root_consistency(
760            tree_data in any::<RandomMerkleTree>(),
761            node in any::<Word>()
762        ) {
763            let RandomMerkleTree { tree, leaves: _,  indices } = tree_data;
764
765            for &leaf_index in indices.iter() {
766                let path = tree.get_path(NodeIndex::new(tree.depth(), leaf_index).unwrap()).unwrap();
767                let sparse = SparseMerklePath::from_sized_iter(path.clone().into_iter()).unwrap();
768
769                let merkle_root = path.compute_root(leaf_index, node);
770                let sparse_root = sparse.compute_root(leaf_index, node);
771
772                match (merkle_root, sparse_root) {
773                    (Ok(m), Ok(s)) => prop_assert_eq!(m, s),
774                    (Err(e1), Err(e2)) => {
775                        // Both should have the same error type
776                        prop_assert_eq!(format!("{:?}", e1), format!("{:?}", e2));
777                    },
778                    _ => prop_assert!(false, "Inconsistent compute_root results"),
779                }
780            }
781        }
782
783        #[test]
784        fn verify_consistency(
785            tree_data in any::<RandomMerkleTree>(),
786            node in any::<Word>()
787        ) {
788            let RandomMerkleTree { tree, leaves, indices } = tree_data;
789
790            for (i, &leaf_index) in indices.iter().enumerate() {
791                let leaf = leaves[i];
792                let path = tree.get_path(NodeIndex::new(tree.depth(), leaf_index).unwrap()).unwrap();
793                let sparse = SparseMerklePath::from_sized_iter(path.clone().into_iter()).unwrap();
794
795                let root = tree.root();
796
797                let merkle_verify = path.verify(leaf_index, leaf, &root);
798                let sparse_verify = sparse.verify(leaf_index, leaf, &root);
799
800                match (merkle_verify, sparse_verify) {
801                    (Ok(()), Ok(())) => {},
802                    (Err(e1), Err(e2)) => {
803                        // Both should have the same error type
804                        prop_assert_eq!(format!("{:?}", e1), format!("{:?}", e2));
805                    },
806                    _ => prop_assert!(false, "Inconsistent verify results"),
807                }
808
809                // Test with wrong node - both should fail
810                let wrong_verify = path.verify(leaf_index, node, &root);
811                let wrong_sparse_verify = sparse.verify(leaf_index, node, &root);
812
813                match (wrong_verify, wrong_sparse_verify) {
814                    (Ok(()), Ok(())) => prop_assert!(false, "Verification should have failed with wrong node"),
815                    (Err(_), Err(_)) => {},
816                    _ => prop_assert!(false, "Inconsistent verification results with wrong node"),
817                }
818            }
819        }
820
821        #[test]
822        fn authenticated_nodes_consistency(
823            tree_data in any::<RandomMerkleTree>()
824        ) {
825            let RandomMerkleTree { tree, leaves, indices } = tree_data;
826
827            for (i, &leaf_index) in indices.iter().enumerate() {
828                let leaf = leaves[i];
829                let path = tree.get_path(NodeIndex::new(tree.depth(), leaf_index).unwrap()).unwrap();
830                let sparse = SparseMerklePath::from_sized_iter(path.clone().into_iter()).unwrap();
831
832                let merkle_result = path.authenticated_nodes(leaf_index, leaf);
833                let sparse_result = sparse.authenticated_nodes(leaf_index, leaf);
834
835                match (merkle_result, sparse_result) {
836                    (Ok(m_iter), Ok(s_iter)) => {
837                        let merkle_nodes: Vec<_> = m_iter.collect();
838                        let sparse_nodes: Vec<_> = s_iter.collect();
839                        prop_assert_eq!(merkle_nodes.len(), sparse_nodes.len());
840                        for (m, s) in merkle_nodes.iter().zip(sparse_nodes.iter()) {
841                            prop_assert_eq!(m, s);
842                        }
843                    },
844                    (Err(e1), Err(e2)) => {
845                        prop_assert_eq!(format!("{:?}", e1), format!("{:?}", e2));
846                    },
847                    _ => prop_assert!(false, "Inconsistent authenticated_nodes results"),
848                }
849            }
850        }
851    }
852
853    #[test]
854    fn test_api_differences() {
855        // This test documents API differences between MerklePath and SparseMerklePath
856
857        // 1. MerklePath has Deref/DerefMut to Vec<Word> - SparseMerklePath does not
858        let merkle = MerklePath::new(vec![Word::default(); 3]);
859        let _vec_ref: &Vec<Word> = &merkle; // This works due to Deref
860        let _vec_mut: &mut Vec<Word> = &mut merkle.clone(); // This works due to DerefMut
861
862        // 2. SparseMerklePath has from_parts() - MerklePath uses new() or from_iter()
863        let sparse = SparseMerklePath::from_parts(0b101, vec![Word::default(); 2]).unwrap();
864        assert_eq!(sparse.depth(), 4); // depth is 4 because mask has bits set up to depth 4
865
866        // 3. SparseMerklePath has from_sized_iter() - MerklePath uses from_iter()
867        let nodes = vec![Word::default(); 3];
868        let sparse_from_iter = SparseMerklePath::from_sized_iter(nodes.clone()).unwrap();
869        let merkle_from_iter = MerklePath::from_iter(nodes);
870        assert_eq!(sparse_from_iter.depth(), merkle_from_iter.depth());
871    }
872
873    // Arbitrary instance for MerkleTree with random leaves
874    #[derive(Debug, Clone)]
875    struct RandomMerkleTree {
876        tree: MerkleTree,
877        leaves: Vec<Word>,
878        indices: Vec<u64>,
879    }
880
881    impl Arbitrary for RandomMerkleTree {
882        type Parameters = ();
883        type Strategy = BoxedStrategy<Self>;
884
885        fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
886            // Generate trees with power-of-2 leaves up to 1024 (2^10)
887            prop::sample::select(&[2, 4, 8, 16, 32, 64, 128, 256, 512, 1024])
888                .prop_flat_map(|num_leaves| {
889                    prop::collection::vec(any::<Word>(), num_leaves).prop_map(|leaves| {
890                        let tree = MerkleTree::new(leaves.clone()).unwrap();
891                        let indices: Vec<u64> = (0..leaves.len() as u64).collect();
892                        RandomMerkleTree { tree, leaves, indices }
893                    })
894                })
895                .boxed()
896        }
897    }
898
899    // Arbitrary instance for SimpleSmt with random entries
900    #[derive(Debug, Clone)]
901    struct RandomSimpleSmt {
902        tree: SimpleSmt<10>, // Depth 10 = 1024 leaves
903        entries: Vec<(u64, Word)>,
904    }
905
906    impl Arbitrary for RandomSimpleSmt {
907        type Parameters = ();
908        type Strategy = BoxedStrategy<Self>;
909
910        fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
911            (1..=100usize) // 1-100 entries in an 1024-leaf tree
912                .prop_flat_map(|num_entries| {
913                    prop::collection::vec(
914                        (
915                            0..1024u64, // Valid indices for 1024-leaf tree
916                            any::<Word>(),
917                        ),
918                        num_entries,
919                    )
920                    .prop_map(|mut entries| {
921                        // Ensure unique indices to avoid duplicates
922                        let mut seen = alloc::collections::BTreeSet::new();
923                        entries.retain(|(idx, _)| seen.insert(*idx));
924
925                        let mut tree = SimpleSmt::new().unwrap();
926                        for (idx, value) in &entries {
927                            let leaf_idx = LeafIndex::new(*idx).unwrap();
928                            tree.insert(leaf_idx, *value);
929                        }
930                        RandomSimpleSmt { tree, entries }
931                    })
932                })
933                .boxed()
934        }
935    }
936
937    // Arbitrary instance for Smt with random entries
938    #[derive(Debug, Clone)]
939    struct RandomSmt {
940        tree: Smt,
941        entries: Vec<(Word, Word)>,
942    }
943
944    impl Arbitrary for RandomSmt {
945        type Parameters = ();
946        type Strategy = BoxedStrategy<Self>;
947
948        fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
949            (1..=100usize) // 1-100 entries in a sparse tree
950                .prop_flat_map(|num_entries| {
951                    prop::collection::vec((any::<u64>(), any::<Word>()), num_entries).prop_map(
952                        |indices_n_values| {
953                            let entries: Vec<(Word, Word)> = indices_n_values
954                                .into_iter()
955                                .enumerate()
956                                .map(|(n, (leaf_index, value))| {
957                                    // SMT uses the most significant element (index 3) as leaf index
958                                    // Ensure we use valid leaf indices for the SMT depth
959                                    let valid_leaf_index = leaf_index % (1u64 << 60); // Use large but valid range
960                                    let key = Word::new([
961                                        Felt::new(n as u64),         // element 0
962                                        Felt::new(n as u64 + 1),     // element 1
963                                        Felt::new(n as u64 + 2),     // element 2
964                                        Felt::new(valid_leaf_index), // element 3 (leaf index)
965                                    ]);
966                                    (key, value)
967                                })
968                                .collect();
969
970                            // Ensure unique keys to avoid duplicates
971                            let mut seen = alloc::collections::BTreeSet::new();
972                            let unique_entries: Vec<_> =
973                                entries.into_iter().filter(|(key, _)| seen.insert(*key)).collect();
974
975                            let tree = Smt::with_entries(unique_entries.clone()).unwrap();
976                            RandomSmt { tree, entries: unique_entries }
977                        },
978                    )
979                })
980                .boxed()
981        }
982    }
983
984    proptest! {
985        #![proptest_config(ProptestConfig::with_cases(20))]
986
987        #[test]
988        fn simple_smt_path_consistency(tree_data in any::<RandomSimpleSmt>()) {
989            let RandomSimpleSmt { tree, entries } = tree_data;
990
991            for (leaf_index, value) in &entries {
992                let merkle_path = tree.get_path(&LeafIndex::new(*leaf_index).unwrap());
993                let sparse_path = SparseMerklePath::from_sized_iter(merkle_path.clone().into_iter()).unwrap();
994
995                // Verify both paths have same depth
996                prop_assert_eq!(merkle_path.depth(), sparse_path.depth());
997
998                // Verify both paths produce same root for the same value
999                let merkle_root = merkle_path.compute_root(*leaf_index, *value).unwrap();
1000                let sparse_root = sparse_path.compute_root(*leaf_index, *value).unwrap();
1001                prop_assert_eq!(merkle_root, sparse_root);
1002
1003                // Verify both paths verify correctly
1004                let tree_root = tree.root();
1005                prop_assert!(merkle_path.verify(*leaf_index, *value, &tree_root).is_ok());
1006                prop_assert!(sparse_path.verify(*leaf_index, *value, &tree_root).is_ok());
1007
1008                // Test with random additional leaf
1009                let random_leaf = Word::new([Felt::ONE; 4]);
1010                let random_index = *leaf_index ^ 1; // Ensure it's a sibling
1011
1012                // Both should fail verification with wrong leaf
1013                let merkle_wrong = merkle_path.verify(random_index, random_leaf, &tree_root);
1014                let sparse_wrong = sparse_path.verify(random_index, random_leaf, &tree_root);
1015                prop_assert_eq!(merkle_wrong.is_err(), sparse_wrong.is_err());
1016            }
1017        }
1018
1019        #[test]
1020        fn smt_path_consistency(tree_data in any::<RandomSmt>()) {
1021            let RandomSmt { tree, entries } = tree_data;
1022
1023            for (key, _value) in &entries {
1024                let (merkle_path, leaf) = tree.open(key).into_parts();
1025                let sparse_path = SparseMerklePath::from_sized_iter(merkle_path.clone().into_iter()).unwrap();
1026
1027                let leaf_index = Smt::key_to_leaf_index(key).value();
1028                let actual_value = leaf.hash(); // Use the actual leaf hash
1029
1030                // Verify both paths have same depth
1031                prop_assert_eq!(merkle_path.depth(), sparse_path.depth());
1032
1033                // Verify both paths produce same root for the same value
1034                let merkle_root = merkle_path.compute_root(leaf_index, actual_value).unwrap();
1035                let sparse_root = sparse_path.compute_root(leaf_index, actual_value).unwrap();
1036                prop_assert_eq!(merkle_root, sparse_root);
1037
1038                // Verify both paths verify correctly
1039                let tree_root = tree.root();
1040                prop_assert!(merkle_path.verify(leaf_index, actual_value, &tree_root).is_ok());
1041                prop_assert!(sparse_path.verify(leaf_index, actual_value, &tree_root).is_ok());
1042
1043                // Test authenticated nodes consistency
1044                let merkle_auth = merkle_path.authenticated_nodes(leaf_index, actual_value).unwrap().collect::<Vec<_>>();
1045                let sparse_auth = sparse_path.authenticated_nodes(leaf_index, actual_value).unwrap().collect::<Vec<_>>();
1046                prop_assert_eq!(merkle_auth, sparse_auth);
1047            }
1048        }
1049
1050        #[test]
1051        fn reverse_conversion_from_sparse(tree_data in any::<RandomMerkleTree>()) {
1052            let RandomMerkleTree { tree, leaves, indices } = tree_data;
1053
1054            for (i, &leaf_index) in indices.iter().enumerate() {
1055                let leaf = leaves[i];
1056                let merkle_path = tree.get_path(NodeIndex::new(tree.depth(), leaf_index).unwrap()).unwrap();
1057
1058                // Create SparseMerklePath first, then convert to MerklePath
1059                let sparse_path = SparseMerklePath::from_sized_iter(merkle_path.clone().into_iter()).unwrap();
1060                let converted_merkle = MerklePath::from(sparse_path.clone());
1061
1062                // Verify conversion back and forth works
1063                let back_to_sparse = SparseMerklePath::try_from(converted_merkle.clone()).unwrap();
1064                prop_assert_eq!(sparse_path, back_to_sparse);
1065
1066                // Verify all APIs work identically
1067                prop_assert_eq!(merkle_path.depth(), converted_merkle.depth());
1068
1069                let merkle_root = merkle_path.compute_root(leaf_index, leaf).unwrap();
1070                let converted_root = converted_merkle.compute_root(leaf_index, leaf).unwrap();
1071                prop_assert_eq!(merkle_root, converted_root);
1072            }
1073        }
1074    }
1075}