light_merkle_tree_reference/
lib.rs

1pub mod indexed;
2pub mod sparse_merkle_tree;
3
4use std::marker::PhantomData;
5
6use light_hasher::{errors::HasherError, Hasher};
7use light_indexed_array::errors::IndexedArrayError;
8use thiserror::Error;
9
10#[derive(Debug, Error, PartialEq)]
11pub enum ReferenceMerkleTreeError {
12    #[error("Leaf {0} does not exist")]
13    LeafDoesNotExist(usize),
14    #[error("Hasher error: {0}")]
15    Hasher(#[from] HasherError),
16    #[error("Invalid proof length provided: {0} required {1}")]
17    InvalidProofLength(usize, usize),
18    #[error("IndexedArray error: {0}")]
19    IndexedArray(#[from] IndexedArrayError),
20    #[error("RootHistoryArrayLenNotSet")]
21    RootHistoryArrayLenNotSet,
22}
23
24#[derive(Debug, Clone)]
25pub struct MerkleTree<H>
26where
27    H: Hasher,
28{
29    pub height: usize,
30    pub capacity: usize,
31    pub canopy_depth: usize,
32    pub layers: Vec<Vec<[u8; 32]>>,
33    pub roots: Vec<[u8; 32]>,
34    pub rightmost_index: usize,
35    pub num_root_updates: usize,
36    pub sequence_number: usize,
37    pub root_history_start_offset: usize,
38    pub root_history_array_len: Option<usize>,
39    // pub batch_size: Option<usize>,
40    _hasher: PhantomData<H>,
41}
42
43impl<H> MerkleTree<H>
44where
45    H: Hasher,
46{
47    pub fn new(height: usize, canopy_depth: usize) -> Self {
48        Self {
49            height,
50            capacity: 1 << height,
51            canopy_depth,
52            layers: vec![Vec::new(); height],
53            roots: vec![H::zero_bytes()[height]],
54            rightmost_index: 0,
55            sequence_number: 0,
56            root_history_start_offset: 0,
57            root_history_array_len: None,
58            num_root_updates: 0,
59            _hasher: PhantomData,
60        }
61    }
62
63    pub fn new_with_history(
64        height: usize,
65        canopy_depth: usize,
66        root_history_start_offset: usize,
67        root_history_array_len: usize,
68    ) -> Self {
69        Self {
70            height,
71            capacity: 1 << height,
72            canopy_depth,
73            layers: vec![Vec::new(); height],
74            roots: vec![H::zero_bytes()[height]],
75            rightmost_index: 0,
76            sequence_number: 0,
77            root_history_start_offset,
78            root_history_array_len: Some(root_history_array_len),
79            num_root_updates: 0,
80            _hasher: PhantomData,
81        }
82    }
83
84    pub fn get_history_root_index(&self) -> Result<u16, ReferenceMerkleTreeError> {
85        if let Some(root_history_array_len) = self.root_history_array_len {
86            println!("root_history_array_len {}", root_history_array_len);
87            println!("rightmost_index {}", self.rightmost_index);
88            println!(
89                "root_history_start_offset {}",
90                self.root_history_start_offset
91            );
92            Ok(
93                ((self.rightmost_index - self.root_history_start_offset) % root_history_array_len)
94                    .try_into()
95                    .unwrap(),
96            )
97        } else {
98            Err(ReferenceMerkleTreeError::RootHistoryArrayLenNotSet)
99        }
100    }
101
102    /// Get root history index for v2 (batched) Merkle trees.
103    pub fn get_history_root_index_v2(&self) -> Result<u16, ReferenceMerkleTreeError> {
104        if let Some(root_history_array_len) = self.root_history_array_len {
105            println!("root_history_array_len {}", root_history_array_len);
106            println!("rightmost_index {}", self.rightmost_index);
107            println!("num_root_updates {}", self.num_root_updates);
108            Ok(((self.num_root_updates) % root_history_array_len)
109                .try_into()
110                .unwrap())
111        } else {
112            Err(ReferenceMerkleTreeError::RootHistoryArrayLenNotSet)
113        }
114    }
115
116    /// Number of nodes to include in canopy, based on `canopy_depth`.
117    pub fn canopy_size(&self) -> usize {
118        (1 << (self.canopy_depth + 1)) - 2
119    }
120
121    fn update_upper_layers(&mut self, mut i: usize) -> Result<(), HasherError> {
122        for level in 1..self.height {
123            i /= 2;
124
125            let left_index = i * 2;
126            let right_index = i * 2 + 1;
127
128            let left_child = self.layers[level - 1]
129                .get(left_index)
130                .cloned()
131                .unwrap_or(H::zero_bytes()[level - 1]);
132            let right_child = self.layers[level - 1]
133                .get(right_index)
134                .cloned()
135                .unwrap_or(H::zero_bytes()[level - 1]);
136
137            let node = H::hashv(&[&left_child[..], &right_child[..]])?;
138            if self.layers[level].len() > i {
139                // A node already exists and we are overwriting it.
140                self.layers[level][i] = node;
141            } else {
142                // A node didn't exist before.
143                self.layers[level].push(node);
144            }
145        }
146
147        let left_child = &self.layers[self.height - 1]
148            .first()
149            .cloned()
150            .unwrap_or(H::zero_bytes()[self.height - 1]);
151        let right_child = &self.layers[self.height - 1]
152            .get(1)
153            .cloned()
154            .unwrap_or(H::zero_bytes()[self.height - 1]);
155        let root = H::hashv(&[&left_child[..], &right_child[..]])?;
156
157        self.roots.push(root);
158
159        Ok(())
160    }
161
162    pub fn append(&mut self, leaf: &[u8; 32]) -> Result<(), HasherError> {
163        self.layers[0].push(*leaf);
164
165        let i = self.rightmost_index;
166        if self.rightmost_index == self.capacity {
167            println!("Merkle tree full");
168            return Err(HasherError::IntegerOverflow);
169        }
170        self.rightmost_index += 1;
171
172        self.update_upper_layers(i)?;
173
174        self.sequence_number += 1;
175        Ok(())
176    }
177
178    pub fn append_batch(&mut self, leaves: &[&[u8; 32]]) -> Result<(), HasherError> {
179        for leaf in leaves {
180            self.append(leaf)?;
181        }
182        Ok(())
183    }
184
185    pub fn update(
186        &mut self,
187        leaf: &[u8; 32],
188        leaf_index: usize,
189    ) -> Result<(), ReferenceMerkleTreeError> {
190        *self.layers[0]
191            .get_mut(leaf_index)
192            .ok_or(ReferenceMerkleTreeError::LeafDoesNotExist(leaf_index))? = *leaf;
193
194        self.update_upper_layers(leaf_index)?;
195
196        self.sequence_number += 1;
197        Ok(())
198    }
199
200    pub fn root(&self) -> [u8; 32] {
201        // PANICS: We always initialize the Merkle tree with a
202        // root (from zero bytes), so the following should never
203        // panic.
204        self.roots.last().cloned().unwrap()
205    }
206
207    pub fn get_path_of_leaf(
208        &self,
209        mut index: usize,
210        full: bool,
211    ) -> Result<Vec<[u8; 32]>, ReferenceMerkleTreeError> {
212        let mut path = Vec::with_capacity(self.height);
213        let limit = match full {
214            true => self.height,
215            false => self.height - self.canopy_depth,
216        };
217
218        for level in 0..limit {
219            let node = self.layers[level]
220                .get(index)
221                .cloned()
222                .unwrap_or(H::zero_bytes()[level]);
223            path.push(node);
224
225            index /= 2;
226        }
227
228        Ok(path)
229    }
230
231    pub fn get_proof_of_leaf(
232        &self,
233        mut index: usize,
234        full: bool,
235    ) -> Result<Vec<[u8; 32]>, ReferenceMerkleTreeError> {
236        let mut proof = Vec::with_capacity(self.height);
237        let limit = match full {
238            true => self.height,
239            false => self.height - self.canopy_depth,
240        };
241
242        for level in 0..limit {
243            let is_left = index % 2 == 0;
244
245            let sibling_index = if is_left { index + 1 } else { index - 1 };
246            let node = self.layers[level]
247                .get(sibling_index)
248                .cloned()
249                .unwrap_or(H::zero_bytes()[level]);
250            proof.push(node);
251
252            index /= 2;
253        }
254
255        Ok(proof)
256    }
257
258    pub fn get_proof_by_indices(&self, indices: &[i32]) -> Vec<Vec<[u8; 32]>> {
259        let mut proofs = Vec::new();
260        for &index in indices {
261            let mut index = index as usize;
262            let mut proof = Vec::with_capacity(self.height);
263
264            for level in 0..self.height {
265                let is_left = index % 2 == 0;
266                let sibling_index = if is_left { index + 1 } else { index - 1 };
267                let node = self.layers[level]
268                    .get(sibling_index)
269                    .cloned()
270                    .unwrap_or(H::zero_bytes()[level]);
271                proof.push(node);
272                index /= 2;
273            }
274            proofs.push(proof);
275        }
276        proofs
277    }
278
279    pub fn get_canopy(&self) -> Result<Vec<[u8; 32]>, ReferenceMerkleTreeError> {
280        if self.canopy_depth == 0 {
281            return Ok(Vec::with_capacity(0));
282        }
283        let mut canopy = Vec::with_capacity(self.canopy_size());
284
285        let mut num_nodes_in_level = 2;
286        for i in 0..self.canopy_depth {
287            let level = self.height - 1 - i;
288            for j in 0..num_nodes_in_level {
289                let node = self.layers[level]
290                    .get(j)
291                    .cloned()
292                    .unwrap_or(H::zero_bytes()[level]);
293                canopy.push(node);
294            }
295            num_nodes_in_level *= 2;
296        }
297
298        Ok(canopy)
299    }
300
301    pub fn leaf(&self, leaf_index: usize) -> [u8; 32] {
302        self.layers[0]
303            .get(leaf_index)
304            .cloned()
305            .unwrap_or(H::zero_bytes()[0])
306    }
307
308    pub fn get_leaf_index(&self, leaf: &[u8; 32]) -> Option<usize> {
309        self.layers[0].iter().position(|node| node == leaf)
310    }
311
312    pub fn leaves(&self) -> &[[u8; 32]] {
313        self.layers[0].as_slice()
314    }
315
316    pub fn verify(
317        &self,
318        leaf: &[u8; 32],
319        proof: &[[u8; 32]],
320        leaf_index: usize,
321    ) -> Result<bool, ReferenceMerkleTreeError> {
322        if leaf_index >= self.capacity {
323            return Err(ReferenceMerkleTreeError::LeafDoesNotExist(leaf_index));
324        }
325        if proof.len() != self.height {
326            return Err(ReferenceMerkleTreeError::InvalidProofLength(
327                proof.len(),
328                self.height,
329            ));
330        }
331
332        let mut computed_hash = *leaf;
333        let mut current_index = leaf_index;
334
335        for sibling_hash in proof.iter() {
336            let is_left = current_index % 2 == 0;
337            let hashes = if is_left {
338                [&computed_hash[..], &sibling_hash[..]]
339            } else {
340                [&sibling_hash[..], &computed_hash[..]]
341            };
342
343            computed_hash = H::hashv(&hashes)?;
344            // Move to the parent index for the next iteration
345            current_index /= 2;
346        }
347
348        // Compare the computed hash to the last known root
349        Ok(computed_hash == self.root())
350    }
351
352    /// Returns the filled subtrees of the Merkle tree.
353    /// Subtrees are the rightmost left node of each level.
354    /// Subtrees can be used for efficient append operations.
355    pub fn get_subtrees(&self) -> Vec<[u8; 32]> {
356        let mut subtrees = H::zero_bytes()[0..self.height].to_vec();
357        if self.layers.last().and_then(|layer| layer.first()).is_some() {
358            for level in (0..self.height).rev() {
359                if let Some(left_child) = self.layers.get(level).and_then(|layer| {
360                    if layer.len() % 2 == 0 {
361                        layer.get(layer.len() - 2)
362                    } else {
363                        layer.last()
364                    }
365                }) {
366                    subtrees[level] = *left_child;
367                }
368            }
369        }
370        subtrees
371    }
372
373    pub fn get_next_index(&self) -> usize {
374        self.rightmost_index + 1
375    }
376
377    pub fn get_leaf(&self, index: usize) -> Result<[u8; 32], ReferenceMerkleTreeError> {
378        self.layers[0]
379            .get(index)
380            .cloned()
381            .ok_or(ReferenceMerkleTreeError::LeafDoesNotExist(index))
382    }
383}
384
385#[cfg(test)]
386mod tests {
387    use light_hasher::{zero_bytes::poseidon::ZERO_BYTES, Poseidon};
388
389    use super::*;
390
391    const TREE_AFTER_1_UPDATE: [[u8; 32]; 4] = [
392        [
393            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
394            0, 0, 1,
395        ],
396        [
397            0, 122, 243, 70, 226, 211, 4, 39, 158, 121, 224, 169, 243, 2, 63, 119, 18, 148, 167,
398            138, 203, 112, 231, 63, 144, 175, 226, 124, 173, 64, 30, 129,
399        ],
400        [
401            4, 163, 62, 195, 162, 201, 237, 49, 131, 153, 66, 155, 106, 112, 192, 40, 76, 131, 230,
402            239, 224, 130, 106, 36, 128, 57, 172, 107, 60, 247, 103, 194,
403        ],
404        [
405            7, 118, 172, 114, 242, 52, 137, 62, 111, 106, 113, 139, 123, 161, 39, 255, 86, 13, 105,
406            167, 223, 52, 15, 29, 137, 37, 106, 178, 49, 44, 226, 75,
407        ],
408    ];
409
410    const TREE_AFTER_2_UPDATES: [[u8; 32]; 4] = [
411        [
412            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
413            0, 0, 2,
414        ],
415        [
416            0, 122, 243, 70, 226, 211, 4, 39, 158, 121, 224, 169, 243, 2, 63, 119, 18, 148, 167,
417            138, 203, 112, 231, 63, 144, 175, 226, 124, 173, 64, 30, 129,
418        ],
419        [
420            18, 102, 129, 25, 152, 42, 192, 218, 100, 215, 169, 202, 77, 24, 100, 133, 45, 152, 17,
421            121, 103, 9, 187, 226, 182, 36, 35, 35, 126, 255, 244, 140,
422        ],
423        [
424            11, 230, 92, 56, 65, 91, 231, 137, 40, 92, 11, 193, 90, 225, 123, 79, 82, 17, 212, 147,
425            43, 41, 126, 223, 49, 2, 139, 211, 249, 138, 7, 12,
426        ],
427    ];
428
429    #[test]
430    fn test_subtrees() {
431        let tree_depth = 4;
432        let mut tree = MerkleTree::<Poseidon>::new(tree_depth, 0);
433
434        let subtrees = tree.get_subtrees();
435        for (i, subtree) in subtrees.iter().enumerate() {
436            assert_eq!(*subtree, ZERO_BYTES[i]);
437        }
438
439        let mut leaf_0: [u8; 32] = [0; 32];
440        leaf_0[31] = 1;
441        tree.append(&leaf_0).unwrap();
442        tree.append(&leaf_0).unwrap();
443
444        let subtrees = tree.get_subtrees();
445        for (i, subtree) in subtrees.iter().enumerate() {
446            assert_eq!(*subtree, TREE_AFTER_1_UPDATE[i]);
447        }
448
449        let mut leaf_1: [u8; 32] = [0; 32];
450        leaf_1[31] = 2;
451        tree.append(&leaf_1).unwrap();
452        tree.append(&leaf_1).unwrap();
453
454        let subtrees = tree.get_subtrees();
455        for (i, subtree) in subtrees.iter().enumerate() {
456            assert_eq!(*subtree, TREE_AFTER_2_UPDATES[i]);
457        }
458    }
459}