miden_crypto/merkle/
merkle_tree.rs

1use alloc::{string::String, vec::Vec};
2use core::{fmt, ops::Deref, slice};
3
4use super::{InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, RpoDigest, Word};
5use crate::utils::{uninit_vector, word_to_hex};
6
7// MERKLE TREE
8// ================================================================================================
9
10/// A fully-balanced binary Merkle tree (i.e., a tree where the number of leaves is a power of two).
11#[derive(Debug, Clone, PartialEq, Eq)]
12#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
13pub struct MerkleTree {
14    nodes: Vec<RpoDigest>,
15}
16
17impl MerkleTree {
18    // CONSTRUCTOR
19    // --------------------------------------------------------------------------------------------
20    /// Returns a Merkle tree instantiated from the provided leaves.
21    ///
22    /// # Errors
23    /// Returns an error if the number of leaves is smaller than two or is not a power of two.
24    pub fn new<T>(leaves: T) -> Result<Self, MerkleError>
25    where
26        T: AsRef<[Word]>,
27    {
28        let leaves = leaves.as_ref();
29        let n = leaves.len();
30        if n <= 1 {
31            return Err(MerkleError::DepthTooSmall(n as u8));
32        } else if !n.is_power_of_two() {
33            return Err(MerkleError::NumLeavesNotPowerOfTwo(n));
34        }
35
36        // create un-initialized vector to hold all tree nodes
37        let mut nodes = unsafe { uninit_vector(2 * n) };
38        nodes[0] = RpoDigest::default();
39
40        // copy leaves into the second part of the nodes vector
41        nodes[n..].iter_mut().zip(leaves).for_each(|(node, leaf)| {
42            *node = RpoDigest::from(*leaf);
43        });
44
45        // re-interpret nodes as an array of two nodes fused together
46        // Safety: `nodes` will never move here as it is not bound to an external lifetime (i.e.
47        // `self`).
48        let ptr = nodes.as_ptr() as *const [RpoDigest; 2];
49        let pairs = unsafe { slice::from_raw_parts(ptr, n) };
50
51        // calculate all internal tree nodes
52        for i in (1..n).rev() {
53            nodes[i] = Rpo256::merge(&pairs[i]);
54        }
55
56        Ok(Self { nodes })
57    }
58
59    // PUBLIC ACCESSORS
60    // --------------------------------------------------------------------------------------------
61
62    /// Returns the root of this Merkle tree.
63    pub fn root(&self) -> RpoDigest {
64        self.nodes[1]
65    }
66
67    /// Returns the depth of this Merkle tree.
68    ///
69    /// Merkle tree of depth 1 has two leaves, depth 2 has four leaves etc.
70    pub fn depth(&self) -> u8 {
71        (self.nodes.len() / 2).ilog2() as u8
72    }
73
74    /// Returns a node at the specified depth and index value.
75    ///
76    /// # Errors
77    /// Returns an error if:
78    /// * The specified depth is greater than the depth of the tree.
79    /// * The specified index is not valid for the specified depth.
80    pub fn get_node(&self, index: NodeIndex) -> Result<RpoDigest, MerkleError> {
81        if index.is_root() {
82            return Err(MerkleError::DepthTooSmall(index.depth()));
83        } else if index.depth() > self.depth() {
84            return Err(MerkleError::DepthTooBig(index.depth() as u64));
85        }
86
87        let pos = index.to_scalar_index() as usize;
88        Ok(self.nodes[pos])
89    }
90
91    /// Returns a Merkle path to the node at the specified depth and index value. The node itself
92    /// is not included in the path.
93    ///
94    /// # Errors
95    /// Returns an error if:
96    /// * The specified depth is greater than the depth of the tree.
97    /// * The specified value is not valid for the specified depth.
98    pub fn get_path(&self, mut index: NodeIndex) -> Result<MerklePath, MerkleError> {
99        if index.is_root() {
100            return Err(MerkleError::DepthTooSmall(index.depth()));
101        } else if index.depth() > self.depth() {
102            return Err(MerkleError::DepthTooBig(index.depth() as u64));
103        }
104
105        // TODO should we create a helper in `NodeIndex` that will encapsulate traversal to root so
106        // we always use inlined `for` instead of `while`? the reason to use `for` is because its
107        // easier for the compiler to vectorize.
108        let mut path = Vec::with_capacity(index.depth() as usize);
109        for _ in 0..index.depth() {
110            let sibling = index.sibling().to_scalar_index() as usize;
111            path.push(self.nodes[sibling]);
112            index.move_up();
113        }
114
115        debug_assert!(index.is_root(), "the path walk must go all the way to the root");
116
117        Ok(path.into())
118    }
119
120    // ITERATORS
121    // --------------------------------------------------------------------------------------------
122
123    /// Returns an iterator over the leaves of this [MerkleTree].
124    pub fn leaves(&self) -> impl Iterator<Item = (u64, &Word)> {
125        let leaves_start = self.nodes.len() / 2;
126        self.nodes
127            .iter()
128            .skip(leaves_start)
129            .enumerate()
130            .map(|(i, v)| (i as u64, v.deref()))
131    }
132
133    /// Returns n iterator over every inner node of this [MerkleTree].
134    ///
135    /// The iterator order is unspecified.
136    pub fn inner_nodes(&self) -> InnerNodeIterator {
137        InnerNodeIterator {
138            nodes: &self.nodes,
139            index: 1, // index 0 is just padding, start at 1
140        }
141    }
142
143    // STATE MUTATORS
144    // --------------------------------------------------------------------------------------------
145
146    /// Replaces the leaf at the specified index with the provided value.
147    ///
148    /// # Errors
149    /// Returns an error if the specified index value is not a valid leaf value for this tree.
150    pub fn update_leaf<'a>(&'a mut self, index_value: u64, value: Word) -> Result<(), MerkleError> {
151        let mut index = NodeIndex::new(self.depth(), index_value)?;
152
153        // we don't need to copy the pairs into a new address as we are logically guaranteed to not
154        // overlap write instructions. however, it's important to bind the lifetime of pairs to
155        // `self.nodes` so the compiler will never move one without moving the other.
156        debug_assert_eq!(self.nodes.len() & 1, 0);
157        let n = self.nodes.len() / 2;
158
159        // Safety: the length of nodes is guaranteed to contain pairs of words; hence, pairs of
160        // digests. we explicitly bind the lifetime here so we add an extra layer of guarantee that
161        // `self.nodes` will be moved only if `pairs` is moved as well. also, the algorithm is
162        // logically guaranteed to not overlap write positions as the write index is always half
163        // the index from which we read the digest input.
164        let ptr = self.nodes.as_ptr() as *const [RpoDigest; 2];
165        let pairs: &'a [[RpoDigest; 2]] = unsafe { slice::from_raw_parts(ptr, n) };
166
167        // update the current node
168        let pos = index.to_scalar_index() as usize;
169        self.nodes[pos] = value.into();
170
171        // traverse to the root, updating each node with the merged values of its parents
172        for _ in 0..index.depth() {
173            index.move_up();
174            let pos = index.to_scalar_index() as usize;
175            let value = Rpo256::merge(&pairs[pos]);
176            self.nodes[pos] = value;
177        }
178
179        Ok(())
180    }
181}
182
183// CONVERSIONS
184// ================================================================================================
185
186impl TryFrom<&[Word]> for MerkleTree {
187    type Error = MerkleError;
188
189    fn try_from(value: &[Word]) -> Result<Self, Self::Error> {
190        MerkleTree::new(value)
191    }
192}
193
194impl TryFrom<&[RpoDigest]> for MerkleTree {
195    type Error = MerkleError;
196
197    fn try_from(value: &[RpoDigest]) -> Result<Self, Self::Error> {
198        let value: Vec<Word> = value.iter().map(|v| *v.deref()).collect();
199        MerkleTree::new(value)
200    }
201}
202
203// ITERATORS
204// ================================================================================================
205
206/// An iterator over every inner node of the [MerkleTree].
207///
208/// Use this to extract the data of the tree, there is no guarantee on the order of the elements.
209pub struct InnerNodeIterator<'a> {
210    nodes: &'a Vec<RpoDigest>,
211    index: usize,
212}
213
214impl Iterator for InnerNodeIterator<'_> {
215    type Item = InnerNodeInfo;
216
217    fn next(&mut self) -> Option<Self::Item> {
218        if self.index < self.nodes.len() / 2 {
219            let value = self.index;
220            let left = self.index * 2;
221            let right = left + 1;
222
223            self.index += 1;
224
225            Some(InnerNodeInfo {
226                value: self.nodes[value],
227                left: self.nodes[left],
228                right: self.nodes[right],
229            })
230        } else {
231            None
232        }
233    }
234}
235
236// UTILITY FUNCTIONS
237// ================================================================================================
238
239/// Utility to visualize a [MerkleTree] in text.
240pub fn tree_to_text(tree: &MerkleTree) -> Result<String, fmt::Error> {
241    let indent = "  ";
242    let mut s = String::new();
243    s.push_str(&word_to_hex(&tree.root())?);
244    s.push('\n');
245    for d in 1..=tree.depth() {
246        let entries = 2u64.pow(d.into());
247        for i in 0..entries {
248            let index = NodeIndex::new(d, i).expect("The index must always be valid");
249            let node = tree.get_node(index).expect("The node must always be found");
250
251            for _ in 0..d {
252                s.push_str(indent);
253            }
254            s.push_str(&word_to_hex(&node)?);
255            s.push('\n');
256        }
257    }
258
259    Ok(s)
260}
261
262/// Utility to visualize a [MerklePath] in text.
263pub fn path_to_text(path: &MerklePath) -> Result<String, fmt::Error> {
264    let mut s = String::new();
265    s.push('[');
266
267    for el in path.iter() {
268        s.push_str(&word_to_hex(el)?);
269        s.push_str(", ");
270    }
271
272    // remove the last ", "
273    if !path.is_empty() {
274        s.pop();
275        s.pop();
276    }
277    s.push(']');
278
279    Ok(s)
280}
281
282// TESTS
283// ================================================================================================
284
285#[cfg(test)]
286mod tests {
287    use core::mem::size_of;
288
289    use proptest::prelude::*;
290
291    use super::*;
292    use crate::{
293        Felt, WORD_SIZE,
294        merkle::{digests_to_words, int_to_leaf, int_to_node},
295    };
296
297    const LEAVES4: [RpoDigest; WORD_SIZE] =
298        [int_to_node(1), int_to_node(2), int_to_node(3), int_to_node(4)];
299
300    const LEAVES8: [RpoDigest; 8] = [
301        int_to_node(1),
302        int_to_node(2),
303        int_to_node(3),
304        int_to_node(4),
305        int_to_node(5),
306        int_to_node(6),
307        int_to_node(7),
308        int_to_node(8),
309    ];
310
311    #[test]
312    fn build_merkle_tree() {
313        let tree = super::MerkleTree::new(digests_to_words(&LEAVES4)).unwrap();
314        assert_eq!(8, tree.nodes.len());
315
316        // leaves were copied correctly
317        for (a, b) in tree.nodes.iter().skip(4).zip(LEAVES4.iter()) {
318            assert_eq!(a, b);
319        }
320
321        let (root, node2, node3) = compute_internal_nodes();
322
323        assert_eq!(root, tree.nodes[1]);
324        assert_eq!(node2, tree.nodes[2]);
325        assert_eq!(node3, tree.nodes[3]);
326
327        assert_eq!(root, tree.root());
328    }
329
330    #[test]
331    fn get_leaf() {
332        let tree = super::MerkleTree::new(digests_to_words(&LEAVES4)).unwrap();
333
334        // check depth 2
335        assert_eq!(LEAVES4[0], tree.get_node(NodeIndex::make(2, 0)).unwrap());
336        assert_eq!(LEAVES4[1], tree.get_node(NodeIndex::make(2, 1)).unwrap());
337        assert_eq!(LEAVES4[2], tree.get_node(NodeIndex::make(2, 2)).unwrap());
338        assert_eq!(LEAVES4[3], tree.get_node(NodeIndex::make(2, 3)).unwrap());
339
340        // check depth 1
341        let (_, node2, node3) = compute_internal_nodes();
342
343        assert_eq!(node2, tree.get_node(NodeIndex::make(1, 0)).unwrap());
344        assert_eq!(node3, tree.get_node(NodeIndex::make(1, 1)).unwrap());
345    }
346
347    #[test]
348    fn get_path() {
349        let tree = super::MerkleTree::new(digests_to_words(&LEAVES4)).unwrap();
350
351        let (_, node2, node3) = compute_internal_nodes();
352
353        // check depth 2
354        assert_eq!(vec![LEAVES4[1], node3], *tree.get_path(NodeIndex::make(2, 0)).unwrap());
355        assert_eq!(vec![LEAVES4[0], node3], *tree.get_path(NodeIndex::make(2, 1)).unwrap());
356        assert_eq!(vec![LEAVES4[3], node2], *tree.get_path(NodeIndex::make(2, 2)).unwrap());
357        assert_eq!(vec![LEAVES4[2], node2], *tree.get_path(NodeIndex::make(2, 3)).unwrap());
358
359        // check depth 1
360        assert_eq!(vec![node3], *tree.get_path(NodeIndex::make(1, 0)).unwrap());
361        assert_eq!(vec![node2], *tree.get_path(NodeIndex::make(1, 1)).unwrap());
362    }
363
364    #[test]
365    fn update_leaf() {
366        let mut tree = super::MerkleTree::new(digests_to_words(&LEAVES8)).unwrap();
367
368        // update one leaf
369        let value = 3;
370        let new_node = int_to_leaf(9);
371        let mut expected_leaves = digests_to_words(&LEAVES8);
372        expected_leaves[value as usize] = new_node;
373        let expected_tree = super::MerkleTree::new(expected_leaves.clone()).unwrap();
374
375        tree.update_leaf(value, new_node).unwrap();
376        assert_eq!(expected_tree.nodes, tree.nodes);
377
378        // update another leaf
379        let value = 6;
380        let new_node = int_to_leaf(10);
381        expected_leaves[value as usize] = new_node;
382        let expected_tree = super::MerkleTree::new(expected_leaves.clone()).unwrap();
383
384        tree.update_leaf(value, new_node).unwrap();
385        assert_eq!(expected_tree.nodes, tree.nodes);
386    }
387
388    #[test]
389    fn nodes() -> Result<(), MerkleError> {
390        let tree = super::MerkleTree::new(digests_to_words(&LEAVES4)).unwrap();
391        let root = tree.root();
392        let l1n0 = tree.get_node(NodeIndex::make(1, 0))?;
393        let l1n1 = tree.get_node(NodeIndex::make(1, 1))?;
394        let l2n0 = tree.get_node(NodeIndex::make(2, 0))?;
395        let l2n1 = tree.get_node(NodeIndex::make(2, 1))?;
396        let l2n2 = tree.get_node(NodeIndex::make(2, 2))?;
397        let l2n3 = tree.get_node(NodeIndex::make(2, 3))?;
398
399        let nodes: Vec<InnerNodeInfo> = tree.inner_nodes().collect();
400        let expected = vec![
401            InnerNodeInfo { value: root, left: l1n0, right: l1n1 },
402            InnerNodeInfo { value: l1n0, left: l2n0, right: l2n1 },
403            InnerNodeInfo { value: l1n1, left: l2n2, right: l2n3 },
404        ];
405        assert_eq!(nodes, expected);
406
407        Ok(())
408    }
409
410    proptest! {
411        #[test]
412        fn arbitrary_word_can_be_represented_as_digest(
413            a in prop::num::u64::ANY,
414            b in prop::num::u64::ANY,
415            c in prop::num::u64::ANY,
416            d in prop::num::u64::ANY,
417        ) {
418            // this test will assert the memory equivalence between word and digest.
419            // it is used to safeguard the `[MerkleTee::update_leaf]` implementation
420            // that assumes this equivalence.
421
422            // build a word and copy it to another address as digest
423            let word = [Felt::new(a), Felt::new(b), Felt::new(c), Felt::new(d)];
424            let digest = RpoDigest::from(word);
425
426            // assert the addresses are different
427            let word_ptr = word.as_ptr() as *const u8;
428            let digest_ptr = digest.as_ptr() as *const u8;
429            assert_ne!(word_ptr, digest_ptr);
430
431            // compare the bytes representation
432            let word_bytes = unsafe { slice::from_raw_parts(word_ptr, size_of::<Word>()) };
433            let digest_bytes = unsafe { slice::from_raw_parts(digest_ptr, size_of::<RpoDigest>()) };
434            assert_eq!(word_bytes, digest_bytes);
435        }
436    }
437
438    // HELPER FUNCTIONS
439    // --------------------------------------------------------------------------------------------
440
441    fn compute_internal_nodes() -> (RpoDigest, RpoDigest, RpoDigest) {
442        let node2 =
443            Rpo256::hash_elements(&[Word::from(LEAVES4[0]), Word::from(LEAVES4[1])].concat());
444        let node3 =
445            Rpo256::hash_elements(&[Word::from(LEAVES4[2]), Word::from(LEAVES4[3])].concat());
446        let root = Rpo256::merge(&[node2, node3]);
447
448        (root, node2, node3)
449    }
450}