Skip to main content

miden_crypto/merkle/partial_mt/
mod.rs

1use alloc::{
2    collections::{BTreeMap, BTreeSet},
3    string::String,
4    vec::Vec,
5};
6use core::fmt;
7
8use super::{
9    EMPTY_WORD, InnerNodeInfo, MerkleError, MerklePath, MerkleProof, NodeIndex, Poseidon2, Word,
10};
11use crate::utils::{
12    ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable, word_to_hex,
13};
14
15#[cfg(test)]
16mod tests;
17
18// CONSTANTS
19// ================================================================================================
20
21/// Index of the root node.
22const ROOT_INDEX: NodeIndex = NodeIndex::root();
23
24/// An Word consisting of 4 ZERO elements.
25const EMPTY_DIGEST: Word = EMPTY_WORD;
26
27// PARTIAL MERKLE TREE
28// ================================================================================================
29
30/// A partial Merkle tree with NodeIndex keys and 4-element [Word] leaf values. Partial Merkle
31/// Tree allows to create Merkle Tree by providing Merkle paths of different lengths.
32///
33/// The root of the tree is recomputed on each new leaf update.
34#[derive(Debug, Clone, PartialEq, Eq)]
35#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
36pub struct PartialMerkleTree {
37    max_depth: u8,
38    nodes: BTreeMap<NodeIndex, Word>,
39    leaves: BTreeSet<NodeIndex>,
40}
41
42impl Default for PartialMerkleTree {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl PartialMerkleTree {
49    // CONSTANTS
50    // --------------------------------------------------------------------------------------------
51
52    /// Minimum supported depth.
53    pub const MIN_DEPTH: u8 = 1;
54
55    /// Maximum supported depth.
56    pub const MAX_DEPTH: u8 = 64;
57
58    // CONSTRUCTORS
59    // --------------------------------------------------------------------------------------------
60
61    /// Returns a new empty [PartialMerkleTree].
62    pub fn new() -> Self {
63        PartialMerkleTree {
64            max_depth: 0,
65            nodes: BTreeMap::new(),
66            leaves: BTreeSet::new(),
67        }
68    }
69
70    /// Appends the provided paths iterator into the set.
71    ///
72    /// Analogous to [Self::add_path].
73    pub fn with_paths<I>(paths: I) -> Result<Self, MerkleError>
74    where
75        I: IntoIterator<Item = (u64, Word, MerklePath)>,
76    {
77        // create an empty tree
78        let tree = PartialMerkleTree::new();
79
80        paths.into_iter().try_fold(tree, |mut tree, (index, value, path)| {
81            tree.add_path(index, value, path)?;
82            Ok(tree)
83        })
84    }
85
86    /// Returns a new [PartialMerkleTree] instantiated with leaves map as specified by the provided
87    /// entries.
88    ///
89    /// # Errors
90    /// Returns an error if:
91    /// - Any entry has depth 0 or is greater than 64.
92    /// - The number of entries exceeds the maximum tree capacity, that is 2^{depth}.
93    /// - The provided entries contain an insufficient set of nodes.
94    /// - Any entry is an ancestor of another entry (creates hash ambiguity).
95    ///
96    /// An empty input returns an empty tree.
97    pub fn with_leaves<R, I>(entries: R) -> Result<Self, MerkleError>
98    where
99        R: IntoIterator<IntoIter = I>,
100        I: Iterator<Item = (NodeIndex, Word)> + ExactSizeIterator,
101    {
102        let entries = entries.into_iter();
103        if entries.len() == 0 {
104            return Ok(PartialMerkleTree::new());
105        }
106
107        let mut layers: BTreeMap<u8, Vec<u64>> = BTreeMap::new();
108        let mut leaves = BTreeSet::new();
109        let mut nodes = BTreeMap::new();
110
111        // add data to the leaves and nodes maps and also fill layers map, where the key is the
112        // depth of the node and value is its index.
113        for (node_index, hash) in entries {
114            leaves.insert(node_index);
115            nodes.insert(node_index, hash);
116            layers
117                .entry(node_index.depth())
118                .and_modify(|layer_vec| layer_vec.push(node_index.position()))
119                .or_insert(vec![node_index.position()]);
120        }
121
122        // make sure the depth of the last layer is 64 or smaller
123        if let Some(last_layer) = layers.last_entry() {
124            let last_layer_depth = *last_layer.key();
125            if last_layer_depth > 64 {
126                return Err(MerkleError::TooManyEntries(last_layer_depth));
127            }
128        }
129
130        // Get maximum depth
131        let max_depth = *layers.keys().next_back().unwrap_or(&0);
132
133        // fill layers without nodes with empty vector
134        for depth in 0..max_depth {
135            layers.entry(depth).or_default();
136        }
137
138        let mut layer_iter = layers.into_values().rev();
139        let mut parent_layer = layer_iter.next().unwrap();
140        let mut current_layer;
141
142        for depth in (1..max_depth + 1).rev() {
143            // set current_layer = parent_layer and parent_layer = layer_iter.next()
144            current_layer = layer_iter.next().unwrap();
145            core::mem::swap(&mut current_layer, &mut parent_layer);
146
147            for index_value in current_layer {
148                // get the parent node index
149                let parent_node = NodeIndex::new(depth - 1, index_value / 2)?;
150
151                // If parent already exists, check if it's user-provided (invalid) or computed
152                // (skip)
153                if parent_layer.contains(&parent_node.position()) {
154                    // If the parent was provided as a leaf, that's invalid - we can't have both
155                    // a node and its descendant in the input set.
156                    if leaves.contains(&parent_node) {
157                        return Err(MerkleError::EntryIsNotLeaf { node: parent_node });
158                    }
159                    continue;
160                }
161
162                // create current node index
163                let index = NodeIndex::new(depth, index_value)?;
164
165                // get hash of the current node
166                let node = nodes.get(&index).ok_or(MerkleError::NodeIndexNotFoundInTree(index))?;
167                // get hash of the sibling node
168                let sibling = nodes
169                    .get(&index.sibling())
170                    .ok_or(MerkleError::NodeIndexNotFoundInTree(index.sibling()))?;
171                // get parent hash
172                let parent = Poseidon2::merge(&index.build_node(*node, *sibling));
173
174                // add index value of the calculated node to the parents layer
175                parent_layer.push(parent_node.position());
176                // add index and hash to the nodes map
177                nodes.insert(parent_node, parent);
178            }
179        }
180
181        Ok(PartialMerkleTree { max_depth, nodes, leaves })
182    }
183
184    // PUBLIC ACCESSORS
185    // --------------------------------------------------------------------------------------------
186
187    /// Returns the root of this Merkle tree.
188    pub fn root(&self) -> Word {
189        self.nodes.get(&ROOT_INDEX).cloned().unwrap_or(EMPTY_DIGEST)
190    }
191
192    /// Returns the depth of this Merkle tree.
193    pub fn max_depth(&self) -> u8 {
194        self.max_depth
195    }
196
197    /// Returns a node at the specified NodeIndex.
198    ///
199    /// # Errors
200    /// Returns an error if the specified NodeIndex is not contained in the nodes map.
201    pub fn get_node(&self, index: NodeIndex) -> Result<Word, MerkleError> {
202        self.nodes
203            .get(&index)
204            .ok_or(MerkleError::NodeIndexNotFoundInTree(index))
205            .copied()
206    }
207
208    /// Returns true if provided index contains in the leaves set, false otherwise.
209    pub fn is_leaf(&self, index: NodeIndex) -> bool {
210        self.leaves.contains(&index)
211    }
212
213    /// Returns a vector of paths from every leaf to the root.
214    pub fn to_paths(&self) -> Vec<(NodeIndex, MerkleProof)> {
215        let mut paths = Vec::new();
216        self.leaves.iter().for_each(|&leaf| {
217            paths.push((
218                leaf,
219                MerkleProof {
220                    value: self.get_node(leaf).expect("Failed to get leaf node"),
221                    path: self.get_path(leaf).expect("Failed to get path"),
222                },
223            ));
224        });
225        paths
226    }
227
228    /// Returns a Merkle path from the node at the specified index to the root.
229    ///
230    /// The node itself is not included in the path.
231    ///
232    /// # Errors
233    /// Returns an error if:
234    /// - the specified index has depth set to 0 or the depth is greater than the depth of this
235    ///   Merkle tree.
236    /// - the specified index is not contained in the nodes map.
237    pub fn get_path(&self, mut index: NodeIndex) -> Result<MerklePath, MerkleError> {
238        if index.is_root() {
239            return Err(MerkleError::DepthTooSmall(index.depth()));
240        } else if index.depth() > self.max_depth() {
241            return Err(MerkleError::DepthTooBig(index.depth() as u64));
242        }
243
244        if !self.nodes.contains_key(&index) {
245            return Err(MerkleError::NodeIndexNotFoundInTree(index));
246        }
247
248        let mut path = Vec::new();
249        for _ in 0..index.depth() {
250            let sibling_index = index.sibling();
251            index.move_up();
252            let sibling =
253                self.nodes.get(&sibling_index).cloned().expect("Sibling node not in the map");
254            path.push(sibling);
255        }
256        Ok(MerklePath::new(path))
257    }
258
259    // ITERATORS
260    // --------------------------------------------------------------------------------------------
261
262    /// Returns an iterator over the leaves of this [PartialMerkleTree].
263    pub fn leaves(&self) -> impl Iterator<Item = (NodeIndex, Word)> + '_ {
264        self.leaves.iter().map(|&leaf| {
265            (
266                leaf,
267                self.get_node(leaf)
268                    .unwrap_or_else(|_| panic!("Leaf with {leaf} is not in the nodes map")),
269            )
270        })
271    }
272
273    /// Returns an iterator over the inner nodes of this Merkle tree.
274    pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
275        let inner_nodes = self.nodes.iter().filter(|(index, _)| !self.leaves.contains(index));
276        inner_nodes.map(|(index, digest)| {
277            let left_hash =
278                self.nodes.get(&index.left_child()).expect("Failed to get left child hash");
279            let right_hash =
280                self.nodes.get(&index.right_child()).expect("Failed to get right child hash");
281            InnerNodeInfo {
282                value: *digest,
283                left: *left_hash,
284                right: *right_hash,
285            }
286        })
287    }
288
289    // STATE MUTATORS
290    // --------------------------------------------------------------------------------------------
291
292    /// Adds the nodes of the specified Merkle path to this [PartialMerkleTree]. The `index_value`
293    /// and `value` parameters specify the leaf node at which the path starts.
294    ///
295    /// # Errors
296    /// Returns an error if:
297    /// - The depth of the specified node_index is greater than 64 or smaller than 1.
298    /// - The specified path is not consistent with other paths in the set (i.e., resolves to a
299    ///   different root).
300    pub fn add_path(
301        &mut self,
302        index_value: u64,
303        value: Word,
304        path: MerklePath,
305    ) -> Result<(), MerkleError> {
306        let index_value = NodeIndex::new(path.len() as u8, index_value)?;
307
308        Self::check_depth(index_value.depth())?;
309        self.update_depth(index_value.depth());
310
311        // add provided node and its sibling to the leaves set
312        self.leaves.insert(index_value);
313        let sibling_node_index = index_value.sibling();
314        self.leaves.insert(sibling_node_index);
315
316        // add provided node and its sibling to the nodes map
317        self.nodes.insert(index_value, value);
318        self.nodes.insert(sibling_node_index, path[0]);
319
320        // traverse to the root, updating the nodes
321        let mut index_value = index_value;
322        let node = Poseidon2::merge(&index_value.build_node(value, path[0]));
323        let root = path.iter().skip(1).copied().fold(node, |node, hash| {
324            index_value.move_up();
325            // insert calculated node to the nodes map
326            self.nodes.insert(index_value, node);
327
328            // if the calculated node was a leaf, remove it from leaves set.
329            self.leaves.remove(&index_value);
330
331            let sibling_node = index_value.sibling();
332
333            // Insert node from Merkle path to the nodes map. This sibling node becomes a leaf only
334            // if it is a new node (it wasn't in nodes map).
335            // Node can be in 3 states: internal node, leaf of the tree and not a tree node at all.
336            // - Internal node can only stay in this state -- addition of a new path can't make it
337            // a leaf or remove it from the tree.
338            // - Leaf node can stay in the same state (remain a leaf) or can become an internal
339            // node. In the first case we don't need to do anything, and the second case is handled
340            // by the call of `self.leaves.remove(&index_value);`
341            // - New node can be a calculated node or a "sibling" node from a Merkle Path:
342            // --- Calculated node, obviously, never can be a leaf.
343            // --- Sibling node can be only a leaf, because otherwise it is not a new node.
344            if self.nodes.insert(sibling_node, hash).is_none() {
345                self.leaves.insert(sibling_node);
346            }
347
348            Poseidon2::merge(&index_value.build_node(node, hash))
349        });
350
351        // if the path set is empty (the root is all ZEROs), set the root to the root of the added
352        // path; otherwise, the root of the added path must be identical to the current root
353        if self.root() == EMPTY_DIGEST {
354            self.nodes.insert(ROOT_INDEX, root);
355        } else if self.root() != root {
356            return Err(MerkleError::ConflictingRoots {
357                expected_root: self.root(),
358                actual_root: root,
359            });
360        }
361
362        Ok(())
363    }
364
365    /// Updates value of the leaf at the specified index returning the old leaf value.
366    ///
367    /// By default the specified index is assumed to belong to the deepest layer. If the considered
368    /// node does not belong to the tree, the first node on the way to the root will be changed.
369    ///
370    /// This also recomputes all hashes between the leaf and the root, updating the root itself.
371    ///
372    /// # Errors
373    /// Returns an error if:
374    /// - No entry exists at the specified index.
375    /// - The specified index is greater than the maximum number of nodes on the deepest layer.
376    pub fn update_leaf(&mut self, index: u64, value: Word) -> Result<Word, MerkleError> {
377        let mut node_index = NodeIndex::new(self.max_depth(), index)?;
378
379        // proceed to the leaf
380        for _ in 0..node_index.depth() {
381            if !self.leaves.contains(&node_index) {
382                node_index.move_up();
383            }
384        }
385
386        // add node value to the nodes Map
387        let old_value = self
388            .nodes
389            .insert(node_index, value)
390            .ok_or(MerkleError::NodeIndexNotFoundInTree(node_index))?;
391
392        // if the old value and new value are the same, there is nothing to update
393        if value == old_value {
394            return Ok(old_value);
395        }
396
397        let mut value = value;
398        for _ in 0..node_index.depth() {
399            let sibling = self.nodes.get(&node_index.sibling()).expect("sibling should exist");
400            value = Poseidon2::merge(&node_index.build_node(value, *sibling));
401            node_index.move_up();
402            self.nodes.insert(node_index, value);
403        }
404
405        Ok(old_value)
406    }
407
408    // UTILITY FUNCTIONS
409    // --------------------------------------------------------------------------------------------
410
411    /// Utility to visualize a [PartialMerkleTree] in text.
412    pub fn print(&self) -> Result<String, fmt::Error> {
413        let indent = "  ";
414        let mut s = String::new();
415        s.push_str("root: ");
416        s.push_str(&word_to_hex(&self.root())?);
417        s.push('\n');
418        for d in 1..=self.max_depth() {
419            let entries = 2u64.pow(d.into());
420            for i in 0..entries {
421                let index = NodeIndex::new(d, i).expect("The index must always be valid");
422                let node = self.get_node(index);
423                let node = match node {
424                    Err(_) => continue,
425                    Ok(node) => node,
426                };
427
428                for _ in 0..d {
429                    s.push_str(indent);
430                }
431                s.push_str(&format!("({}, {}): ", index.depth(), index.position()));
432                s.push_str(&word_to_hex(&node)?);
433                s.push('\n');
434            }
435        }
436
437        Ok(s)
438    }
439
440    // HELPER METHODS
441    // --------------------------------------------------------------------------------------------
442
443    /// Updates depth value with the maximum of current and provided depth.
444    fn update_depth(&mut self, new_depth: u8) {
445        self.max_depth = new_depth.max(self.max_depth);
446    }
447
448    /// Returns an error if the depth is 0 or is greater than 64.
449    fn check_depth(depth: u8) -> Result<(), MerkleError> {
450        // validate the range of the depth.
451        if depth < Self::MIN_DEPTH {
452            return Err(MerkleError::DepthTooSmall(depth));
453        } else if Self::MAX_DEPTH < depth {
454            return Err(MerkleError::DepthTooBig(depth as u64));
455        }
456        Ok(())
457    }
458}
459
460// SERIALIZATION
461// ================================================================================================
462
463impl Serializable for PartialMerkleTree {
464    fn write_into<W: ByteWriter>(&self, target: &mut W) {
465        // write leaf nodes
466        target.write_u64(self.leaves.len() as u64);
467        for leaf_index in self.leaves.iter() {
468            leaf_index.write_into(target);
469            self.get_node(*leaf_index).expect("Leaf hash not found").write_into(target);
470        }
471    }
472}
473
474impl Deserializable for PartialMerkleTree {
475    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
476        let leaves_len_u64 = source.read_u64()?;
477        let leaves_len = usize::try_from(leaves_len_u64).map_err(|_| {
478            DeserializationError::InvalidValue("PartialMerkleTree leaf count too large".into())
479        })?;
480
481        // Use read_many_iter to avoid eager allocation and respect BudgetedReader limits
482        let leaf_nodes: Vec<(NodeIndex, Word)> =
483            source.read_many_iter(leaves_len)?.collect::<Result<_, _>>()?;
484
485        let pmt = PartialMerkleTree::with_leaves(leaf_nodes).map_err(|_| {
486            DeserializationError::InvalidValue("Invalid data for PartialMerkleTree creation".into())
487        })?;
488
489        Ok(pmt)
490    }
491
492    /// Minimum serialized size: u64 length prefix (0 entries).
493    fn min_serialized_size() -> usize {
494        8
495    }
496}