1pub mod indexed;
2pub mod sparse_merkle_tree;
3
4use std::marker::PhantomData;
5
6use light_hasher::{errors::HasherError, Hasher};
7use light_indexed_array::errors::IndexedArrayError;
8use thiserror::Error;
9
10#[derive(Debug, Error, PartialEq)]
11pub enum ReferenceMerkleTreeError {
12 #[error("Leaf {0} does not exist")]
13 LeafDoesNotExist(usize),
14 #[error("Hasher error: {0}")]
15 Hasher(#[from] HasherError),
16 #[error("Invalid proof length provided: {0} required {1}")]
17 InvalidProofLength(usize, usize),
18 #[error("IndexedArray error: {0}")]
19 IndexedArray(#[from] IndexedArrayError),
20 #[error("RootHistoryArrayLenNotSet")]
21 RootHistoryArrayLenNotSet,
22}
23
24#[derive(Debug, Clone)]
25pub struct MerkleTree<H>
26where
27 H: Hasher,
28{
29 pub height: usize,
30 pub capacity: usize,
31 pub canopy_depth: usize,
32 pub layers: Vec<Vec<[u8; 32]>>,
33 pub roots: Vec<[u8; 32]>,
34 pub rightmost_index: usize,
35 pub num_root_updates: usize,
36 pub sequence_number: usize,
37 pub root_history_start_offset: usize,
38 pub root_history_array_len: Option<usize>,
39 _hasher: PhantomData<H>,
41}
42
43impl<H> MerkleTree<H>
44where
45 H: Hasher,
46{
47 pub fn new(height: usize, canopy_depth: usize) -> Self {
48 Self {
49 height,
50 capacity: 1 << height,
51 canopy_depth,
52 layers: vec![Vec::new(); height],
53 roots: vec![H::zero_bytes()[height]],
54 rightmost_index: 0,
55 sequence_number: 0,
56 root_history_start_offset: 0,
57 root_history_array_len: None,
58 num_root_updates: 0,
59 _hasher: PhantomData,
60 }
61 }
62
63 pub fn new_with_history(
64 height: usize,
65 canopy_depth: usize,
66 root_history_start_offset: usize,
67 root_history_array_len: usize,
68 ) -> Self {
69 Self {
70 height,
71 capacity: 1 << height,
72 canopy_depth,
73 layers: vec![Vec::new(); height],
74 roots: vec![H::zero_bytes()[height]],
75 rightmost_index: 0,
76 sequence_number: 0,
77 root_history_start_offset,
78 root_history_array_len: Some(root_history_array_len),
79 num_root_updates: 0,
80 _hasher: PhantomData,
81 }
82 }
83
84 pub fn get_history_root_index(&self) -> Result<u16, ReferenceMerkleTreeError> {
85 if let Some(root_history_array_len) = self.root_history_array_len {
86 println!("root_history_array_len {}", root_history_array_len);
87 println!("rightmost_index {}", self.rightmost_index);
88 println!(
89 "root_history_start_offset {}",
90 self.root_history_start_offset
91 );
92 Ok(
93 ((self.rightmost_index - self.root_history_start_offset) % root_history_array_len)
94 .try_into()
95 .unwrap(),
96 )
97 } else {
98 Err(ReferenceMerkleTreeError::RootHistoryArrayLenNotSet)
99 }
100 }
101
102 pub fn get_history_root_index_v2(&self) -> Result<u16, ReferenceMerkleTreeError> {
104 if let Some(root_history_array_len) = self.root_history_array_len {
105 println!("root_history_array_len {}", root_history_array_len);
106 println!("rightmost_index {}", self.rightmost_index);
107 println!("num_root_updates {}", self.num_root_updates);
108 Ok(((self.num_root_updates) % root_history_array_len)
109 .try_into()
110 .unwrap())
111 } else {
112 Err(ReferenceMerkleTreeError::RootHistoryArrayLenNotSet)
113 }
114 }
115
116 pub fn canopy_size(&self) -> usize {
118 (1 << (self.canopy_depth + 1)) - 2
119 }
120
121 fn update_upper_layers(&mut self, mut i: usize) -> Result<(), HasherError> {
122 for level in 1..self.height {
123 i /= 2;
124
125 let left_index = i * 2;
126 let right_index = i * 2 + 1;
127
128 let left_child = self.layers[level - 1]
129 .get(left_index)
130 .cloned()
131 .unwrap_or(H::zero_bytes()[level - 1]);
132 let right_child = self.layers[level - 1]
133 .get(right_index)
134 .cloned()
135 .unwrap_or(H::zero_bytes()[level - 1]);
136
137 let node = H::hashv(&[&left_child[..], &right_child[..]])?;
138 if self.layers[level].len() > i {
139 self.layers[level][i] = node;
141 } else {
142 self.layers[level].push(node);
144 }
145 }
146
147 let left_child = &self.layers[self.height - 1]
148 .first()
149 .cloned()
150 .unwrap_or(H::zero_bytes()[self.height - 1]);
151 let right_child = &self.layers[self.height - 1]
152 .get(1)
153 .cloned()
154 .unwrap_or(H::zero_bytes()[self.height - 1]);
155 let root = H::hashv(&[&left_child[..], &right_child[..]])?;
156
157 self.roots.push(root);
158
159 Ok(())
160 }
161
162 pub fn append(&mut self, leaf: &[u8; 32]) -> Result<(), HasherError> {
163 self.layers[0].push(*leaf);
164
165 let i = self.rightmost_index;
166 if self.rightmost_index == self.capacity {
167 println!("Merkle tree full");
168 return Err(HasherError::IntegerOverflow);
169 }
170 self.rightmost_index += 1;
171
172 self.update_upper_layers(i)?;
173
174 self.sequence_number += 1;
175 Ok(())
176 }
177
178 pub fn append_batch(&mut self, leaves: &[&[u8; 32]]) -> Result<(), HasherError> {
179 for leaf in leaves {
180 self.append(leaf)?;
181 }
182 Ok(())
183 }
184
185 pub fn update(
186 &mut self,
187 leaf: &[u8; 32],
188 leaf_index: usize,
189 ) -> Result<(), ReferenceMerkleTreeError> {
190 *self.layers[0]
191 .get_mut(leaf_index)
192 .ok_or(ReferenceMerkleTreeError::LeafDoesNotExist(leaf_index))? = *leaf;
193
194 self.update_upper_layers(leaf_index)?;
195
196 self.sequence_number += 1;
197 Ok(())
198 }
199
200 pub fn root(&self) -> [u8; 32] {
201 self.roots.last().cloned().unwrap()
205 }
206
207 pub fn get_path_of_leaf(
208 &self,
209 mut index: usize,
210 full: bool,
211 ) -> Result<Vec<[u8; 32]>, ReferenceMerkleTreeError> {
212 let mut path = Vec::with_capacity(self.height);
213 let limit = match full {
214 true => self.height,
215 false => self.height - self.canopy_depth,
216 };
217
218 for level in 0..limit {
219 let node = self.layers[level]
220 .get(index)
221 .cloned()
222 .unwrap_or(H::zero_bytes()[level]);
223 path.push(node);
224
225 index /= 2;
226 }
227
228 Ok(path)
229 }
230
231 pub fn get_proof_of_leaf(
232 &self,
233 mut index: usize,
234 full: bool,
235 ) -> Result<Vec<[u8; 32]>, ReferenceMerkleTreeError> {
236 let mut proof = Vec::with_capacity(self.height);
237 let limit = match full {
238 true => self.height,
239 false => self.height - self.canopy_depth,
240 };
241
242 for level in 0..limit {
243 let is_left = index % 2 == 0;
244
245 let sibling_index = if is_left { index + 1 } else { index - 1 };
246 let node = self.layers[level]
247 .get(sibling_index)
248 .cloned()
249 .unwrap_or(H::zero_bytes()[level]);
250 proof.push(node);
251
252 index /= 2;
253 }
254
255 Ok(proof)
256 }
257
258 pub fn get_proof_by_indices(&self, indices: &[i32]) -> Vec<Vec<[u8; 32]>> {
259 let mut proofs = Vec::new();
260 for &index in indices {
261 let mut index = index as usize;
262 let mut proof = Vec::with_capacity(self.height);
263
264 for level in 0..self.height {
265 let is_left = index % 2 == 0;
266 let sibling_index = if is_left { index + 1 } else { index - 1 };
267 let node = self.layers[level]
268 .get(sibling_index)
269 .cloned()
270 .unwrap_or(H::zero_bytes()[level]);
271 proof.push(node);
272 index /= 2;
273 }
274 proofs.push(proof);
275 }
276 proofs
277 }
278
279 pub fn get_canopy(&self) -> Result<Vec<[u8; 32]>, ReferenceMerkleTreeError> {
280 if self.canopy_depth == 0 {
281 return Ok(Vec::with_capacity(0));
282 }
283 let mut canopy = Vec::with_capacity(self.canopy_size());
284
285 let mut num_nodes_in_level = 2;
286 for i in 0..self.canopy_depth {
287 let level = self.height - 1 - i;
288 for j in 0..num_nodes_in_level {
289 let node = self.layers[level]
290 .get(j)
291 .cloned()
292 .unwrap_or(H::zero_bytes()[level]);
293 canopy.push(node);
294 }
295 num_nodes_in_level *= 2;
296 }
297
298 Ok(canopy)
299 }
300
301 pub fn leaf(&self, leaf_index: usize) -> [u8; 32] {
302 self.layers[0]
303 .get(leaf_index)
304 .cloned()
305 .unwrap_or(H::zero_bytes()[0])
306 }
307
308 pub fn get_leaf_index(&self, leaf: &[u8; 32]) -> Option<usize> {
309 self.layers[0].iter().position(|node| node == leaf)
310 }
311
312 pub fn leaves(&self) -> &[[u8; 32]] {
313 self.layers[0].as_slice()
314 }
315
316 pub fn verify(
317 &self,
318 leaf: &[u8; 32],
319 proof: &[[u8; 32]],
320 leaf_index: usize,
321 ) -> Result<bool, ReferenceMerkleTreeError> {
322 if leaf_index >= self.capacity {
323 return Err(ReferenceMerkleTreeError::LeafDoesNotExist(leaf_index));
324 }
325 if proof.len() != self.height {
326 return Err(ReferenceMerkleTreeError::InvalidProofLength(
327 proof.len(),
328 self.height,
329 ));
330 }
331
332 let mut computed_hash = *leaf;
333 let mut current_index = leaf_index;
334
335 for sibling_hash in proof.iter() {
336 let is_left = current_index % 2 == 0;
337 let hashes = if is_left {
338 [&computed_hash[..], &sibling_hash[..]]
339 } else {
340 [&sibling_hash[..], &computed_hash[..]]
341 };
342
343 computed_hash = H::hashv(&hashes)?;
344 current_index /= 2;
346 }
347
348 Ok(computed_hash == self.root())
350 }
351
352 pub fn get_subtrees(&self) -> Vec<[u8; 32]> {
356 let mut subtrees = H::zero_bytes()[0..self.height].to_vec();
357 if self.layers.last().and_then(|layer| layer.first()).is_some() {
358 for level in (0..self.height).rev() {
359 if let Some(left_child) = self.layers.get(level).and_then(|layer| {
360 if layer.len() % 2 == 0 {
361 layer.get(layer.len() - 2)
362 } else {
363 layer.last()
364 }
365 }) {
366 subtrees[level] = *left_child;
367 }
368 }
369 }
370 subtrees
371 }
372
373 pub fn get_next_index(&self) -> usize {
374 self.rightmost_index + 1
375 }
376
377 pub fn get_leaf(&self, index: usize) -> Result<[u8; 32], ReferenceMerkleTreeError> {
378 self.layers[0]
379 .get(index)
380 .cloned()
381 .ok_or(ReferenceMerkleTreeError::LeafDoesNotExist(index))
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use light_hasher::{zero_bytes::poseidon::ZERO_BYTES, Poseidon};
388
389 use super::*;
390
391 const TREE_AFTER_1_UPDATE: [[u8; 32]; 4] = [
392 [
393 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
394 0, 0, 1,
395 ],
396 [
397 0, 122, 243, 70, 226, 211, 4, 39, 158, 121, 224, 169, 243, 2, 63, 119, 18, 148, 167,
398 138, 203, 112, 231, 63, 144, 175, 226, 124, 173, 64, 30, 129,
399 ],
400 [
401 4, 163, 62, 195, 162, 201, 237, 49, 131, 153, 66, 155, 106, 112, 192, 40, 76, 131, 230,
402 239, 224, 130, 106, 36, 128, 57, 172, 107, 60, 247, 103, 194,
403 ],
404 [
405 7, 118, 172, 114, 242, 52, 137, 62, 111, 106, 113, 139, 123, 161, 39, 255, 86, 13, 105,
406 167, 223, 52, 15, 29, 137, 37, 106, 178, 49, 44, 226, 75,
407 ],
408 ];
409
410 const TREE_AFTER_2_UPDATES: [[u8; 32]; 4] = [
411 [
412 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
413 0, 0, 2,
414 ],
415 [
416 0, 122, 243, 70, 226, 211, 4, 39, 158, 121, 224, 169, 243, 2, 63, 119, 18, 148, 167,
417 138, 203, 112, 231, 63, 144, 175, 226, 124, 173, 64, 30, 129,
418 ],
419 [
420 18, 102, 129, 25, 152, 42, 192, 218, 100, 215, 169, 202, 77, 24, 100, 133, 45, 152, 17,
421 121, 103, 9, 187, 226, 182, 36, 35, 35, 126, 255, 244, 140,
422 ],
423 [
424 11, 230, 92, 56, 65, 91, 231, 137, 40, 92, 11, 193, 90, 225, 123, 79, 82, 17, 212, 147,
425 43, 41, 126, 223, 49, 2, 139, 211, 249, 138, 7, 12,
426 ],
427 ];
428
429 #[test]
430 fn test_subtrees() {
431 let tree_depth = 4;
432 let mut tree = MerkleTree::<Poseidon>::new(tree_depth, 0);
433
434 let subtrees = tree.get_subtrees();
435 for (i, subtree) in subtrees.iter().enumerate() {
436 assert_eq!(*subtree, ZERO_BYTES[i]);
437 }
438
439 let mut leaf_0: [u8; 32] = [0; 32];
440 leaf_0[31] = 1;
441 tree.append(&leaf_0).unwrap();
442 tree.append(&leaf_0).unwrap();
443
444 let subtrees = tree.get_subtrees();
445 for (i, subtree) in subtrees.iter().enumerate() {
446 assert_eq!(*subtree, TREE_AFTER_1_UPDATE[i]);
447 }
448
449 let mut leaf_1: [u8; 32] = [0; 32];
450 leaf_1[31] = 2;
451 tree.append(&leaf_1).unwrap();
452 tree.append(&leaf_1).unwrap();
453
454 let subtrees = tree.get_subtrees();
455 for (i, subtree) in subtrees.iter().enumerate() {
456 assert_eq!(*subtree, TREE_AFTER_2_UPDATES[i]);
457 }
458 }
459}