1pub mod indexed;
2
3use std::marker::PhantomData;
4
5use light_hasher::{errors::HasherError, Hasher};
6use light_indexed_array::errors::IndexedArrayError;
7use thiserror::Error;
8
9#[derive(Debug, Error, PartialEq)]
10pub enum ReferenceMerkleTreeError {
11 #[error("Leaf {0} does not exist")]
12 LeafDoesNotExist(usize),
13 #[error("Hasher error: {0}")]
14 Hasher(#[from] HasherError),
15 #[error("Invalid proof length provided: {0} required {1}")]
16 InvalidProofLength(usize, usize),
17 #[error("IndexedArray error: {0}")]
18 IndexedArray(#[from] IndexedArrayError),
19 #[error("RootHistoryArrayLenNotSet")]
20 RootHistoryArrayLenNotSet,
21}
22
23#[derive(Debug, Clone)]
24pub struct MerkleTree<H>
25where
26 H: Hasher,
27{
28 pub height: usize,
29 pub capacity: usize,
30 pub canopy_depth: usize,
31 pub layers: Vec<Vec<[u8; 32]>>,
32 pub roots: Vec<[u8; 32]>,
33 pub rightmost_index: usize,
34 pub num_root_updates: usize,
35 pub sequence_number: usize,
36 pub root_history_start_offset: usize,
37 pub root_history_array_len: Option<usize>,
38 _hasher: PhantomData<H>,
40}
41
42impl<H> MerkleTree<H>
43where
44 H: Hasher,
45{
46 pub fn new(height: usize, canopy_depth: usize) -> Self {
47 Self {
48 height,
49 capacity: 1 << height,
50 canopy_depth,
51 layers: vec![Vec::new(); height],
52 roots: vec![H::zero_bytes()[height]],
53 rightmost_index: 0,
54 sequence_number: 0,
55 root_history_start_offset: 0,
56 root_history_array_len: None,
57 num_root_updates: 0,
58 _hasher: PhantomData,
59 }
60 }
61
62 pub fn new_with_history(
63 height: usize,
64 canopy_depth: usize,
65 root_history_start_offset: usize,
66 root_history_array_len: usize,
67 ) -> Self {
68 Self {
69 height,
70 capacity: 1 << height,
71 canopy_depth,
72 layers: vec![Vec::new(); height],
73 roots: vec![H::zero_bytes()[height]],
74 rightmost_index: 0,
75 sequence_number: 0,
76 root_history_start_offset,
77 root_history_array_len: Some(root_history_array_len),
78 num_root_updates: 0,
79 _hasher: PhantomData,
80 }
81 }
82
83 pub fn get_history_root_index(&self) -> Result<u16, ReferenceMerkleTreeError> {
84 if let Some(root_history_array_len) = self.root_history_array_len {
85 Ok(
86 ((self.rightmost_index - self.root_history_start_offset) % root_history_array_len)
87 .try_into()
88 .unwrap(),
89 )
90 } else {
91 Err(ReferenceMerkleTreeError::RootHistoryArrayLenNotSet)
92 }
93 }
94
95 pub fn get_history_root_index_v2(&self) -> Result<u16, ReferenceMerkleTreeError> {
97 if let Some(root_history_array_len) = self.root_history_array_len {
98 Ok(((self.num_root_updates) % root_history_array_len)
99 .try_into()
100 .unwrap())
101 } else {
102 Err(ReferenceMerkleTreeError::RootHistoryArrayLenNotSet)
103 }
104 }
105
106 pub fn canopy_size(&self) -> usize {
108 (1 << (self.canopy_depth + 1)) - 2
109 }
110
111 fn update_upper_layers(&mut self, mut i: usize) -> Result<(), HasherError> {
112 for level in 1..self.height {
113 i /= 2;
114
115 let left_index = i * 2;
116 let right_index = i * 2 + 1;
117
118 let left_child = self.layers[level - 1]
119 .get(left_index)
120 .cloned()
121 .unwrap_or(H::zero_bytes()[level - 1]);
122 let right_child = self.layers[level - 1]
123 .get(right_index)
124 .cloned()
125 .unwrap_or(H::zero_bytes()[level - 1]);
126
127 let node = H::hashv(&[&left_child[..], &right_child[..]])?;
128 if self.layers[level].len() > i {
129 self.layers[level][i] = node;
131 } else {
132 self.layers[level].push(node);
134 }
135 }
136
137 let left_child = &self.layers[self.height - 1]
138 .first()
139 .cloned()
140 .unwrap_or(H::zero_bytes()[self.height - 1]);
141 let right_child = &self.layers[self.height - 1]
142 .get(1)
143 .cloned()
144 .unwrap_or(H::zero_bytes()[self.height - 1]);
145 let root = H::hashv(&[&left_child[..], &right_child[..]])?;
146
147 self.roots.push(root);
148
149 Ok(())
150 }
151
152 pub fn append(&mut self, leaf: &[u8; 32]) -> Result<(), HasherError> {
153 self.layers[0].push(*leaf);
154
155 let i = self.rightmost_index;
156 if self.rightmost_index == self.capacity {
157 println!("Merkle tree full");
158 return Err(HasherError::IntegerOverflow);
159 }
160 self.rightmost_index += 1;
161
162 self.update_upper_layers(i)?;
163
164 self.sequence_number += 1;
165 Ok(())
166 }
167
168 pub fn append_batch(&mut self, leaves: &[&[u8; 32]]) -> Result<(), HasherError> {
169 for leaf in leaves {
170 self.append(leaf)?;
171 }
172 Ok(())
173 }
174
175 pub fn update(
176 &mut self,
177 leaf: &[u8; 32],
178 leaf_index: usize,
179 ) -> Result<(), ReferenceMerkleTreeError> {
180 *self.layers[0]
181 .get_mut(leaf_index)
182 .ok_or(ReferenceMerkleTreeError::LeafDoesNotExist(leaf_index))? = *leaf;
183
184 self.update_upper_layers(leaf_index)?;
185
186 self.sequence_number += 1;
187 Ok(())
188 }
189
190 pub fn root(&self) -> [u8; 32] {
191 self.roots.last().cloned().unwrap()
195 }
196
197 pub fn get_path_of_leaf(
198 &self,
199 mut index: usize,
200 full: bool,
201 ) -> Result<Vec<[u8; 32]>, ReferenceMerkleTreeError> {
202 let mut path = Vec::with_capacity(self.height);
203 let limit = match full {
204 true => self.height,
205 false => self.height - self.canopy_depth,
206 };
207
208 for level in 0..limit {
209 let node = self.layers[level]
210 .get(index)
211 .cloned()
212 .unwrap_or(H::zero_bytes()[level]);
213 path.push(node);
214
215 index /= 2;
216 }
217
218 Ok(path)
219 }
220
221 pub fn get_proof_of_leaf(
222 &self,
223 mut index: usize,
224 full: bool,
225 ) -> Result<Vec<[u8; 32]>, ReferenceMerkleTreeError> {
226 let mut proof = Vec::with_capacity(self.height);
227 let limit = match full {
228 true => self.height,
229 false => self.height - self.canopy_depth,
230 };
231
232 for level in 0..limit {
233 #[allow(clippy::manual_is_multiple_of)]
234 let is_left = index % 2 == 0;
235
236 let sibling_index = if is_left { index + 1 } else { index - 1 };
237 let node = self.layers[level]
238 .get(sibling_index)
239 .cloned()
240 .unwrap_or(H::zero_bytes()[level]);
241 proof.push(node);
242
243 index /= 2;
244 }
245
246 Ok(proof)
247 }
248
249 pub fn get_proof_by_indices(&self, indices: &[i32]) -> Vec<Vec<[u8; 32]>> {
250 let mut proofs = Vec::new();
251 for &index in indices {
252 let mut index = index as usize;
253 let mut proof = Vec::with_capacity(self.height);
254
255 for level in 0..self.height {
256 #[allow(clippy::manual_is_multiple_of)]
257 let is_left = index % 2 == 0;
258 let sibling_index = if is_left { index + 1 } else { index - 1 };
259 let node = self.layers[level]
260 .get(sibling_index)
261 .cloned()
262 .unwrap_or(H::zero_bytes()[level]);
263 proof.push(node);
264 index /= 2;
265 }
266 proofs.push(proof);
267 }
268 proofs
269 }
270
271 pub fn get_canopy(&self) -> Result<Vec<[u8; 32]>, ReferenceMerkleTreeError> {
272 if self.canopy_depth == 0 {
273 return Ok(Vec::with_capacity(0));
274 }
275 let mut canopy = Vec::with_capacity(self.canopy_size());
276
277 let mut num_nodes_in_level = 2;
278 for i in 0..self.canopy_depth {
279 let level = self.height - 1 - i;
280 for j in 0..num_nodes_in_level {
281 let node = self.layers[level]
282 .get(j)
283 .cloned()
284 .unwrap_or(H::zero_bytes()[level]);
285 canopy.push(node);
286 }
287 num_nodes_in_level *= 2;
288 }
289
290 Ok(canopy)
291 }
292
293 pub fn leaf(&self, leaf_index: usize) -> [u8; 32] {
294 self.layers[0]
295 .get(leaf_index)
296 .cloned()
297 .unwrap_or(H::zero_bytes()[0])
298 }
299
300 pub fn get_leaf_index(&self, leaf: &[u8; 32]) -> Option<usize> {
301 self.layers[0].iter().position(|node| node == leaf)
302 }
303
304 pub fn leaves(&self) -> &[[u8; 32]] {
305 self.layers[0].as_slice()
306 }
307
308 pub fn verify(
309 &self,
310 leaf: &[u8; 32],
311 proof: &[[u8; 32]],
312 leaf_index: usize,
313 ) -> Result<bool, ReferenceMerkleTreeError> {
314 if leaf_index >= self.capacity {
315 return Err(ReferenceMerkleTreeError::LeafDoesNotExist(leaf_index));
316 }
317 if proof.len() != self.height {
318 return Err(ReferenceMerkleTreeError::InvalidProofLength(
319 proof.len(),
320 self.height,
321 ));
322 }
323
324 let mut computed_hash = *leaf;
325 let mut current_index = leaf_index;
326
327 for sibling_hash in proof.iter() {
328 #[allow(clippy::manual_is_multiple_of)]
329 let is_left = current_index % 2 == 0;
330 let hashes = if is_left {
331 [&computed_hash[..], &sibling_hash[..]]
332 } else {
333 [&sibling_hash[..], &computed_hash[..]]
334 };
335
336 computed_hash = H::hashv(&hashes)?;
337 current_index /= 2;
339 }
340
341 Ok(computed_hash == self.root())
343 }
344
345 pub fn get_subtrees(&self) -> Vec<[u8; 32]> {
349 let mut subtrees = H::zero_bytes()[0..self.height].to_vec();
350 if self.layers.last().and_then(|layer| layer.first()).is_some() {
351 for level in (0..self.height).rev() {
352 if let Some(left_child) = self.layers.get(level).and_then(|layer| {
353 if layer.len() % 2 == 0 {
354 layer.get(layer.len() - 2)
355 } else {
356 layer.last()
357 }
358 }) {
359 subtrees[level] = *left_child;
360 }
361 }
362 }
363 subtrees
364 }
365
366 pub fn get_next_index(&self) -> usize {
367 self.rightmost_index + 1
368 }
369
370 pub fn get_leaf(&self, index: usize) -> Result<[u8; 32], ReferenceMerkleTreeError> {
371 self.layers[0]
372 .get(index)
373 .cloned()
374 .ok_or(ReferenceMerkleTreeError::LeafDoesNotExist(index))
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use light_hasher::{zero_bytes::poseidon::ZERO_BYTES, Poseidon};
381
382 use super::*;
383
384 const TREE_AFTER_1_UPDATE: [[u8; 32]; 4] = [
385 [
386 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
387 0, 0, 1,
388 ],
389 [
390 0, 122, 243, 70, 226, 211, 4, 39, 158, 121, 224, 169, 243, 2, 63, 119, 18, 148, 167,
391 138, 203, 112, 231, 63, 144, 175, 226, 124, 173, 64, 30, 129,
392 ],
393 [
394 4, 163, 62, 195, 162, 201, 237, 49, 131, 153, 66, 155, 106, 112, 192, 40, 76, 131, 230,
395 239, 224, 130, 106, 36, 128, 57, 172, 107, 60, 247, 103, 194,
396 ],
397 [
398 7, 118, 172, 114, 242, 52, 137, 62, 111, 106, 113, 139, 123, 161, 39, 255, 86, 13, 105,
399 167, 223, 52, 15, 29, 137, 37, 106, 178, 49, 44, 226, 75,
400 ],
401 ];
402
403 const TREE_AFTER_2_UPDATES: [[u8; 32]; 4] = [
404 [
405 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
406 0, 0, 2,
407 ],
408 [
409 0, 122, 243, 70, 226, 211, 4, 39, 158, 121, 224, 169, 243, 2, 63, 119, 18, 148, 167,
410 138, 203, 112, 231, 63, 144, 175, 226, 124, 173, 64, 30, 129,
411 ],
412 [
413 18, 102, 129, 25, 152, 42, 192, 218, 100, 215, 169, 202, 77, 24, 100, 133, 45, 152, 17,
414 121, 103, 9, 187, 226, 182, 36, 35, 35, 126, 255, 244, 140,
415 ],
416 [
417 11, 230, 92, 56, 65, 91, 231, 137, 40, 92, 11, 193, 90, 225, 123, 79, 82, 17, 212, 147,
418 43, 41, 126, 223, 49, 2, 139, 211, 249, 138, 7, 12,
419 ],
420 ];
421
422 #[test]
423 fn test_subtrees() {
424 let tree_depth = 4;
425 let mut tree = MerkleTree::<Poseidon>::new(tree_depth, 0);
426
427 let subtrees = tree.get_subtrees();
428 for (i, subtree) in subtrees.iter().enumerate() {
429 assert_eq!(*subtree, ZERO_BYTES[i]);
430 }
431
432 let mut leaf_0: [u8; 32] = [0; 32];
433 leaf_0[31] = 1;
434 tree.append(&leaf_0).unwrap();
435 tree.append(&leaf_0).unwrap();
436
437 let subtrees = tree.get_subtrees();
438 for (i, subtree) in subtrees.iter().enumerate() {
439 assert_eq!(*subtree, TREE_AFTER_1_UPDATE[i]);
440 }
441
442 let mut leaf_1: [u8; 32] = [0; 32];
443 leaf_1[31] = 2;
444 tree.append(&leaf_1).unwrap();
445 tree.append(&leaf_1).unwrap();
446
447 let subtrees = tree.get_subtrees();
448 for (i, subtree) in subtrees.iter().enumerate() {
449 assert_eq!(*subtree, TREE_AFTER_2_UPDATES[i]);
450 }
451 }
452}