miden_crypto/merkle/partial_mt/
mod.rs

1use alloc::{
2    collections::{BTreeMap, BTreeSet},
3    string::String,
4    vec::Vec,
5};
6use core::fmt;
7
8use super::{
9    InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, ValuePath, Word,
10    EMPTY_WORD,
11};
12use crate::utils::{
13    word_to_hex, ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
14};
15
16#[cfg(test)]
17mod tests;
18
19// CONSTANTS
20// ================================================================================================
21
22/// Index of the root node.
23const ROOT_INDEX: NodeIndex = NodeIndex::root();
24
25/// An RpoDigest consisting of 4 ZERO elements.
26const EMPTY_DIGEST: RpoDigest = RpoDigest::new(EMPTY_WORD);
27
28// PARTIAL MERKLE TREE
29// ================================================================================================
30
31/// A partial Merkle tree with NodeIndex keys and 4-element RpoDigest leaf values. Partial Merkle
32/// Tree allows to create Merkle Tree by providing Merkle paths of different lengths.
33///
34/// The root of the tree is recomputed on each new leaf update.
35#[derive(Debug, Clone, PartialEq, Eq)]
36#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
37pub struct PartialMerkleTree {
38    max_depth: u8,
39    nodes: BTreeMap<NodeIndex, RpoDigest>,
40    leaves: BTreeSet<NodeIndex>,
41}
42
43impl Default for PartialMerkleTree {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49impl PartialMerkleTree {
50    // CONSTANTS
51    // --------------------------------------------------------------------------------------------
52
53    /// Minimum supported depth.
54    pub const MIN_DEPTH: u8 = 1;
55
56    /// Maximum supported depth.
57    pub const MAX_DEPTH: u8 = 64;
58
59    // CONSTRUCTORS
60    // --------------------------------------------------------------------------------------------
61
62    /// Returns a new empty [PartialMerkleTree].
63    pub fn new() -> Self {
64        PartialMerkleTree {
65            max_depth: 0,
66            nodes: BTreeMap::new(),
67            leaves: BTreeSet::new(),
68        }
69    }
70
71    /// Appends the provided paths iterator into the set.
72    ///
73    /// Analogous to [Self::add_path].
74    pub fn with_paths<I>(paths: I) -> Result<Self, MerkleError>
75    where
76        I: IntoIterator<Item = (u64, RpoDigest, MerklePath)>,
77    {
78        // create an empty tree
79        let tree = PartialMerkleTree::new();
80
81        paths.into_iter().try_fold(tree, |mut tree, (index, value, path)| {
82            tree.add_path(index, value, path)?;
83            Ok(tree)
84        })
85    }
86
87    /// Returns a new [PartialMerkleTree] instantiated with leaves map as specified by the provided
88    /// entries.
89    ///
90    /// # Errors
91    /// Returns an error if:
92    /// - If the depth is 0 or is greater than 64.
93    /// - The number of entries exceeds the maximum tree capacity, that is 2^{depth}.
94    /// - The provided entries contain an insufficient set of nodes.
95    pub fn with_leaves<R, I>(entries: R) -> Result<Self, MerkleError>
96    where
97        R: IntoIterator<IntoIter = I>,
98        I: Iterator<Item = (NodeIndex, RpoDigest)> + ExactSizeIterator,
99    {
100        let mut layers: BTreeMap<u8, Vec<u64>> = BTreeMap::new();
101        let mut leaves = BTreeSet::new();
102        let mut nodes = BTreeMap::new();
103
104        // add data to the leaves and nodes maps and also fill layers map, where the key is the
105        // depth of the node and value is its index.
106        for (node_index, hash) in entries.into_iter() {
107            leaves.insert(node_index);
108            nodes.insert(node_index, hash);
109            layers
110                .entry(node_index.depth())
111                .and_modify(|layer_vec| layer_vec.push(node_index.value()))
112                .or_insert(vec![node_index.value()]);
113        }
114
115        // check if the number of leaves can be accommodated by the tree's depth; we use a min
116        // depth of 63 because we consider passing in a vector of size 2^64 infeasible.
117        let max = 2usize.pow(63);
118        if layers.len() > max {
119            return Err(MerkleError::TooManyEntries(max));
120        }
121
122        // Get maximum depth
123        let max_depth = *layers.keys().next_back().unwrap_or(&0);
124
125        // fill layers without nodes with empty vector
126        for depth in 0..max_depth {
127            layers.entry(depth).or_default();
128        }
129
130        let mut layer_iter = layers.into_values().rev();
131        let mut parent_layer = layer_iter.next().unwrap();
132        let mut current_layer;
133
134        for depth in (1..max_depth + 1).rev() {
135            // set current_layer = parent_layer and parent_layer = layer_iter.next()
136            current_layer = layer_iter.next().unwrap();
137            core::mem::swap(&mut current_layer, &mut parent_layer);
138
139            for index_value in current_layer {
140                // get the parent node index
141                let parent_node = NodeIndex::new(depth - 1, index_value / 2)?;
142
143                // Check if the parent hash was already calculated. In about half of the cases, we
144                // don't need to do anything.
145                if !parent_layer.contains(&parent_node.value()) {
146                    // create current node index
147                    let index = NodeIndex::new(depth, index_value)?;
148
149                    // get hash of the current node
150                    let node =
151                        nodes.get(&index).ok_or(MerkleError::NodeIndexNotFoundInTree(index))?;
152                    // get hash of the sibling node
153                    let sibling = nodes
154                        .get(&index.sibling())
155                        .ok_or(MerkleError::NodeIndexNotFoundInTree(index.sibling()))?;
156                    // get parent hash
157                    let parent = Rpo256::merge(&index.build_node(*node, *sibling));
158
159                    // add index value of the calculated node to the parents layer
160                    parent_layer.push(parent_node.value());
161                    // add index and hash to the nodes map
162                    nodes.insert(parent_node, parent);
163                }
164            }
165        }
166
167        Ok(PartialMerkleTree { max_depth, nodes, leaves })
168    }
169
170    // PUBLIC ACCESSORS
171    // --------------------------------------------------------------------------------------------
172
173    /// Returns the root of this Merkle tree.
174    pub fn root(&self) -> RpoDigest {
175        self.nodes.get(&ROOT_INDEX).cloned().unwrap_or(EMPTY_DIGEST)
176    }
177
178    /// Returns the depth of this Merkle tree.
179    pub fn max_depth(&self) -> u8 {
180        self.max_depth
181    }
182
183    /// Returns a node at the specified NodeIndex.
184    ///
185    /// # Errors
186    /// Returns an error if the specified NodeIndex is not contained in the nodes map.
187    pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
188        self.nodes
189            .get(&index)
190            .ok_or(MerkleError::NodeIndexNotFoundInTree(index))
191            .copied()
192    }
193
194    /// Returns true if provided index contains in the leaves set, false otherwise.
195    pub fn is_leaf(&self, index: NodeIndex) -> bool {
196        self.leaves.contains(&index)
197    }
198
199    /// Returns a vector of paths from every leaf to the root.
200    pub fn to_paths(&self) -> Vec<(NodeIndex, ValuePath)> {
201        let mut paths = Vec::new();
202        self.leaves.iter().for_each(|&leaf| {
203            paths.push((
204                leaf,
205                ValuePath {
206                    value: self.get_node(leaf).expect("Failed to get leaf node"),
207                    path: self.get_path(leaf).expect("Failed to get path"),
208                },
209            ));
210        });
211        paths
212    }
213
214    /// Returns a Merkle path from the node at the specified index to the root.
215    ///
216    /// The node itself is not included in the path.
217    ///
218    /// # Errors
219    /// Returns an error if:
220    /// - the specified index has depth set to 0 or the depth is greater than the depth of this
221    ///   Merkle tree.
222    /// - the specified index is not contained in the nodes map.
223    pub fn get_path(&self, mut index: NodeIndex) -> Result<MerklePath, MerkleError> {
224        if index.is_root() {
225            return Err(MerkleError::DepthTooSmall(index.depth()));
226        } else if index.depth() > self.max_depth() {
227            return Err(MerkleError::DepthTooBig(index.depth() as u64));
228        }
229
230        if !self.nodes.contains_key(&index) {
231            return Err(MerkleError::NodeIndexNotFoundInTree(index));
232        }
233
234        let mut path = Vec::new();
235        for _ in 0..index.depth() {
236            let sibling_index = index.sibling();
237            index.move_up();
238            let sibling =
239                self.nodes.get(&sibling_index).cloned().expect("Sibling node not in the map");
240            path.push(sibling);
241        }
242        Ok(MerklePath::new(path))
243    }
244
245    // ITERATORS
246    // --------------------------------------------------------------------------------------------
247
248    /// Returns an iterator over the leaves of this [PartialMerkleTree].
249    pub fn leaves(&self) -> impl Iterator<Item = (NodeIndex, RpoDigest)> + '_ {
250        self.leaves.iter().map(|&leaf| {
251            (
252                leaf,
253                self.get_node(leaf)
254                    .unwrap_or_else(|_| panic!("Leaf with {leaf} is not in the nodes map")),
255            )
256        })
257    }
258
259    /// Returns an iterator over the inner nodes of this Merkle tree.
260    pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
261        let inner_nodes = self.nodes.iter().filter(|(index, _)| !self.leaves.contains(index));
262        inner_nodes.map(|(index, digest)| {
263            let left_hash =
264                self.nodes.get(&index.left_child()).expect("Failed to get left child hash");
265            let right_hash =
266                self.nodes.get(&index.right_child()).expect("Failed to get right child hash");
267            InnerNodeInfo {
268                value: *digest,
269                left: *left_hash,
270                right: *right_hash,
271            }
272        })
273    }
274
275    // STATE MUTATORS
276    // --------------------------------------------------------------------------------------------
277
278    /// Adds the nodes of the specified Merkle path to this [PartialMerkleTree]. The `index_value`
279    /// and `value` parameters specify the leaf node at which the path starts.
280    ///
281    /// # Errors
282    /// Returns an error if:
283    /// - The depth of the specified node_index is greater than 64 or smaller than 1.
284    /// - The specified path is not consistent with other paths in the set (i.e., resolves to a
285    ///   different root).
286    pub fn add_path(
287        &mut self,
288        index_value: u64,
289        value: RpoDigest,
290        path: MerklePath,
291    ) -> Result<(), MerkleError> {
292        let index_value = NodeIndex::new(path.len() as u8, index_value)?;
293
294        Self::check_depth(index_value.depth())?;
295        self.update_depth(index_value.depth());
296
297        // add provided node and its sibling to the leaves set
298        self.leaves.insert(index_value);
299        let sibling_node_index = index_value.sibling();
300        self.leaves.insert(sibling_node_index);
301
302        // add provided node and its sibling to the nodes map
303        self.nodes.insert(index_value, value);
304        self.nodes.insert(sibling_node_index, path[0]);
305
306        // traverse to the root, updating the nodes
307        let mut index_value = index_value;
308        let node = Rpo256::merge(&index_value.build_node(value, path[0]));
309        let root = path.iter().skip(1).copied().fold(node, |node, hash| {
310            index_value.move_up();
311            // insert calculated node to the nodes map
312            self.nodes.insert(index_value, node);
313
314            // if the calculated node was a leaf, remove it from leaves set.
315            self.leaves.remove(&index_value);
316
317            let sibling_node = index_value.sibling();
318
319            // Insert node from Merkle path to the nodes map. This sibling node becomes a leaf only
320            // if it is a new node (it wasn't in nodes map).
321            // Node can be in 3 states: internal node, leaf of the tree and not a tree node at all.
322            // - Internal node can only stay in this state -- addition of a new path can't make it
323            // a leaf or remove it from the tree.
324            // - Leaf node can stay in the same state (remain a leaf) or can become an internal
325            // node. In the first case we don't need to do anything, and the second case is handled
326            // by the call of `self.leaves.remove(&index_value);`
327            // - New node can be a calculated node or a "sibling" node from a Merkle Path:
328            // --- Calculated node, obviously, never can be a leaf.
329            // --- Sibling node can be only a leaf, because otherwise it is not a new node.
330            if self.nodes.insert(sibling_node, hash).is_none() {
331                self.leaves.insert(sibling_node);
332            }
333
334            Rpo256::merge(&index_value.build_node(node, hash))
335        });
336
337        // if the path set is empty (the root is all ZEROs), set the root to the root of the added
338        // path; otherwise, the root of the added path must be identical to the current root
339        if self.root() == EMPTY_DIGEST {
340            self.nodes.insert(ROOT_INDEX, root);
341        } else if self.root() != root {
342            return Err(MerkleError::ConflictingRoots {
343                expected_root: self.root(),
344                actual_root: root,
345            });
346        }
347
348        Ok(())
349    }
350
351    /// Updates value of the leaf at the specified index returning the old leaf value.
352    ///
353    /// By default the specified index is assumed to belong to the deepest layer. If the considered
354    /// node does not belong to the tree, the first node on the way to the root will be changed.
355    ///
356    /// This also recomputes all hashes between the leaf and the root, updating the root itself.
357    ///
358    /// # Errors
359    /// Returns an error if:
360    /// - No entry exists at the specified index.
361    /// - The specified index is greater than the maximum number of nodes on the deepest layer.
362    pub fn update_leaf(&mut self, index: u64, value: Word) -> Result<RpoDigest, MerkleError> {
363        let mut node_index = NodeIndex::new(self.max_depth(), index)?;
364
365        // proceed to the leaf
366        for _ in 0..node_index.depth() {
367            if !self.leaves.contains(&node_index) {
368                node_index.move_up();
369            }
370        }
371
372        // add node value to the nodes Map
373        let old_value = self
374            .nodes
375            .insert(node_index, value.into())
376            .ok_or(MerkleError::NodeIndexNotFoundInTree(node_index))?;
377
378        // if the old value and new value are the same, there is nothing to update
379        if value == *old_value {
380            return Ok(old_value);
381        }
382
383        let mut value = value.into();
384        for _ in 0..node_index.depth() {
385            let sibling = self.nodes.get(&node_index.sibling()).expect("sibling should exist");
386            value = Rpo256::merge(&node_index.build_node(value, *sibling));
387            node_index.move_up();
388            self.nodes.insert(node_index, value);
389        }
390
391        Ok(old_value)
392    }
393
394    // UTILITY FUNCTIONS
395    // --------------------------------------------------------------------------------------------
396
397    /// Utility to visualize a [PartialMerkleTree] in text.
398    pub fn print(&self) -> Result<String, fmt::Error> {
399        let indent = "  ";
400        let mut s = String::new();
401        s.push_str("root: ");
402        s.push_str(&word_to_hex(&self.root())?);
403        s.push('\n');
404        for d in 1..=self.max_depth() {
405            let entries = 2u64.pow(d.into());
406            for i in 0..entries {
407                let index = NodeIndex::new(d, i).expect("The index must always be valid");
408                let node = self.get_node(index);
409                let node = match node {
410                    Err(_) => continue,
411                    Ok(node) => node,
412                };
413
414                for _ in 0..d {
415                    s.push_str(indent);
416                }
417                s.push_str(&format!("({}, {}): ", index.depth(), index.value()));
418                s.push_str(&word_to_hex(&node)?);
419                s.push('\n');
420            }
421        }
422
423        Ok(s)
424    }
425
426    // HELPER METHODS
427    // --------------------------------------------------------------------------------------------
428
429    /// Updates depth value with the maximum of current and provided depth.
430    fn update_depth(&mut self, new_depth: u8) {
431        self.max_depth = new_depth.max(self.max_depth);
432    }
433
434    /// Returns an error if the depth is 0 or is greater than 64.
435    fn check_depth(depth: u8) -> Result<(), MerkleError> {
436        // validate the range of the depth.
437        if depth < Self::MIN_DEPTH {
438            return Err(MerkleError::DepthTooSmall(depth));
439        } else if Self::MAX_DEPTH < depth {
440            return Err(MerkleError::DepthTooBig(depth as u64));
441        }
442        Ok(())
443    }
444}
445
446// SERIALIZATION
447// ================================================================================================
448
449impl Serializable for PartialMerkleTree {
450    fn write_into<W: ByteWriter>(&self, target: &mut W) {
451        // write leaf nodes
452        target.write_u64(self.leaves.len() as u64);
453        for leaf_index in self.leaves.iter() {
454            leaf_index.write_into(target);
455            self.get_node(*leaf_index).expect("Leaf hash not found").write_into(target);
456        }
457    }
458}
459
460impl Deserializable for PartialMerkleTree {
461    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
462        let leaves_len = source.read_u64()? as usize;
463        let mut leaf_nodes = Vec::with_capacity(leaves_len);
464
465        // add leaf nodes to the vector
466        for _ in 0..leaves_len {
467            let index = NodeIndex::read_from(source)?;
468            let hash = RpoDigest::read_from(source)?;
469            leaf_nodes.push((index, hash));
470        }
471
472        let pmt = PartialMerkleTree::with_leaves(leaf_nodes).map_err(|_| {
473            DeserializationError::InvalidValue("Invalid data for PartialMerkleTree creation".into())
474        })?;
475
476        Ok(pmt)
477    }
478}