lambdaworks_crypto/merkle_tree/
merkle.rs

1use core::fmt::Display;
2
3use alloc::vec::Vec;
4
5use super::{proof::Proof, traits::IsMerkleTreeBackend, utils::*};
6
7#[derive(Debug)]
8pub enum Error {
9    OutOfBounds,
10}
11impl Display for Error {
12    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
13        write!(f, "Accessed node was out of bound")
14    }
15}
16
17#[cfg(feature = "std")]
18impl std::error::Error for Error {}
19
20/// The struct for the Merkle tree, consisting of the root and the nodes.
21/// A typical tree would look like this
22///                 root
23///              /        \
24///          leaf 12     leaf 34
25///        /         \    /      \
26///    leaf 1     leaf 2 leaf 3  leaf 4
27/// The bottom leafs correspond to the hashes of the elements, while each upper
28/// layer contains the hash of the concatenation of the daughter nodes.
29#[derive(Clone)]
30#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
31pub struct MerkleTree<B: IsMerkleTreeBackend> {
32    pub root: B::Node,
33    nodes: Vec<B::Node>,
34}
35
36const ROOT: usize = 0;
37
38impl<B> MerkleTree<B>
39where
40    B: IsMerkleTreeBackend,
41{
42    /// Create a Merkle tree from a slice of data
43    pub fn build(unhashed_leaves: &[B::Data]) -> Option<Self> {
44        if unhashed_leaves.is_empty() {
45            return None;
46        }
47
48        let hashed_leaves: Vec<B::Node> = B::hash_leaves(unhashed_leaves);
49
50        //The leaf must be a power of 2 set
51        let hashed_leaves = complete_until_power_of_two(hashed_leaves);
52        let leaves_len = hashed_leaves.len();
53
54        //The length of leaves minus one inner node in the merkle tree
55        //The first elements are overwritten by build function, it doesn't matter what it's there
56        let mut nodes = vec![hashed_leaves[0].clone(); leaves_len - 1];
57        nodes.extend(hashed_leaves);
58
59        //Build the inner nodes of the tree
60        build::<B>(&mut nodes, leaves_len);
61
62        Some(MerkleTree {
63            root: nodes[ROOT].clone(),
64            nodes,
65        })
66    }
67
68    /// Returns a Merkle proof for the element/s at position pos
69    /// For example, give me an inclusion proof for the 3rd element in the
70    /// Merkle tree
71    pub fn get_proof_by_pos(&self, pos: usize) -> Option<Proof<B::Node>> {
72        let pos = pos + self.nodes.len() / 2;
73        let Ok(merkle_path) = self.build_merkle_path(pos) else {
74            return None;
75        };
76
77        self.create_proof(merkle_path)
78    }
79
80    /// Creates a proof from a Merkle pasth
81    fn create_proof(&self, merkle_path: Vec<B::Node>) -> Option<Proof<B::Node>> {
82        Some(Proof { merkle_path })
83    }
84
85    /// Returns the Merkle path for the element/s for the leaf at position pos
86    fn build_merkle_path(&self, pos: usize) -> Result<Vec<B::Node>, Error> {
87        let mut merkle_path = Vec::new();
88        let mut pos = pos;
89
90        while pos != ROOT {
91            let Some(node) = self.nodes.get(sibling_index(pos)) else {
92                // out of bounds, exit returning the current merkle_path
93                return Err(Error::OutOfBounds);
94            };
95            merkle_path.push(node.clone());
96
97            pos = parent_index(pos);
98        }
99
100        Ok(merkle_path)
101    }
102}
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use lambdaworks_math::field::{element::FieldElement, fields::u64_prime_field::U64PrimeField};
107
108    use crate::merkle_tree::{merkle::MerkleTree, test_merkle::TestBackend};
109
110    const MODULUS: u64 = 13;
111    type U64PF = U64PrimeField<MODULUS>;
112    type FE = FieldElement<U64PF>;
113
114    #[test]
115    fn build_merkle_tree_from_a_power_of_two_list_of_values() {
116        let values: Vec<FE> = (1..5).map(FE::new).collect();
117        let merkle_tree = MerkleTree::<TestBackend<U64PF>>::build(&values).unwrap();
118        assert_eq!(merkle_tree.root, FE::new(7)); // Adjusted expected value
119    }
120
121    #[test]
122    // expected | 8 | 7 | 1 | 6 | 1 | 7 | 7 | 2 | 4 | 6 | 8 | 10 | 10 | 10 | 10 |
123    fn build_merkle_tree_from_an_odd_set_of_leaves() {
124        const MODULUS: u64 = 13;
125        type U64PF = U64PrimeField<MODULUS>;
126        type FE = FieldElement<U64PF>;
127
128        let values: Vec<FE> = (1..6).map(FE::new).collect();
129        let merkle_tree = MerkleTree::<TestBackend<U64PF>>::build(&values).unwrap();
130        assert_eq!(merkle_tree.root, FE::new(8)); // Adjusted expected value
131    }
132
133    #[test]
134    fn build_merkle_tree_from_a_single_value() {
135        const MODULUS: u64 = 13;
136        type U64PF = U64PrimeField<MODULUS>;
137        type FE = FieldElement<U64PF>;
138
139        let values: Vec<FE> = vec![FE::new(1)]; // Single element
140        let merkle_tree = MerkleTree::<TestBackend<U64PF>>::build(&values).unwrap();
141        assert_eq!(merkle_tree.root, FE::new(2)); // Adjusted expected value
142    }
143
144    #[test]
145    fn build_empty_tree_should_not_panic() {
146        assert!(MerkleTree::<TestBackend<U64PF>>::build(&[]).is_none());
147    }
148}