miden_crypto/merkle/smt/simple/
mod.rs

1use alloc::collections::BTreeSet;
2
3use super::{
4    super::ValuePath, EMPTY_WORD, EmptySubtreeRoots, InnerNode, InnerNodeInfo, InnerNodes,
5    LeafIndex, MerkleError, MerklePath, MutationSet, NodeIndex, RpoDigest, SMT_MAX_DEPTH,
6    SMT_MIN_DEPTH, SparseMerkleTree, Word,
7};
8
9#[cfg(test)]
10mod tests;
11
12// SPARSE MERKLE TREE
13// ================================================================================================
14
15type Leaves = super::Leaves<Word>;
16
17/// A sparse Merkle tree with 64-bit keys and 4-element leaf values, without compaction.
18///
19/// The root of the tree is recomputed on each new leaf update.
20#[derive(Debug, Clone, PartialEq, Eq)]
21#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
22pub struct SimpleSmt<const DEPTH: u8> {
23    root: RpoDigest,
24    inner_nodes: InnerNodes,
25    leaves: Leaves,
26}
27
28impl<const DEPTH: u8> SimpleSmt<DEPTH> {
29    // CONSTANTS
30    // --------------------------------------------------------------------------------------------
31
32    /// The default value used to compute the hash of empty leaves
33    pub const EMPTY_VALUE: Word = <Self as SparseMerkleTree<DEPTH>>::EMPTY_VALUE;
34
35    // CONSTRUCTORS
36    // --------------------------------------------------------------------------------------------
37
38    /// Returns a new [SimpleSmt].
39    ///
40    /// All leaves in the returned tree are set to [ZERO; 4].
41    ///
42    /// # Errors
43    /// Returns an error if DEPTH is 0 or is greater than 64.
44    pub fn new() -> Result<Self, MerkleError> {
45        // validate the range of the depth.
46        if DEPTH < SMT_MIN_DEPTH {
47            return Err(MerkleError::DepthTooSmall(DEPTH));
48        } else if SMT_MAX_DEPTH < DEPTH {
49            return Err(MerkleError::DepthTooBig(DEPTH as u64));
50        }
51
52        let root = *EmptySubtreeRoots::entry(DEPTH, 0);
53
54        Ok(Self {
55            root,
56            inner_nodes: Default::default(),
57            leaves: Default::default(),
58        })
59    }
60
61    /// Returns a new [SimpleSmt] instantiated with leaves set as specified by the provided entries.
62    ///
63    /// All leaves omitted from the entries list are set to [ZERO; 4].
64    ///
65    /// # Errors
66    /// Returns an error if:
67    /// - If the depth is 0 or is greater than 64.
68    /// - The number of entries exceeds the maximum tree capacity, that is 2^{depth}.
69    /// - The provided entries contain multiple values for the same key.
70    pub fn with_leaves(
71        entries: impl IntoIterator<Item = (u64, Word)>,
72    ) -> Result<Self, MerkleError> {
73        // create an empty tree
74        let mut tree = Self::new()?;
75
76        // compute the max number of entries. We use an upper bound of depth 63 because we consider
77        // passing in a vector of size 2^64 infeasible.
78        let max_num_entries = 2_usize.pow(DEPTH.min(63).into());
79
80        // This being a sparse data structure, the EMPTY_WORD is not assigned to the `BTreeMap`, so
81        // entries with the empty value need additional tracking.
82        let mut key_set_to_zero = BTreeSet::new();
83
84        for (idx, (key, value)) in entries.into_iter().enumerate() {
85            if idx >= max_num_entries {
86                return Err(MerkleError::TooManyEntries(max_num_entries));
87            }
88
89            let old_value = tree.insert(LeafIndex::<DEPTH>::new(key)?, value);
90
91            if old_value != Self::EMPTY_VALUE || key_set_to_zero.contains(&key) {
92                return Err(MerkleError::DuplicateValuesForIndex(key));
93            }
94
95            if value == Self::EMPTY_VALUE {
96                key_set_to_zero.insert(key);
97            };
98        }
99        Ok(tree)
100    }
101
102    /// Returns a new [`SimpleSmt`] instantiated from already computed leaves and nodes.
103    ///
104    /// This function performs minimal consistency checking. It is the caller's responsibility to
105    /// ensure the passed arguments are correct and consistent with each other.
106    ///
107    /// # Panics
108    /// With debug assertions on, this function panics if `root` does not match the root node in
109    /// `inner_nodes`.
110    pub fn from_raw_parts(inner_nodes: InnerNodes, leaves: Leaves, root: RpoDigest) -> Self {
111        // Our particular implementation of `from_raw_parts()` never returns `Err`.
112        <Self as SparseMerkleTree<DEPTH>>::from_raw_parts(inner_nodes, leaves, root).unwrap()
113    }
114
115    /// Wrapper around [`SimpleSmt::with_leaves`] which inserts leaves at contiguous indices
116    /// starting at index 0.
117    pub fn with_contiguous_leaves(
118        entries: impl IntoIterator<Item = Word>,
119    ) -> Result<Self, MerkleError> {
120        Self::with_leaves(
121            entries
122                .into_iter()
123                .enumerate()
124                .map(|(idx, word)| (idx.try_into().expect("tree max depth is 2^8"), word)),
125        )
126    }
127
128    // PUBLIC ACCESSORS
129    // --------------------------------------------------------------------------------------------
130
131    /// Returns the depth of the tree
132    pub const fn depth(&self) -> u8 {
133        DEPTH
134    }
135
136    /// Returns the root of the tree
137    pub fn root(&self) -> RpoDigest {
138        <Self as SparseMerkleTree<DEPTH>>::root(self)
139    }
140
141    /// Returns the number of non-empty leaves in this tree.
142    pub fn num_leaves(&self) -> usize {
143        self.leaves.len()
144    }
145
146    /// Returns the leaf at the specified index.
147    pub fn get_leaf(&self, key: &LeafIndex<DEPTH>) -> Word {
148        <Self as SparseMerkleTree<DEPTH>>::get_leaf(self, key)
149    }
150
151    /// Returns a node at the specified index.
152    ///
153    /// # Errors
154    /// Returns an error if the specified index has depth set to 0 or the depth is greater than
155    /// the depth of this Merkle tree.
156    pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
157        if index.is_root() {
158            Err(MerkleError::DepthTooSmall(index.depth()))
159        } else if index.depth() > DEPTH {
160            Err(MerkleError::DepthTooBig(index.depth() as u64))
161        } else if index.depth() == DEPTH {
162            let leaf = self.get_leaf(&LeafIndex::<DEPTH>::try_from(index)?);
163
164            Ok(leaf.into())
165        } else {
166            Ok(self.get_inner_node(index).hash())
167        }
168    }
169
170    /// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
171    /// path to the leaf, as well as the leaf itself.
172    pub fn open(&self, key: &LeafIndex<DEPTH>) -> ValuePath {
173        <Self as SparseMerkleTree<DEPTH>>::open(self, key)
174    }
175
176    /// Returns a boolean value indicating whether the SMT is empty.
177    pub fn is_empty(&self) -> bool {
178        debug_assert_eq!(self.leaves.is_empty(), self.root == Self::EMPTY_ROOT);
179        self.root == Self::EMPTY_ROOT
180    }
181
182    // ITERATORS
183    // --------------------------------------------------------------------------------------------
184
185    /// Returns an iterator over the leaves of this [SimpleSmt].
186    pub fn leaves(&self) -> impl Iterator<Item = (u64, &Word)> {
187        self.leaves.iter().map(|(i, w)| (*i, w))
188    }
189
190    /// Returns an iterator over the inner nodes of this [SimpleSmt].
191    pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
192        self.inner_nodes.values().map(|e| InnerNodeInfo {
193            value: e.hash(),
194            left: e.left,
195            right: e.right,
196        })
197    }
198
199    // STATE MUTATORS
200    // --------------------------------------------------------------------------------------------
201
202    /// Inserts a value at the specified key, returning the previous value associated with that key.
203    /// Recall that by definition, any key that hasn't been updated is associated with
204    /// [`EMPTY_WORD`].
205    ///
206    /// This also recomputes all hashes between the leaf (associated with the key) and the root,
207    /// updating the root itself.
208    pub fn insert(&mut self, key: LeafIndex<DEPTH>, value: Word) -> Word {
209        <Self as SparseMerkleTree<DEPTH>>::insert(self, key, value)
210    }
211
212    /// Computes what changes are necessary to insert the specified key-value pairs into this
213    /// Merkle tree, allowing for validation before applying those changes.
214    ///
215    /// This method returns a [`MutationSet`], which contains all the information for inserting
216    /// `kv_pairs` into this Merkle tree already calculated, including the new root hash, which can
217    /// be queried with [`MutationSet::root()`]. Once a mutation set is returned,
218    /// [`SimpleSmt::apply_mutations()`] can be called in order to commit these changes to the
219    /// Merkle tree, or [`drop()`] to discard them.
220    ///
221    /// # Example
222    /// ```
223    /// # use miden_crypto::{hash::rpo::RpoDigest, Felt, Word};
224    /// # use miden_crypto::merkle::{LeafIndex, SimpleSmt, EmptySubtreeRoots, SMT_DEPTH};
225    /// let mut smt: SimpleSmt<3> = SimpleSmt::new().unwrap();
226    /// let pair = (LeafIndex::default(), Word::default());
227    /// let mutations = smt.compute_mutations(vec![pair]);
228    /// assert_eq!(mutations.root(), *EmptySubtreeRoots::entry(3, 0));
229    /// smt.apply_mutations(mutations);
230    /// assert_eq!(smt.root(), *EmptySubtreeRoots::entry(3, 0));
231    /// ```
232    pub fn compute_mutations(
233        &self,
234        kv_pairs: impl IntoIterator<Item = (LeafIndex<DEPTH>, Word)>,
235    ) -> MutationSet<DEPTH, LeafIndex<DEPTH>, Word> {
236        <Self as SparseMerkleTree<DEPTH>>::compute_mutations(self, kv_pairs)
237    }
238
239    /// Applies the prospective mutations computed with [`SimpleSmt::compute_mutations()`] to this
240    /// tree.
241    ///
242    /// # Errors
243    /// If `mutations` was computed on a tree with a different root than this one, returns
244    /// [`MerkleError::ConflictingRoots`] with a two-item [`alloc::vec::Vec`]. The first item is the
245    /// root hash the `mutations` were computed against, and the second item is the actual
246    /// current root of this tree.
247    pub fn apply_mutations(
248        &mut self,
249        mutations: MutationSet<DEPTH, LeafIndex<DEPTH>, Word>,
250    ) -> Result<(), MerkleError> {
251        <Self as SparseMerkleTree<DEPTH>>::apply_mutations(self, mutations)
252    }
253
254    /// Applies the prospective mutations computed with [`SimpleSmt::compute_mutations()`] to
255    /// this tree and returns the reverse mutation set.
256    ///
257    /// Applying the reverse mutation sets to the updated tree will revert the changes.
258    ///
259    /// # Errors
260    /// If `mutations` was computed on a tree with a different root than this one, returns
261    /// [`MerkleError::ConflictingRoots`] with a two-item [`alloc::vec::Vec`]. The first item is the
262    /// root hash the `mutations` were computed against, and the second item is the actual
263    /// current root of this tree.
264    pub fn apply_mutations_with_reversion(
265        &mut self,
266        mutations: MutationSet<DEPTH, LeafIndex<DEPTH>, Word>,
267    ) -> Result<MutationSet<DEPTH, LeafIndex<DEPTH>, Word>, MerkleError> {
268        <Self as SparseMerkleTree<DEPTH>>::apply_mutations_with_reversion(self, mutations)
269    }
270
271    /// Inserts a subtree at the specified index. The depth at which the subtree is inserted is
272    /// computed as `DEPTH - SUBTREE_DEPTH`.
273    ///
274    /// Returns the new root.
275    pub fn set_subtree<const SUBTREE_DEPTH: u8>(
276        &mut self,
277        subtree_insertion_index: u64,
278        subtree: SimpleSmt<SUBTREE_DEPTH>,
279    ) -> Result<RpoDigest, MerkleError> {
280        if SUBTREE_DEPTH > DEPTH {
281            return Err(MerkleError::SubtreeDepthExceedsDepth {
282                subtree_depth: SUBTREE_DEPTH,
283                tree_depth: DEPTH,
284            });
285        }
286
287        // Verify that `subtree_insertion_index` is valid.
288        let subtree_root_insertion_depth = DEPTH - SUBTREE_DEPTH;
289        let subtree_root_index =
290            NodeIndex::new(subtree_root_insertion_depth, subtree_insertion_index)?;
291
292        // add leaves
293        // --------------
294
295        // The subtree's leaf indices live in their own context - i.e. a subtree of depth `d`. If we
296        // insert the subtree at `subtree_insertion_index = 0`, then the subtree leaf indices are
297        // valid as they are. However, consider what happens when we insert at
298        // `subtree_insertion_index = 1`. The first leaf of our subtree now will have index `2^d`;
299        // you can see it as there's a full subtree sitting on its left. In general, for
300        // `subtree_insertion_index = i`, there are `i` subtrees sitting before the subtree we want
301        // to insert, so we need to adjust all its leaves by `i * 2^d`.
302        let leaf_index_shift: u64 = subtree_insertion_index * 2_u64.pow(SUBTREE_DEPTH.into());
303        for (subtree_leaf_idx, leaf_value) in subtree.leaves() {
304            let new_leaf_idx = leaf_index_shift + subtree_leaf_idx;
305            debug_assert!(new_leaf_idx < 2_u64.pow(DEPTH.into()));
306
307            self.leaves.insert(new_leaf_idx, *leaf_value);
308        }
309
310        // add subtree's branch nodes (which includes the root)
311        // --------------
312        for (branch_idx, branch_node) in subtree.inner_nodes {
313            let new_branch_idx = {
314                let new_depth = subtree_root_insertion_depth + branch_idx.depth();
315                let new_value = subtree_insertion_index * 2_u64.pow(branch_idx.depth().into())
316                    + branch_idx.value();
317
318                NodeIndex::new(new_depth, new_value).expect("index guaranteed to be valid")
319            };
320
321            self.inner_nodes.insert(new_branch_idx, branch_node);
322        }
323
324        // recompute nodes starting from subtree root
325        // --------------
326        self.recompute_nodes_from_index_to_root(subtree_root_index, subtree.root);
327
328        Ok(self.root)
329    }
330}
331
332impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
333    type Key = LeafIndex<DEPTH>;
334    type Value = Word;
335    type Leaf = Word;
336    type Opening = ValuePath;
337
338    const EMPTY_VALUE: Self::Value = EMPTY_WORD;
339    const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(DEPTH, 0);
340
341    fn from_raw_parts(
342        inner_nodes: InnerNodes,
343        leaves: Leaves,
344        root: RpoDigest,
345    ) -> Result<Self, MerkleError> {
346        if cfg!(debug_assertions) {
347            let root_node = inner_nodes.get(&NodeIndex::root()).unwrap();
348            assert_eq!(root_node.hash(), root);
349        }
350
351        Ok(Self { root, inner_nodes, leaves })
352    }
353
354    fn root(&self) -> RpoDigest {
355        self.root
356    }
357
358    fn set_root(&mut self, root: RpoDigest) {
359        self.root = root;
360    }
361
362    fn get_inner_node(&self, index: NodeIndex) -> InnerNode {
363        self.inner_nodes
364            .get(&index)
365            .cloned()
366            .unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(DEPTH, index.depth()))
367    }
368
369    fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) -> Option<InnerNode> {
370        self.inner_nodes.insert(index, inner_node)
371    }
372
373    fn remove_inner_node(&mut self, index: NodeIndex) -> Option<InnerNode> {
374        self.inner_nodes.remove(&index)
375    }
376
377    fn insert_value(&mut self, key: LeafIndex<DEPTH>, value: Word) -> Option<Word> {
378        if value == Self::EMPTY_VALUE {
379            self.leaves.remove(&key.value())
380        } else {
381            self.leaves.insert(key.value(), value)
382        }
383    }
384
385    fn get_value(&self, key: &LeafIndex<DEPTH>) -> Word {
386        self.get_leaf(key)
387    }
388
389    fn get_leaf(&self, key: &LeafIndex<DEPTH>) -> Word {
390        let leaf_pos = key.value();
391        match self.leaves.get(&leaf_pos) {
392            Some(word) => *word,
393            None => Self::EMPTY_VALUE,
394        }
395    }
396
397    fn hash_leaf(leaf: &Word) -> RpoDigest {
398        // `SimpleSmt` takes the leaf value itself as the hash
399        leaf.into()
400    }
401
402    fn construct_prospective_leaf(
403        &self,
404        _existing_leaf: Word,
405        _key: &LeafIndex<DEPTH>,
406        value: &Word,
407    ) -> Word {
408        *value
409    }
410
411    fn key_to_leaf_index(key: &LeafIndex<DEPTH>) -> LeafIndex<DEPTH> {
412        *key
413    }
414
415    fn path_and_leaf_to_opening(path: MerklePath, leaf: Word) -> ValuePath {
416        (path, leaf).into()
417    }
418}