miden_crypto/merkle/smt/simple/mod.rs
1use alloc::collections::BTreeSet;
2
3use super::{
4 EMPTY_WORD, EmptySubtreeRoots, InnerNode, InnerNodeInfo, InnerNodes, LeafIndex, MerkleError,
5 MutationSet, NodeIndex, SMT_MAX_DEPTH, SMT_MIN_DEPTH, SparseMerkleTree, SparseMerkleTreeReader,
6 Word,
7};
8use crate::merkle::{SparseMerklePath, smt::SmtLeafError};
9
10mod proof;
11pub use proof::SimpleSmtProof;
12
13#[cfg(test)]
14mod tests;
15
16// SPARSE MERKLE TREE
17// ================================================================================================
18
19type Leaves = super::Leaves<Word>;
20
21/// A sparse Merkle tree with 64-bit keys and 4-element leaf values, without compaction.
22///
23/// The root of the tree is recomputed on each new leaf update.
24#[derive(Debug, Clone, PartialEq, Eq)]
25#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
26pub struct SimpleSmt<const DEPTH: u8> {
27 root: Word,
28 inner_nodes: InnerNodes,
29 leaves: Leaves,
30}
31
32impl<const DEPTH: u8> SimpleSmt<DEPTH> {
33 // CONSTANTS
34 // --------------------------------------------------------------------------------------------
35
36 /// The default value used to compute the hash of empty leaves
37 pub const EMPTY_VALUE: Word = <Self as SparseMerkleTreeReader<DEPTH>>::EMPTY_VALUE;
38
39 // CONSTRUCTORS
40 // --------------------------------------------------------------------------------------------
41
42 /// Returns a new [SimpleSmt].
43 ///
44 /// All leaves in the returned tree are set to [ZERO; 4].
45 ///
46 /// # Errors
47 /// Returns an error if DEPTH is 0 or is greater than 64.
48 pub fn new() -> Result<Self, MerkleError> {
49 // validate the range of the depth.
50 if DEPTH < SMT_MIN_DEPTH {
51 return Err(MerkleError::DepthTooSmall(DEPTH));
52 } else if SMT_MAX_DEPTH < DEPTH {
53 return Err(MerkleError::DepthTooBig(DEPTH as u64));
54 }
55
56 let root = *EmptySubtreeRoots::entry(DEPTH, 0);
57
58 Ok(Self {
59 root,
60 inner_nodes: Default::default(),
61 leaves: Default::default(),
62 })
63 }
64
65 /// Returns a new [SimpleSmt] instantiated with leaves set as specified by the provided entries.
66 ///
67 /// All leaves omitted from the entries list are set to [ZERO; 4].
68 ///
69 /// # Errors
70 /// Returns an error if:
71 /// - If the depth is 0 or is greater than 64.
72 /// - The number of entries exceeds the maximum tree capacity, that is 2^{depth}.
73 /// - The provided entries contain multiple values for the same key.
74 pub fn with_leaves(
75 entries: impl IntoIterator<Item = (u64, Word)>,
76 ) -> Result<Self, MerkleError> {
77 // create an empty tree
78 let mut tree = Self::new()?;
79
80 // compute the max number of entries. We use an upper bound of depth 63 because we consider
81 // passing in a vector of size 2^64 infeasible.
82 let max_num_entries = 2_u64.pow(DEPTH.min(63).into());
83
84 // This being a sparse data structure, the EMPTY_WORD is not assigned to the `BTreeMap`, so
85 // entries with the empty value need additional tracking.
86 let mut key_set_to_zero = BTreeSet::new();
87
88 for (idx, (key, value)) in entries.into_iter().enumerate() {
89 if idx as u64 >= max_num_entries {
90 return Err(MerkleError::TooManyEntries(DEPTH));
91 }
92
93 let old_value = tree.insert(LeafIndex::<DEPTH>::new(key)?, value);
94
95 if old_value != Self::EMPTY_VALUE || key_set_to_zero.contains(&key) {
96 return Err(MerkleError::DuplicateValuesForIndex(key));
97 }
98
99 if value == Self::EMPTY_VALUE {
100 key_set_to_zero.insert(key);
101 };
102 }
103 Ok(tree)
104 }
105
106 /// Returns a new [`SimpleSmt`] instantiated from already computed leaves and nodes.
107 ///
108 /// This function performs minimal consistency checking. It is the caller's responsibility to
109 /// ensure the passed arguments are correct and consistent with each other.
110 ///
111 /// # Panics
112 /// With debug assertions on, this function panics if `root` does not match the root node in
113 /// `inner_nodes`.
114 pub fn from_raw_parts(inner_nodes: InnerNodes, leaves: Leaves, root: Word) -> Self {
115 if cfg!(debug_assertions) {
116 let root_node_hash = inner_nodes
117 .get(&NodeIndex::root())
118 .map(InnerNode::hash)
119 .unwrap_or(Self::EMPTY_ROOT);
120
121 assert_eq!(root_node_hash, root);
122 }
123
124 Self { root, inner_nodes, leaves }
125 }
126
127 /// Wrapper around [`SimpleSmt::with_leaves`] which inserts leaves at contiguous indices
128 /// starting at index 0.
129 pub fn with_contiguous_leaves(
130 entries: impl IntoIterator<Item = Word>,
131 ) -> Result<Self, MerkleError> {
132 Self::with_leaves(
133 entries
134 .into_iter()
135 .enumerate()
136 .map(|(idx, word)| (idx.try_into().expect("tree max depth is 2^8"), word)),
137 )
138 }
139
140 // PUBLIC ACCESSORS
141 // --------------------------------------------------------------------------------------------
142
143 /// Returns the depth of the tree
144 pub const fn depth(&self) -> u8 {
145 DEPTH
146 }
147
148 /// Returns the root of the tree
149 pub fn root(&self) -> Word {
150 <Self as SparseMerkleTreeReader<DEPTH>>::root(self)
151 }
152
153 /// Returns the number of non-empty leaves in this tree.
154 pub fn num_leaves(&self) -> usize {
155 self.leaves.len()
156 }
157
158 /// Returns the leaf at the specified index.
159 pub fn get_leaf(&self, key: &LeafIndex<DEPTH>) -> Word {
160 <Self as SparseMerkleTreeReader<DEPTH>>::get_leaf(self, key)
161 }
162
163 /// Returns a node at the specified index.
164 ///
165 /// # Errors
166 /// Returns an error if the specified index has depth set to 0 or the depth is greater than
167 /// the depth of this Merkle tree.
168 pub fn get_node(&self, index: NodeIndex) -> Result<Word, MerkleError> {
169 if index.is_root() {
170 Err(MerkleError::DepthTooSmall(index.depth()))
171 } else if index.depth() > DEPTH {
172 Err(MerkleError::DepthTooBig(index.depth() as u64))
173 } else if index.depth() == DEPTH {
174 let leaf = self.get_leaf(&LeafIndex::<DEPTH>::try_from(index)?);
175
176 Ok(leaf)
177 } else {
178 Ok(self.get_inner_node(index).hash())
179 }
180 }
181
182 /// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
183 /// path to the leaf, as well as the leaf itself.
184 pub fn open(&self, key: &LeafIndex<DEPTH>) -> SimpleSmtProof {
185 let value = self.get_value(key);
186 let nodes = key.index.proof_indices().map(|index| self.get_node_hash(index));
187 // `from_sized_iter()` returns an error if there are more nodes than `SMT_MAX_DEPTH`, but
188 // this could only happen if we have more levels than `SMT_MAX_DEPTH` ourselves, which is
189 // guarded against in `SimpleSmt::new()`.
190 let path = SparseMerklePath::from_sized_iter(nodes).unwrap();
191
192 SimpleSmtProof { value, path }
193 }
194
195 /// Returns a boolean value indicating whether the SMT is empty.
196 pub fn is_empty(&self) -> bool {
197 debug_assert_eq!(self.leaves.is_empty(), self.root == Self::EMPTY_ROOT);
198 self.root == Self::EMPTY_ROOT
199 }
200
201 // ITERATORS
202 // --------------------------------------------------------------------------------------------
203
204 /// Returns an iterator over the leaves of this [SimpleSmt].
205 pub fn leaves(&self) -> impl Iterator<Item = (u64, &Word)> {
206 self.leaves.iter().map(|(i, w)| (*i, w))
207 }
208
209 /// Returns an iterator over the inner nodes of this [SimpleSmt].
210 pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
211 self.inner_nodes.values().map(|e| InnerNodeInfo {
212 value: e.hash(),
213 left: e.left,
214 right: e.right,
215 })
216 }
217
218 // STATE MUTATORS
219 // --------------------------------------------------------------------------------------------
220
221 /// Inserts a value at the specified key, returning the previous value associated with that key.
222 /// Recall that by definition, any key that hasn't been updated is associated with
223 /// [`EMPTY_WORD`].
224 ///
225 /// This also recomputes all hashes between the leaf (associated with the key) and the root,
226 /// updating the root itself.
227 pub fn insert(&mut self, key: LeafIndex<DEPTH>, value: Word) -> Word {
228 // SAFETY: a SimpleSmt does not contain multi-value leaves. The underlying
229 // SimpleSmt::insert_value does not return any errors so it's safe to unwrap here.
230 <Self as SparseMerkleTree<DEPTH>>::insert(self, key, value)
231 .expect("inserting a value into a simple smt never returns an error")
232 }
233
234 /// Computes what changes are necessary to insert the specified key-value pairs into this
235 /// Merkle tree, allowing for validation before applying those changes.
236 ///
237 /// This method returns a [`MutationSet`], which contains all the information for inserting
238 /// `kv_pairs` into this Merkle tree already calculated, including the new root hash, which can
239 /// be queried with [`MutationSet::root()`]. Once a mutation set is returned,
240 /// [`SimpleSmt::apply_mutations()`] can be called in order to commit these changes to the
241 /// Merkle tree, or [`drop()`] to discard them.
242 ///
243 /// # Errors
244 ///
245 /// - [`MerkleError::DuplicateValuesForIndex`] if the provided `kv_pairs` contain duplicate
246 /// keys.
247 ///
248 /// # Example
249 /// ```
250 /// # use miden_crypto::{Felt, Word};
251 /// # use miden_crypto::merkle::{smt::{LeafIndex, SimpleSmt, SMT_DEPTH}, EmptySubtreeRoots};
252 /// let mut smt: SimpleSmt<3> = SimpleSmt::new().unwrap();
253 /// let pair = (LeafIndex::default(), Word::default());
254 /// let mutations = smt.compute_mutations(vec![pair]).unwrap();
255 /// assert_eq!(mutations.root(), *EmptySubtreeRoots::entry(3, 0));
256 /// smt.apply_mutations(mutations).unwrap();
257 /// assert_eq!(smt.root(), *EmptySubtreeRoots::entry(3, 0));
258 /// ```
259 pub fn compute_mutations(
260 &self,
261 kv_pairs: impl IntoIterator<Item = (LeafIndex<DEPTH>, Word)>,
262 ) -> Result<MutationSet<DEPTH, LeafIndex<DEPTH>, Word>, MerkleError> {
263 <Self as SparseMerkleTreeReader<DEPTH>>::compute_mutations(self, kv_pairs)
264 }
265
266 /// Applies the prospective mutations computed with [`SimpleSmt::compute_mutations()`] to this
267 /// tree.
268 ///
269 /// # Errors
270 /// If `mutations` was computed on a tree with a different root than this one, returns
271 /// [`MerkleError::ConflictingRoots`] with a two-item [`alloc::vec::Vec`]. The first item is the
272 /// root hash the `mutations` were computed against, and the second item is the actual
273 /// current root of this tree.
274 pub fn apply_mutations(
275 &mut self,
276 mutations: MutationSet<DEPTH, LeafIndex<DEPTH>, Word>,
277 ) -> Result<(), MerkleError> {
278 <Self as SparseMerkleTree<DEPTH>>::apply_mutations(self, mutations)
279 }
280
281 /// Applies the prospective mutations computed with [`SimpleSmt::compute_mutations()`] to
282 /// this tree and returns the reverse mutation set.
283 ///
284 /// Applying the reverse mutation sets to the updated tree will revert the changes.
285 ///
286 /// # Errors
287 /// If `mutations` was computed on a tree with a different root than this one, returns
288 /// [`MerkleError::ConflictingRoots`] with a two-item [`alloc::vec::Vec`]. The first item is the
289 /// root hash the `mutations` were computed against, and the second item is the actual
290 /// current root of this tree.
291 pub fn apply_mutations_with_reversion(
292 &mut self,
293 mutations: MutationSet<DEPTH, LeafIndex<DEPTH>, Word>,
294 ) -> Result<MutationSet<DEPTH, LeafIndex<DEPTH>, Word>, MerkleError> {
295 <Self as SparseMerkleTree<DEPTH>>::apply_mutations_with_reversion(self, mutations)
296 }
297
298 /// Inserts a subtree at the specified index. The depth at which the subtree is inserted is
299 /// computed as `DEPTH - SUBTREE_DEPTH`.
300 ///
301 /// Returns the new root.
302 pub fn set_subtree<const SUBTREE_DEPTH: u8>(
303 &mut self,
304 subtree_insertion_index: u64,
305 subtree: SimpleSmt<SUBTREE_DEPTH>,
306 ) -> Result<Word, MerkleError> {
307 if SUBTREE_DEPTH > DEPTH {
308 return Err(MerkleError::SubtreeDepthExceedsDepth {
309 subtree_depth: SUBTREE_DEPTH,
310 tree_depth: DEPTH,
311 });
312 }
313
314 // Verify that `subtree_insertion_index` is valid.
315 let subtree_root_insertion_depth = DEPTH - SUBTREE_DEPTH;
316 let subtree_root_index =
317 NodeIndex::new(subtree_root_insertion_depth, subtree_insertion_index)?;
318
319 // remove leaves and inner nodes under the insertion root
320 // --------------
321
322 // The subtree's leaf indices live in their own context - i.e. a subtree of depth `d`. If we
323 // insert the subtree at `subtree_insertion_index = 0`, then the subtree leaf indices are
324 // valid as they are. However, consider what happens when we insert at
325 // `subtree_insertion_index = 1`. The first leaf of our subtree now will have index `2^d`;
326 // you can see it as there's a full subtree sitting on its left. In general, for
327 // `subtree_insertion_index = i`, there are `i` subtrees sitting before the subtree we want
328 // to insert, so we need to adjust all its leaves by `i * 2^d`.
329 let leaf_index_shift: u64 = if SUBTREE_DEPTH == SMT_MAX_DEPTH {
330 0
331 } else {
332 subtree_insertion_index << u32::from(SUBTREE_DEPTH)
333 };
334
335 self.leaves.retain(|leaf_idx, _| {
336 !Self::leaf_is_in_subtree::<SUBTREE_DEPTH>(*leaf_idx, subtree_insertion_index)
337 });
338 self.inner_nodes.retain(|node_idx, _| {
339 !Self::node_is_in_subtree(
340 *node_idx,
341 subtree_root_insertion_depth,
342 subtree_insertion_index,
343 )
344 });
345
346 // add leaves
347 // --------------
348 for (subtree_leaf_idx, leaf_value) in subtree.leaves() {
349 let new_leaf_idx = leaf_index_shift + subtree_leaf_idx;
350 debug_assert!(DEPTH == SMT_MAX_DEPTH || new_leaf_idx < 2_u64.pow(DEPTH.into()));
351
352 self.leaves.insert(new_leaf_idx, *leaf_value);
353 }
354
355 // add subtree's branch nodes (which includes the root)
356 // --------------
357 for (branch_idx, branch_node) in subtree.inner_nodes {
358 let new_branch_idx = {
359 let new_depth = subtree_root_insertion_depth + branch_idx.depth();
360 let new_value = subtree_insertion_index * 2_u64.pow(branch_idx.depth().into())
361 + branch_idx.position();
362
363 NodeIndex::new(new_depth, new_value).expect("index guaranteed to be valid")
364 };
365
366 self.inner_nodes.insert(new_branch_idx, branch_node);
367 }
368
369 // recompute nodes starting from subtree root
370 // --------------
371 self.recompute_nodes_from_index_to_root(subtree_root_index, subtree.root);
372
373 Ok(self.root)
374 }
375
376 fn leaf_is_in_subtree<const SUBTREE_DEPTH: u8>(
377 leaf_idx: u64,
378 subtree_insertion_index: u64,
379 ) -> bool {
380 if SUBTREE_DEPTH == SMT_MAX_DEPTH {
381 true
382 } else {
383 (leaf_idx >> u32::from(SUBTREE_DEPTH)) == subtree_insertion_index
384 }
385 }
386
387 fn node_is_in_subtree(
388 node_idx: NodeIndex,
389 subtree_root_depth: u8,
390 subtree_insertion_index: u64,
391 ) -> bool {
392 if node_idx.depth() < subtree_root_depth {
393 return false;
394 }
395
396 let depth_offset = node_idx.depth() - subtree_root_depth;
397 if depth_offset == SMT_MAX_DEPTH {
398 subtree_insertion_index == 0
399 } else {
400 (node_idx.position() >> u32::from(depth_offset)) == subtree_insertion_index
401 }
402 }
403}
404
405impl<const DEPTH: u8> SparseMerkleTreeReader<DEPTH> for SimpleSmt<DEPTH> {
406 type Key = LeafIndex<DEPTH>;
407 type Value = Word;
408 type Leaf = Word;
409 type Opening = SimpleSmtProof;
410
411 const EMPTY_VALUE: Self::Value = EMPTY_WORD;
412 const EMPTY_ROOT: Word = *EmptySubtreeRoots::entry(DEPTH, 0);
413
414 fn root(&self) -> Word {
415 self.root
416 }
417
418 fn get_inner_node(&self, index: NodeIndex) -> InnerNode {
419 self.inner_nodes
420 .get(&index)
421 .cloned()
422 .unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(DEPTH, index.depth()))
423 }
424
425 fn get_value(&self, key: &LeafIndex<DEPTH>) -> Word {
426 self.get_leaf(key)
427 }
428
429 fn get_leaf(&self, key: &LeafIndex<DEPTH>) -> Word {
430 let leaf_pos = key.position();
431 match self.leaves.get(&leaf_pos) {
432 Some(word) => *word,
433 None => Self::EMPTY_VALUE,
434 }
435 }
436
437 fn hash_leaf(leaf: &Word) -> Word {
438 // `SimpleSmt` takes the leaf value itself as the hash
439 *leaf
440 }
441
442 fn construct_prospective_leaf(
443 &self,
444 _existing_leaf: Word,
445 _key: &LeafIndex<DEPTH>,
446 value: &Word,
447 ) -> Result<Word, SmtLeafError> {
448 Ok(*value)
449 }
450
451 fn key_to_leaf_index(key: &LeafIndex<DEPTH>) -> LeafIndex<DEPTH> {
452 *key
453 }
454
455 fn path_and_leaf_to_opening(path: SparseMerklePath, leaf: Word) -> SimpleSmtProof {
456 (path, leaf).into()
457 }
458}
459
460impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
461 fn set_root(&mut self, root: Word) {
462 self.root = root;
463 }
464
465 fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) -> Option<InnerNode> {
466 self.inner_nodes.insert(index, inner_node)
467 }
468
469 fn remove_inner_node(&mut self, index: NodeIndex) -> Option<InnerNode> {
470 self.inner_nodes.remove(&index)
471 }
472
473 fn insert_value(
474 &mut self,
475 key: LeafIndex<DEPTH>,
476 value: Word,
477 ) -> Result<Option<Word>, MerkleError> {
478 let result = if value == Self::EMPTY_VALUE {
479 self.leaves.remove(&key.position())
480 } else {
481 self.leaves.insert(key.position(), value)
482 };
483 Ok(result)
484 }
485}