Skip to main content

miden_crypto/merkle/smt/partial/
mod.rs

1use super::{EmptySubtreeRoots, LeafIndex, SMT_DEPTH};
2use crate::{
3    EMPTY_WORD, Word,
4    merkle::{
5        InnerNodeInfo, MerkleError, NodeIndex, SparseMerklePath,
6        smt::{InnerNode, InnerNodes, Leaves, SmtLeaf, SmtLeafError, SmtProof},
7    },
8    utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
9};
10
11#[cfg(test)]
12mod tests;
13
14/// A partial version of an [`super::Smt`].
15///
16/// This type can track a subset of the key-value pairs of a full [`super::Smt`] and allows for
17/// updating those pairs to compute the new root of the tree, as if the updates had been done on the
18/// full tree. This is useful so that not all leaves have to be present and loaded into memory to
19/// compute an update.
20///
21/// A key is considered "tracked" if either:
22/// 1. Its merkle path was explicitly added to the tree (via [`PartialSmt::add_path`] or
23///    [`PartialSmt::add_proof`]), or
24/// 2. The path from the leaf to the root goes through empty subtrees that are consistent with the
25///    stored inner nodes (provably empty with zero hash computations).
26///
27/// The second condition allows updating keys in empty subtrees without explicitly adding their
28/// merkle paths. This is verified by walking up from the leaf and checking that any stored
29/// inner node has an empty subtree root as the child on our path.
30///
31/// An important caveat is that only tracked keys can be updated. Attempting to update an
32/// untracked key will result in an error. See [`PartialSmt::insert`] for more details.
33///
34/// Once a partial SMT has been constructed, its root is set in stone. All subsequently added proofs
35/// or merkle paths must match that root, otherwise an error is returned.
36#[derive(Debug, Clone, PartialEq, Eq)]
37#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
38pub struct PartialSmt {
39    root: Word,
40    num_entries: usize,
41    leaves: Leaves<SmtLeaf>,
42    inner_nodes: InnerNodes,
43}
44
45impl PartialSmt {
46    // CONSTANTS
47    // --------------------------------------------------------------------------------------------
48
49    /// The default value used to compute the hash of empty leaves.
50    pub const EMPTY_VALUE: Word = EMPTY_WORD;
51
52    /// The root of an empty tree.
53    pub const EMPTY_ROOT: Word = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
54
55    // CONSTRUCTORS
56    // --------------------------------------------------------------------------------------------
57
58    /// Constructs a [`PartialSmt`] from a root.
59    ///
60    /// All subsequently added proofs or paths must have the same root.
61    pub fn new(root: Word) -> Self {
62        Self {
63            root,
64            num_entries: 0,
65            leaves: Leaves::<SmtLeaf>::default(),
66            inner_nodes: InnerNodes::default(),
67        }
68    }
69
70    /// Instantiates a new [`PartialSmt`] by calling [`PartialSmt::add_proof`] for all [`SmtProof`]s
71    /// in the provided iterator.
72    ///
73    /// If the provided iterator is empty, an empty [`PartialSmt`] is returned.
74    ///
75    /// # Errors
76    ///
77    /// Returns an error if:
78    /// - the roots of the provided proofs are not the same.
79    pub fn from_proofs<I>(proofs: I) -> Result<Self, MerkleError>
80    where
81        I: IntoIterator<Item = SmtProof>,
82    {
83        let mut proofs = proofs.into_iter();
84
85        let Some(first_proof) = proofs.next() else {
86            return Ok(Self::default());
87        };
88
89        // Add the first path to an empty partial SMT without checking that the existing root
90        // matches the new one. This sets the expected root to the root of the first proof and all
91        // subsequently added proofs must match it.
92        let mut partial_smt = Self::default();
93        let (path, leaf) = first_proof.into_parts();
94        let path_root = partial_smt.add_path_unchecked(leaf, path);
95        partial_smt.root = path_root;
96
97        for proof in proofs {
98            partial_smt.add_proof(proof)?;
99        }
100
101        Ok(partial_smt)
102    }
103
104    // PUBLIC ACCESSORS
105    // --------------------------------------------------------------------------------------------
106
107    /// Returns the root of the tree.
108    pub fn root(&self) -> Word {
109        self.root
110    }
111
112    /// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
113    /// path to the leaf, as well as the leaf itself.
114    ///
115    /// # Errors
116    ///
117    /// Returns an error if:
118    /// - the key is not tracked by this partial SMT.
119    pub fn open(&self, key: &Word) -> Result<SmtProof, MerkleError> {
120        let leaf = self.get_leaf(key)?;
121        let merkle_path = self.get_path(key);
122        Ok(SmtProof::new_unchecked(merkle_path, leaf))
123    }
124
125    /// Returns the leaf to which `key` maps.
126    ///
127    /// # Errors
128    ///
129    /// Returns an error if:
130    /// - the key is not tracked by this partial SMT.
131    pub fn get_leaf(&self, key: &Word) -> Result<SmtLeaf, MerkleError> {
132        self.get_tracked_leaf(key).ok_or(MerkleError::UntrackedKey(*key))
133    }
134
135    /// Returns the value associated with `key`.
136    ///
137    /// # Errors
138    ///
139    /// Returns an error if:
140    /// - the key is not tracked by this partial SMT.
141    pub fn get_value(&self, key: &Word) -> Result<Word, MerkleError> {
142        self.get_tracked_leaf(key)
143            .map(|leaf| leaf.get_value(key).unwrap_or_default())
144            .ok_or(MerkleError::UntrackedKey(*key))
145    }
146
147    /// Returns an iterator over the inner nodes of the [`PartialSmt`].
148    pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
149        self.inner_nodes.values().map(|e| InnerNodeInfo {
150            value: e.hash(),
151            left: e.left,
152            right: e.right,
153        })
154    }
155
156    /// Returns an iterator over the [`InnerNode`] and the respective [`NodeIndex`] of the
157    /// [`PartialSmt`].
158    pub fn inner_node_indices(&self) -> impl Iterator<Item = (NodeIndex, InnerNode)> + '_ {
159        self.inner_nodes.iter().map(|(idx, inner)| (*idx, inner.clone()))
160    }
161
162    /// Returns an iterator over the explicitly stored leaves of the [`PartialSmt`] in arbitrary
163    /// order.
164    ///
165    /// Note: This only returns leaves that were explicitly added via [`Self::add_path`] or
166    /// [`Self::add_proof`], or created through [`Self::insert`]. It does not include implicitly
167    /// trackable leaves in empty subtrees.
168    pub fn leaves(&self) -> impl Iterator<Item = (LeafIndex<SMT_DEPTH>, &SmtLeaf)> {
169        self.leaves
170            .iter()
171            .map(|(leaf_index, leaf)| (LeafIndex::new_max_depth(*leaf_index), leaf))
172    }
173
174    /// Returns an iterator over the tracked, non-empty key-value pairs of the [`PartialSmt`] in
175    /// arbitrary order.
176    pub fn entries(&self) -> impl Iterator<Item = &(Word, Word)> {
177        self.leaves().flat_map(|(_, leaf)| leaf.entries())
178    }
179
180    /// Returns the number of non-empty leaves in this tree.
181    ///
182    /// Note that this may return a different value from [Self::num_entries()] as a single leaf may
183    /// contain more than one key-value pair.
184    pub fn num_leaves(&self) -> usize {
185        self.leaves.len()
186    }
187
188    /// Returns the number of tracked, non-empty key-value pairs in this tree.
189    ///
190    /// Note that this may return a different value from [Self::num_leaves()] as a single leaf may
191    /// contain more than one key-value pair.
192    pub fn num_entries(&self) -> usize {
193        self.num_entries
194    }
195
196    /// Returns a boolean value indicating whether the [`PartialSmt`] tracks any leaves.
197    ///
198    /// Note that if a partial SMT does not track leaves, its root is not necessarily the empty SMT
199    /// root, since it could have been constructed from a different root but without tracking any
200    /// leaves.
201    pub fn tracks_leaves(&self) -> bool {
202        !self.leaves.is_empty()
203    }
204
205    // STATE MUTATORS
206    // --------------------------------------------------------------------------------------------
207
208    /// Inserts a value at the specified key, returning the previous value associated with that key.
209    /// Recall that by definition, any key that hasn't been updated is associated with
210    /// [`Self::EMPTY_VALUE`].
211    ///
212    /// This also recomputes all hashes between the leaf (associated with the key) and the root,
213    /// updating the root itself.
214    ///
215    /// # Errors
216    ///
217    /// Returns an error if:
218    /// - the key is not tracked (see the type documentation for the definition of "tracked"). If an
219    ///   error is returned the tree is in the same state as before.
220    /// - inserting the key-value pair would exceed [`super::MAX_LEAF_ENTRIES`] (1024 entries) in
221    ///   the leaf.
222    pub fn insert(&mut self, key: Word, value: Word) -> Result<Word, MerkleError> {
223        let current_leaf = self.get_tracked_leaf(&key).ok_or(MerkleError::UntrackedKey(key))?;
224        let leaf_index = current_leaf.index();
225        let previous_value = current_leaf.get_value(&key).unwrap_or(EMPTY_WORD);
226        let prev_entries = current_leaf.num_entries();
227
228        let leaf = self
229            .leaves
230            .entry(leaf_index.value())
231            .or_insert_with(|| SmtLeaf::new_empty(leaf_index));
232
233        if value != EMPTY_WORD {
234            leaf.insert(key, value).map_err(|e| match e {
235                SmtLeafError::TooManyLeafEntries { actual } => {
236                    MerkleError::TooManyLeafEntries { actual }
237                },
238                other => panic!("unexpected SmtLeaf::insert error: {:?}", other),
239            })?;
240        } else {
241            leaf.remove(key);
242        }
243        let current_entries = leaf.num_entries();
244        let new_leaf_hash = leaf.hash();
245        self.num_entries = self.num_entries + current_entries - prev_entries;
246
247        // Remove empty leaf
248        if current_entries == 0 {
249            self.leaves.remove(&leaf_index.value());
250        }
251
252        // Recompute the path from leaf to root
253        self.recompute_nodes_from_leaf_to_root(leaf_index, new_leaf_hash);
254
255        Ok(previous_value)
256    }
257
258    /// Adds an [`SmtProof`] to this [`PartialSmt`].
259    ///
260    /// This is a convenience method which calls [`Self::add_path`] on the proof. See its
261    /// documentation for details on errors.
262    pub fn add_proof(&mut self, proof: SmtProof) -> Result<(), MerkleError> {
263        let (path, leaf) = proof.into_parts();
264        self.add_path(leaf, path)
265    }
266
267    /// Adds a leaf and its sparse merkle path to this [`PartialSmt`].
268    ///
269    /// If this function was called, any key that is part of the `leaf` can subsequently be updated
270    /// to a new value and produce a correct new tree root.
271    ///
272    /// # Errors
273    ///
274    /// Returns an error if:
275    /// - the new root after the insertion of the leaf and the path does not match the existing
276    ///   root. If an error is returned, the tree is left in an inconsistent state.
277    pub fn add_path(&mut self, leaf: SmtLeaf, path: SparseMerklePath) -> Result<(), MerkleError> {
278        let path_root = self.add_path_unchecked(leaf, path);
279
280        // Check if the newly added merkle path is consistent with the existing tree. If not, the
281        // merkle path was invalid or computed against another tree.
282        if self.root() != path_root {
283            return Err(MerkleError::ConflictingRoots {
284                expected_root: self.root(),
285                actual_root: path_root,
286            });
287        }
288
289        Ok(())
290    }
291
292    // PRIVATE HELPERS
293    // --------------------------------------------------------------------------------------------
294
295    /// Adds a leaf and its sparse merkle path to this [`PartialSmt`] and returns the root of the
296    /// inserted path.
297    ///
298    /// This does not check that the path root matches the existing root of the tree and if so, the
299    /// tree is left in an inconsistent state. This state can be made consistent again by setting
300    /// the root of the SMT to the path root.
301    fn add_path_unchecked(&mut self, leaf: SmtLeaf, path: SparseMerklePath) -> Word {
302        let mut current_index = leaf.index().index;
303
304        let mut node_hash_at_current_index = leaf.hash();
305
306        let prev_entries = self
307            .leaves
308            .get(&current_index.value())
309            .map(|leaf| leaf.num_entries())
310            .unwrap_or(0);
311        let current_entries = leaf.num_entries();
312        // Only store non-empty leaves
313        if current_entries > 0 {
314            self.leaves.insert(current_index.value(), leaf);
315        } else {
316            self.leaves.remove(&current_index.value());
317        }
318
319        // Guaranteed not to over/underflow. All variables are <= MAX_LEAF_ENTRIES and result > 0.
320        self.num_entries = self.num_entries + current_entries - prev_entries;
321
322        for sibling_hash in path {
323            // Find the index of the sibling node and compute whether it is a left or right child.
324            let is_sibling_right = current_index.sibling().is_value_odd();
325
326            // Move the index up so it points to the parent of the current index and the sibling.
327            current_index.move_up();
328
329            // Construct the new parent node from the child that was updated and the sibling from
330            // the merkle path.
331            let new_parent_node = if is_sibling_right {
332                InnerNode {
333                    left: node_hash_at_current_index,
334                    right: sibling_hash,
335                }
336            } else {
337                InnerNode {
338                    left: sibling_hash,
339                    right: node_hash_at_current_index,
340                }
341            };
342
343            node_hash_at_current_index = new_parent_node.hash();
344
345            self.insert_inner_node(current_index, new_parent_node);
346        }
347
348        node_hash_at_current_index
349    }
350
351    /// Returns the leaf for a key if it can be tracked.
352    ///
353    /// A key is trackable if:
354    /// 1. It was explicitly added via `add_path`/`add_proof`, OR
355    /// 2. The path to the leaf goes through empty subtrees (provably empty)
356    ///
357    /// Returns `None` if the key cannot be tracked (path goes through non-empty
358    /// subtrees we don't have data for).
359    fn get_tracked_leaf(&self, key: &Word) -> Option<SmtLeaf> {
360        let leaf_index = Self::key_to_leaf_index(key);
361
362        // Explicitly stored leaves are always trackable
363        if let Some(leaf) = self.leaves.get(&leaf_index.value()) {
364            return Some(leaf.clone());
365        }
366
367        // Empty tree - all leaves implicitly trackable
368        if self.root == Self::EMPTY_ROOT {
369            return Some(SmtLeaf::new_empty(leaf_index));
370        }
371
372        // Walk from root down towards the leaf
373        let target: NodeIndex = leaf_index.into();
374        let mut index = NodeIndex::root();
375
376        for i in (0..SMT_DEPTH).rev() {
377            let inner_node = self.get_inner_node(index)?;
378
379            let is_right = target.is_nth_bit_odd(i);
380            let child_hash = if is_right { inner_node.right } else { inner_node.left };
381
382            // If child is empty subtree root, leaf is implicitly trackable
383            if child_hash == *EmptySubtreeRoots::entry(SMT_DEPTH, SMT_DEPTH - i) {
384                return Some(SmtLeaf::new_empty(leaf_index));
385            }
386
387            index = if is_right {
388                index.right_child()
389            } else {
390                index.left_child()
391            };
392        }
393
394        // Reached leaf level without finding empty subtree - can't track
395        None
396    }
397
398    /// Converts a key to a leaf index.
399    fn key_to_leaf_index(key: &Word) -> LeafIndex<SMT_DEPTH> {
400        let most_significant_felt = key[3];
401        LeafIndex::new_max_depth(most_significant_felt.as_canonical_u64())
402    }
403
404    /// Returns the inner node at the specified index, or `None` if not stored.
405    fn get_inner_node(&self, index: NodeIndex) -> Option<InnerNode> {
406        self.inner_nodes.get(&index).cloned()
407    }
408
409    /// Returns the inner node at the specified index, falling back to the empty subtree root
410    /// if not stored.
411    fn get_inner_node_or_empty(&self, index: NodeIndex) -> InnerNode {
412        self.get_inner_node(index)
413            .unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(SMT_DEPTH, index.depth()))
414    }
415
416    /// Inserts an inner node at the specified index, or removes it if it equals the empty
417    /// subtree root.
418    fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) {
419        if inner_node == EmptySubtreeRoots::get_inner_node(SMT_DEPTH, index.depth()) {
420            self.inner_nodes.remove(&index);
421        } else {
422            self.inner_nodes.insert(index, inner_node);
423        }
424    }
425
426    /// Returns the merkle path for a key by walking up the tree from the leaf.
427    fn get_path(&self, key: &Word) -> SparseMerklePath {
428        let index = NodeIndex::from(Self::key_to_leaf_index(key));
429
430        // Use proof_indices to get sibling indices from leaf to root,
431        // and get each sibling's hash
432        SparseMerklePath::from_sized_iter(index.proof_indices().map(|idx| self.get_node_hash(idx)))
433            .expect("path should be valid since it's from a valid SMT")
434    }
435
436    /// Get the hash of a node at an arbitrary index, including the root or leaf hashes.
437    ///
438    /// The root index simply returns the root. Other hashes are retrieved by looking at
439    /// the parent inner node and returning the respective child hash.
440    fn get_node_hash(&self, index: NodeIndex) -> Word {
441        if index.is_root() {
442            return self.root;
443        }
444
445        let InnerNode { left, right } = self.get_inner_node_or_empty(index.parent());
446
447        if index.is_value_odd() { right } else { left }
448    }
449
450    /// Recomputes all inner nodes from a leaf up to the root after a leaf value change.
451    fn recompute_nodes_from_leaf_to_root(
452        &mut self,
453        leaf_index: LeafIndex<SMT_DEPTH>,
454        leaf_hash: Word,
455    ) {
456        use crate::hash::poseidon2::Poseidon2;
457
458        let mut index: NodeIndex = leaf_index.into();
459        let mut node_hash = leaf_hash;
460
461        for _ in (0..index.depth()).rev() {
462            let is_right = index.is_value_odd();
463            index.move_up();
464            let InnerNode { left, right } = self.get_inner_node_or_empty(index);
465            let (left, right) = if is_right {
466                (left, node_hash)
467            } else {
468                (node_hash, right)
469            };
470            node_hash = Poseidon2::merge(&[left, right]);
471
472            // insert_inner_node handles removing empty subtree roots
473            self.insert_inner_node(index, InnerNode { left, right });
474        }
475        self.root = node_hash;
476    }
477
478    /// Validates the internal structure during deserialization.
479    ///
480    /// Checks that:
481    /// - Each inner node's hash is consistent with its parent.
482    /// - Each leaf's hash is consistent with its parent inner node's left/right child.
483    fn validate(&self) -> Result<(), DeserializationError> {
484        // Validate each inner node is consistent with its parent
485        for (&idx, node) in &self.inner_nodes {
486            let node_hash = node.hash();
487            let expected_hash = self.get_node_hash(idx);
488
489            if node_hash != expected_hash {
490                return Err(DeserializationError::InvalidValue(
491                    "inner node hash is inconsistent with parent".into(),
492                ));
493            }
494        }
495
496        // Validate each leaf's hash is consistent with its parent inner node
497        for (&leaf_pos, leaf) in &self.leaves {
498            let leaf_index = LeafIndex::<SMT_DEPTH>::new_max_depth(leaf_pos);
499            let node_index: NodeIndex = leaf_index.into();
500            let leaf_hash = leaf.hash();
501            let expected_hash = self.get_node_hash(node_index);
502
503            if leaf_hash != expected_hash {
504                return Err(DeserializationError::InvalidValue(
505                    "leaf hash is inconsistent with parent inner node".into(),
506                ));
507            }
508        }
509
510        Ok(())
511    }
512}
513
514impl Default for PartialSmt {
515    /// Returns a new, empty [`PartialSmt`].
516    ///
517    /// All leaves in the returned tree are set to [`Self::EMPTY_VALUE`].
518    fn default() -> Self {
519        Self::new(Self::EMPTY_ROOT)
520    }
521}
522
523// CONVERSIONS
524// ================================================================================================
525
526impl From<super::Smt> for PartialSmt {
527    fn from(smt: super::Smt) -> Self {
528        Self {
529            root: smt.root(),
530            num_entries: smt.num_entries(),
531            leaves: smt.leaves().map(|(idx, leaf)| (idx.value(), leaf.clone())).collect(),
532            inner_nodes: smt.inner_node_indices().collect(),
533        }
534    }
535}
536
537// SERIALIZATION
538// ================================================================================================
539
540impl Serializable for PartialSmt {
541    fn write_into<W: ByteWriter>(&self, target: &mut W) {
542        target.write(self.root());
543        target.write_usize(self.leaves.len());
544        for (i, leaf) in &self.leaves {
545            target.write_u64(*i);
546            target.write(leaf);
547        }
548        target.write_usize(self.inner_nodes.len());
549        for (idx, node) in &self.inner_nodes {
550            target.write(idx);
551            target.write(node);
552        }
553    }
554}
555
556impl Deserializable for PartialSmt {
557    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
558        let root: Word = source.read()?;
559
560        let mut leaves = Leaves::<SmtLeaf>::default();
561        for _ in 0..source.read_usize()? {
562            let pos: u64 = source.read()?;
563            let leaf: SmtLeaf = source.read()?;
564            leaves.insert(pos, leaf);
565        }
566
567        let mut inner_nodes = InnerNodes::default();
568        for _ in 0..source.read_usize()? {
569            let idx: NodeIndex = source.read()?;
570            let node: InnerNode = source.read()?;
571            inner_nodes.insert(idx, node);
572        }
573
574        let num_entries = leaves.values().map(|leaf| leaf.num_entries()).sum();
575
576        let partial = Self { root, num_entries, leaves, inner_nodes };
577        partial.validate()?;
578
579        Ok(partial)
580    }
581}