Skip to main content

miden_crypto/merkle/smt/partial/
mod.rs

1use alloc::{collections::VecDeque, string::ToString, vec::Vec};
2
3use super::{EmptySubtreeRoots, LeafIndex, SMT_DEPTH};
4use crate::{
5    EMPTY_WORD, Map, Set, Word,
6    merkle::{
7        InnerNodeInfo, MerkleError, NodeIndex, SparseMerklePath,
8        smt::{InnerNode, InnerNodes, Leaves, SmtLeaf, SmtLeafError, SmtProof},
9    },
10    utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
11};
12
13mod serialization;
14#[cfg(test)]
15mod tests;
16
17pub use serialization::{NodeValue, UniqueNodes};
18
19/// A partial version of an [`super::Smt`].
20///
21/// This type can track a subset of the key-value pairs of a full [`super::Smt`] and allows for
22/// updating those pairs to compute the new root of the tree, as if the updates had been done on the
23/// full tree. This is useful so that not all leaves have to be present and loaded into memory to
24/// compute an update.
25///
26/// A key is considered "tracked" if either:
27/// 1. Its merkle path was explicitly added to the tree (via [`PartialSmt::add_path`] or
28///    [`PartialSmt::add_proof`]), or
29/// 2. The path from the leaf to the root goes through empty subtrees that are consistent with the
30///    stored inner nodes (provably empty with zero hash computations).
31///
32/// The second condition allows updating keys in empty subtrees without explicitly adding their
33/// merkle paths. This is verified by walking up from the leaf and checking that any stored
34/// inner node has an empty subtree root as the child on our path.
35///
36/// An important caveat is that only tracked keys can be updated. Attempting to update an
37/// untracked key will result in an error. See [`PartialSmt::insert`] for more details.
38///
39/// Once a partial SMT has been constructed, its root is set in stone. All subsequently added proofs
40/// or merkle paths must match that root, otherwise an error is returned.
41#[derive(Debug, Clone, PartialEq, Eq)]
42#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
43pub struct PartialSmt {
44    root: Word,
45    num_entries: usize,
46    leaves: Leaves<SmtLeaf>,
47    inner_nodes: InnerNodes,
48}
49
50impl PartialSmt {
51    // CONSTANTS
52    // --------------------------------------------------------------------------------------------
53
54    /// The default value used to compute the hash of empty leaves.
55    pub const EMPTY_VALUE: Word = EMPTY_WORD;
56
57    /// The root of an empty tree.
58    pub const EMPTY_ROOT: Word = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
59
60    // CONSTRUCTORS
61    // --------------------------------------------------------------------------------------------
62
63    /// Constructs a [`PartialSmt`] from a root.
64    ///
65    /// All subsequently added proofs or paths must have the same root.
66    pub fn new(root: Word) -> Self {
67        Self {
68            root,
69            num_entries: 0,
70            leaves: Leaves::<SmtLeaf>::default(),
71            inner_nodes: InnerNodes::default(),
72        }
73    }
74
75    /// Instantiates a new [`PartialSmt`] by calling [`PartialSmt::add_proof`] for all [`SmtProof`]s
76    /// in the provided iterator.
77    ///
78    /// If the provided iterator is empty, an empty [`PartialSmt`] is returned.
79    ///
80    /// # Errors
81    ///
82    /// Returns an error if:
83    /// - the roots of the provided proofs are not the same.
84    pub fn from_proofs<I>(proofs: I) -> Result<Self, MerkleError>
85    where
86        I: IntoIterator<Item = SmtProof>,
87    {
88        let mut proofs = proofs.into_iter();
89
90        let Some(first_proof) = proofs.next() else {
91            return Ok(Self::default());
92        };
93
94        // Add the first path to an empty partial SMT without checking that the existing root
95        // matches the new one. This sets the expected root to the root of the first proof and all
96        // subsequently added proofs must match it.
97        let mut partial_smt = Self::default();
98        let (path, leaf) = first_proof.into_parts();
99        let path_root = partial_smt.add_path_unchecked(leaf, path);
100        partial_smt.root = path_root;
101
102        for proof in proofs {
103            partial_smt.add_proof(proof)?;
104        }
105
106        Ok(partial_smt)
107    }
108
109    // PUBLIC ACCESSORS
110    // --------------------------------------------------------------------------------------------
111
112    /// Returns the root of the tree.
113    pub fn root(&self) -> Word {
114        self.root
115    }
116
117    /// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
118    /// path to the leaf, as well as the leaf itself.
119    ///
120    /// # Errors
121    ///
122    /// Returns an error if:
123    /// - the key is not tracked by this partial SMT.
124    pub fn open(&self, key: &Word) -> Result<SmtProof, MerkleError> {
125        let leaf = self.get_leaf(key)?;
126        let merkle_path = self.get_path(key);
127        Ok(SmtProof::new_unchecked(merkle_path, leaf))
128    }
129
130    /// Returns the leaf to which `key` maps.
131    ///
132    /// # Errors
133    ///
134    /// Returns an error if:
135    /// - the key is not tracked by this partial SMT.
136    pub fn get_leaf(&self, key: &Word) -> Result<SmtLeaf, MerkleError> {
137        self.get_tracked_leaf(key).ok_or(MerkleError::UntrackedKey(*key))
138    }
139
140    /// Returns the value associated with `key`.
141    ///
142    /// # Errors
143    ///
144    /// Returns an error if:
145    /// - the key is not tracked by this partial SMT.
146    pub fn get_value(&self, key: &Word) -> Result<Word, MerkleError> {
147        self.get_tracked_leaf(key)
148            .map(|leaf| leaf.get_value(key).unwrap_or_default())
149            .ok_or(MerkleError::UntrackedKey(*key))
150    }
151
152    /// Returns an iterator over the inner nodes of the [`PartialSmt`].
153    pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
154        self.inner_nodes.values().map(|e| InnerNodeInfo {
155            value: e.hash(),
156            left: e.left,
157            right: e.right,
158        })
159    }
160
161    /// Returns an iterator over the [`InnerNode`] and the respective [`NodeIndex`] of the
162    /// [`PartialSmt`].
163    pub fn inner_node_indices(&self) -> impl Iterator<Item = (NodeIndex, InnerNode)> + '_ {
164        self.inner_nodes.iter().map(|(idx, inner)| (*idx, inner.clone()))
165    }
166
167    /// Returns an iterator over the explicitly stored leaves of the [`PartialSmt`] in arbitrary
168    /// order.
169    ///
170    /// Note: This only returns leaves that were explicitly added via [`Self::add_path`] or
171    /// [`Self::add_proof`], or created through [`Self::insert`]. It does not include implicitly
172    /// trackable leaves in empty subtrees.
173    pub fn leaves(&self) -> impl Iterator<Item = (LeafIndex<SMT_DEPTH>, &SmtLeaf)> {
174        self.leaves
175            .iter()
176            .map(|(leaf_index, leaf)| (LeafIndex::new_max_depth(*leaf_index), leaf))
177    }
178
179    /// Returns an iterator over the tracked, non-empty key-value pairs of the [`PartialSmt`] in
180    /// arbitrary order.
181    pub fn entries(&self) -> impl Iterator<Item = &(Word, Word)> {
182        self.leaves().flat_map(|(_, leaf)| leaf.entries())
183    }
184
185    /// Returns the number of non-empty leaves in this tree.
186    ///
187    /// Note that this may return a different value from [Self::num_entries()] as a single leaf may
188    /// contain more than one key-value pair.
189    pub fn num_leaves(&self) -> usize {
190        self.leaves.len()
191    }
192
193    /// Returns the number of tracked, non-empty key-value pairs in this tree.
194    ///
195    /// Note that this may return a different value from [Self::num_leaves()] as a single leaf may
196    /// contain more than one key-value pair.
197    pub fn num_entries(&self) -> usize {
198        self.num_entries
199    }
200
201    /// Returns a boolean value indicating whether the [`PartialSmt`] tracks any leaves.
202    ///
203    /// Note that if a partial SMT does not track leaves, its root is not necessarily the empty SMT
204    /// root, since it could have been constructed from a different root but without tracking any
205    /// leaves.
206    pub fn tracks_leaves(&self) -> bool {
207        !self.leaves.is_empty()
208    }
209
210    // STATE MUTATORS
211    // --------------------------------------------------------------------------------------------
212
213    /// Inserts a value at the specified key, returning the previous value associated with that key.
214    /// Recall that by definition, any key that hasn't been updated is associated with
215    /// [`Self::EMPTY_VALUE`].
216    ///
217    /// This also recomputes all hashes between the leaf (associated with the key) and the root,
218    /// updating the root itself.
219    ///
220    /// # Errors
221    ///
222    /// Returns an error if:
223    /// - the key is not tracked (see the type documentation for the definition of "tracked"). If an
224    ///   error is returned the tree is in the same state as before.
225    /// - inserting the key-value pair would exceed [`super::MAX_LEAF_ENTRIES`] (1024 entries) in
226    ///   the leaf.
227    pub fn insert(&mut self, key: Word, value: Word) -> Result<Word, MerkleError> {
228        let current_leaf = self.get_tracked_leaf(&key).ok_or(MerkleError::UntrackedKey(key))?;
229        let leaf_index = current_leaf.index();
230        let previous_value = current_leaf.get_value(&key).unwrap_or(EMPTY_WORD);
231        let prev_entries = current_leaf.num_entries();
232
233        let leaf = self
234            .leaves
235            .entry(leaf_index.position())
236            .or_insert_with(|| SmtLeaf::new_empty(leaf_index));
237
238        if value != EMPTY_WORD {
239            leaf.insert(key, value).map_err(|e| match e {
240                SmtLeafError::TooManyLeafEntries { actual } => {
241                    MerkleError::TooManyLeafEntries { actual }
242                },
243                other => panic!("unexpected SmtLeaf::insert error: {:?}", other),
244            })?;
245        } else {
246            leaf.remove(key);
247        }
248        let current_entries = leaf.num_entries();
249        let new_leaf_hash = leaf.hash();
250        self.num_entries = self.num_entries + current_entries - prev_entries;
251
252        // Remove empty leaf
253        if current_entries == 0 {
254            self.leaves.remove(&leaf_index.position());
255        }
256
257        // Recompute the path from leaf to root
258        self.recompute_nodes_from_leaf_to_root(leaf_index, new_leaf_hash);
259
260        Ok(previous_value)
261    }
262
263    /// Adds an [`SmtProof`] to this [`PartialSmt`].
264    ///
265    /// This is a convenience method which calls [`Self::add_path`] on the proof. See its
266    /// documentation for details on errors.
267    pub fn add_proof(&mut self, proof: SmtProof) -> Result<(), MerkleError> {
268        let (path, leaf) = proof.into_parts();
269        self.add_path(leaf, path)
270    }
271
272    /// Adds a leaf and its sparse merkle path to this [`PartialSmt`].
273    ///
274    /// If this function was called, any key that is part of the `leaf` can subsequently be updated
275    /// to a new value and produce a correct new tree root.
276    ///
277    /// # Errors
278    ///
279    /// Returns an error if:
280    /// - the new root after the insertion of the leaf and the path does not match the existing
281    ///   root. If an error is returned, the tree is left in an inconsistent state.
282    pub fn add_path(&mut self, leaf: SmtLeaf, path: SparseMerklePath) -> Result<(), MerkleError> {
283        let path_root = self.add_path_unchecked(leaf, path);
284
285        // Check if the newly added merkle path is consistent with the existing tree. If not, the
286        // merkle path was invalid or computed against another tree.
287        if self.root() != path_root {
288            return Err(MerkleError::ConflictingRoots {
289                expected_root: self.root(),
290                actual_root: path_root,
291            });
292        }
293
294        Ok(())
295    }
296
297    // UNIQUE NODES
298    // --------------------------------------------------------------------------------------------
299
300    /// Converts `self` into the [`UniqueNodes`] serialization representation for compact
301    /// serialization.
302    ///
303    /// This method assumes that the `PartialSmt` is in a valid state.
304    ///
305    /// # Reconstructable Sets
306    ///
307    /// We define the notion of a reconstructable set as one which stores the minimum amount of
308    /// information necessary in order to reconstruct the full state of the tree. We build this set
309    /// as follows:
310    ///
311    /// 1. Start at the leaves and traverse toward the root.
312    /// 2. Wherever a node's value is determined solely by children already implicitly contained
313    ///    within the set, store no new information. If additional information is required (e.g. a
314    ///    sibling node) store that.
315    /// 3. Repeat until the root is reached.
316    ///
317    /// To reconstruct the tree, we just start at the leaves and compute all intermediary nodes from
318    /// the data stored in the reconstructible set.
319    pub fn to_unique_nodes(&self) -> UniqueNodes {
320        // We start by getting all the known leaves, as these give us the starting point for the
321        // reconstruction.
322        let leaf_nodes = self
323            .leaves()
324            .map(|(k, v)| (k, v.clone()))
325            .collect::<Map<LeafIndex<SMT_DEPTH>, SmtLeaf>>();
326
327        // We also create storage for the nodes necessary for reconstruction of the tree...
328        let mut needed_nodes: Map<NodeIndex, NodeValue> = Map::new();
329
330        // ... and grab the full set of inner nodes to work from as a queue for easy use. We sort
331        // them from the bottom of the tree to the top, but retain the standard left-to-right
332        // ordering.
333        let mut inner_nodes = self.inner_node_indices().collect::<Vec<(NodeIndex, InnerNode)>>();
334        inner_nodes.sort_by(|(il, _), (ir, _)| {
335            ir.depth().cmp(&il.depth()).then(il.position().cmp(&ir.position()))
336        });
337        let mut inner_nodes = inner_nodes.into_iter().collect::<VecDeque<(NodeIndex, InnerNode)>>();
338
339        // We also need to store the values for leaves where we ONLY have the hash value, rather
340        // than the proper leaf value.
341        let mut value_only_leaves = Vec::new();
342
343        // We then need to iterate over all the nodes to work out which ones are reconstructible,
344        // and which need us to store additional data to be reconstructible.
345        while let Some((ix, v)) = inner_nodes.pop_front() {
346            // There must be data available for both of the node's children for it to be
347            // reconstructible.
348            for (child, val) in [(ix.left_child(), v.left), (ix.right_child(), v.right)] {
349                if child.depth() != SMT_DEPTH {
350                    // A child of the node `v` can be in one of three states:
351                    //
352                    // 1. The child does not exist as a physical node in `self`, but its value as
353                    //    stored in `v` is real.
354                    // 2. The child does not exist as a physical node in `self`, but its value is
355                    //    the default empty subtree root.
356                    // 3. The child does exist as a physical node in `self`. By induction, as this
357                    //    algorithm runs bottom-up, the data to reconstruct the node already exists.
358                    if self.get_inner_node(child).is_none() {
359                        // In this case, the node does not exist physically, so we have to work out
360                        // which of the other cases it is.
361                        let new = if val == *EmptySubtreeRoots::entry(SMT_DEPTH, child.depth()) {
362                            NodeValue::EmptySubtreeRoot
363                        } else {
364                            NodeValue::Present(val)
365                        };
366
367                        // We allow overwriting existing inserts for algorithmic simplicity, but we
368                        // always check that it is the same value if an overwrite occurs as this
369                        // indicates a programmer bug.
370                        if let Some(v) = needed_nodes.insert(child, new.clone())
371                            && v != new
372                        {
373                            panic!("Overwrite occurred with a different value ")
374                        }
375                    } else {
376                        // Here, the node exists physically, so by induction, it is reconstructible.
377                        // We fall-through with an explicit `continue` for algorithmic clarity.
378                        continue;
379                    }
380                } else {
381                    // Here the child is a leaf node. Leaf nodes can be in one of three states:
382                    //
383                    // 1. A node that has the default empty value, in which case we encode it using
384                    //    absence in the compact representation.
385                    // 2. A node that has a hash value, but that does not exist in the physical
386                    //    leaves in the PartialSmt. These are encoded using an auxiliary buffer to
387                    //    aid in reconstruction.
388                    // 3. A node that exists in fully-materialized form. These are encoded with
389                    //    their full content.
390                    //
391                    // Cases 1 and 3 require no special handling here, as they are encoded with the
392                    // leaves below. Case 2 needs us to take action here.
393                    let empty_leaf_hash =
394                        SmtLeaf::new_empty(LeafIndex::new_max_depth(child.position())).hash();
395
396                    if val != empty_leaf_hash && !self.leaves.contains_key(&child.position()) {
397                        // We are in case 2 here, as the value is not that of the empty leaf, nor is
398                        // there a physical leaf stored in the tree for this. We store this leaf
399                        // value in the auxiliary buffer so we can reconstruct correctly in this
400                        // scenario.
401                        value_only_leaves.push((child.position(), val));
402                    }
403                }
404            }
405        }
406
407        // With all the data gathered, we can convert our types as necessary to create our output.
408        let leaves = leaf_nodes.into_iter().map(|(i, l)| (i.position(), l)).collect::<Vec<_>>();
409        let mut nodes: Map<u8, Vec<(u64, NodeValue)>> = Map::new();
410
411        for (ix, value) in needed_nodes {
412            nodes.entry(ix.depth()).or_default().push((ix.position(), value));
413        }
414
415        UniqueNodes {
416            root: self.root(),
417            leaves,
418            nodes,
419            value_only_leaves,
420        }
421    }
422
423    /// Constructs a new `PartialSmt` from the provided `unique_nodes`, reconstituting the full data
424    /// from the compact representation.
425    ///
426    /// This method assumes that the `unique_nodes` represent a valid `PartialSmt` instance.
427    ///
428    /// See the documentation of [`Self::to_unique_nodes`] for the reconstruction algorithm.
429    ///
430    /// # Errors
431    ///
432    /// - [`MerkleError::NodeIndexNotFoundInStore`] if any node necessary for reconstruction is not
433    ///   available in the provided `unique_nodes` data.
434    pub fn from_unique_nodes(unique_nodes: UniqueNodes) -> Result<Self, DeserializationError> {
435        // We perform our transformation by directly mutating a new instance of `Self`.
436        let mut smt = Self::new(unique_nodes.root);
437
438        // We rely on a minimal set of node values and leaf values to reconstruct the tree, so we
439        // have to be able to perform lookups.
440        let nodes = unique_nodes
441            .nodes
442            .into_iter()
443            .flat_map(|(depth, nodes)| {
444                nodes.into_iter().map(move |(ix, val)| Ok((NodeIndex::new(depth, ix)?, val)))
445            })
446            .collect::<Result<Map<NodeIndex, NodeValue>, MerkleError>>()
447            .map_err(|e| DeserializationError::InvalidValue(e.to_string()))?;
448        let all_leaves = unique_nodes
449            .leaves
450            .into_iter()
451            .map(|(ix, l)| {
452                let node_index = NodeIndex::new(SMT_DEPTH, ix)
453                    .map_err(|e| DeserializationError::InvalidValue(e.to_string()))?;
454                if node_index != l.index().index {
455                    Err(DeserializationError::InvalidValue(format!(
456                        "Node index {ix} did not match the embedded leaf index {}",
457                        l.index().index
458                    )))
459                } else {
460                    Ok((
461                        NodeIndex::new(SMT_DEPTH, ix)
462                            .map_err(|e| DeserializationError::InvalidValue(e.to_string()))?,
463                        l,
464                    ))
465                }
466            })
467            .collect::<Result<Map<_, _>, DeserializationError>>()?;
468
469        // We also need to grab the buffer of the additional leaf values, and we convert it into a
470        // map for easy lookup. It is safe to use `new_unchecked` here as, while this comes from
471        // untrusted input, `ix` can correctly take the value of any `u64`.
472        let value_only_leaves = unique_nodes
473            .value_only_leaves
474            .into_iter()
475            .map(|(ix, v)| (NodeIndex::new_unchecked(SMT_DEPTH, ix), v))
476            .collect::<Map<_, _>>();
477
478        // We then want to process leaf by leaf, with a queue of parent nodes that need visiting.
479        // Rather than trying to de-duplicate on the fly, we instead just discard nodes that have
480        // already been processed when we see them.
481        //
482        // It must be ensured that at no point an index that is lower in the tree than any index
483        // preceding it is inserted.
484        let leaf_based_starting_nodes =
485            all_leaves.keys().map(|k| k.parent()).collect::<VecDeque<_>>();
486
487        // We also, however, need to account for inner nodes which are not reachable in a parent
488        // chain from a leaf, such as those from an exclusion proof. These are all nodes that do not
489        // have a (present) child in the set of nodes or leaves, so to enforce our layering
490        // invariant we add them in sorted order from bottom to top, left to right.
491        //
492        // We process these after the leaf-based nodes to avoid issues with the layering invariant.
493        let mut additional_nodes = nodes.keys().map(|ix| ix.parent()).collect::<Vec<_>>();
494        additional_nodes
495            .sort_by(|il, ir| ir.depth().cmp(&il.depth()).then(il.position().cmp(&ir.position())));
496        let additional_nodes = additional_nodes.into_iter().collect::<VecDeque<_>>();
497
498        // We also track the nodes we have seen to avoid re-doing unnecessary work.
499        let mut seen_nodes = Set::new();
500
501        for mut active_nodes in [leaf_based_starting_nodes, additional_nodes] {
502            seen_nodes.clear();
503            while let Some(ix) = active_nodes.pop_front() {
504                // To avoid re-doing work we immediately discard a node that is already in our tree.
505                if smt.inner_nodes.contains_key(&ix) {
506                    continue;
507                }
508
509                if ix.depth() + 1 == SMT_DEPTH {
510                    // We have to handle the case where the children are the leaves specially.
511                    //
512                    // If no corresponding leaf is present, then either it was a default value, or
513                    // it exists in the value-only leaves buffer, so we have to check both.
514                    let left_child = ix.left_child();
515                    let left = all_leaves
516                        .get(&left_child)
517                        .map(SmtLeaf::hash)
518                        .or_else(|| value_only_leaves.get(&left_child).copied())
519                        .unwrap_or(
520                            SmtLeaf::new_empty(LeafIndex::new_max_depth(left_child.position()))
521                                .hash(),
522                        );
523                    let right_child = ix.right_child();
524                    let right = all_leaves
525                        .get(&right_child)
526                        .map(SmtLeaf::hash)
527                        .or_else(|| value_only_leaves.get(&right_child).copied())
528                        .unwrap_or(
529                            SmtLeaf::new_empty(LeafIndex::new_max_depth(right_child.position()))
530                                .hash(),
531                        );
532
533                    smt.insert_inner_node(ix, InnerNode { left, right })
534                } else {
535                    // If the children are not in the leaves, they can be either in the tree already
536                    // (having been reconstructed) or as a value in the nodes from the unique nodes
537                    // structure.
538                    let [left, right] = [ix.left_child(), ix.right_child()].map(|ix| {
539                        smt.get_inner_node(ix).map(|n| Ok(n.hash())).unwrap_or_else(|| match nodes
540                            .get(&ix)
541                            .ok_or_else(|| {
542                                DeserializationError::InvalidValue(format!(
543                                    "Node at {ix} not found but is required"
544                                ))
545                            })? {
546                            NodeValue::EmptySubtreeRoot => {
547                                Ok(*EmptySubtreeRoots::entry(SMT_DEPTH, ix.depth()))
548                            },
549                            NodeValue::Present(v) => Ok(*v),
550                        })
551                    });
552                    let left = left?;
553                    let right = right?;
554
555                    smt.insert_inner_node(ix, InnerNode { left, right });
556                }
557
558                // Finally, we push the node's parent into the queue if we have not already visited
559                // it. While it would be correct to do unconditionally, we operate over untrusted
560                // input and hence we have to be careful.
561                let parent = ix.parent();
562                if !seen_nodes.contains(&parent) {
563                    active_nodes.push_back(parent);
564                    seen_nodes.insert(parent);
565                }
566            }
567        }
568
569        // With that done, we simply have to write the remaining keys into the tree.
570        all_leaves.into_iter().for_each(|(ix, leaf)| {
571            smt.num_entries += leaf.num_entries();
572            smt.leaves.insert(ix.position(), leaf);
573        });
574
575        smt.validate()?;
576
577        Ok(smt)
578    }
579
580    // PRIVATE HELPERS
581    // --------------------------------------------------------------------------------------------
582
583    /// Adds a leaf and its sparse merkle path to this [`PartialSmt`] and returns the root of the
584    /// inserted path.
585    ///
586    /// This does not check that the path root matches the existing root of the tree and if so, the
587    /// tree is left in an inconsistent state. This state can be made consistent again by setting
588    /// the root of the SMT to the path root.
589    fn add_path_unchecked(&mut self, leaf: SmtLeaf, path: SparseMerklePath) -> Word {
590        let mut current_index = leaf.index().index;
591
592        let mut node_hash_at_current_index = leaf.hash();
593
594        let prev_entries = self
595            .leaves
596            .get(&current_index.position())
597            .map(SmtLeaf::num_entries)
598            .unwrap_or(0);
599        let current_entries = leaf.num_entries();
600        // Only store non-empty leaves
601        if current_entries > 0 {
602            self.leaves.insert(current_index.position(), leaf);
603        } else {
604            self.leaves.remove(&current_index.position());
605        }
606
607        // Guaranteed not to over/underflow. All variables are <= MAX_LEAF_ENTRIES and result > 0.
608        self.num_entries = self.num_entries + current_entries - prev_entries;
609
610        for sibling_hash in path {
611            // Find the index of the sibling node and compute whether it is a left or right child.
612            let is_sibling_right = current_index.sibling().is_position_odd();
613
614            // Move the index up so it points to the parent of the current index and the sibling.
615            current_index.move_up();
616
617            // Construct the new parent node from the child that was updated and the sibling from
618            // the merkle path.
619            let new_parent_node = if is_sibling_right {
620                InnerNode {
621                    left: node_hash_at_current_index,
622                    right: sibling_hash,
623                }
624            } else {
625                InnerNode {
626                    left: sibling_hash,
627                    right: node_hash_at_current_index,
628                }
629            };
630
631            node_hash_at_current_index = new_parent_node.hash();
632
633            self.insert_inner_node(current_index, new_parent_node);
634        }
635
636        node_hash_at_current_index
637    }
638
639    /// Returns the leaf for a key if it can be tracked.
640    ///
641    /// A key is trackable if:
642    /// 1. It was explicitly added via `add_path`/`add_proof`, OR
643    /// 2. The path to the leaf goes through empty subtrees (provably empty)
644    ///
645    /// Returns `None` if the key cannot be tracked (path goes through non-empty
646    /// subtrees we don't have data for).
647    fn get_tracked_leaf(&self, key: &Word) -> Option<SmtLeaf> {
648        let leaf_index = Self::key_to_leaf_index(key);
649
650        // Explicitly stored leaves are always trackable
651        if let Some(leaf) = self.leaves.get(&leaf_index.position()) {
652            return Some(leaf.clone());
653        }
654
655        // Empty tree - all leaves implicitly trackable
656        if self.root == Self::EMPTY_ROOT {
657            return Some(SmtLeaf::new_empty(leaf_index));
658        }
659
660        // Walk from root down towards the leaf
661        let target: NodeIndex = leaf_index.into();
662        let mut index = NodeIndex::root();
663
664        for i in (0..SMT_DEPTH).rev() {
665            let inner_node = self.get_inner_node(index)?;
666
667            let is_right = target.is_nth_bit_odd(i);
668            let child_hash = if is_right { inner_node.right } else { inner_node.left };
669
670            // If child is empty subtree root, leaf is implicitly trackable
671            if child_hash == *EmptySubtreeRoots::entry(SMT_DEPTH, SMT_DEPTH - i) {
672                return Some(SmtLeaf::new_empty(leaf_index));
673            }
674
675            index = if is_right {
676                index.right_child()
677            } else {
678                index.left_child()
679            };
680        }
681
682        // Reached leaf level without finding empty subtree - can't track
683        None
684    }
685
686    /// Converts a key to a leaf index.
687    fn key_to_leaf_index(key: &Word) -> LeafIndex<SMT_DEPTH> {
688        let most_significant_felt = key[3];
689        LeafIndex::new_max_depth(most_significant_felt.as_canonical_u64())
690    }
691
692    /// Returns the inner node at the specified index, or `None` if not stored.
693    fn get_inner_node(&self, index: NodeIndex) -> Option<InnerNode> {
694        self.inner_nodes.get(&index).cloned()
695    }
696
697    /// Returns the inner node at the specified index, falling back to the empty subtree root
698    /// if not stored.
699    fn get_inner_node_or_empty(&self, index: NodeIndex) -> InnerNode {
700        self.get_inner_node(index)
701            .unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(SMT_DEPTH, index.depth()))
702    }
703
704    /// Inserts an inner node at the specified index, or removes it if it equals the empty
705    /// subtree root.
706    fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) {
707        if inner_node == EmptySubtreeRoots::get_inner_node(SMT_DEPTH, index.depth()) {
708            self.inner_nodes.remove(&index);
709        } else {
710            self.inner_nodes.insert(index, inner_node);
711        }
712    }
713
714    /// Returns the merkle path for a key by walking up the tree from the leaf.
715    fn get_path(&self, key: &Word) -> SparseMerklePath {
716        let index = NodeIndex::from(Self::key_to_leaf_index(key));
717
718        // Use proof_indices to get sibling indices from leaf to root,
719        // and get each sibling's hash
720        SparseMerklePath::from_sized_iter(index.proof_indices().map(|idx| self.get_node_hash(idx)))
721            .expect("path should be valid since it's from a valid SMT")
722    }
723
724    /// Get the hash of a node at an arbitrary index, including the root or leaf hashes.
725    ///
726    /// The root index simply returns the root. Other hashes are retrieved by looking at
727    /// the parent inner node and returning the respective child hash.
728    fn get_node_hash(&self, index: NodeIndex) -> Word {
729        if index.is_root() {
730            return self.root;
731        }
732
733        let InnerNode { left, right } = self.get_inner_node_or_empty(index.parent());
734
735        if index.is_position_odd() { right } else { left }
736    }
737
738    /// Recomputes all inner nodes from a leaf up to the root after a leaf value change.
739    fn recompute_nodes_from_leaf_to_root(
740        &mut self,
741        leaf_index: LeafIndex<SMT_DEPTH>,
742        leaf_hash: Word,
743    ) {
744        use crate::hash::poseidon2::Poseidon2;
745
746        let mut index: NodeIndex = leaf_index.into();
747        let mut node_hash = leaf_hash;
748
749        for _ in (0..index.depth()).rev() {
750            let is_right = index.is_position_odd();
751            index.move_up();
752            let InnerNode { left, right } = self.get_inner_node_or_empty(index);
753            let (left, right) = if is_right {
754                (left, node_hash)
755            } else {
756                (node_hash, right)
757            };
758            node_hash = Poseidon2::merge(&[left, right]);
759
760            // insert_inner_node handles removing empty subtree roots
761            self.insert_inner_node(index, InnerNode { left, right });
762        }
763        self.root = node_hash;
764    }
765
766    /// Validates the internal structure during deserialization.
767    ///
768    /// Checks that:
769    /// - Each inner node's hash is consistent with its parent.
770    /// - Each leaf's hash is consistent with its parent inner node's left/right child.
771    fn validate(&self) -> Result<(), DeserializationError> {
772        // Validate each inner node is consistent with its parent
773        for (&idx, node) in &self.inner_nodes {
774            let node_hash = node.hash();
775            let expected_hash = self.get_node_hash(idx);
776
777            if node_hash != expected_hash {
778                return Err(DeserializationError::InvalidValue(
779                    "inner node hash is inconsistent with parent".into(),
780                ));
781            }
782        }
783
784        // Validate each leaf's hash is consistent with its parent inner node
785        for (&leaf_pos, leaf) in &self.leaves {
786            let leaf_index = LeafIndex::<SMT_DEPTH>::new_max_depth(leaf_pos);
787            let node_index: NodeIndex = leaf_index.into();
788            let leaf_hash = leaf.hash();
789            let expected_hash = self.get_node_hash(node_index);
790
791            if leaf_hash != expected_hash {
792                return Err(DeserializationError::InvalidValue(
793                    "leaf hash is inconsistent with parent inner node".into(),
794                ));
795            }
796        }
797
798        Ok(())
799    }
800}
801
802impl Default for PartialSmt {
803    /// Returns a new, empty [`PartialSmt`].
804    ///
805    /// All leaves in the returned tree are set to [`Self::EMPTY_VALUE`].
806    fn default() -> Self {
807        Self::new(Self::EMPTY_ROOT)
808    }
809}
810
811// CONVERSIONS
812// ================================================================================================
813
814impl From<super::Smt> for PartialSmt {
815    fn from(smt: super::Smt) -> Self {
816        Self {
817            root: smt.root(),
818            num_entries: smt.num_entries(),
819            leaves: smt.leaves().map(|(idx, leaf)| (idx.position(), leaf.clone())).collect(),
820            inner_nodes: smt.inner_node_indices().collect(),
821        }
822    }
823}
824
825// SERIALIZATION
826// ================================================================================================
827
828impl Serializable for PartialSmt {
829    fn write_into<W: ByteWriter>(&self, target: &mut W) {
830        let unique_rep = self.to_unique_nodes();
831        unique_rep.write_into(target);
832    }
833}
834
835impl Deserializable for PartialSmt {
836    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
837        let unique_rep = UniqueNodes::read_from(source)?;
838        PartialSmt::from_unique_nodes(unique_rep)
839            .map_err(|e| DeserializationError::InvalidValue(format!("{e}")))
840    }
841}