Skip to main content

miden_crypto/merkle/smt/simple/
mod.rs

1use alloc::collections::BTreeSet;
2
3use super::{
4    EMPTY_WORD, EmptySubtreeRoots, InnerNode, InnerNodeInfo, InnerNodes, LeafIndex, MerkleError,
5    MutationSet, NodeIndex, SMT_MAX_DEPTH, SMT_MIN_DEPTH, SparseMerkleTree, SparseMerkleTreeReader,
6    Word,
7};
8use crate::merkle::{SparseMerklePath, smt::SmtLeafError};
9
10mod proof;
11pub use proof::SimpleSmtProof;
12
13#[cfg(test)]
14mod tests;
15
16// SPARSE MERKLE TREE
17// ================================================================================================
18
19type Leaves = super::Leaves<Word>;
20
21/// A sparse Merkle tree with 64-bit keys and 4-element leaf values, without compaction.
22///
23/// The root of the tree is recomputed on each new leaf update.
24#[derive(Debug, Clone, PartialEq, Eq)]
25#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
26pub struct SimpleSmt<const DEPTH: u8> {
27    root: Word,
28    inner_nodes: InnerNodes,
29    leaves: Leaves,
30}
31
32impl<const DEPTH: u8> SimpleSmt<DEPTH> {
33    // CONSTANTS
34    // --------------------------------------------------------------------------------------------
35
36    /// The default value used to compute the hash of empty leaves
37    pub const EMPTY_VALUE: Word = <Self as SparseMerkleTreeReader<DEPTH>>::EMPTY_VALUE;
38
39    // CONSTRUCTORS
40    // --------------------------------------------------------------------------------------------
41
42    /// Returns a new [SimpleSmt].
43    ///
44    /// All leaves in the returned tree are set to [ZERO; 4].
45    ///
46    /// # Errors
47    /// Returns an error if DEPTH is 0 or is greater than 64.
48    pub fn new() -> Result<Self, MerkleError> {
49        // validate the range of the depth.
50        if DEPTH < SMT_MIN_DEPTH {
51            return Err(MerkleError::DepthTooSmall(DEPTH));
52        } else if SMT_MAX_DEPTH < DEPTH {
53            return Err(MerkleError::DepthTooBig(DEPTH as u64));
54        }
55
56        let root = *EmptySubtreeRoots::entry(DEPTH, 0);
57
58        Ok(Self {
59            root,
60            inner_nodes: Default::default(),
61            leaves: Default::default(),
62        })
63    }
64
65    /// Returns a new [SimpleSmt] instantiated with leaves set as specified by the provided entries.
66    ///
67    /// All leaves omitted from the entries list are set to [ZERO; 4].
68    ///
69    /// # Errors
70    /// Returns an error if:
71    /// - If the depth is 0 or is greater than 64.
72    /// - The number of entries exceeds the maximum tree capacity, that is 2^{depth}.
73    /// - The provided entries contain multiple values for the same key.
74    pub fn with_leaves(
75        entries: impl IntoIterator<Item = (u64, Word)>,
76    ) -> Result<Self, MerkleError> {
77        // create an empty tree
78        let mut tree = Self::new()?;
79
80        // compute the max number of entries. We use an upper bound of depth 63 because we consider
81        // passing in a vector of size 2^64 infeasible.
82        let max_num_entries = 2_u64.pow(DEPTH.min(63).into());
83
84        // This being a sparse data structure, the EMPTY_WORD is not assigned to the `BTreeMap`, so
85        // entries with the empty value need additional tracking.
86        let mut key_set_to_zero = BTreeSet::new();
87
88        for (idx, (key, value)) in entries.into_iter().enumerate() {
89            if idx as u64 >= max_num_entries {
90                return Err(MerkleError::TooManyEntries(DEPTH));
91            }
92
93            let old_value = tree.insert(LeafIndex::<DEPTH>::new(key)?, value);
94
95            if old_value != Self::EMPTY_VALUE || key_set_to_zero.contains(&key) {
96                return Err(MerkleError::DuplicateValuesForIndex(key));
97            }
98
99            if value == Self::EMPTY_VALUE {
100                key_set_to_zero.insert(key);
101            };
102        }
103        Ok(tree)
104    }
105
106    /// Returns a new [`SimpleSmt`] instantiated from already computed leaves and nodes.
107    ///
108    /// This function performs minimal consistency checking. It is the caller's responsibility to
109    /// ensure the passed arguments are correct and consistent with each other.
110    ///
111    /// # Panics
112    /// With debug assertions on, this function panics if `root` does not match the root node in
113    /// `inner_nodes`.
114    pub fn from_raw_parts(inner_nodes: InnerNodes, leaves: Leaves, root: Word) -> Self {
115        if cfg!(debug_assertions) {
116            let root_node_hash = inner_nodes
117                .get(&NodeIndex::root())
118                .map(InnerNode::hash)
119                .unwrap_or(Self::EMPTY_ROOT);
120
121            assert_eq!(root_node_hash, root);
122        }
123
124        Self { root, inner_nodes, leaves }
125    }
126
127    /// Wrapper around [`SimpleSmt::with_leaves`] which inserts leaves at contiguous indices
128    /// starting at index 0.
129    pub fn with_contiguous_leaves(
130        entries: impl IntoIterator<Item = Word>,
131    ) -> Result<Self, MerkleError> {
132        Self::with_leaves(
133            entries
134                .into_iter()
135                .enumerate()
136                .map(|(idx, word)| (idx.try_into().expect("tree max depth is 2^8"), word)),
137        )
138    }
139
140    // PUBLIC ACCESSORS
141    // --------------------------------------------------------------------------------------------
142
143    /// Returns the depth of the tree
144    pub const fn depth(&self) -> u8 {
145        DEPTH
146    }
147
148    /// Returns the root of the tree
149    pub fn root(&self) -> Word {
150        <Self as SparseMerkleTreeReader<DEPTH>>::root(self)
151    }
152
153    /// Returns the number of non-empty leaves in this tree.
154    pub fn num_leaves(&self) -> usize {
155        self.leaves.len()
156    }
157
158    /// Returns the leaf at the specified index.
159    pub fn get_leaf(&self, key: &LeafIndex<DEPTH>) -> Word {
160        <Self as SparseMerkleTreeReader<DEPTH>>::get_leaf(self, key)
161    }
162
163    /// Returns a node at the specified index.
164    ///
165    /// # Errors
166    /// Returns an error if the specified index has depth set to 0 or the depth is greater than
167    /// the depth of this Merkle tree.
168    pub fn get_node(&self, index: NodeIndex) -> Result<Word, MerkleError> {
169        if index.is_root() {
170            Err(MerkleError::DepthTooSmall(index.depth()))
171        } else if index.depth() > DEPTH {
172            Err(MerkleError::DepthTooBig(index.depth() as u64))
173        } else if index.depth() == DEPTH {
174            let leaf = self.get_leaf(&LeafIndex::<DEPTH>::try_from(index)?);
175
176            Ok(leaf)
177        } else {
178            Ok(self.get_inner_node(index).hash())
179        }
180    }
181
182    /// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
183    /// path to the leaf, as well as the leaf itself.
184    pub fn open(&self, key: &LeafIndex<DEPTH>) -> SimpleSmtProof {
185        let value = self.get_value(key);
186        let nodes = key.index.proof_indices().map(|index| self.get_node_hash(index));
187        // `from_sized_iter()` returns an error if there are more nodes than `SMT_MAX_DEPTH`, but
188        // this could only happen if we have more levels than `SMT_MAX_DEPTH` ourselves, which is
189        // guarded against in `SimpleSmt::new()`.
190        let path = SparseMerklePath::from_sized_iter(nodes).unwrap();
191
192        SimpleSmtProof { value, path }
193    }
194
195    /// Returns a boolean value indicating whether the SMT is empty.
196    pub fn is_empty(&self) -> bool {
197        debug_assert_eq!(self.leaves.is_empty(), self.root == Self::EMPTY_ROOT);
198        self.root == Self::EMPTY_ROOT
199    }
200
201    // ITERATORS
202    // --------------------------------------------------------------------------------------------
203
204    /// Returns an iterator over the leaves of this [SimpleSmt].
205    pub fn leaves(&self) -> impl Iterator<Item = (u64, &Word)> {
206        self.leaves.iter().map(|(i, w)| (*i, w))
207    }
208
209    /// Returns an iterator over the inner nodes of this [SimpleSmt].
210    pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
211        self.inner_nodes.values().map(|e| InnerNodeInfo {
212            value: e.hash(),
213            left: e.left,
214            right: e.right,
215        })
216    }
217
218    // STATE MUTATORS
219    // --------------------------------------------------------------------------------------------
220
221    /// Inserts a value at the specified key, returning the previous value associated with that key.
222    /// Recall that by definition, any key that hasn't been updated is associated with
223    /// [`EMPTY_WORD`].
224    ///
225    /// This also recomputes all hashes between the leaf (associated with the key) and the root,
226    /// updating the root itself.
227    pub fn insert(&mut self, key: LeafIndex<DEPTH>, value: Word) -> Word {
228        // SAFETY: a SimpleSmt does not contain multi-value leaves. The underlying
229        // SimpleSmt::insert_value does not return any errors so it's safe to unwrap here.
230        <Self as SparseMerkleTree<DEPTH>>::insert(self, key, value)
231            .expect("inserting a value into a simple smt never returns an error")
232    }
233
234    /// Computes what changes are necessary to insert the specified key-value pairs into this
235    /// Merkle tree, allowing for validation before applying those changes.
236    ///
237    /// This method returns a [`MutationSet`], which contains all the information for inserting
238    /// `kv_pairs` into this Merkle tree already calculated, including the new root hash, which can
239    /// be queried with [`MutationSet::root()`]. Once a mutation set is returned,
240    /// [`SimpleSmt::apply_mutations()`] can be called in order to commit these changes to the
241    /// Merkle tree, or [`drop()`] to discard them.
242    ///
243    /// # Errors
244    ///
245    /// - [`MerkleError::DuplicateValuesForIndex`] if the provided `kv_pairs` contain duplicate
246    ///   keys.
247    ///
248    /// # Example
249    /// ```
250    /// # use miden_crypto::{Felt, Word};
251    /// # use miden_crypto::merkle::{smt::{LeafIndex, SimpleSmt, SMT_DEPTH}, EmptySubtreeRoots};
252    /// let mut smt: SimpleSmt<3> = SimpleSmt::new().unwrap();
253    /// let pair = (LeafIndex::default(), Word::default());
254    /// let mutations = smt.compute_mutations(vec![pair]).unwrap();
255    /// assert_eq!(mutations.root(), *EmptySubtreeRoots::entry(3, 0));
256    /// smt.apply_mutations(mutations).unwrap();
257    /// assert_eq!(smt.root(), *EmptySubtreeRoots::entry(3, 0));
258    /// ```
259    pub fn compute_mutations(
260        &self,
261        kv_pairs: impl IntoIterator<Item = (LeafIndex<DEPTH>, Word)>,
262    ) -> Result<MutationSet<DEPTH, LeafIndex<DEPTH>, Word>, MerkleError> {
263        <Self as SparseMerkleTreeReader<DEPTH>>::compute_mutations(self, kv_pairs)
264    }
265
266    /// Applies the prospective mutations computed with [`SimpleSmt::compute_mutations()`] to this
267    /// tree.
268    ///
269    /// # Errors
270    /// If `mutations` was computed on a tree with a different root than this one, returns
271    /// [`MerkleError::ConflictingRoots`] with a two-item [`alloc::vec::Vec`]. The first item is the
272    /// root hash the `mutations` were computed against, and the second item is the actual
273    /// current root of this tree.
274    pub fn apply_mutations(
275        &mut self,
276        mutations: MutationSet<DEPTH, LeafIndex<DEPTH>, Word>,
277    ) -> Result<(), MerkleError> {
278        <Self as SparseMerkleTree<DEPTH>>::apply_mutations(self, mutations)
279    }
280
281    /// Applies the prospective mutations computed with [`SimpleSmt::compute_mutations()`] to
282    /// this tree and returns the reverse mutation set.
283    ///
284    /// Applying the reverse mutation sets to the updated tree will revert the changes.
285    ///
286    /// # Errors
287    /// If `mutations` was computed on a tree with a different root than this one, returns
288    /// [`MerkleError::ConflictingRoots`] with a two-item [`alloc::vec::Vec`]. The first item is the
289    /// root hash the `mutations` were computed against, and the second item is the actual
290    /// current root of this tree.
291    pub fn apply_mutations_with_reversion(
292        &mut self,
293        mutations: MutationSet<DEPTH, LeafIndex<DEPTH>, Word>,
294    ) -> Result<MutationSet<DEPTH, LeafIndex<DEPTH>, Word>, MerkleError> {
295        <Self as SparseMerkleTree<DEPTH>>::apply_mutations_with_reversion(self, mutations)
296    }
297
298    /// Inserts a subtree at the specified index. The depth at which the subtree is inserted is
299    /// computed as `DEPTH - SUBTREE_DEPTH`.
300    ///
301    /// Returns the new root.
302    pub fn set_subtree<const SUBTREE_DEPTH: u8>(
303        &mut self,
304        subtree_insertion_index: u64,
305        subtree: SimpleSmt<SUBTREE_DEPTH>,
306    ) -> Result<Word, MerkleError> {
307        if SUBTREE_DEPTH > DEPTH {
308            return Err(MerkleError::SubtreeDepthExceedsDepth {
309                subtree_depth: SUBTREE_DEPTH,
310                tree_depth: DEPTH,
311            });
312        }
313
314        // Verify that `subtree_insertion_index` is valid.
315        let subtree_root_insertion_depth = DEPTH - SUBTREE_DEPTH;
316        let subtree_root_index =
317            NodeIndex::new(subtree_root_insertion_depth, subtree_insertion_index)?;
318
319        // remove leaves and inner nodes under the insertion root
320        // --------------
321
322        // The subtree's leaf indices live in their own context - i.e. a subtree of depth `d`. If we
323        // insert the subtree at `subtree_insertion_index = 0`, then the subtree leaf indices are
324        // valid as they are. However, consider what happens when we insert at
325        // `subtree_insertion_index = 1`. The first leaf of our subtree now will have index `2^d`;
326        // you can see it as there's a full subtree sitting on its left. In general, for
327        // `subtree_insertion_index = i`, there are `i` subtrees sitting before the subtree we want
328        // to insert, so we need to adjust all its leaves by `i * 2^d`.
329        let leaf_index_shift: u64 = if SUBTREE_DEPTH == SMT_MAX_DEPTH {
330            0
331        } else {
332            subtree_insertion_index << u32::from(SUBTREE_DEPTH)
333        };
334
335        self.leaves.retain(|leaf_idx, _| {
336            !Self::leaf_is_in_subtree::<SUBTREE_DEPTH>(*leaf_idx, subtree_insertion_index)
337        });
338        self.inner_nodes.retain(|node_idx, _| {
339            !Self::node_is_in_subtree(
340                *node_idx,
341                subtree_root_insertion_depth,
342                subtree_insertion_index,
343            )
344        });
345
346        // add leaves
347        // --------------
348        for (subtree_leaf_idx, leaf_value) in subtree.leaves() {
349            let new_leaf_idx = leaf_index_shift + subtree_leaf_idx;
350            debug_assert!(DEPTH == SMT_MAX_DEPTH || new_leaf_idx < 2_u64.pow(DEPTH.into()));
351
352            self.leaves.insert(new_leaf_idx, *leaf_value);
353        }
354
355        // add subtree's branch nodes (which includes the root)
356        // --------------
357        for (branch_idx, branch_node) in subtree.inner_nodes {
358            let new_branch_idx = {
359                let new_depth = subtree_root_insertion_depth + branch_idx.depth();
360                let new_value = subtree_insertion_index * 2_u64.pow(branch_idx.depth().into())
361                    + branch_idx.position();
362
363                NodeIndex::new(new_depth, new_value).expect("index guaranteed to be valid")
364            };
365
366            self.inner_nodes.insert(new_branch_idx, branch_node);
367        }
368
369        // recompute nodes starting from subtree root
370        // --------------
371        self.recompute_nodes_from_index_to_root(subtree_root_index, subtree.root);
372
373        Ok(self.root)
374    }
375
376    fn leaf_is_in_subtree<const SUBTREE_DEPTH: u8>(
377        leaf_idx: u64,
378        subtree_insertion_index: u64,
379    ) -> bool {
380        if SUBTREE_DEPTH == SMT_MAX_DEPTH {
381            true
382        } else {
383            (leaf_idx >> u32::from(SUBTREE_DEPTH)) == subtree_insertion_index
384        }
385    }
386
387    fn node_is_in_subtree(
388        node_idx: NodeIndex,
389        subtree_root_depth: u8,
390        subtree_insertion_index: u64,
391    ) -> bool {
392        if node_idx.depth() < subtree_root_depth {
393            return false;
394        }
395
396        let depth_offset = node_idx.depth() - subtree_root_depth;
397        if depth_offset == SMT_MAX_DEPTH {
398            subtree_insertion_index == 0
399        } else {
400            (node_idx.position() >> u32::from(depth_offset)) == subtree_insertion_index
401        }
402    }
403}
404
405impl<const DEPTH: u8> SparseMerkleTreeReader<DEPTH> for SimpleSmt<DEPTH> {
406    type Key = LeafIndex<DEPTH>;
407    type Value = Word;
408    type Leaf = Word;
409    type Opening = SimpleSmtProof;
410
411    const EMPTY_VALUE: Self::Value = EMPTY_WORD;
412    const EMPTY_ROOT: Word = *EmptySubtreeRoots::entry(DEPTH, 0);
413
414    fn root(&self) -> Word {
415        self.root
416    }
417
418    fn get_inner_node(&self, index: NodeIndex) -> InnerNode {
419        self.inner_nodes
420            .get(&index)
421            .cloned()
422            .unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(DEPTH, index.depth()))
423    }
424
425    fn get_value(&self, key: &LeafIndex<DEPTH>) -> Word {
426        self.get_leaf(key)
427    }
428
429    fn get_leaf(&self, key: &LeafIndex<DEPTH>) -> Word {
430        let leaf_pos = key.position();
431        match self.leaves.get(&leaf_pos) {
432            Some(word) => *word,
433            None => Self::EMPTY_VALUE,
434        }
435    }
436
437    fn hash_leaf(leaf: &Word) -> Word {
438        // `SimpleSmt` takes the leaf value itself as the hash
439        *leaf
440    }
441
442    fn construct_prospective_leaf(
443        &self,
444        _existing_leaf: Word,
445        _key: &LeafIndex<DEPTH>,
446        value: &Word,
447    ) -> Result<Word, SmtLeafError> {
448        Ok(*value)
449    }
450
451    fn key_to_leaf_index(key: &LeafIndex<DEPTH>) -> LeafIndex<DEPTH> {
452        *key
453    }
454
455    fn path_and_leaf_to_opening(path: SparseMerklePath, leaf: Word) -> SimpleSmtProof {
456        (path, leaf).into()
457    }
458}
459
460impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
461    fn set_root(&mut self, root: Word) {
462        self.root = root;
463    }
464
465    fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) -> Option<InnerNode> {
466        self.inner_nodes.insert(index, inner_node)
467    }
468
469    fn remove_inner_node(&mut self, index: NodeIndex) -> Option<InnerNode> {
470        self.inner_nodes.remove(&index)
471    }
472
473    fn insert_value(
474        &mut self,
475        key: LeafIndex<DEPTH>,
476        value: Word,
477    ) -> Result<Option<Word>, MerkleError> {
478        let result = if value == Self::EMPTY_VALUE {
479            self.leaves.remove(&key.position())
480        } else {
481            self.leaves.insert(key.position(), value)
482        };
483        Ok(result)
484    }
485}