light_merkle_tree_reference/
lib.rs

1pub mod indexed;
2
3use std::marker::PhantomData;
4
5use light_hasher::{errors::HasherError, Hasher};
6use light_indexed_array::errors::IndexedArrayError;
7use thiserror::Error;
8
9#[derive(Debug, Error, PartialEq)]
10pub enum ReferenceMerkleTreeError {
11    #[error("Leaf {0} does not exist")]
12    LeafDoesNotExist(usize),
13    #[error("Hasher error: {0}")]
14    Hasher(#[from] HasherError),
15    #[error("Invalid proof length provided: {0} required {1}")]
16    InvalidProofLength(usize, usize),
17    #[error("IndexedArray error: {0}")]
18    IndexedArray(#[from] IndexedArrayError),
19    #[error("RootHistoryArrayLenNotSet")]
20    RootHistoryArrayLenNotSet,
21}
22
23#[derive(Debug, Clone)]
24pub struct MerkleTree<H>
25where
26    H: Hasher,
27{
28    pub height: usize,
29    pub capacity: usize,
30    pub canopy_depth: usize,
31    pub layers: Vec<Vec<[u8; 32]>>,
32    pub roots: Vec<[u8; 32]>,
33    pub rightmost_index: usize,
34    pub num_root_updates: usize,
35    pub sequence_number: usize,
36    pub root_history_start_offset: usize,
37    pub root_history_array_len: Option<usize>,
38    // pub batch_size: Option<usize>,
39    _hasher: PhantomData<H>,
40}
41
42impl<H> MerkleTree<H>
43where
44    H: Hasher,
45{
46    pub fn new(height: usize, canopy_depth: usize) -> Self {
47        Self {
48            height,
49            capacity: 1 << height,
50            canopy_depth,
51            layers: vec![Vec::new(); height],
52            roots: vec![H::zero_bytes()[height]],
53            rightmost_index: 0,
54            sequence_number: 0,
55            root_history_start_offset: 0,
56            root_history_array_len: None,
57            num_root_updates: 0,
58            _hasher: PhantomData,
59        }
60    }
61
62    pub fn new_with_history(
63        height: usize,
64        canopy_depth: usize,
65        root_history_start_offset: usize,
66        root_history_array_len: usize,
67    ) -> Self {
68        Self {
69            height,
70            capacity: 1 << height,
71            canopy_depth,
72            layers: vec![Vec::new(); height],
73            roots: vec![H::zero_bytes()[height]],
74            rightmost_index: 0,
75            sequence_number: 0,
76            root_history_start_offset,
77            root_history_array_len: Some(root_history_array_len),
78            num_root_updates: 0,
79            _hasher: PhantomData,
80        }
81    }
82
83    pub fn get_history_root_index(&self) -> Result<u16, ReferenceMerkleTreeError> {
84        if let Some(root_history_array_len) = self.root_history_array_len {
85            Ok(
86                ((self.rightmost_index - self.root_history_start_offset) % root_history_array_len)
87                    .try_into()
88                    .unwrap(),
89            )
90        } else {
91            Err(ReferenceMerkleTreeError::RootHistoryArrayLenNotSet)
92        }
93    }
94
95    /// Get root history index for v2 (batched) Merkle trees.
96    pub fn get_history_root_index_v2(&self) -> Result<u16, ReferenceMerkleTreeError> {
97        if let Some(root_history_array_len) = self.root_history_array_len {
98            Ok(((self.num_root_updates) % root_history_array_len)
99                .try_into()
100                .unwrap())
101        } else {
102            Err(ReferenceMerkleTreeError::RootHistoryArrayLenNotSet)
103        }
104    }
105
106    /// Number of nodes to include in canopy, based on `canopy_depth`.
107    pub fn canopy_size(&self) -> usize {
108        (1 << (self.canopy_depth + 1)) - 2
109    }
110
111    fn update_upper_layers(&mut self, mut i: usize) -> Result<(), HasherError> {
112        for level in 1..self.height {
113            i /= 2;
114
115            let left_index = i * 2;
116            let right_index = i * 2 + 1;
117
118            let left_child = self.layers[level - 1]
119                .get(left_index)
120                .cloned()
121                .unwrap_or(H::zero_bytes()[level - 1]);
122            let right_child = self.layers[level - 1]
123                .get(right_index)
124                .cloned()
125                .unwrap_or(H::zero_bytes()[level - 1]);
126
127            let node = H::hashv(&[&left_child[..], &right_child[..]])?;
128            if self.layers[level].len() > i {
129                // A node already exists and we are overwriting it.
130                self.layers[level][i] = node;
131            } else {
132                // A node didn't exist before.
133                self.layers[level].push(node);
134            }
135        }
136
137        let left_child = &self.layers[self.height - 1]
138            .first()
139            .cloned()
140            .unwrap_or(H::zero_bytes()[self.height - 1]);
141        let right_child = &self.layers[self.height - 1]
142            .get(1)
143            .cloned()
144            .unwrap_or(H::zero_bytes()[self.height - 1]);
145        let root = H::hashv(&[&left_child[..], &right_child[..]])?;
146
147        self.roots.push(root);
148
149        Ok(())
150    }
151
152    pub fn append(&mut self, leaf: &[u8; 32]) -> Result<(), HasherError> {
153        self.layers[0].push(*leaf);
154
155        let i = self.rightmost_index;
156        if self.rightmost_index == self.capacity {
157            println!("Merkle tree full");
158            return Err(HasherError::IntegerOverflow);
159        }
160        self.rightmost_index += 1;
161
162        self.update_upper_layers(i)?;
163
164        self.sequence_number += 1;
165        Ok(())
166    }
167
168    pub fn append_batch(&mut self, leaves: &[&[u8; 32]]) -> Result<(), HasherError> {
169        for leaf in leaves {
170            self.append(leaf)?;
171        }
172        Ok(())
173    }
174
175    pub fn update(
176        &mut self,
177        leaf: &[u8; 32],
178        leaf_index: usize,
179    ) -> Result<(), ReferenceMerkleTreeError> {
180        *self.layers[0]
181            .get_mut(leaf_index)
182            .ok_or(ReferenceMerkleTreeError::LeafDoesNotExist(leaf_index))? = *leaf;
183
184        self.update_upper_layers(leaf_index)?;
185
186        self.sequence_number += 1;
187        Ok(())
188    }
189
190    pub fn root(&self) -> [u8; 32] {
191        // PANICS: We always initialize the Merkle tree with a
192        // root (from zero bytes), so the following should never
193        // panic.
194        self.roots.last().cloned().unwrap()
195    }
196
197    pub fn get_path_of_leaf(
198        &self,
199        mut index: usize,
200        full: bool,
201    ) -> Result<Vec<[u8; 32]>, ReferenceMerkleTreeError> {
202        let mut path = Vec::with_capacity(self.height);
203        let limit = match full {
204            true => self.height,
205            false => self.height - self.canopy_depth,
206        };
207
208        for level in 0..limit {
209            let node = self.layers[level]
210                .get(index)
211                .cloned()
212                .unwrap_or(H::zero_bytes()[level]);
213            path.push(node);
214
215            index /= 2;
216        }
217
218        Ok(path)
219    }
220
221    pub fn get_proof_of_leaf(
222        &self,
223        mut index: usize,
224        full: bool,
225    ) -> Result<Vec<[u8; 32]>, ReferenceMerkleTreeError> {
226        let mut proof = Vec::with_capacity(self.height);
227        let limit = match full {
228            true => self.height,
229            false => self.height - self.canopy_depth,
230        };
231
232        for level in 0..limit {
233            #[allow(clippy::manual_is_multiple_of)]
234            let is_left = index % 2 == 0;
235
236            let sibling_index = if is_left { index + 1 } else { index - 1 };
237            let node = self.layers[level]
238                .get(sibling_index)
239                .cloned()
240                .unwrap_or(H::zero_bytes()[level]);
241            proof.push(node);
242
243            index /= 2;
244        }
245
246        Ok(proof)
247    }
248
249    pub fn get_proof_by_indices(&self, indices: &[i32]) -> Vec<Vec<[u8; 32]>> {
250        let mut proofs = Vec::new();
251        for &index in indices {
252            let mut index = index as usize;
253            let mut proof = Vec::with_capacity(self.height);
254
255            for level in 0..self.height {
256                #[allow(clippy::manual_is_multiple_of)]
257                let is_left = index % 2 == 0;
258                let sibling_index = if is_left { index + 1 } else { index - 1 };
259                let node = self.layers[level]
260                    .get(sibling_index)
261                    .cloned()
262                    .unwrap_or(H::zero_bytes()[level]);
263                proof.push(node);
264                index /= 2;
265            }
266            proofs.push(proof);
267        }
268        proofs
269    }
270
271    pub fn get_canopy(&self) -> Result<Vec<[u8; 32]>, ReferenceMerkleTreeError> {
272        if self.canopy_depth == 0 {
273            return Ok(Vec::with_capacity(0));
274        }
275        let mut canopy = Vec::with_capacity(self.canopy_size());
276
277        let mut num_nodes_in_level = 2;
278        for i in 0..self.canopy_depth {
279            let level = self.height - 1 - i;
280            for j in 0..num_nodes_in_level {
281                let node = self.layers[level]
282                    .get(j)
283                    .cloned()
284                    .unwrap_or(H::zero_bytes()[level]);
285                canopy.push(node);
286            }
287            num_nodes_in_level *= 2;
288        }
289
290        Ok(canopy)
291    }
292
293    pub fn leaf(&self, leaf_index: usize) -> [u8; 32] {
294        self.layers[0]
295            .get(leaf_index)
296            .cloned()
297            .unwrap_or(H::zero_bytes()[0])
298    }
299
300    pub fn get_leaf_index(&self, leaf: &[u8; 32]) -> Option<usize> {
301        self.layers[0].iter().position(|node| node == leaf)
302    }
303
304    pub fn leaves(&self) -> &[[u8; 32]] {
305        self.layers[0].as_slice()
306    }
307
308    pub fn verify(
309        &self,
310        leaf: &[u8; 32],
311        proof: &[[u8; 32]],
312        leaf_index: usize,
313    ) -> Result<bool, ReferenceMerkleTreeError> {
314        if leaf_index >= self.capacity {
315            return Err(ReferenceMerkleTreeError::LeafDoesNotExist(leaf_index));
316        }
317        if proof.len() != self.height {
318            return Err(ReferenceMerkleTreeError::InvalidProofLength(
319                proof.len(),
320                self.height,
321            ));
322        }
323
324        let mut computed_hash = *leaf;
325        let mut current_index = leaf_index;
326
327        for sibling_hash in proof.iter() {
328            #[allow(clippy::manual_is_multiple_of)]
329            let is_left = current_index % 2 == 0;
330            let hashes = if is_left {
331                [&computed_hash[..], &sibling_hash[..]]
332            } else {
333                [&sibling_hash[..], &computed_hash[..]]
334            };
335
336            computed_hash = H::hashv(&hashes)?;
337            // Move to the parent index for the next iteration
338            current_index /= 2;
339        }
340
341        // Compare the computed hash to the last known root
342        Ok(computed_hash == self.root())
343    }
344
345    /// Returns the filled subtrees of the Merkle tree.
346    /// Subtrees are the rightmost left node of each level.
347    /// Subtrees can be used for efficient append operations.
348    pub fn get_subtrees(&self) -> Vec<[u8; 32]> {
349        let mut subtrees = H::zero_bytes()[0..self.height].to_vec();
350        if self.layers.last().and_then(|layer| layer.first()).is_some() {
351            for level in (0..self.height).rev() {
352                if let Some(left_child) = self.layers.get(level).and_then(|layer| {
353                    if layer.len() % 2 == 0 {
354                        layer.get(layer.len() - 2)
355                    } else {
356                        layer.last()
357                    }
358                }) {
359                    subtrees[level] = *left_child;
360                }
361            }
362        }
363        subtrees
364    }
365
366    pub fn get_next_index(&self) -> usize {
367        self.rightmost_index + 1
368    }
369
370    pub fn get_leaf(&self, index: usize) -> Result<[u8; 32], ReferenceMerkleTreeError> {
371        self.layers[0]
372            .get(index)
373            .cloned()
374            .ok_or(ReferenceMerkleTreeError::LeafDoesNotExist(index))
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use light_hasher::{zero_bytes::poseidon::ZERO_BYTES, Poseidon};
381
382    use super::*;
383
384    const TREE_AFTER_1_UPDATE: [[u8; 32]; 4] = [
385        [
386            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
387            0, 0, 1,
388        ],
389        [
390            0, 122, 243, 70, 226, 211, 4, 39, 158, 121, 224, 169, 243, 2, 63, 119, 18, 148, 167,
391            138, 203, 112, 231, 63, 144, 175, 226, 124, 173, 64, 30, 129,
392        ],
393        [
394            4, 163, 62, 195, 162, 201, 237, 49, 131, 153, 66, 155, 106, 112, 192, 40, 76, 131, 230,
395            239, 224, 130, 106, 36, 128, 57, 172, 107, 60, 247, 103, 194,
396        ],
397        [
398            7, 118, 172, 114, 242, 52, 137, 62, 111, 106, 113, 139, 123, 161, 39, 255, 86, 13, 105,
399            167, 223, 52, 15, 29, 137, 37, 106, 178, 49, 44, 226, 75,
400        ],
401    ];
402
403    const TREE_AFTER_2_UPDATES: [[u8; 32]; 4] = [
404        [
405            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
406            0, 0, 2,
407        ],
408        [
409            0, 122, 243, 70, 226, 211, 4, 39, 158, 121, 224, 169, 243, 2, 63, 119, 18, 148, 167,
410            138, 203, 112, 231, 63, 144, 175, 226, 124, 173, 64, 30, 129,
411        ],
412        [
413            18, 102, 129, 25, 152, 42, 192, 218, 100, 215, 169, 202, 77, 24, 100, 133, 45, 152, 17,
414            121, 103, 9, 187, 226, 182, 36, 35, 35, 126, 255, 244, 140,
415        ],
416        [
417            11, 230, 92, 56, 65, 91, 231, 137, 40, 92, 11, 193, 90, 225, 123, 79, 82, 17, 212, 147,
418            43, 41, 126, 223, 49, 2, 139, 211, 249, 138, 7, 12,
419        ],
420    ];
421
422    #[test]
423    fn test_subtrees() {
424        let tree_depth = 4;
425        let mut tree = MerkleTree::<Poseidon>::new(tree_depth, 0);
426
427        let subtrees = tree.get_subtrees();
428        for (i, subtree) in subtrees.iter().enumerate() {
429            assert_eq!(*subtree, ZERO_BYTES[i]);
430        }
431
432        let mut leaf_0: [u8; 32] = [0; 32];
433        leaf_0[31] = 1;
434        tree.append(&leaf_0).unwrap();
435        tree.append(&leaf_0).unwrap();
436
437        let subtrees = tree.get_subtrees();
438        for (i, subtree) in subtrees.iter().enumerate() {
439            assert_eq!(*subtree, TREE_AFTER_1_UPDATE[i]);
440        }
441
442        let mut leaf_1: [u8; 32] = [0; 32];
443        leaf_1[31] = 2;
444        tree.append(&leaf_1).unwrap();
445        tree.append(&leaf_1).unwrap();
446
447        let subtrees = tree.get_subtrees();
448        for (i, subtree) in subtrees.iter().enumerate() {
449            assert_eq!(*subtree, TREE_AFTER_2_UPDATES[i]);
450        }
451    }
452}