miden_crypto/merkle/smt/mod.rs
1use alloc::{collections::BTreeMap, vec::Vec};
2
3use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
4
5use super::{EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex};
6use crate::{
7 hash::rpo::{Rpo256, RpoDigest},
8 Felt, Word, EMPTY_WORD,
9};
10
11mod full;
12pub use full::{Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH};
13
14mod simple;
15pub use simple::SimpleSmt;
16
17mod partial;
18pub use partial::PartialSmt;
19
20// CONSTANTS
21// ================================================================================================
22
23/// Minimum supported depth.
24pub const SMT_MIN_DEPTH: u8 = 1;
25
26/// Maximum supported depth.
27pub const SMT_MAX_DEPTH: u8 = 64;
28
29// SPARSE MERKLE TREE
30// ================================================================================================
31
32/// An abstract description of a sparse Merkle tree.
33///
34/// A sparse Merkle tree is a key-value map which also supports proving that a given value is indeed
35/// stored at a given key in the tree. It is viewed as always being fully populated. If a leaf's
36/// value was not explicitly set, then its value is the default value. Typically, the vast majority
37/// of leaves will store the default value (hence it is "sparse"), and therefore the internal
38/// representation of the tree will only keep track of the leaves that have a different value from
39/// the default.
40///
41/// All leaves sit at the same depth. The deeper the tree, the more leaves it has; but also the
42/// longer its proofs are - of exactly `log(depth)` size. A tree cannot have depth 0, since such a
43/// tree is just a single value, and is probably a programming mistake.
44///
45/// Every key maps to one leaf. If there are as many keys as there are leaves, then
46/// [Self::Leaf] should be the same type as [Self::Value], as is the case with
47/// [crate::merkle::SimpleSmt]. However, if there are more keys than leaves, then [`Self::Leaf`]
48/// must accommodate all keys that map to the same leaf.
49///
50/// [SparseMerkleTree] currently doesn't support optimizations that compress Merkle proofs.
51pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
52 /// The type for a key
53 type Key: Clone + Ord;
54 /// The type for a value
55 type Value: Clone + PartialEq;
56 /// The type for a leaf
57 type Leaf: Clone;
58 /// The type for an opening (i.e. a "proof") of a leaf
59 type Opening;
60
61 /// The default value used to compute the hash of empty leaves
62 const EMPTY_VALUE: Self::Value;
63
64 /// The root of the empty tree with provided DEPTH
65 const EMPTY_ROOT: RpoDigest;
66
67 // PROVIDED METHODS
68 // ---------------------------------------------------------------------------------------------
69
70 /// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
71 /// path to the leaf, as well as the leaf itself.
72 fn open(&self, key: &Self::Key) -> Self::Opening {
73 let leaf = self.get_leaf(key);
74
75 let mut index: NodeIndex = {
76 let leaf_index: LeafIndex<DEPTH> = Self::key_to_leaf_index(key);
77 leaf_index.into()
78 };
79
80 let merkle_path = {
81 let mut path = Vec::with_capacity(index.depth() as usize);
82 for _ in 0..index.depth() {
83 let is_right = index.is_value_odd();
84 index.move_up();
85 let InnerNode { left, right } = self.get_inner_node(index);
86 let value = if is_right { left } else { right };
87 path.push(value);
88 }
89
90 MerklePath::new(path)
91 };
92
93 Self::path_and_leaf_to_opening(merkle_path, leaf)
94 }
95
96 /// Inserts a value at the specified key, returning the previous value associated with that key.
97 /// Recall that by definition, any key that hasn't been updated is associated with
98 /// [`Self::EMPTY_VALUE`].
99 ///
100 /// This also recomputes all hashes between the leaf (associated with the key) and the root,
101 /// updating the root itself.
102 fn insert(&mut self, key: Self::Key, value: Self::Value) -> Self::Value {
103 let old_value = self.insert_value(key.clone(), value.clone()).unwrap_or(Self::EMPTY_VALUE);
104
105 // if the old value and new value are the same, there is nothing to update
106 if value == old_value {
107 return value;
108 }
109
110 let leaf = self.get_leaf(&key);
111 let node_index = {
112 let leaf_index: LeafIndex<DEPTH> = Self::key_to_leaf_index(&key);
113 leaf_index.into()
114 };
115
116 self.recompute_nodes_from_index_to_root(node_index, Self::hash_leaf(&leaf));
117
118 old_value
119 }
120
121 /// Recomputes the branch nodes (including the root) from `index` all the way to the root.
122 /// `node_hash_at_index` is the hash of the node stored at index.
123 fn recompute_nodes_from_index_to_root(
124 &mut self,
125 mut index: NodeIndex,
126 node_hash_at_index: RpoDigest,
127 ) {
128 let mut node_hash = node_hash_at_index;
129 for node_depth in (0..index.depth()).rev() {
130 let is_right = index.is_value_odd();
131 index.move_up();
132 let InnerNode { left, right } = self.get_inner_node(index);
133 let (left, right) = if is_right {
134 (left, node_hash)
135 } else {
136 (node_hash, right)
137 };
138 node_hash = Rpo256::merge(&[left, right]);
139
140 if node_hash == *EmptySubtreeRoots::entry(DEPTH, node_depth) {
141 // If a subtree is empty, then can remove the inner node, since it's equal to the
142 // default value
143 self.remove_inner_node(index);
144 } else {
145 self.insert_inner_node(index, InnerNode { left, right });
146 }
147 }
148 self.set_root(node_hash);
149 }
150
151 /// Computes what changes are necessary to insert the specified key-value pairs into this Merkle
152 /// tree, allowing for validation before applying those changes.
153 ///
154 /// This method returns a [`MutationSet`], which contains all the information for inserting
155 /// `kv_pairs` into this Merkle tree already calculated, including the new root hash, which can
156 /// be queried with [`MutationSet::root()`]. Once a mutation set is returned,
157 /// [`SparseMerkleTree::apply_mutations()`] can be called in order to commit these changes to
158 /// the Merkle tree, or [`drop()`] to discard them.
159 fn compute_mutations(
160 &self,
161 kv_pairs: impl IntoIterator<Item = (Self::Key, Self::Value)>,
162 ) -> MutationSet<DEPTH, Self::Key, Self::Value> {
163 use NodeMutation::*;
164
165 let mut new_root = self.root();
166 let mut new_pairs: BTreeMap<Self::Key, Self::Value> = Default::default();
167 let mut node_mutations: BTreeMap<NodeIndex, NodeMutation> = Default::default();
168
169 for (key, value) in kv_pairs {
170 // If the old value and the new value are the same, there is nothing to update.
171 // For the unusual case that kv_pairs has multiple values at the same key, we'll have
172 // to check the key-value pairs we've already seen to get the "effective" old value.
173 let old_value = new_pairs.get(&key).cloned().unwrap_or_else(|| self.get_value(&key));
174 if value == old_value {
175 continue;
176 }
177
178 let leaf_index = Self::key_to_leaf_index(&key);
179 let mut node_index = NodeIndex::from(leaf_index);
180
181 // We need the current leaf's hash to calculate the new leaf, but in the rare case that
182 // `kv_pairs` has multiple pairs that go into the same leaf, then those pairs are also
183 // part of the "current leaf".
184 let old_leaf = {
185 let pairs_at_index = new_pairs
186 .iter()
187 .filter(|&(new_key, _)| Self::key_to_leaf_index(new_key) == leaf_index);
188
189 pairs_at_index.fold(self.get_leaf(&key), |acc, (k, v)| {
190 // Most of the time `pairs_at_index` should only contain a single entry (or
191 // none at all), as multi-leaves should be really rare.
192 let existing_leaf = acc.clone();
193 self.construct_prospective_leaf(existing_leaf, k, v)
194 })
195 };
196
197 let new_leaf = self.construct_prospective_leaf(old_leaf, &key, &value);
198
199 let mut new_child_hash = Self::hash_leaf(&new_leaf);
200
201 for node_depth in (0..node_index.depth()).rev() {
202 // Whether the node we're replacing is the right child or the left child.
203 let is_right = node_index.is_value_odd();
204 node_index.move_up();
205
206 let old_node = node_mutations
207 .get(&node_index)
208 .map(|mutation| match mutation {
209 Addition(node) => node.clone(),
210 Removal => EmptySubtreeRoots::get_inner_node(DEPTH, node_depth),
211 })
212 .unwrap_or_else(|| self.get_inner_node(node_index));
213
214 let new_node = if is_right {
215 InnerNode {
216 left: old_node.left,
217 right: new_child_hash,
218 }
219 } else {
220 InnerNode {
221 left: new_child_hash,
222 right: old_node.right,
223 }
224 };
225
226 // The next iteration will operate on this new node's hash.
227 new_child_hash = new_node.hash();
228
229 let &equivalent_empty_hash = EmptySubtreeRoots::entry(DEPTH, node_depth);
230 let is_removal = new_child_hash == equivalent_empty_hash;
231 let new_entry = if is_removal { Removal } else { Addition(new_node) };
232 node_mutations.insert(node_index, new_entry);
233 }
234
235 // Once we're at depth 0, the last node we made is the new root.
236 new_root = new_child_hash;
237 // And then we're done with this pair; on to the next one.
238 new_pairs.insert(key, value);
239 }
240
241 MutationSet {
242 old_root: self.root(),
243 new_root,
244 node_mutations,
245 new_pairs,
246 }
247 }
248
249 /// Applies the prospective mutations computed with [`SparseMerkleTree::compute_mutations()`] to
250 /// this tree.
251 ///
252 /// # Errors
253 /// If `mutations` was computed on a tree with a different root than this one, returns
254 /// [`MerkleError::ConflictingRoots`] with a two-item [`Vec`]. The first item is the root hash
255 /// the `mutations` were computed against, and the second item is the actual current root of
256 /// this tree.
257 fn apply_mutations(
258 &mut self,
259 mutations: MutationSet<DEPTH, Self::Key, Self::Value>,
260 ) -> Result<(), MerkleError>
261 where
262 Self: Sized,
263 {
264 use NodeMutation::*;
265 let MutationSet {
266 old_root,
267 node_mutations,
268 new_pairs,
269 new_root,
270 } = mutations;
271
272 // Guard against accidentally trying to apply mutations that were computed against a
273 // different tree, including a stale version of this tree.
274 if old_root != self.root() {
275 return Err(MerkleError::ConflictingRoots {
276 expected_root: self.root(),
277 actual_root: old_root,
278 });
279 }
280
281 for (index, mutation) in node_mutations {
282 match mutation {
283 Removal => {
284 self.remove_inner_node(index);
285 },
286 Addition(node) => {
287 self.insert_inner_node(index, node);
288 },
289 }
290 }
291
292 for (key, value) in new_pairs {
293 self.insert_value(key, value);
294 }
295
296 self.set_root(new_root);
297
298 Ok(())
299 }
300
301 /// Applies the prospective mutations computed with [`SparseMerkleTree::compute_mutations()`] to
302 /// this tree and returns the reverse mutation set. Applying the reverse mutation sets to the
303 /// updated tree will revert the changes.
304 ///
305 /// # Errors
306 /// If `mutations` was computed on a tree with a different root than this one, returns
307 /// [`MerkleError::ConflictingRoots`] with a two-item [`Vec`]. The first item is the root hash
308 /// the `mutations` were computed against, and the second item is the actual current root of
309 /// this tree.
310 fn apply_mutations_with_reversion(
311 &mut self,
312 mutations: MutationSet<DEPTH, Self::Key, Self::Value>,
313 ) -> Result<MutationSet<DEPTH, Self::Key, Self::Value>, MerkleError>
314 where
315 Self: Sized,
316 {
317 use NodeMutation::*;
318 let MutationSet {
319 old_root,
320 node_mutations,
321 new_pairs,
322 new_root,
323 } = mutations;
324
325 // Guard against accidentally trying to apply mutations that were computed against a
326 // different tree, including a stale version of this tree.
327 if old_root != self.root() {
328 return Err(MerkleError::ConflictingRoots {
329 expected_root: self.root(),
330 actual_root: old_root,
331 });
332 }
333
334 let mut reverse_mutations = BTreeMap::new();
335 for (index, mutation) in node_mutations {
336 match mutation {
337 Removal => {
338 if let Some(node) = self.remove_inner_node(index) {
339 reverse_mutations.insert(index, Addition(node));
340 }
341 },
342 Addition(node) => {
343 if let Some(old_node) = self.insert_inner_node(index, node) {
344 reverse_mutations.insert(index, Addition(old_node));
345 } else {
346 reverse_mutations.insert(index, Removal);
347 }
348 },
349 }
350 }
351
352 let mut reverse_pairs = BTreeMap::new();
353 for (key, value) in new_pairs {
354 if let Some(old_value) = self.insert_value(key.clone(), value) {
355 reverse_pairs.insert(key, old_value);
356 } else {
357 reverse_pairs.insert(key, Self::EMPTY_VALUE);
358 }
359 }
360
361 self.set_root(new_root);
362
363 Ok(MutationSet {
364 old_root: new_root,
365 node_mutations: reverse_mutations,
366 new_pairs: reverse_pairs,
367 new_root: old_root,
368 })
369 }
370
371 // REQUIRED METHODS
372 // ---------------------------------------------------------------------------------------------
373
374 /// The root of the tree
375 fn root(&self) -> RpoDigest;
376
377 /// Sets the root of the tree
378 fn set_root(&mut self, root: RpoDigest);
379
380 /// Retrieves an inner node at the given index
381 fn get_inner_node(&self, index: NodeIndex) -> InnerNode;
382
383 /// Inserts an inner node at the given index
384 fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) -> Option<InnerNode>;
385
386 /// Removes an inner node at the given index
387 fn remove_inner_node(&mut self, index: NodeIndex) -> Option<InnerNode>;
388
389 /// Inserts a leaf node, and returns the value at the key if already exists
390 fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value>;
391
392 /// Returns the value at the specified key. Recall that by definition, any key that hasn't been
393 /// updated is associated with [`Self::EMPTY_VALUE`].
394 fn get_value(&self, key: &Self::Key) -> Self::Value;
395
396 /// Returns the leaf at the specified index.
397 fn get_leaf(&self, key: &Self::Key) -> Self::Leaf;
398
399 /// Returns the hash of a leaf
400 fn hash_leaf(leaf: &Self::Leaf) -> RpoDigest;
401
402 /// Returns what a leaf would look like if a key-value pair were inserted into the tree, without
403 /// mutating the tree itself. The existing leaf can be empty.
404 ///
405 /// To get a prospective leaf based on the current state of the tree, use `self.get_leaf(key)`
406 /// as the argument for `existing_leaf`. The return value from this function can be chained back
407 /// into this function as the first argument to continue making prospective changes.
408 ///
409 /// # Invariants
410 /// Because this method is for a prospective key-value insertion into a specific leaf,
411 /// `existing_leaf` must have the same leaf index as `key` (as determined by
412 /// [`SparseMerkleTree::key_to_leaf_index()`]), or the result will be meaningless.
413 fn construct_prospective_leaf(
414 &self,
415 existing_leaf: Self::Leaf,
416 key: &Self::Key,
417 value: &Self::Value,
418 ) -> Self::Leaf;
419
420 /// Maps a key to a leaf index
421 fn key_to_leaf_index(key: &Self::Key) -> LeafIndex<DEPTH>;
422
423 /// Maps a (MerklePath, Self::Leaf) to an opening.
424 ///
425 /// The length `path` is guaranteed to be equal to `DEPTH`
426 fn path_and_leaf_to_opening(path: MerklePath, leaf: Self::Leaf) -> Self::Opening;
427}
428
429// INNER NODE
430// ================================================================================================
431
432#[derive(Debug, Default, Clone, PartialEq, Eq)]
433#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
434pub struct InnerNode {
435 pub left: RpoDigest,
436 pub right: RpoDigest,
437}
438
439impl InnerNode {
440 pub fn hash(&self) -> RpoDigest {
441 Rpo256::merge(&[self.left, self.right])
442 }
443}
444
445// LEAF INDEX
446// ================================================================================================
447
448/// The index of a leaf, at a depth known at compile-time.
449#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
450#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
451pub struct LeafIndex<const DEPTH: u8> {
452 index: NodeIndex,
453}
454
455impl<const DEPTH: u8> LeafIndex<DEPTH> {
456 pub fn new(value: u64) -> Result<Self, MerkleError> {
457 if DEPTH < SMT_MIN_DEPTH {
458 return Err(MerkleError::DepthTooSmall(DEPTH));
459 }
460
461 Ok(LeafIndex { index: NodeIndex::new(DEPTH, value)? })
462 }
463
464 pub fn value(&self) -> u64 {
465 self.index.value()
466 }
467}
468
469impl LeafIndex<SMT_MAX_DEPTH> {
470 pub const fn new_max_depth(value: u64) -> Self {
471 LeafIndex {
472 index: NodeIndex::new_unchecked(SMT_MAX_DEPTH, value),
473 }
474 }
475}
476
477impl<const DEPTH: u8> From<LeafIndex<DEPTH>> for NodeIndex {
478 fn from(value: LeafIndex<DEPTH>) -> Self {
479 value.index
480 }
481}
482
483impl<const DEPTH: u8> TryFrom<NodeIndex> for LeafIndex<DEPTH> {
484 type Error = MerkleError;
485
486 fn try_from(node_index: NodeIndex) -> Result<Self, Self::Error> {
487 if node_index.depth() != DEPTH {
488 return Err(MerkleError::InvalidNodeIndexDepth {
489 expected: DEPTH,
490 provided: node_index.depth(),
491 });
492 }
493
494 Self::new(node_index.value())
495 }
496}
497
498impl<const DEPTH: u8> Serializable for LeafIndex<DEPTH> {
499 fn write_into<W: ByteWriter>(&self, target: &mut W) {
500 self.index.write_into(target);
501 }
502}
503
504impl<const DEPTH: u8> Deserializable for LeafIndex<DEPTH> {
505 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
506 Ok(Self { index: source.read()? })
507 }
508}
509
510// MUTATIONS
511// ================================================================================================
512
513/// A change to an inner node of a sparse Merkle tree that hasn't yet been applied.
514/// [`MutationSet`] stores this type in relation to a [`NodeIndex`] to keep track of what changes
515/// need to occur at which node indices.
516#[derive(Debug, Clone, PartialEq, Eq)]
517pub enum NodeMutation {
518 /// Node needs to be removed.
519 Removal,
520 /// Node needs to be inserted.
521 Addition(InnerNode),
522}
523
524/// Represents a group of prospective mutations to a `SparseMerkleTree`, created by
525/// `SparseMerkleTree::compute_mutations()`, and that can be applied with
526/// `SparseMerkleTree::apply_mutations()`.
527#[derive(Debug, Clone, PartialEq, Eq, Default)]
528pub struct MutationSet<const DEPTH: u8, K, V> {
529 /// The root of the Merkle tree this MutationSet is for, recorded at the time
530 /// [`SparseMerkleTree::compute_mutations()`] was called. Exists to guard against applying
531 /// mutations to the wrong tree or applying stale mutations to a tree that has since changed.
532 old_root: RpoDigest,
533 /// The set of nodes that need to be removed or added. The "effective" node at an index is the
534 /// Merkle tree's existing node at that index, with the [`NodeMutation`] in this map at that
535 /// index overlayed, if any. Each [`NodeMutation::Addition`] corresponds to a
536 /// [`SparseMerkleTree::insert_inner_node()`] call, and each [`NodeMutation::Removal`]
537 /// corresponds to a [`SparseMerkleTree::remove_inner_node()`] call.
538 node_mutations: BTreeMap<NodeIndex, NodeMutation>,
539 /// The set of top-level key-value pairs we're prospectively adding to the tree, including
540 /// adding empty values. The "effective" value for a key is the value in this BTreeMap, falling
541 /// back to the existing value in the Merkle tree. Each entry corresponds to a
542 /// [`SparseMerkleTree::insert_value()`] call.
543 new_pairs: BTreeMap<K, V>,
544 /// The calculated root for the Merkle tree, given these mutations. Publicly retrievable with
545 /// [`MutationSet::root()`]. Corresponds to a [`SparseMerkleTree::set_root()`]. call.
546 new_root: RpoDigest,
547}
548
549impl<const DEPTH: u8, K, V> MutationSet<DEPTH, K, V> {
550 /// Returns the SMT root that was calculated during `SparseMerkleTree::compute_mutations()`. See
551 /// that method for more information.
552 pub fn root(&self) -> RpoDigest {
553 self.new_root
554 }
555
556 /// Returns the SMT root before the mutations were applied.
557 pub fn old_root(&self) -> RpoDigest {
558 self.old_root
559 }
560
561 /// Returns the set of inner nodes that need to be removed or added.
562 pub fn node_mutations(&self) -> &BTreeMap<NodeIndex, NodeMutation> {
563 &self.node_mutations
564 }
565
566 /// Returns the set of top-level key-value pairs that need to be added, updated or deleted
567 /// (i.e. set to `EMPTY_WORD`).
568 pub fn new_pairs(&self) -> &BTreeMap<K, V> {
569 &self.new_pairs
570 }
571}
572
573// SERIALIZATION
574// ================================================================================================
575
576impl Serializable for InnerNode {
577 fn write_into<W: ByteWriter>(&self, target: &mut W) {
578 self.left.write_into(target);
579 self.right.write_into(target);
580 }
581}
582
583impl Deserializable for InnerNode {
584 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
585 let left = source.read()?;
586 let right = source.read()?;
587
588 Ok(Self { left, right })
589 }
590}
591
592impl Serializable for NodeMutation {
593 fn write_into<W: ByteWriter>(&self, target: &mut W) {
594 match self {
595 NodeMutation::Removal => target.write_bool(false),
596 NodeMutation::Addition(inner_node) => {
597 target.write_bool(true);
598 inner_node.write_into(target);
599 },
600 }
601 }
602}
603
604impl Deserializable for NodeMutation {
605 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
606 if source.read_bool()? {
607 let inner_node = source.read()?;
608 return Ok(NodeMutation::Addition(inner_node));
609 }
610
611 Ok(NodeMutation::Removal)
612 }
613}
614
615impl<const DEPTH: u8, K: Serializable, V: Serializable> Serializable for MutationSet<DEPTH, K, V> {
616 fn write_into<W: ByteWriter>(&self, target: &mut W) {
617 target.write(self.old_root);
618 target.write(self.new_root);
619 self.node_mutations.write_into(target);
620 self.new_pairs.write_into(target);
621 }
622}
623
624impl<const DEPTH: u8, K: Deserializable + Ord, V: Deserializable> Deserializable
625 for MutationSet<DEPTH, K, V>
626{
627 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
628 let old_root = source.read()?;
629 let new_root = source.read()?;
630 let node_mutations = source.read()?;
631 let new_pairs = source.read()?;
632
633 Ok(Self {
634 old_root,
635 node_mutations,
636 new_pairs,
637 new_root,
638 })
639 }
640}