miden_crypto/merkle/smt/
mod.rs

1use alloc::{collections::BTreeMap, vec::Vec};
2
3use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
4
5use super::{EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex};
6use crate::{
7    hash::rpo::{Rpo256, RpoDigest},
8    Felt, Word, EMPTY_WORD,
9};
10
11mod full;
12pub use full::{Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH};
13
14mod simple;
15pub use simple::SimpleSmt;
16
17mod partial;
18pub use partial::PartialSmt;
19
20// CONSTANTS
21// ================================================================================================
22
23/// Minimum supported depth.
24pub const SMT_MIN_DEPTH: u8 = 1;
25
26/// Maximum supported depth.
27pub const SMT_MAX_DEPTH: u8 = 64;
28
29// SPARSE MERKLE TREE
30// ================================================================================================
31
32/// An abstract description of a sparse Merkle tree.
33///
34/// A sparse Merkle tree is a key-value map which also supports proving that a given value is indeed
35/// stored at a given key in the tree. It is viewed as always being fully populated. If a leaf's
36/// value was not explicitly set, then its value is the default value. Typically, the vast majority
37/// of leaves will store the default value (hence it is "sparse"), and therefore the internal
38/// representation of the tree will only keep track of the leaves that have a different value from
39/// the default.
40///
41/// All leaves sit at the same depth. The deeper the tree, the more leaves it has; but also the
42/// longer its proofs are - of exactly `log(depth)` size. A tree cannot have depth 0, since such a
43/// tree is just a single value, and is probably a programming mistake.
44///
45/// Every key maps to one leaf. If there are as many keys as there are leaves, then
46/// [Self::Leaf] should be the same type as [Self::Value], as is the case with
47/// [crate::merkle::SimpleSmt]. However, if there are more keys than leaves, then [`Self::Leaf`]
48/// must accommodate all keys that map to the same leaf.
49///
50/// [SparseMerkleTree] currently doesn't support optimizations that compress Merkle proofs.
51pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
52    /// The type for a key
53    type Key: Clone + Ord;
54    /// The type for a value
55    type Value: Clone + PartialEq;
56    /// The type for a leaf
57    type Leaf: Clone;
58    /// The type for an opening (i.e. a "proof") of a leaf
59    type Opening;
60
61    /// The default value used to compute the hash of empty leaves
62    const EMPTY_VALUE: Self::Value;
63
64    /// The root of the empty tree with provided DEPTH
65    const EMPTY_ROOT: RpoDigest;
66
67    // PROVIDED METHODS
68    // ---------------------------------------------------------------------------------------------
69
70    /// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
71    /// path to the leaf, as well as the leaf itself.
72    fn open(&self, key: &Self::Key) -> Self::Opening {
73        let leaf = self.get_leaf(key);
74
75        let mut index: NodeIndex = {
76            let leaf_index: LeafIndex<DEPTH> = Self::key_to_leaf_index(key);
77            leaf_index.into()
78        };
79
80        let merkle_path = {
81            let mut path = Vec::with_capacity(index.depth() as usize);
82            for _ in 0..index.depth() {
83                let is_right = index.is_value_odd();
84                index.move_up();
85                let InnerNode { left, right } = self.get_inner_node(index);
86                let value = if is_right { left } else { right };
87                path.push(value);
88            }
89
90            MerklePath::new(path)
91        };
92
93        Self::path_and_leaf_to_opening(merkle_path, leaf)
94    }
95
96    /// Inserts a value at the specified key, returning the previous value associated with that key.
97    /// Recall that by definition, any key that hasn't been updated is associated with
98    /// [`Self::EMPTY_VALUE`].
99    ///
100    /// This also recomputes all hashes between the leaf (associated with the key) and the root,
101    /// updating the root itself.
102    fn insert(&mut self, key: Self::Key, value: Self::Value) -> Self::Value {
103        let old_value = self.insert_value(key.clone(), value.clone()).unwrap_or(Self::EMPTY_VALUE);
104
105        // if the old value and new value are the same, there is nothing to update
106        if value == old_value {
107            return value;
108        }
109
110        let leaf = self.get_leaf(&key);
111        let node_index = {
112            let leaf_index: LeafIndex<DEPTH> = Self::key_to_leaf_index(&key);
113            leaf_index.into()
114        };
115
116        self.recompute_nodes_from_index_to_root(node_index, Self::hash_leaf(&leaf));
117
118        old_value
119    }
120
121    /// Recomputes the branch nodes (including the root) from `index` all the way to the root.
122    /// `node_hash_at_index` is the hash of the node stored at index.
123    fn recompute_nodes_from_index_to_root(
124        &mut self,
125        mut index: NodeIndex,
126        node_hash_at_index: RpoDigest,
127    ) {
128        let mut node_hash = node_hash_at_index;
129        for node_depth in (0..index.depth()).rev() {
130            let is_right = index.is_value_odd();
131            index.move_up();
132            let InnerNode { left, right } = self.get_inner_node(index);
133            let (left, right) = if is_right {
134                (left, node_hash)
135            } else {
136                (node_hash, right)
137            };
138            node_hash = Rpo256::merge(&[left, right]);
139
140            if node_hash == *EmptySubtreeRoots::entry(DEPTH, node_depth) {
141                // If a subtree is empty, then can remove the inner node, since it's equal to the
142                // default value
143                self.remove_inner_node(index);
144            } else {
145                self.insert_inner_node(index, InnerNode { left, right });
146            }
147        }
148        self.set_root(node_hash);
149    }
150
151    /// Computes what changes are necessary to insert the specified key-value pairs into this Merkle
152    /// tree, allowing for validation before applying those changes.
153    ///
154    /// This method returns a [`MutationSet`], which contains all the information for inserting
155    /// `kv_pairs` into this Merkle tree already calculated, including the new root hash, which can
156    /// be queried with [`MutationSet::root()`]. Once a mutation set is returned,
157    /// [`SparseMerkleTree::apply_mutations()`] can be called in order to commit these changes to
158    /// the Merkle tree, or [`drop()`] to discard them.
159    fn compute_mutations(
160        &self,
161        kv_pairs: impl IntoIterator<Item = (Self::Key, Self::Value)>,
162    ) -> MutationSet<DEPTH, Self::Key, Self::Value> {
163        use NodeMutation::*;
164
165        let mut new_root = self.root();
166        let mut new_pairs: BTreeMap<Self::Key, Self::Value> = Default::default();
167        let mut node_mutations: BTreeMap<NodeIndex, NodeMutation> = Default::default();
168
169        for (key, value) in kv_pairs {
170            // If the old value and the new value are the same, there is nothing to update.
171            // For the unusual case that kv_pairs has multiple values at the same key, we'll have
172            // to check the key-value pairs we've already seen to get the "effective" old value.
173            let old_value = new_pairs.get(&key).cloned().unwrap_or_else(|| self.get_value(&key));
174            if value == old_value {
175                continue;
176            }
177
178            let leaf_index = Self::key_to_leaf_index(&key);
179            let mut node_index = NodeIndex::from(leaf_index);
180
181            // We need the current leaf's hash to calculate the new leaf, but in the rare case that
182            // `kv_pairs` has multiple pairs that go into the same leaf, then those pairs are also
183            // part of the "current leaf".
184            let old_leaf = {
185                let pairs_at_index = new_pairs
186                    .iter()
187                    .filter(|&(new_key, _)| Self::key_to_leaf_index(new_key) == leaf_index);
188
189                pairs_at_index.fold(self.get_leaf(&key), |acc, (k, v)| {
190                    // Most of the time `pairs_at_index` should only contain a single entry (or
191                    // none at all), as multi-leaves should be really rare.
192                    let existing_leaf = acc.clone();
193                    self.construct_prospective_leaf(existing_leaf, k, v)
194                })
195            };
196
197            let new_leaf = self.construct_prospective_leaf(old_leaf, &key, &value);
198
199            let mut new_child_hash = Self::hash_leaf(&new_leaf);
200
201            for node_depth in (0..node_index.depth()).rev() {
202                // Whether the node we're replacing is the right child or the left child.
203                let is_right = node_index.is_value_odd();
204                node_index.move_up();
205
206                let old_node = node_mutations
207                    .get(&node_index)
208                    .map(|mutation| match mutation {
209                        Addition(node) => node.clone(),
210                        Removal => EmptySubtreeRoots::get_inner_node(DEPTH, node_depth),
211                    })
212                    .unwrap_or_else(|| self.get_inner_node(node_index));
213
214                let new_node = if is_right {
215                    InnerNode {
216                        left: old_node.left,
217                        right: new_child_hash,
218                    }
219                } else {
220                    InnerNode {
221                        left: new_child_hash,
222                        right: old_node.right,
223                    }
224                };
225
226                // The next iteration will operate on this new node's hash.
227                new_child_hash = new_node.hash();
228
229                let &equivalent_empty_hash = EmptySubtreeRoots::entry(DEPTH, node_depth);
230                let is_removal = new_child_hash == equivalent_empty_hash;
231                let new_entry = if is_removal { Removal } else { Addition(new_node) };
232                node_mutations.insert(node_index, new_entry);
233            }
234
235            // Once we're at depth 0, the last node we made is the new root.
236            new_root = new_child_hash;
237            // And then we're done with this pair; on to the next one.
238            new_pairs.insert(key, value);
239        }
240
241        MutationSet {
242            old_root: self.root(),
243            new_root,
244            node_mutations,
245            new_pairs,
246        }
247    }
248
249    /// Applies the prospective mutations computed with [`SparseMerkleTree::compute_mutations()`] to
250    /// this tree.
251    ///
252    /// # Errors
253    /// If `mutations` was computed on a tree with a different root than this one, returns
254    /// [`MerkleError::ConflictingRoots`] with a two-item [`Vec`]. The first item is the root hash
255    /// the `mutations` were computed against, and the second item is the actual current root of
256    /// this tree.
257    fn apply_mutations(
258        &mut self,
259        mutations: MutationSet<DEPTH, Self::Key, Self::Value>,
260    ) -> Result<(), MerkleError>
261    where
262        Self: Sized,
263    {
264        use NodeMutation::*;
265        let MutationSet {
266            old_root,
267            node_mutations,
268            new_pairs,
269            new_root,
270        } = mutations;
271
272        // Guard against accidentally trying to apply mutations that were computed against a
273        // different tree, including a stale version of this tree.
274        if old_root != self.root() {
275            return Err(MerkleError::ConflictingRoots {
276                expected_root: self.root(),
277                actual_root: old_root,
278            });
279        }
280
281        for (index, mutation) in node_mutations {
282            match mutation {
283                Removal => {
284                    self.remove_inner_node(index);
285                },
286                Addition(node) => {
287                    self.insert_inner_node(index, node);
288                },
289            }
290        }
291
292        for (key, value) in new_pairs {
293            self.insert_value(key, value);
294        }
295
296        self.set_root(new_root);
297
298        Ok(())
299    }
300
301    /// Applies the prospective mutations computed with [`SparseMerkleTree::compute_mutations()`] to
302    /// this tree and returns the reverse mutation set. Applying the reverse mutation sets to the
303    /// updated tree will revert the changes.
304    ///
305    /// # Errors
306    /// If `mutations` was computed on a tree with a different root than this one, returns
307    /// [`MerkleError::ConflictingRoots`] with a two-item [`Vec`]. The first item is the root hash
308    /// the `mutations` were computed against, and the second item is the actual current root of
309    /// this tree.
310    fn apply_mutations_with_reversion(
311        &mut self,
312        mutations: MutationSet<DEPTH, Self::Key, Self::Value>,
313    ) -> Result<MutationSet<DEPTH, Self::Key, Self::Value>, MerkleError>
314    where
315        Self: Sized,
316    {
317        use NodeMutation::*;
318        let MutationSet {
319            old_root,
320            node_mutations,
321            new_pairs,
322            new_root,
323        } = mutations;
324
325        // Guard against accidentally trying to apply mutations that were computed against a
326        // different tree, including a stale version of this tree.
327        if old_root != self.root() {
328            return Err(MerkleError::ConflictingRoots {
329                expected_root: self.root(),
330                actual_root: old_root,
331            });
332        }
333
334        let mut reverse_mutations = BTreeMap::new();
335        for (index, mutation) in node_mutations {
336            match mutation {
337                Removal => {
338                    if let Some(node) = self.remove_inner_node(index) {
339                        reverse_mutations.insert(index, Addition(node));
340                    }
341                },
342                Addition(node) => {
343                    if let Some(old_node) = self.insert_inner_node(index, node) {
344                        reverse_mutations.insert(index, Addition(old_node));
345                    } else {
346                        reverse_mutations.insert(index, Removal);
347                    }
348                },
349            }
350        }
351
352        let mut reverse_pairs = BTreeMap::new();
353        for (key, value) in new_pairs {
354            if let Some(old_value) = self.insert_value(key.clone(), value) {
355                reverse_pairs.insert(key, old_value);
356            } else {
357                reverse_pairs.insert(key, Self::EMPTY_VALUE);
358            }
359        }
360
361        self.set_root(new_root);
362
363        Ok(MutationSet {
364            old_root: new_root,
365            node_mutations: reverse_mutations,
366            new_pairs: reverse_pairs,
367            new_root: old_root,
368        })
369    }
370
371    // REQUIRED METHODS
372    // ---------------------------------------------------------------------------------------------
373
374    /// The root of the tree
375    fn root(&self) -> RpoDigest;
376
377    /// Sets the root of the tree
378    fn set_root(&mut self, root: RpoDigest);
379
380    /// Retrieves an inner node at the given index
381    fn get_inner_node(&self, index: NodeIndex) -> InnerNode;
382
383    /// Inserts an inner node at the given index
384    fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) -> Option<InnerNode>;
385
386    /// Removes an inner node at the given index
387    fn remove_inner_node(&mut self, index: NodeIndex) -> Option<InnerNode>;
388
389    /// Inserts a leaf node, and returns the value at the key if already exists
390    fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value>;
391
392    /// Returns the value at the specified key. Recall that by definition, any key that hasn't been
393    /// updated is associated with [`Self::EMPTY_VALUE`].
394    fn get_value(&self, key: &Self::Key) -> Self::Value;
395
396    /// Returns the leaf at the specified index.
397    fn get_leaf(&self, key: &Self::Key) -> Self::Leaf;
398
399    /// Returns the hash of a leaf
400    fn hash_leaf(leaf: &Self::Leaf) -> RpoDigest;
401
402    /// Returns what a leaf would look like if a key-value pair were inserted into the tree, without
403    /// mutating the tree itself. The existing leaf can be empty.
404    ///
405    /// To get a prospective leaf based on the current state of the tree, use `self.get_leaf(key)`
406    /// as the argument for `existing_leaf`. The return value from this function can be chained back
407    /// into this function as the first argument to continue making prospective changes.
408    ///
409    /// # Invariants
410    /// Because this method is for a prospective key-value insertion into a specific leaf,
411    /// `existing_leaf` must have the same leaf index as `key` (as determined by
412    /// [`SparseMerkleTree::key_to_leaf_index()`]), or the result will be meaningless.
413    fn construct_prospective_leaf(
414        &self,
415        existing_leaf: Self::Leaf,
416        key: &Self::Key,
417        value: &Self::Value,
418    ) -> Self::Leaf;
419
420    /// Maps a key to a leaf index
421    fn key_to_leaf_index(key: &Self::Key) -> LeafIndex<DEPTH>;
422
423    /// Maps a (MerklePath, Self::Leaf) to an opening.
424    ///
425    /// The length `path` is guaranteed to be equal to `DEPTH`
426    fn path_and_leaf_to_opening(path: MerklePath, leaf: Self::Leaf) -> Self::Opening;
427}
428
429// INNER NODE
430// ================================================================================================
431
432#[derive(Debug, Default, Clone, PartialEq, Eq)]
433#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
434pub struct InnerNode {
435    pub left: RpoDigest,
436    pub right: RpoDigest,
437}
438
439impl InnerNode {
440    pub fn hash(&self) -> RpoDigest {
441        Rpo256::merge(&[self.left, self.right])
442    }
443}
444
445// LEAF INDEX
446// ================================================================================================
447
448/// The index of a leaf, at a depth known at compile-time.
449#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
450#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
451pub struct LeafIndex<const DEPTH: u8> {
452    index: NodeIndex,
453}
454
455impl<const DEPTH: u8> LeafIndex<DEPTH> {
456    pub fn new(value: u64) -> Result<Self, MerkleError> {
457        if DEPTH < SMT_MIN_DEPTH {
458            return Err(MerkleError::DepthTooSmall(DEPTH));
459        }
460
461        Ok(LeafIndex { index: NodeIndex::new(DEPTH, value)? })
462    }
463
464    pub fn value(&self) -> u64 {
465        self.index.value()
466    }
467}
468
469impl LeafIndex<SMT_MAX_DEPTH> {
470    pub const fn new_max_depth(value: u64) -> Self {
471        LeafIndex {
472            index: NodeIndex::new_unchecked(SMT_MAX_DEPTH, value),
473        }
474    }
475}
476
477impl<const DEPTH: u8> From<LeafIndex<DEPTH>> for NodeIndex {
478    fn from(value: LeafIndex<DEPTH>) -> Self {
479        value.index
480    }
481}
482
483impl<const DEPTH: u8> TryFrom<NodeIndex> for LeafIndex<DEPTH> {
484    type Error = MerkleError;
485
486    fn try_from(node_index: NodeIndex) -> Result<Self, Self::Error> {
487        if node_index.depth() != DEPTH {
488            return Err(MerkleError::InvalidNodeIndexDepth {
489                expected: DEPTH,
490                provided: node_index.depth(),
491            });
492        }
493
494        Self::new(node_index.value())
495    }
496}
497
498impl<const DEPTH: u8> Serializable for LeafIndex<DEPTH> {
499    fn write_into<W: ByteWriter>(&self, target: &mut W) {
500        self.index.write_into(target);
501    }
502}
503
504impl<const DEPTH: u8> Deserializable for LeafIndex<DEPTH> {
505    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
506        Ok(Self { index: source.read()? })
507    }
508}
509
510// MUTATIONS
511// ================================================================================================
512
513/// A change to an inner node of a sparse Merkle tree that hasn't yet been applied.
514/// [`MutationSet`] stores this type in relation to a [`NodeIndex`] to keep track of what changes
515/// need to occur at which node indices.
516#[derive(Debug, Clone, PartialEq, Eq)]
517pub enum NodeMutation {
518    /// Node needs to be removed.
519    Removal,
520    /// Node needs to be inserted.
521    Addition(InnerNode),
522}
523
524/// Represents a group of prospective mutations to a `SparseMerkleTree`, created by
525/// `SparseMerkleTree::compute_mutations()`, and that can be applied with
526/// `SparseMerkleTree::apply_mutations()`.
527#[derive(Debug, Clone, PartialEq, Eq, Default)]
528pub struct MutationSet<const DEPTH: u8, K, V> {
529    /// The root of the Merkle tree this MutationSet is for, recorded at the time
530    /// [`SparseMerkleTree::compute_mutations()`] was called. Exists to guard against applying
531    /// mutations to the wrong tree or applying stale mutations to a tree that has since changed.
532    old_root: RpoDigest,
533    /// The set of nodes that need to be removed or added. The "effective" node at an index is the
534    /// Merkle tree's existing node at that index, with the [`NodeMutation`] in this map at that
535    /// index overlayed, if any. Each [`NodeMutation::Addition`] corresponds to a
536    /// [`SparseMerkleTree::insert_inner_node()`] call, and each [`NodeMutation::Removal`]
537    /// corresponds to a [`SparseMerkleTree::remove_inner_node()`] call.
538    node_mutations: BTreeMap<NodeIndex, NodeMutation>,
539    /// The set of top-level key-value pairs we're prospectively adding to the tree, including
540    /// adding empty values. The "effective" value for a key is the value in this BTreeMap, falling
541    /// back to the existing value in the Merkle tree. Each entry corresponds to a
542    /// [`SparseMerkleTree::insert_value()`] call.
543    new_pairs: BTreeMap<K, V>,
544    /// The calculated root for the Merkle tree, given these mutations. Publicly retrievable with
545    /// [`MutationSet::root()`]. Corresponds to a [`SparseMerkleTree::set_root()`]. call.
546    new_root: RpoDigest,
547}
548
549impl<const DEPTH: u8, K, V> MutationSet<DEPTH, K, V> {
550    /// Returns the SMT root that was calculated during `SparseMerkleTree::compute_mutations()`. See
551    /// that method for more information.
552    pub fn root(&self) -> RpoDigest {
553        self.new_root
554    }
555
556    /// Returns the SMT root before the mutations were applied.
557    pub fn old_root(&self) -> RpoDigest {
558        self.old_root
559    }
560
561    /// Returns the set of inner nodes that need to be removed or added.
562    pub fn node_mutations(&self) -> &BTreeMap<NodeIndex, NodeMutation> {
563        &self.node_mutations
564    }
565
566    /// Returns the set of top-level key-value pairs that need to be added, updated or deleted
567    /// (i.e. set to `EMPTY_WORD`).
568    pub fn new_pairs(&self) -> &BTreeMap<K, V> {
569        &self.new_pairs
570    }
571}
572
573// SERIALIZATION
574// ================================================================================================
575
576impl Serializable for InnerNode {
577    fn write_into<W: ByteWriter>(&self, target: &mut W) {
578        self.left.write_into(target);
579        self.right.write_into(target);
580    }
581}
582
583impl Deserializable for InnerNode {
584    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
585        let left = source.read()?;
586        let right = source.read()?;
587
588        Ok(Self { left, right })
589    }
590}
591
592impl Serializable for NodeMutation {
593    fn write_into<W: ByteWriter>(&self, target: &mut W) {
594        match self {
595            NodeMutation::Removal => target.write_bool(false),
596            NodeMutation::Addition(inner_node) => {
597                target.write_bool(true);
598                inner_node.write_into(target);
599            },
600        }
601    }
602}
603
604impl Deserializable for NodeMutation {
605    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
606        if source.read_bool()? {
607            let inner_node = source.read()?;
608            return Ok(NodeMutation::Addition(inner_node));
609        }
610
611        Ok(NodeMutation::Removal)
612    }
613}
614
615impl<const DEPTH: u8, K: Serializable, V: Serializable> Serializable for MutationSet<DEPTH, K, V> {
616    fn write_into<W: ByteWriter>(&self, target: &mut W) {
617        target.write(self.old_root);
618        target.write(self.new_root);
619        self.node_mutations.write_into(target);
620        self.new_pairs.write_into(target);
621    }
622}
623
624impl<const DEPTH: u8, K: Deserializable + Ord, V: Deserializable> Deserializable
625    for MutationSet<DEPTH, K, V>
626{
627    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
628        let old_root = source.read()?;
629        let new_root = source.read()?;
630        let node_mutations = source.read()?;
631        let new_pairs = source.read()?;
632
633        Ok(Self {
634            old_root,
635            node_mutations,
636            new_pairs,
637            new_root,
638        })
639    }
640}