miden_crypto/merkle/smt/partial/
mod.rs

1use p3_field::PrimeField64;
2
3use super::{EmptySubtreeRoots, LeafIndex, SMT_DEPTH};
4use crate::{
5    EMPTY_WORD, Word,
6    merkle::{
7        InnerNodeInfo, MerkleError, NodeIndex, SparseMerklePath,
8        smt::{InnerNode, InnerNodes, Leaves, SmtLeaf, SmtLeafError, SmtProof},
9    },
10    utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
11};
12
13#[cfg(test)]
14mod tests;
15
16/// A partial version of an [`super::Smt`].
17///
18/// This type can track a subset of the key-value pairs of a full [`super::Smt`] and allows for
19/// updating those pairs to compute the new root of the tree, as if the updates had been done on the
20/// full tree. This is useful so that not all leaves have to be present and loaded into memory to
21/// compute an update.
22///
23/// A key is considered "tracked" if either:
24/// 1. Its merkle path was explicitly added to the tree (via [`PartialSmt::add_path`] or
25///    [`PartialSmt::add_proof`]), or
26/// 2. The path from the leaf to the root goes through empty subtrees that are consistent with the
27///    stored inner nodes (provably empty with zero hash computations).
28///
29/// The second condition allows updating keys in empty subtrees without explicitly adding their
30/// merkle paths. This is verified by walking up from the leaf and checking that any stored
31/// inner node has an empty subtree root as the child on our path.
32///
33/// An important caveat is that only tracked keys can be updated. Attempting to update an
34/// untracked key will result in an error. See [`PartialSmt::insert`] for more details.
35///
36/// Once a partial SMT has been constructed, its root is set in stone. All subsequently added proofs
37/// or merkle paths must match that root, otherwise an error is returned.
38#[derive(Debug, Clone, PartialEq, Eq)]
39#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
40pub struct PartialSmt {
41    root: Word,
42    num_entries: usize,
43    leaves: Leaves<SmtLeaf>,
44    inner_nodes: InnerNodes,
45}
46
47impl PartialSmt {
48    // CONSTANTS
49    // --------------------------------------------------------------------------------------------
50
51    /// The default value used to compute the hash of empty leaves.
52    pub const EMPTY_VALUE: Word = EMPTY_WORD;
53
54    /// The root of an empty tree.
55    pub const EMPTY_ROOT: Word = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
56
57    // CONSTRUCTORS
58    // --------------------------------------------------------------------------------------------
59
60    /// Constructs a [`PartialSmt`] from a root.
61    ///
62    /// All subsequently added proofs or paths must have the same root.
63    pub fn new(root: Word) -> Self {
64        Self {
65            root,
66            num_entries: 0,
67            leaves: Leaves::<SmtLeaf>::default(),
68            inner_nodes: InnerNodes::default(),
69        }
70    }
71
72    /// Instantiates a new [`PartialSmt`] by calling [`PartialSmt::add_proof`] for all [`SmtProof`]s
73    /// in the provided iterator.
74    ///
75    /// If the provided iterator is empty, an empty [`PartialSmt`] is returned.
76    ///
77    /// # Errors
78    ///
79    /// Returns an error if:
80    /// - the roots of the provided proofs are not the same.
81    pub fn from_proofs<I>(proofs: I) -> Result<Self, MerkleError>
82    where
83        I: IntoIterator<Item = SmtProof>,
84    {
85        let mut proofs = proofs.into_iter();
86
87        let Some(first_proof) = proofs.next() else {
88            return Ok(Self::default());
89        };
90
91        // Add the first path to an empty partial SMT without checking that the existing root
92        // matches the new one. This sets the expected root to the root of the first proof and all
93        // subsequently added proofs must match it.
94        let mut partial_smt = Self::default();
95        let (path, leaf) = first_proof.into_parts();
96        let path_root = partial_smt.add_path_unchecked(leaf, path);
97        partial_smt.root = path_root;
98
99        for proof in proofs {
100            partial_smt.add_proof(proof)?;
101        }
102
103        Ok(partial_smt)
104    }
105
106    // PUBLIC ACCESSORS
107    // --------------------------------------------------------------------------------------------
108
109    /// Returns the root of the tree.
110    pub fn root(&self) -> Word {
111        self.root
112    }
113
114    /// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
115    /// path to the leaf, as well as the leaf itself.
116    ///
117    /// # Errors
118    ///
119    /// Returns an error if:
120    /// - the key is not tracked by this partial SMT.
121    pub fn open(&self, key: &Word) -> Result<SmtProof, MerkleError> {
122        let leaf = self.get_leaf(key)?;
123        let merkle_path = self.get_path(key);
124        Ok(SmtProof::new_unchecked(merkle_path, leaf))
125    }
126
127    /// Returns the leaf to which `key` maps.
128    ///
129    /// # Errors
130    ///
131    /// Returns an error if:
132    /// - the key is not tracked by this partial SMT.
133    pub fn get_leaf(&self, key: &Word) -> Result<SmtLeaf, MerkleError> {
134        self.get_tracked_leaf(key).ok_or(MerkleError::UntrackedKey(*key))
135    }
136
137    /// Returns the value associated with `key`.
138    ///
139    /// # Errors
140    ///
141    /// Returns an error if:
142    /// - the key is not tracked by this partial SMT.
143    pub fn get_value(&self, key: &Word) -> Result<Word, MerkleError> {
144        self.get_tracked_leaf(key)
145            .map(|leaf| leaf.get_value(key).unwrap_or_default())
146            .ok_or(MerkleError::UntrackedKey(*key))
147    }
148
149    /// Returns an iterator over the inner nodes of the [`PartialSmt`].
150    pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
151        self.inner_nodes.values().map(|e| InnerNodeInfo {
152            value: e.hash(),
153            left: e.left,
154            right: e.right,
155        })
156    }
157
158    /// Returns an iterator over the [`InnerNode`] and the respective [`NodeIndex`] of the
159    /// [`PartialSmt`].
160    pub fn inner_node_indices(&self) -> impl Iterator<Item = (NodeIndex, InnerNode)> + '_ {
161        self.inner_nodes.iter().map(|(idx, inner)| (*idx, inner.clone()))
162    }
163
164    /// Returns an iterator over the explicitly stored leaves of the [`PartialSmt`] in arbitrary
165    /// order.
166    ///
167    /// Note: This only returns leaves that were explicitly added via [`Self::add_path`] or
168    /// [`Self::add_proof`], or created through [`Self::insert`]. It does not include implicitly
169    /// trackable leaves in empty subtrees.
170    pub fn leaves(&self) -> impl Iterator<Item = (LeafIndex<SMT_DEPTH>, &SmtLeaf)> {
171        self.leaves
172            .iter()
173            .map(|(leaf_index, leaf)| (LeafIndex::new_max_depth(*leaf_index), leaf))
174    }
175
176    /// Returns an iterator over the tracked, non-empty key-value pairs of the [`PartialSmt`] in
177    /// arbitrary order.
178    pub fn entries(&self) -> impl Iterator<Item = &(Word, Word)> {
179        self.leaves().flat_map(|(_, leaf)| leaf.entries())
180    }
181
182    /// Returns the number of non-empty leaves in this tree.
183    ///
184    /// Note that this may return a different value from [Self::num_entries()] as a single leaf may
185    /// contain more than one key-value pair.
186    pub fn num_leaves(&self) -> usize {
187        self.leaves.len()
188    }
189
190    /// Returns the number of tracked, non-empty key-value pairs in this tree.
191    ///
192    /// Note that this may return a different value from [Self::num_leaves()] as a single leaf may
193    /// contain more than one key-value pair.
194    pub fn num_entries(&self) -> usize {
195        self.num_entries
196    }
197
198    /// Returns a boolean value indicating whether the [`PartialSmt`] tracks any leaves.
199    ///
200    /// Note that if a partial SMT does not track leaves, its root is not necessarily the empty SMT
201    /// root, since it could have been constructed from a different root but without tracking any
202    /// leaves.
203    pub fn tracks_leaves(&self) -> bool {
204        !self.leaves.is_empty()
205    }
206
207    // STATE MUTATORS
208    // --------------------------------------------------------------------------------------------
209
210    /// Inserts a value at the specified key, returning the previous value associated with that key.
211    /// Recall that by definition, any key that hasn't been updated is associated with
212    /// [`Self::EMPTY_VALUE`].
213    ///
214    /// This also recomputes all hashes between the leaf (associated with the key) and the root,
215    /// updating the root itself.
216    ///
217    /// # Errors
218    ///
219    /// Returns an error if:
220    /// - the key is not tracked (see the type documentation for the definition of "tracked"). If an
221    ///   error is returned the tree is in the same state as before.
222    /// - inserting the key-value pair would exceed [`super::MAX_LEAF_ENTRIES`] (1024 entries) in
223    ///   the leaf.
224    pub fn insert(&mut self, key: Word, value: Word) -> Result<Word, MerkleError> {
225        let current_leaf = self.get_tracked_leaf(&key).ok_or(MerkleError::UntrackedKey(key))?;
226        let leaf_index = current_leaf.index();
227        let previous_value = current_leaf.get_value(&key).unwrap_or(EMPTY_WORD);
228        let prev_entries = current_leaf.num_entries();
229
230        let leaf = self
231            .leaves
232            .entry(leaf_index.value())
233            .or_insert_with(|| SmtLeaf::new_empty(leaf_index));
234
235        if value != EMPTY_WORD {
236            leaf.insert(key, value).map_err(|e| match e {
237                SmtLeafError::TooManyLeafEntries { actual } => {
238                    MerkleError::TooManyLeafEntries { actual }
239                },
240                other => panic!("unexpected SmtLeaf::insert error: {:?}", other),
241            })?;
242        } else {
243            leaf.remove(key);
244        }
245        let current_entries = leaf.num_entries();
246        let new_leaf_hash = leaf.hash();
247        self.num_entries = self.num_entries + current_entries - prev_entries;
248
249        // Remove empty leaf
250        if current_entries == 0 {
251            self.leaves.remove(&leaf_index.value());
252        }
253
254        // Recompute the path from leaf to root
255        self.recompute_nodes_from_leaf_to_root(leaf_index, new_leaf_hash);
256
257        Ok(previous_value)
258    }
259
260    /// Adds an [`SmtProof`] to this [`PartialSmt`].
261    ///
262    /// This is a convenience method which calls [`Self::add_path`] on the proof. See its
263    /// documentation for details on errors.
264    pub fn add_proof(&mut self, proof: SmtProof) -> Result<(), MerkleError> {
265        let (path, leaf) = proof.into_parts();
266        self.add_path(leaf, path)
267    }
268
269    /// Adds a leaf and its sparse merkle path to this [`PartialSmt`].
270    ///
271    /// If this function was called, any key that is part of the `leaf` can subsequently be updated
272    /// to a new value and produce a correct new tree root.
273    ///
274    /// # Errors
275    ///
276    /// Returns an error if:
277    /// - the new root after the insertion of the leaf and the path does not match the existing
278    ///   root. If an error is returned, the tree is left in an inconsistent state.
279    pub fn add_path(&mut self, leaf: SmtLeaf, path: SparseMerklePath) -> Result<(), MerkleError> {
280        let path_root = self.add_path_unchecked(leaf, path);
281
282        // Check if the newly added merkle path is consistent with the existing tree. If not, the
283        // merkle path was invalid or computed against another tree.
284        if self.root() != path_root {
285            return Err(MerkleError::ConflictingRoots {
286                expected_root: self.root(),
287                actual_root: path_root,
288            });
289        }
290
291        Ok(())
292    }
293
294    // PRIVATE HELPERS
295    // --------------------------------------------------------------------------------------------
296
297    /// Adds a leaf and its sparse merkle path to this [`PartialSmt`] and returns the root of the
298    /// inserted path.
299    ///
300    /// This does not check that the path root matches the existing root of the tree and if so, the
301    /// tree is left in an inconsistent state. This state can be made consistent again by setting
302    /// the root of the SMT to the path root.
303    fn add_path_unchecked(&mut self, leaf: SmtLeaf, path: SparseMerklePath) -> Word {
304        let mut current_index = leaf.index().index;
305
306        let mut node_hash_at_current_index = leaf.hash();
307
308        let prev_entries = self
309            .leaves
310            .get(&current_index.value())
311            .map(|leaf| leaf.num_entries())
312            .unwrap_or(0);
313        let current_entries = leaf.num_entries();
314        // Only store non-empty leaves
315        if current_entries > 0 {
316            self.leaves.insert(current_index.value(), leaf);
317        } else {
318            self.leaves.remove(&current_index.value());
319        }
320
321        // Guaranteed not to over/underflow. All variables are <= MAX_LEAF_ENTRIES and result > 0.
322        self.num_entries = self.num_entries + current_entries - prev_entries;
323
324        for sibling_hash in path {
325            // Find the index of the sibling node and compute whether it is a left or right child.
326            let is_sibling_right = current_index.sibling().is_value_odd();
327
328            // Move the index up so it points to the parent of the current index and the sibling.
329            current_index.move_up();
330
331            // Construct the new parent node from the child that was updated and the sibling from
332            // the merkle path.
333            let new_parent_node = if is_sibling_right {
334                InnerNode {
335                    left: node_hash_at_current_index,
336                    right: sibling_hash,
337                }
338            } else {
339                InnerNode {
340                    left: sibling_hash,
341                    right: node_hash_at_current_index,
342                }
343            };
344
345            node_hash_at_current_index = new_parent_node.hash();
346
347            self.insert_inner_node(current_index, new_parent_node);
348        }
349
350        node_hash_at_current_index
351    }
352
353    /// Returns the leaf for a key if it can be tracked.
354    ///
355    /// A key is trackable if:
356    /// 1. It was explicitly added via `add_path`/`add_proof`, OR
357    /// 2. The path to the leaf goes through empty subtrees (provably empty)
358    ///
359    /// Returns `None` if the key cannot be tracked (path goes through non-empty
360    /// subtrees we don't have data for).
361    fn get_tracked_leaf(&self, key: &Word) -> Option<SmtLeaf> {
362        let leaf_index = Self::key_to_leaf_index(key);
363
364        // Explicitly stored leaves are always trackable
365        if let Some(leaf) = self.leaves.get(&leaf_index.value()) {
366            return Some(leaf.clone());
367        }
368
369        // Empty tree - all leaves implicitly trackable
370        if self.root == Self::EMPTY_ROOT {
371            return Some(SmtLeaf::new_empty(leaf_index));
372        }
373
374        // Walk from root down towards the leaf
375        let target: NodeIndex = leaf_index.into();
376        let mut index = NodeIndex::root();
377
378        for i in (0..SMT_DEPTH).rev() {
379            let inner_node = self.get_inner_node(index)?;
380
381            let is_right = target.is_nth_bit_odd(i);
382            let child_hash = if is_right { inner_node.right } else { inner_node.left };
383
384            // If child is empty subtree root, leaf is implicitly trackable
385            if child_hash == *EmptySubtreeRoots::entry(SMT_DEPTH, SMT_DEPTH - i) {
386                return Some(SmtLeaf::new_empty(leaf_index));
387            }
388
389            index = if is_right {
390                index.right_child()
391            } else {
392                index.left_child()
393            };
394        }
395
396        // Reached leaf level without finding empty subtree - can't track
397        None
398    }
399
400    /// Converts a key to a leaf index.
401    fn key_to_leaf_index(key: &Word) -> LeafIndex<SMT_DEPTH> {
402        let most_significant_felt = key[3];
403        LeafIndex::new_max_depth(most_significant_felt.as_canonical_u64())
404    }
405
406    /// Returns the inner node at the specified index, or `None` if not stored.
407    fn get_inner_node(&self, index: NodeIndex) -> Option<InnerNode> {
408        self.inner_nodes.get(&index).cloned()
409    }
410
411    /// Returns the inner node at the specified index, falling back to the empty subtree root
412    /// if not stored.
413    fn get_inner_node_or_empty(&self, index: NodeIndex) -> InnerNode {
414        self.get_inner_node(index)
415            .unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(SMT_DEPTH, index.depth()))
416    }
417
418    /// Inserts an inner node at the specified index, or removes it if it equals the empty
419    /// subtree root.
420    fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) {
421        if inner_node == EmptySubtreeRoots::get_inner_node(SMT_DEPTH, index.depth()) {
422            self.inner_nodes.remove(&index);
423        } else {
424            self.inner_nodes.insert(index, inner_node);
425        }
426    }
427
428    /// Returns the merkle path for a key by walking up the tree from the leaf.
429    fn get_path(&self, key: &Word) -> SparseMerklePath {
430        let index = NodeIndex::from(Self::key_to_leaf_index(key));
431
432        // Use proof_indices to get sibling indices from leaf to root,
433        // and get each sibling's hash
434        SparseMerklePath::from_sized_iter(index.proof_indices().map(|idx| self.get_node_hash(idx)))
435            .expect("path should be valid since it's from a valid SMT")
436    }
437
438    /// Get the hash of a node at an arbitrary index, including the root or leaf hashes.
439    ///
440    /// The root index simply returns the root. Other hashes are retrieved by looking at
441    /// the parent inner node and returning the respective child hash.
442    fn get_node_hash(&self, index: NodeIndex) -> Word {
443        if index.is_root() {
444            return self.root;
445        }
446
447        let InnerNode { left, right } = self.get_inner_node_or_empty(index.parent());
448
449        if index.is_value_odd() { right } else { left }
450    }
451
452    /// Recomputes all inner nodes from a leaf up to the root after a leaf value change.
453    fn recompute_nodes_from_leaf_to_root(
454        &mut self,
455        leaf_index: LeafIndex<SMT_DEPTH>,
456        leaf_hash: Word,
457    ) {
458        use crate::hash::rpo::Rpo256;
459
460        let mut index: NodeIndex = leaf_index.into();
461        let mut node_hash = leaf_hash;
462
463        for _ in (0..index.depth()).rev() {
464            let is_right = index.is_value_odd();
465            index.move_up();
466            let InnerNode { left, right } = self.get_inner_node_or_empty(index);
467            let (left, right) = if is_right {
468                (left, node_hash)
469            } else {
470                (node_hash, right)
471            };
472            node_hash = Rpo256::merge(&[left, right]);
473
474            // insert_inner_node handles removing empty subtree roots
475            self.insert_inner_node(index, InnerNode { left, right });
476        }
477        self.root = node_hash;
478    }
479
480    /// Validates the internal structure during deserialization.
481    ///
482    /// Checks that:
483    /// - Each inner node's hash is consistent with its parent.
484    /// - Each leaf's hash is consistent with its parent inner node's left/right child.
485    fn validate(&self) -> Result<(), DeserializationError> {
486        // Validate each inner node is consistent with its parent
487        for (&idx, node) in &self.inner_nodes {
488            let node_hash = node.hash();
489            let expected_hash = self.get_node_hash(idx);
490
491            if node_hash != expected_hash {
492                return Err(DeserializationError::InvalidValue(
493                    "inner node hash is inconsistent with parent".into(),
494                ));
495            }
496        }
497
498        // Validate each leaf's hash is consistent with its parent inner node
499        for (&leaf_pos, leaf) in &self.leaves {
500            let leaf_index = LeafIndex::<SMT_DEPTH>::new_max_depth(leaf_pos);
501            let node_index: NodeIndex = leaf_index.into();
502            let leaf_hash = leaf.hash();
503            let expected_hash = self.get_node_hash(node_index);
504
505            if leaf_hash != expected_hash {
506                return Err(DeserializationError::InvalidValue(
507                    "leaf hash is inconsistent with parent inner node".into(),
508                ));
509            }
510        }
511
512        Ok(())
513    }
514}
515
516impl Default for PartialSmt {
517    /// Returns a new, empty [`PartialSmt`].
518    ///
519    /// All leaves in the returned tree are set to [`Self::EMPTY_VALUE`].
520    fn default() -> Self {
521        Self::new(Self::EMPTY_ROOT)
522    }
523}
524
525// CONVERSIONS
526// ================================================================================================
527
528impl From<super::Smt> for PartialSmt {
529    fn from(smt: super::Smt) -> Self {
530        Self {
531            root: smt.root(),
532            num_entries: smt.num_entries(),
533            leaves: smt.leaves().map(|(idx, leaf)| (idx.value(), leaf.clone())).collect(),
534            inner_nodes: smt.inner_node_indices().collect(),
535        }
536    }
537}
538
539// SERIALIZATION
540// ================================================================================================
541
542impl Serializable for PartialSmt {
543    fn write_into<W: ByteWriter>(&self, target: &mut W) {
544        target.write(self.root());
545        target.write_usize(self.leaves.len());
546        for (i, leaf) in &self.leaves {
547            target.write_u64(*i);
548            target.write(leaf);
549        }
550        target.write_usize(self.inner_nodes.len());
551        for (idx, node) in &self.inner_nodes {
552            target.write(idx);
553            target.write(node);
554        }
555    }
556}
557
558impl Deserializable for PartialSmt {
559    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
560        let root: Word = source.read()?;
561
562        let mut leaves = Leaves::<SmtLeaf>::default();
563        for _ in 0..source.read_usize()? {
564            let pos: u64 = source.read()?;
565            let leaf: SmtLeaf = source.read()?;
566            leaves.insert(pos, leaf);
567        }
568
569        let mut inner_nodes = InnerNodes::default();
570        for _ in 0..source.read_usize()? {
571            let idx: NodeIndex = source.read()?;
572            let node: InnerNode = source.read()?;
573            inner_nodes.insert(idx, node);
574        }
575
576        let num_entries = leaves.values().map(|leaf| leaf.num_entries()).sum();
577
578        let partial = Self { root, num_entries, leaves, inner_nodes };
579        partial.validate()?;
580
581        Ok(partial)
582    }
583}