miden_crypto/merkle/smt/simple/
mod.rs

1use alloc::collections::BTreeSet;
2
3use super::{
4    super::ValuePath, EMPTY_WORD, EmptySubtreeRoots, InnerNode, InnerNodeInfo, InnerNodes,
5    LeafIndex, MerkleError, MerklePath, MutationSet, NodeIndex, SMT_MAX_DEPTH, SMT_MIN_DEPTH,
6    SparseMerkleTree, Word,
7};
8use crate::merkle::{SparseMerklePath, SparseValuePath};
9
10#[cfg(test)]
11mod tests;
12
13// SPARSE MERKLE TREE
14// ================================================================================================
15
16type Leaves = super::Leaves<Word>;
17
18/// A sparse Merkle tree with 64-bit keys and 4-element leaf values, without compaction.
19///
20/// The root of the tree is recomputed on each new leaf update.
21#[derive(Debug, Clone, PartialEq, Eq)]
22#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
23pub struct SimpleSmt<const DEPTH: u8> {
24    root: Word,
25    inner_nodes: InnerNodes,
26    leaves: Leaves,
27}
28
29impl<const DEPTH: u8> SimpleSmt<DEPTH> {
30    // CONSTANTS
31    // --------------------------------------------------------------------------------------------
32
33    /// The default value used to compute the hash of empty leaves
34    pub const EMPTY_VALUE: Word = <Self as SparseMerkleTree<DEPTH>>::EMPTY_VALUE;
35
36    // CONSTRUCTORS
37    // --------------------------------------------------------------------------------------------
38
39    /// Returns a new [SimpleSmt].
40    ///
41    /// All leaves in the returned tree are set to [ZERO; 4].
42    ///
43    /// # Errors
44    /// Returns an error if DEPTH is 0 or is greater than 64.
45    pub fn new() -> Result<Self, MerkleError> {
46        // validate the range of the depth.
47        if DEPTH < SMT_MIN_DEPTH {
48            return Err(MerkleError::DepthTooSmall(DEPTH));
49        } else if SMT_MAX_DEPTH < DEPTH {
50            return Err(MerkleError::DepthTooBig(DEPTH as u64));
51        }
52
53        let root = *EmptySubtreeRoots::entry(DEPTH, 0);
54
55        Ok(Self {
56            root,
57            inner_nodes: Default::default(),
58            leaves: Default::default(),
59        })
60    }
61
62    /// Returns a new [SimpleSmt] instantiated with leaves set as specified by the provided entries.
63    ///
64    /// All leaves omitted from the entries list are set to [ZERO; 4].
65    ///
66    /// # Errors
67    /// Returns an error if:
68    /// - If the depth is 0 or is greater than 64.
69    /// - The number of entries exceeds the maximum tree capacity, that is 2^{depth}.
70    /// - The provided entries contain multiple values for the same key.
71    pub fn with_leaves(
72        entries: impl IntoIterator<Item = (u64, Word)>,
73    ) -> Result<Self, MerkleError> {
74        // create an empty tree
75        let mut tree = Self::new()?;
76
77        // compute the max number of entries. We use an upper bound of depth 63 because we consider
78        // passing in a vector of size 2^64 infeasible.
79        let max_num_entries = 2_usize.pow(DEPTH.min(63).into());
80
81        // This being a sparse data structure, the EMPTY_WORD is not assigned to the `BTreeMap`, so
82        // entries with the empty value need additional tracking.
83        let mut key_set_to_zero = BTreeSet::new();
84
85        for (idx, (key, value)) in entries.into_iter().enumerate() {
86            if idx >= max_num_entries {
87                return Err(MerkleError::TooManyEntries(max_num_entries));
88            }
89
90            let old_value = tree.insert(LeafIndex::<DEPTH>::new(key)?, value);
91
92            if old_value != Self::EMPTY_VALUE || key_set_to_zero.contains(&key) {
93                return Err(MerkleError::DuplicateValuesForIndex(key));
94            }
95
96            if value == Self::EMPTY_VALUE {
97                key_set_to_zero.insert(key);
98            };
99        }
100        Ok(tree)
101    }
102
103    /// Returns a new [`SimpleSmt`] instantiated from already computed leaves and nodes.
104    ///
105    /// This function performs minimal consistency checking. It is the caller's responsibility to
106    /// ensure the passed arguments are correct and consistent with each other.
107    ///
108    /// # Panics
109    /// With debug assertions on, this function panics if `root` does not match the root node in
110    /// `inner_nodes`.
111    pub fn from_raw_parts(inner_nodes: InnerNodes, leaves: Leaves, root: Word) -> Self {
112        // Our particular implementation of `from_raw_parts()` never returns `Err`.
113        <Self as SparseMerkleTree<DEPTH>>::from_raw_parts(inner_nodes, leaves, root).unwrap()
114    }
115
116    /// Wrapper around [`SimpleSmt::with_leaves`] which inserts leaves at contiguous indices
117    /// starting at index 0.
118    pub fn with_contiguous_leaves(
119        entries: impl IntoIterator<Item = Word>,
120    ) -> Result<Self, MerkleError> {
121        Self::with_leaves(
122            entries
123                .into_iter()
124                .enumerate()
125                .map(|(idx, word)| (idx.try_into().expect("tree max depth is 2^8"), word)),
126        )
127    }
128
129    // PUBLIC ACCESSORS
130    // --------------------------------------------------------------------------------------------
131
132    /// Returns the depth of the tree
133    pub const fn depth(&self) -> u8 {
134        DEPTH
135    }
136
137    /// Returns the root of the tree
138    pub fn root(&self) -> Word {
139        <Self as SparseMerkleTree<DEPTH>>::root(self)
140    }
141
142    /// Returns the number of non-empty leaves in this tree.
143    pub fn num_leaves(&self) -> usize {
144        self.leaves.len()
145    }
146
147    /// Returns the leaf at the specified index.
148    pub fn get_leaf(&self, key: &LeafIndex<DEPTH>) -> Word {
149        <Self as SparseMerkleTree<DEPTH>>::get_leaf(self, key)
150    }
151
152    /// Returns a node at the specified index.
153    ///
154    /// # Errors
155    /// Returns an error if the specified index has depth set to 0 or the depth is greater than
156    /// the depth of this Merkle tree.
157    pub fn get_node(&self, index: NodeIndex) -> Result<Word, MerkleError> {
158        if index.is_root() {
159            Err(MerkleError::DepthTooSmall(index.depth()))
160        } else if index.depth() > DEPTH {
161            Err(MerkleError::DepthTooBig(index.depth() as u64))
162        } else if index.depth() == DEPTH {
163            let leaf = self.get_leaf(&LeafIndex::<DEPTH>::try_from(index)?);
164
165            Ok(leaf)
166        } else {
167            Ok(self.get_inner_node(index).hash())
168        }
169    }
170
171    /// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
172    /// path to the leaf, as well as the leaf itself.
173    pub fn open(&self, key: &LeafIndex<DEPTH>) -> SparseValuePath {
174        let value = self.get_value(key);
175        let nodes = key.index.proof_indices().map(|index| self.get_node_hash(index));
176        // `from_sized_iter()` returns an error if there are more nodes than `SMT_MAX_DEPTH`, but
177        // this could only happen if we have more levels than `SMT_MAX_DEPTH` ourselves, which is
178        // guarded against in `SimpleSmt::new()`.
179        let path = SparseMerklePath::from_sized_iter(nodes).unwrap();
180
181        SparseValuePath { value, path }
182    }
183
184    /// Returns a boolean value indicating whether the SMT is empty.
185    pub fn is_empty(&self) -> bool {
186        debug_assert_eq!(self.leaves.is_empty(), self.root == Self::EMPTY_ROOT);
187        self.root == Self::EMPTY_ROOT
188    }
189
190    // ITERATORS
191    // --------------------------------------------------------------------------------------------
192
193    /// Returns an iterator over the leaves of this [SimpleSmt].
194    pub fn leaves(&self) -> impl Iterator<Item = (u64, &Word)> {
195        self.leaves.iter().map(|(i, w)| (*i, w))
196    }
197
198    /// Returns an iterator over the inner nodes of this [SimpleSmt].
199    pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
200        self.inner_nodes.values().map(|e| InnerNodeInfo {
201            value: e.hash(),
202            left: e.left,
203            right: e.right,
204        })
205    }
206
207    // STATE MUTATORS
208    // --------------------------------------------------------------------------------------------
209
210    /// Inserts a value at the specified key, returning the previous value associated with that key.
211    /// Recall that by definition, any key that hasn't been updated is associated with
212    /// [`EMPTY_WORD`].
213    ///
214    /// This also recomputes all hashes between the leaf (associated with the key) and the root,
215    /// updating the root itself.
216    pub fn insert(&mut self, key: LeafIndex<DEPTH>, value: Word) -> Word {
217        <Self as SparseMerkleTree<DEPTH>>::insert(self, key, value)
218    }
219
220    /// Computes what changes are necessary to insert the specified key-value pairs into this
221    /// Merkle tree, allowing for validation before applying those changes.
222    ///
223    /// This method returns a [`MutationSet`], which contains all the information for inserting
224    /// `kv_pairs` into this Merkle tree already calculated, including the new root hash, which can
225    /// be queried with [`MutationSet::root()`]. Once a mutation set is returned,
226    /// [`SimpleSmt::apply_mutations()`] can be called in order to commit these changes to the
227    /// Merkle tree, or [`drop()`] to discard them.
228    ///
229    /// # Example
230    /// ```
231    /// # use miden_crypto::{Felt, Word};
232    /// # use miden_crypto::merkle::{LeafIndex, SimpleSmt, EmptySubtreeRoots, SMT_DEPTH};
233    /// let mut smt: SimpleSmt<3> = SimpleSmt::new().unwrap();
234    /// let pair = (LeafIndex::default(), Word::default());
235    /// let mutations = smt.compute_mutations(vec![pair]);
236    /// assert_eq!(mutations.root(), *EmptySubtreeRoots::entry(3, 0));
237    /// smt.apply_mutations(mutations);
238    /// assert_eq!(smt.root(), *EmptySubtreeRoots::entry(3, 0));
239    /// ```
240    pub fn compute_mutations(
241        &self,
242        kv_pairs: impl IntoIterator<Item = (LeafIndex<DEPTH>, Word)>,
243    ) -> MutationSet<DEPTH, LeafIndex<DEPTH>, Word> {
244        <Self as SparseMerkleTree<DEPTH>>::compute_mutations(self, kv_pairs)
245    }
246
247    /// Applies the prospective mutations computed with [`SimpleSmt::compute_mutations()`] to this
248    /// tree.
249    ///
250    /// # Errors
251    /// If `mutations` was computed on a tree with a different root than this one, returns
252    /// [`MerkleError::ConflictingRoots`] with a two-item [`alloc::vec::Vec`]. The first item is the
253    /// root hash the `mutations` were computed against, and the second item is the actual
254    /// current root of this tree.
255    pub fn apply_mutations(
256        &mut self,
257        mutations: MutationSet<DEPTH, LeafIndex<DEPTH>, Word>,
258    ) -> Result<(), MerkleError> {
259        <Self as SparseMerkleTree<DEPTH>>::apply_mutations(self, mutations)
260    }
261
262    /// Applies the prospective mutations computed with [`SimpleSmt::compute_mutations()`] to
263    /// this tree and returns the reverse mutation set.
264    ///
265    /// Applying the reverse mutation sets to the updated tree will revert the changes.
266    ///
267    /// # Errors
268    /// If `mutations` was computed on a tree with a different root than this one, returns
269    /// [`MerkleError::ConflictingRoots`] with a two-item [`alloc::vec::Vec`]. The first item is the
270    /// root hash the `mutations` were computed against, and the second item is the actual
271    /// current root of this tree.
272    pub fn apply_mutations_with_reversion(
273        &mut self,
274        mutations: MutationSet<DEPTH, LeafIndex<DEPTH>, Word>,
275    ) -> Result<MutationSet<DEPTH, LeafIndex<DEPTH>, Word>, MerkleError> {
276        <Self as SparseMerkleTree<DEPTH>>::apply_mutations_with_reversion(self, mutations)
277    }
278
279    /// Inserts a subtree at the specified index. The depth at which the subtree is inserted is
280    /// computed as `DEPTH - SUBTREE_DEPTH`.
281    ///
282    /// Returns the new root.
283    pub fn set_subtree<const SUBTREE_DEPTH: u8>(
284        &mut self,
285        subtree_insertion_index: u64,
286        subtree: SimpleSmt<SUBTREE_DEPTH>,
287    ) -> Result<Word, MerkleError> {
288        if SUBTREE_DEPTH > DEPTH {
289            return Err(MerkleError::SubtreeDepthExceedsDepth {
290                subtree_depth: SUBTREE_DEPTH,
291                tree_depth: DEPTH,
292            });
293        }
294
295        // Verify that `subtree_insertion_index` is valid.
296        let subtree_root_insertion_depth = DEPTH - SUBTREE_DEPTH;
297        let subtree_root_index =
298            NodeIndex::new(subtree_root_insertion_depth, subtree_insertion_index)?;
299
300        // add leaves
301        // --------------
302
303        // The subtree's leaf indices live in their own context - i.e. a subtree of depth `d`. If we
304        // insert the subtree at `subtree_insertion_index = 0`, then the subtree leaf indices are
305        // valid as they are. However, consider what happens when we insert at
306        // `subtree_insertion_index = 1`. The first leaf of our subtree now will have index `2^d`;
307        // you can see it as there's a full subtree sitting on its left. In general, for
308        // `subtree_insertion_index = i`, there are `i` subtrees sitting before the subtree we want
309        // to insert, so we need to adjust all its leaves by `i * 2^d`.
310        let leaf_index_shift: u64 = subtree_insertion_index * 2_u64.pow(SUBTREE_DEPTH.into());
311        for (subtree_leaf_idx, leaf_value) in subtree.leaves() {
312            let new_leaf_idx = leaf_index_shift + subtree_leaf_idx;
313            debug_assert!(new_leaf_idx < 2_u64.pow(DEPTH.into()));
314
315            self.leaves.insert(new_leaf_idx, *leaf_value);
316        }
317
318        // add subtree's branch nodes (which includes the root)
319        // --------------
320        for (branch_idx, branch_node) in subtree.inner_nodes {
321            let new_branch_idx = {
322                let new_depth = subtree_root_insertion_depth + branch_idx.depth();
323                let new_value = subtree_insertion_index * 2_u64.pow(branch_idx.depth().into())
324                    + branch_idx.value();
325
326                NodeIndex::new(new_depth, new_value).expect("index guaranteed to be valid")
327            };
328
329            self.inner_nodes.insert(new_branch_idx, branch_node);
330        }
331
332        // recompute nodes starting from subtree root
333        // --------------
334        self.recompute_nodes_from_index_to_root(subtree_root_index, subtree.root);
335
336        Ok(self.root)
337    }
338}
339
340impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
341    type Key = LeafIndex<DEPTH>;
342    type Value = Word;
343    type Leaf = Word;
344    type Opening = ValuePath;
345
346    const EMPTY_VALUE: Self::Value = EMPTY_WORD;
347    const EMPTY_ROOT: Word = *EmptySubtreeRoots::entry(DEPTH, 0);
348
349    fn from_raw_parts(
350        inner_nodes: InnerNodes,
351        leaves: Leaves,
352        root: Word,
353    ) -> Result<Self, MerkleError> {
354        if cfg!(debug_assertions) {
355            let root_node_hash = inner_nodes
356                .get(&NodeIndex::root())
357                .map(InnerNode::hash)
358                .unwrap_or(Self::EMPTY_ROOT);
359
360            assert_eq!(root_node_hash, root);
361        }
362
363        Ok(Self { root, inner_nodes, leaves })
364    }
365
366    fn root(&self) -> Word {
367        self.root
368    }
369
370    fn set_root(&mut self, root: Word) {
371        self.root = root;
372    }
373
374    fn get_inner_node(&self, index: NodeIndex) -> InnerNode {
375        self.inner_nodes
376            .get(&index)
377            .cloned()
378            .unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(DEPTH, index.depth()))
379    }
380
381    fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) -> Option<InnerNode> {
382        self.inner_nodes.insert(index, inner_node)
383    }
384
385    fn remove_inner_node(&mut self, index: NodeIndex) -> Option<InnerNode> {
386        self.inner_nodes.remove(&index)
387    }
388
389    fn insert_value(&mut self, key: LeafIndex<DEPTH>, value: Word) -> Option<Word> {
390        if value == Self::EMPTY_VALUE {
391            self.leaves.remove(&key.value())
392        } else {
393            self.leaves.insert(key.value(), value)
394        }
395    }
396
397    fn get_value(&self, key: &LeafIndex<DEPTH>) -> Word {
398        self.get_leaf(key)
399    }
400
401    fn get_leaf(&self, key: &LeafIndex<DEPTH>) -> Word {
402        let leaf_pos = key.value();
403        match self.leaves.get(&leaf_pos) {
404            Some(word) => *word,
405            None => Self::EMPTY_VALUE,
406        }
407    }
408
409    fn hash_leaf(leaf: &Word) -> Word {
410        // `SimpleSmt` takes the leaf value itself as the hash
411        *leaf
412    }
413
414    fn construct_prospective_leaf(
415        &self,
416        _existing_leaf: Word,
417        _key: &LeafIndex<DEPTH>,
418        value: &Word,
419    ) -> Word {
420        *value
421    }
422
423    fn key_to_leaf_index(key: &LeafIndex<DEPTH>) -> LeafIndex<DEPTH> {
424        *key
425    }
426
427    fn path_and_leaf_to_opening(path: MerklePath, leaf: Word) -> ValuePath {
428        (path, leaf).into()
429    }
430}