miden_crypto/merkle/
merkle_tree.rs

1use alloc::{string::String, vec::Vec};
2use core::{fmt, slice};
3
4use super::{InnerNodeInfo, MerkleError, MerklePath, NodeIndex, Rpo256, Word};
5use crate::utils::{uninit_vector, word_to_hex};
6
7// MERKLE TREE
8// ================================================================================================
9
10/// A fully-balanced binary Merkle tree (i.e., a tree where the number of leaves is a power of two).
11#[derive(Debug, Clone, PartialEq, Eq)]
12#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
13pub struct MerkleTree {
14    nodes: Vec<Word>,
15}
16
17impl MerkleTree {
18    // CONSTRUCTOR
19    // --------------------------------------------------------------------------------------------
20    /// Returns a Merkle tree instantiated from the provided leaves.
21    ///
22    /// # Errors
23    /// Returns an error if the number of leaves is smaller than two or is not a power of two.
24    pub fn new<T>(leaves: T) -> Result<Self, MerkleError>
25    where
26        T: AsRef<[Word]>,
27    {
28        let leaves = leaves.as_ref();
29        let n = leaves.len();
30        if n <= 1 {
31            return Err(MerkleError::DepthTooSmall(n as u8));
32        } else if !n.is_power_of_two() {
33            return Err(MerkleError::NumLeavesNotPowerOfTwo(n));
34        }
35
36        // create un-initialized vector to hold all tree nodes
37        let mut nodes = unsafe { uninit_vector(2 * n) };
38        nodes[0] = Word::default();
39
40        // copy leaves into the second part of the nodes vector
41        nodes[n..].iter_mut().zip(leaves).for_each(|(node, leaf)| {
42            *node = *leaf;
43        });
44
45        // re-interpret nodes as an array of two nodes fused together
46        // Safety: `nodes` will never move here as it is not bound to an external lifetime (i.e.
47        // `self`).
48        let ptr = nodes.as_ptr() as *const [Word; 2];
49        let pairs = unsafe { slice::from_raw_parts(ptr, n) };
50
51        // calculate all internal tree nodes
52        for i in (1..n).rev() {
53            nodes[i] = Rpo256::merge(&pairs[i]);
54        }
55
56        Ok(Self { nodes })
57    }
58
59    // PUBLIC ACCESSORS
60    // --------------------------------------------------------------------------------------------
61
62    /// Returns the root of this Merkle tree.
63    pub fn root(&self) -> Word {
64        self.nodes[1]
65    }
66
67    /// Returns the depth of this Merkle tree.
68    ///
69    /// Merkle tree of depth 1 has two leaves, depth 2 has four leaves etc.
70    pub fn depth(&self) -> u8 {
71        (self.nodes.len() / 2).ilog2() as u8
72    }
73
74    /// Returns a node at the specified depth and index value.
75    ///
76    /// # Errors
77    /// Returns an error if:
78    /// * The specified depth is greater than the depth of the tree.
79    /// * The specified index is not valid for the specified depth.
80    pub fn get_node(&self, index: NodeIndex) -> Result<Word, MerkleError> {
81        if index.is_root() {
82            return Err(MerkleError::DepthTooSmall(index.depth()));
83        } else if index.depth() > self.depth() {
84            return Err(MerkleError::DepthTooBig(index.depth() as u64));
85        }
86
87        let pos = index.to_scalar_index() as usize;
88        Ok(self.nodes[pos])
89    }
90
91    /// Returns a Merkle path to the node at the specified depth and index value. The node itself
92    /// is not included in the path.
93    ///
94    /// # Errors
95    /// Returns an error if:
96    /// * The specified depth is greater than the depth of the tree.
97    /// * The specified value is not valid for the specified depth.
98    pub fn get_path(&self, index: NodeIndex) -> Result<MerklePath, MerkleError> {
99        if index.is_root() {
100            return Err(MerkleError::DepthTooSmall(index.depth()));
101        } else if index.depth() > self.depth() {
102            return Err(MerkleError::DepthTooBig(index.depth() as u64));
103        }
104
105        Ok(MerklePath::from(Vec::from_iter(
106            index.proof_indices().map(|index| self.get_node(index).unwrap()),
107        )))
108    }
109
110    // ITERATORS
111    // --------------------------------------------------------------------------------------------
112
113    /// Returns an iterator over the leaves of this [MerkleTree].
114    pub fn leaves(&self) -> impl Iterator<Item = (u64, &Word)> {
115        let leaves_start = self.nodes.len() / 2;
116        self.nodes.iter().skip(leaves_start).enumerate().map(|(i, v)| (i as u64, v))
117    }
118
119    /// Returns n iterator over every inner node of this [MerkleTree].
120    ///
121    /// The iterator order is unspecified.
122    pub fn inner_nodes(&self) -> InnerNodeIterator<'_> {
123        InnerNodeIterator {
124            nodes: &self.nodes,
125            index: 1, // index 0 is just padding, start at 1
126        }
127    }
128
129    // STATE MUTATORS
130    // --------------------------------------------------------------------------------------------
131
132    /// Replaces the leaf at the specified index with the provided value.
133    ///
134    /// # Errors
135    /// Returns an error if the specified index value is not a valid leaf value for this tree.
136    pub fn update_leaf<'a>(&'a mut self, index_value: u64, value: Word) -> Result<(), MerkleError> {
137        let mut index = NodeIndex::new(self.depth(), index_value)?;
138
139        // we don't need to copy the pairs into a new address as we are logically guaranteed to not
140        // overlap write instructions. however, it's important to bind the lifetime of pairs to
141        // `self.nodes` so the compiler will never move one without moving the other.
142        debug_assert_eq!(self.nodes.len() & 1, 0);
143        let n = self.nodes.len() / 2;
144
145        // Safety: the length of nodes is guaranteed to contain pairs of words; hence, pairs of
146        // digests. we explicitly bind the lifetime here so we add an extra layer of guarantee that
147        // `self.nodes` will be moved only if `pairs` is moved as well. also, the algorithm is
148        // logically guaranteed to not overlap write positions as the write index is always half
149        // the index from which we read the digest input.
150        let ptr = self.nodes.as_ptr() as *const [Word; 2];
151        let pairs: &'a [[Word; 2]] = unsafe { slice::from_raw_parts(ptr, n) };
152
153        // update the current node
154        let pos = index.to_scalar_index() as usize;
155        self.nodes[pos] = value;
156
157        // traverse to the root, updating each node with the merged values of its parents
158        for _ in 0..index.depth() {
159            index.move_up();
160            let pos = index.to_scalar_index() as usize;
161            let value = Rpo256::merge(&pairs[pos]);
162            self.nodes[pos] = value;
163        }
164
165        Ok(())
166    }
167}
168
169// CONVERSIONS
170// ================================================================================================
171
172impl TryFrom<&[Word]> for MerkleTree {
173    type Error = MerkleError;
174
175    fn try_from(value: &[Word]) -> Result<Self, Self::Error> {
176        MerkleTree::new(value)
177    }
178}
179
180// ITERATORS
181// ================================================================================================
182
183/// An iterator over every inner node of the [MerkleTree].
184///
185/// Use this to extract the data of the tree, there is no guarantee on the order of the elements.
186pub struct InnerNodeIterator<'a> {
187    nodes: &'a Vec<Word>,
188    index: usize,
189}
190
191impl Iterator for InnerNodeIterator<'_> {
192    type Item = InnerNodeInfo;
193
194    fn next(&mut self) -> Option<Self::Item> {
195        if self.index < self.nodes.len() / 2 {
196            let value = self.index;
197            let left = self.index * 2;
198            let right = left + 1;
199
200            self.index += 1;
201
202            Some(InnerNodeInfo {
203                value: self.nodes[value],
204                left: self.nodes[left],
205                right: self.nodes[right],
206            })
207        } else {
208            None
209        }
210    }
211}
212
213// UTILITY FUNCTIONS
214// ================================================================================================
215
216/// Utility to visualize a [MerkleTree] in text.
217pub fn tree_to_text(tree: &MerkleTree) -> Result<String, fmt::Error> {
218    let indent = "  ";
219    let mut s = String::new();
220    s.push_str(&word_to_hex(&tree.root())?);
221    s.push('\n');
222    for d in 1..=tree.depth() {
223        let entries = 2u64.pow(d.into());
224        for i in 0..entries {
225            let index = NodeIndex::new(d, i).expect("The index must always be valid");
226            let node = tree.get_node(index).expect("The node must always be found");
227
228            for _ in 0..d {
229                s.push_str(indent);
230            }
231            s.push_str(&word_to_hex(&node)?);
232            s.push('\n');
233        }
234    }
235
236    Ok(s)
237}
238
239/// Utility to visualize a [MerklePath] in text.
240pub fn path_to_text(path: &MerklePath) -> Result<String, fmt::Error> {
241    let mut s = String::new();
242    s.push('[');
243
244    for el in path.iter() {
245        s.push_str(&word_to_hex(el)?);
246        s.push_str(", ");
247    }
248
249    // remove the last ", "
250    if !path.is_empty() {
251        s.pop();
252        s.pop();
253    }
254    s.push(']');
255
256    Ok(s)
257}
258
259// TESTS
260// ================================================================================================
261
262#[cfg(test)]
263mod tests {
264    use core::mem::size_of;
265
266    use proptest::prelude::*;
267
268    use super::*;
269    use crate::{
270        Felt, WORD_SIZE,
271        merkle::{int_to_leaf, int_to_node},
272    };
273
274    const LEAVES4: [Word; WORD_SIZE] =
275        [int_to_node(1), int_to_node(2), int_to_node(3), int_to_node(4)];
276
277    const LEAVES8: [Word; 8] = [
278        int_to_node(1),
279        int_to_node(2),
280        int_to_node(3),
281        int_to_node(4),
282        int_to_node(5),
283        int_to_node(6),
284        int_to_node(7),
285        int_to_node(8),
286    ];
287
288    #[test]
289    fn build_merkle_tree() {
290        let tree = super::MerkleTree::new(LEAVES4).unwrap();
291        assert_eq!(8, tree.nodes.len());
292
293        // leaves were copied correctly
294        for (a, b) in tree.nodes.iter().skip(4).zip(LEAVES4.iter()) {
295            assert_eq!(a, b);
296        }
297
298        let (root, node2, node3) = compute_internal_nodes();
299
300        assert_eq!(root, tree.nodes[1]);
301        assert_eq!(node2, tree.nodes[2]);
302        assert_eq!(node3, tree.nodes[3]);
303
304        assert_eq!(root, tree.root());
305    }
306
307    #[test]
308    fn get_leaf() {
309        let tree = super::MerkleTree::new(LEAVES4).unwrap();
310
311        // check depth 2
312        assert_eq!(LEAVES4[0], tree.get_node(NodeIndex::make(2, 0)).unwrap());
313        assert_eq!(LEAVES4[1], tree.get_node(NodeIndex::make(2, 1)).unwrap());
314        assert_eq!(LEAVES4[2], tree.get_node(NodeIndex::make(2, 2)).unwrap());
315        assert_eq!(LEAVES4[3], tree.get_node(NodeIndex::make(2, 3)).unwrap());
316
317        // check depth 1
318        let (_, node2, node3) = compute_internal_nodes();
319
320        assert_eq!(node2, tree.get_node(NodeIndex::make(1, 0)).unwrap());
321        assert_eq!(node3, tree.get_node(NodeIndex::make(1, 1)).unwrap());
322    }
323
324    #[test]
325    fn get_path() {
326        let tree = super::MerkleTree::new(LEAVES4).unwrap();
327
328        let (_, node2, node3) = compute_internal_nodes();
329
330        // check depth 2
331        assert_eq!(vec![LEAVES4[1], node3], *tree.get_path(NodeIndex::make(2, 0)).unwrap());
332        assert_eq!(vec![LEAVES4[0], node3], *tree.get_path(NodeIndex::make(2, 1)).unwrap());
333        assert_eq!(vec![LEAVES4[3], node2], *tree.get_path(NodeIndex::make(2, 2)).unwrap());
334        assert_eq!(vec![LEAVES4[2], node2], *tree.get_path(NodeIndex::make(2, 3)).unwrap());
335
336        // check depth 1
337        assert_eq!(vec![node3], *tree.get_path(NodeIndex::make(1, 0)).unwrap());
338        assert_eq!(vec![node2], *tree.get_path(NodeIndex::make(1, 1)).unwrap());
339    }
340
341    #[test]
342    fn update_leaf() {
343        let mut tree = super::MerkleTree::new(LEAVES8).unwrap();
344
345        // update one leaf
346        let value = 3;
347        let new_node = int_to_leaf(9);
348        let mut expected_leaves = LEAVES8.to_vec();
349        expected_leaves[value as usize] = new_node;
350        let expected_tree = super::MerkleTree::new(expected_leaves.clone()).unwrap();
351
352        tree.update_leaf(value, new_node).unwrap();
353        assert_eq!(expected_tree.nodes, tree.nodes);
354
355        // update another leaf
356        let value = 6;
357        let new_node = int_to_leaf(10);
358        expected_leaves[value as usize] = new_node;
359        let expected_tree = super::MerkleTree::new(expected_leaves.clone()).unwrap();
360
361        tree.update_leaf(value, new_node).unwrap();
362        assert_eq!(expected_tree.nodes, tree.nodes);
363    }
364
365    #[test]
366    fn nodes() -> Result<(), MerkleError> {
367        let tree = super::MerkleTree::new(LEAVES4).unwrap();
368        let root = tree.root();
369        let l1n0 = tree.get_node(NodeIndex::make(1, 0))?;
370        let l1n1 = tree.get_node(NodeIndex::make(1, 1))?;
371        let l2n0 = tree.get_node(NodeIndex::make(2, 0))?;
372        let l2n1 = tree.get_node(NodeIndex::make(2, 1))?;
373        let l2n2 = tree.get_node(NodeIndex::make(2, 2))?;
374        let l2n3 = tree.get_node(NodeIndex::make(2, 3))?;
375
376        let nodes: Vec<InnerNodeInfo> = tree.inner_nodes().collect();
377        let expected = vec![
378            InnerNodeInfo { value: root, left: l1n0, right: l1n1 },
379            InnerNodeInfo { value: l1n0, left: l2n0, right: l2n1 },
380            InnerNodeInfo { value: l1n1, left: l2n2, right: l2n3 },
381        ];
382        assert_eq!(nodes, expected);
383
384        Ok(())
385    }
386
387    proptest! {
388        #[test]
389        fn arbitrary_word_can_be_represented_as_digest(
390            a in prop::num::u64::ANY,
391            b in prop::num::u64::ANY,
392            c in prop::num::u64::ANY,
393            d in prop::num::u64::ANY,
394        ) {
395            // this test will assert the memory equivalence between word and digest.
396            // it is used to safeguard the `[MerkleTee::update_leaf]` implementation
397            // that assumes this equivalence.
398
399            // build a word and copy it to another address as digest
400            let word = [Felt::new(a), Felt::new(b), Felt::new(c), Felt::new(d)];
401            let digest = Word::from(word);
402
403            // assert the addresses are different
404            let word_ptr = word.as_ptr() as *const u8;
405            let digest_ptr = digest.as_ptr() as *const u8;
406            assert_ne!(word_ptr, digest_ptr);
407
408            // compare the bytes representation
409            let word_bytes = unsafe { slice::from_raw_parts(word_ptr, size_of::<Word>()) };
410            let digest_bytes = unsafe { slice::from_raw_parts(digest_ptr, size_of::<Word>()) };
411            assert_eq!(word_bytes, digest_bytes);
412        }
413    }
414
415    // HELPER FUNCTIONS
416    // --------------------------------------------------------------------------------------------
417
418    fn compute_internal_nodes() -> (Word, Word, Word) {
419        let node2 = Rpo256::hash_elements(&[*LEAVES4[0], *LEAVES4[1]].concat());
420        let node3 = Rpo256::hash_elements(&[*LEAVES4[2], *LEAVES4[3]].concat());
421        let root = Rpo256::merge(&[node2, node3]);
422
423        (root, node2, node3)
424    }
425}