1use std::{
2 alloc::{self, handle_alloc_error, Layout},
3 iter::Skip,
4 marker::PhantomData,
5 mem,
6};
7
8use changelog::ChangelogPath;
9use light_bounded_vec::{
10 BoundedVec, BoundedVecMetadata, CyclicBoundedVec, CyclicBoundedVecIterator,
11 CyclicBoundedVecMetadata,
12};
13pub use light_hasher;
14use light_hasher::Hasher;
15
16pub mod changelog;
17pub mod copy;
18pub mod errors;
19pub mod event;
20pub mod hash;
21pub mod zero_copy;
22
23use crate::{
24 changelog::ChangelogEntry,
25 errors::ConcurrentMerkleTreeError,
26 hash::{compute_parent_node, compute_root},
27};
28
29#[repr(C)]
41#[derive(Debug)]
42pub struct ConcurrentMerkleTree<H, const HEIGHT: usize>
47where
48 H: Hasher,
49{
50 pub height: usize,
51 pub canopy_depth: usize,
52
53 pub next_index: *mut usize,
54 pub sequence_number: *mut usize,
55 pub rightmost_leaf: *mut [u8; 32],
56
57 pub filled_subtrees: BoundedVec<[u8; 32]>,
59 pub changelog: CyclicBoundedVec<ChangelogEntry<HEIGHT>>,
61 pub roots: CyclicBoundedVec<[u8; 32]>,
63 pub canopy: BoundedVec<[u8; 32]>,
65
66 pub _hasher: PhantomData<H>,
67}
68
69pub type ConcurrentMerkleTree26<H> = ConcurrentMerkleTree<H, 26>;
70
71impl<H, const HEIGHT: usize> ConcurrentMerkleTree<H, HEIGHT>
72where
73 H: Hasher,
74{
75 #[inline(always)]
77 pub fn canopy_size(canopy_depth: usize) -> usize {
78 (1 << (canopy_depth + 1)) - 2
79 }
80
81 pub fn non_dyn_fields_size() -> usize {
84 mem::size_of::<usize>()
86 + mem::size_of::<usize>()
88 + mem::size_of::<usize>()
90 + mem::size_of::<usize>()
92 + mem::size_of::<[u8; 32]>()
94 + mem::size_of::<BoundedVecMetadata>()
96 + mem::size_of::<CyclicBoundedVecMetadata>()
98 + mem::size_of::<CyclicBoundedVecMetadata>()
100 + mem::size_of::<BoundedVecMetadata>()
102 }
103
104 pub fn size_in_account(
106 height: usize,
107 changelog_size: usize,
108 roots_size: usize,
109 canopy_depth: usize,
110 ) -> usize {
111 Self::non_dyn_fields_size()
113 + mem::size_of::<[u8; 32]>() * height
115 + mem::size_of::<ChangelogEntry<HEIGHT>>() * changelog_size
117 + mem::size_of::<[u8; 32]>() * roots_size
119 + mem::size_of::<[u8; 32]>() * Self::canopy_size(canopy_depth)
121 }
122
123 fn check_size_constraints_new(
124 height: usize,
125 changelog_size: usize,
126 roots_size: usize,
127 canopy_depth: usize,
128 ) -> Result<(), ConcurrentMerkleTreeError> {
129 if height == 0 || HEIGHT == 0 {
130 return Err(ConcurrentMerkleTreeError::HeightZero);
131 }
132 if height != HEIGHT {
133 return Err(ConcurrentMerkleTreeError::InvalidHeight(HEIGHT));
134 }
135 if canopy_depth > height {
136 return Err(ConcurrentMerkleTreeError::CanopyGeThanHeight);
137 }
138 if changelog_size == 0 {
141 return Err(ConcurrentMerkleTreeError::ChangelogZero);
142 }
143 if roots_size == 0 {
144 return Err(ConcurrentMerkleTreeError::RootsZero);
145 }
146 Ok(())
147 }
148
149 fn check_size_constraints(&self) -> Result<(), ConcurrentMerkleTreeError> {
150 Self::check_size_constraints_new(
151 self.height,
152 self.changelog.capacity(),
153 self.roots.capacity(),
154 self.canopy_depth,
155 )
156 }
157
158 pub fn new(
159 height: usize,
160 changelog_size: usize,
161 roots_size: usize,
162 canopy_depth: usize,
163 ) -> Result<Self, ConcurrentMerkleTreeError> {
164 Self::check_size_constraints_new(height, changelog_size, roots_size, canopy_depth)?;
165
166 let layout = Layout::new::<usize>();
167 let next_index = unsafe { alloc::alloc(layout) as *mut usize };
168 if next_index.is_null() {
169 handle_alloc_error(layout);
170 }
171 unsafe { *next_index = 0 };
172
173 let layout = Layout::new::<usize>();
174 let sequence_number = unsafe { alloc::alloc(layout) as *mut usize };
175 if sequence_number.is_null() {
176 handle_alloc_error(layout);
177 }
178 unsafe { *sequence_number = 0 };
179
180 let layout = Layout::new::<[u8; 32]>();
181 let rightmost_leaf = unsafe { alloc::alloc(layout) as *mut [u8; 32] };
182 if rightmost_leaf.is_null() {
183 handle_alloc_error(layout);
184 }
185 unsafe { *rightmost_leaf = [0u8; 32] };
186
187 Ok(Self {
188 height,
189 canopy_depth,
190
191 next_index,
192 sequence_number,
193 rightmost_leaf,
194
195 filled_subtrees: BoundedVec::with_capacity(height),
196 changelog: CyclicBoundedVec::with_capacity(changelog_size),
197 roots: CyclicBoundedVec::with_capacity(roots_size),
198 canopy: BoundedVec::with_capacity(Self::canopy_size(canopy_depth)),
199
200 _hasher: PhantomData,
201 })
202 }
203
204 pub fn init(&mut self) -> Result<(), ConcurrentMerkleTreeError> {
206 self.check_size_constraints()?;
207
208 let root = H::zero_bytes()[self.height];
210 self.roots.push(root);
211
212 let path = ChangelogPath::from_fn(|i| Some(H::zero_bytes()[i]));
214 let changelog_entry = ChangelogEntry { path, index: 0 };
215 self.changelog.push(changelog_entry);
216
217 for i in 0..self.height {
219 self.filled_subtrees.push(H::zero_bytes()[i]).unwrap();
220 }
221
222 for level_i in 0..self.canopy_depth {
224 let level_nodes = 1 << (level_i + 1);
225 for _ in 0..level_nodes {
226 let node = H::zero_bytes()[self.height - level_i - 1];
227 self.canopy.push(node)?;
228 }
229 }
230
231 Ok(())
232 }
233
234 pub fn changelog_index(&self) -> usize {
236 self.changelog.last_index()
237 }
238
239 pub fn root_index(&self) -> usize {
241 self.roots.last_index()
242 }
243
244 pub fn root(&self) -> [u8; 32] {
246 self.roots[self.root_index()]
249 }
250
251 pub fn current_index(&self) -> usize {
252 let next_index = self.next_index();
253 if next_index > 0 {
254 next_index - 1
255 } else {
256 next_index
257 }
258 }
259
260 pub fn next_index(&self) -> usize {
261 unsafe { *self.next_index }
262 }
263
264 fn inc_next_index(&mut self) -> Result<(), ConcurrentMerkleTreeError> {
265 unsafe {
266 *self.next_index = self
267 .next_index()
268 .checked_add(1)
269 .ok_or(ConcurrentMerkleTreeError::IntegerOverflow)?;
270 }
271 Ok(())
272 }
273
274 pub fn sequence_number(&self) -> usize {
275 unsafe { *self.sequence_number }
276 }
277
278 fn inc_sequence_number(&mut self) -> Result<(), ConcurrentMerkleTreeError> {
279 unsafe {
280 *self.sequence_number = self
281 .sequence_number()
282 .checked_add(1)
283 .ok_or(ConcurrentMerkleTreeError::IntegerOverflow)?;
284 }
285 Ok(())
286 }
287
288 pub fn rightmost_leaf(&self) -> [u8; 32] {
289 unsafe { *self.rightmost_leaf }
290 }
291
292 fn set_rightmost_leaf(&mut self, leaf: &[u8; 32]) {
293 unsafe { *self.rightmost_leaf = *leaf };
294 }
295
296 pub fn update_proof_from_canopy(
297 &self,
298 leaf_index: usize,
299 proof: &mut BoundedVec<[u8; 32]>,
300 ) -> Result<(), ConcurrentMerkleTreeError> {
301 let mut node_index = ((1 << self.height) + leaf_index) >> (self.height - self.canopy_depth);
302 while node_index > 1 {
303 let canopy_index = node_index - 2;
305 let canopy_index = if canopy_index % 2 == 0 {
306 canopy_index + 1
307 } else {
308 canopy_index - 1
309 };
310 proof.push(self.canopy[canopy_index])?;
311 node_index >>= 1;
312 }
313
314 Ok(())
315 }
316
317 pub fn changelog_entries(
320 &self,
321 changelog_index: usize,
322 ) -> Result<Skip<CyclicBoundedVecIterator<'_, ChangelogEntry<HEIGHT>>>, ConcurrentMerkleTreeError>
323 {
324 Ok(self.changelog.iter_from(changelog_index)?.skip(1))
336 }
337
338 pub fn update_proof_from_changelog(
356 &self,
357 changelog_index: usize,
358 leaf_index: usize,
359 proof: &mut BoundedVec<[u8; 32]>,
360 ) -> Result<(), ConcurrentMerkleTreeError> {
361 for changelog_entry in self.changelog_entries(changelog_index)? {
367 changelog_entry.update_proof(leaf_index, proof)?;
368 }
369
370 Ok(())
371 }
372
373 pub fn validate_proof(
377 &self,
378 leaf: &[u8; 32],
379 leaf_index: usize,
380 proof: &BoundedVec<[u8; 32]>,
381 ) -> Result<(), ConcurrentMerkleTreeError> {
382 let expected_root = self.root();
383 let computed_root = compute_root::<H>(leaf, leaf_index, proof)?;
384 if computed_root == expected_root {
385 Ok(())
386 } else {
387 Err(ConcurrentMerkleTreeError::InvalidProof(
388 expected_root,
389 computed_root,
390 ))
391 }
392 }
393
394 fn update_leaf_in_tree(
410 &mut self,
411 new_leaf: &[u8; 32],
412 leaf_index: usize,
413 proof: &BoundedVec<[u8; 32]>,
414 ) -> Result<(usize, usize), ConcurrentMerkleTreeError> {
415 let mut changelog_entry = ChangelogEntry::default_with_index(leaf_index);
416 let mut current_node = *new_leaf;
417 for (level, sibling) in proof.iter().enumerate() {
418 changelog_entry.path[level] = Some(current_node);
419 current_node = compute_parent_node::<H>(¤t_node, sibling, leaf_index, level)?;
420 }
421
422 self.inc_sequence_number()?;
423
424 self.roots.push(current_node);
425
426 if self.next_index() < (1 << self.height) {
428 changelog_entry.update_proof(self.next_index(), &mut self.filled_subtrees)?;
429 if leaf_index >= self.current_index() {
431 self.set_rightmost_leaf(new_leaf);
432 }
433 }
434 self.changelog.push(changelog_entry);
435
436 if self.canopy_depth > 0 {
437 self.update_canopy(self.changelog.last_index(), 1);
438 }
439
440 Ok((self.changelog.last_index(), self.sequence_number()))
441 }
442
443 #[inline(never)]
447 pub fn update(
448 &mut self,
449 changelog_index: usize,
450 old_leaf: &[u8; 32],
451 new_leaf: &[u8; 32],
452 leaf_index: usize,
453 proof: &mut BoundedVec<[u8; 32]>,
454 ) -> Result<(usize, usize), ConcurrentMerkleTreeError> {
455 let expected_proof_len = self.height - self.canopy_depth;
456 if proof.len() != expected_proof_len {
457 return Err(ConcurrentMerkleTreeError::InvalidProofLength(
458 expected_proof_len,
459 proof.len(),
460 ));
461 }
462 if leaf_index >= self.next_index() {
463 return Err(ConcurrentMerkleTreeError::CannotUpdateEmpty);
464 }
465
466 if self.canopy_depth > 0 {
467 self.update_proof_from_canopy(leaf_index, proof)?;
468 }
469 if changelog_index != self.changelog_index() {
470 self.update_proof_from_changelog(changelog_index, leaf_index, proof)?;
471 }
472 self.validate_proof(old_leaf, leaf_index, proof)?;
473 self.update_leaf_in_tree(new_leaf, leaf_index, proof)
474 }
475
476 pub fn append(&mut self, leaf: &[u8; 32]) -> Result<(usize, usize), ConcurrentMerkleTreeError> {
478 self.append_batch(&[leaf])
479 }
480
481 pub fn append_with_proof(
484 &mut self,
485 leaf: &[u8; 32],
486 proof: &mut BoundedVec<[u8; 32]>,
487 ) -> Result<(usize, usize), ConcurrentMerkleTreeError> {
488 self.append_batch_with_proofs(&[leaf], &mut [proof])
489 }
490
491 pub fn append_batch(
493 &mut self,
494 leaves: &[&[u8; 32]],
495 ) -> Result<(usize, usize), ConcurrentMerkleTreeError> {
496 self.append_batch_common::<false>(leaves, None)
497 }
498
499 pub fn append_batch_with_proofs(
502 &mut self,
503 leaves: &[&[u8; 32]],
504 proofs: &mut [&mut BoundedVec<[u8; 32]>],
505 ) -> Result<(usize, usize), ConcurrentMerkleTreeError> {
506 self.append_batch_common::<true>(leaves, Some(proofs))
507 }
508
509 fn append_batch_common<
515 const WITH_PROOFS: bool,
528 >(
529 &mut self,
530 leaves: &[&[u8; 32]],
531 mut proofs: Option<&mut [&mut BoundedVec<[u8; 32]>]>,
535 ) -> Result<(usize, usize), ConcurrentMerkleTreeError> {
536 if leaves.is_empty() {
537 return Err(ConcurrentMerkleTreeError::EmptyLeaves);
538 }
539 if (self.next_index() + leaves.len() - 1) >= 1 << self.height {
540 return Err(ConcurrentMerkleTreeError::TreeFull);
541 }
542 if leaves.len() > self.changelog.capacity() {
543 return Err(ConcurrentMerkleTreeError::BatchGreaterThanChangelog(
544 leaves.len(),
545 self.changelog.capacity(),
546 ));
547 }
548
549 let first_changelog_index = (self.changelog.last_index() + 1) % self.changelog.capacity();
550 let first_sequence_number = self.sequence_number() + 1;
551
552 for (leaf_i, leaf) in leaves.iter().enumerate() {
553 let mut current_index = self.next_index();
554
555 self.changelog
556 .push(ChangelogEntry::<HEIGHT>::default_with_index(current_index));
557 let changelog_index = self.changelog_index();
558
559 let mut current_node = **leaf;
560
561 self.changelog[changelog_index].path[0] = Some(**leaf);
562
563 for i in 0..self.height {
564 let is_left = current_index % 2 == 0;
565
566 if is_left {
567 let empty_node = H::zero_bytes()[i];
579
580 if WITH_PROOFS {
581 proofs.as_mut().unwrap()[leaf_i].push(empty_node)?;
583 }
584
585 self.filled_subtrees[i] = current_node;
586
587 if leaf_i < leaves.len() - 1 {
592 break;
593 }
594
595 current_node = H::hashv(&[¤t_node, &empty_node])?;
596 } else {
597 if WITH_PROOFS {
609 proofs.as_mut().unwrap()[leaf_i].push(self.filled_subtrees[i])?;
611 }
612
613 current_node = H::hashv(&[&self.filled_subtrees[i], ¤t_node])?;
614 }
615
616 if i < self.height - 1 {
617 self.changelog[changelog_index].path[i + 1] = Some(current_node);
618 }
619
620 current_index /= 2;
621 }
622
623 if leaf_i == leaves.len() - 1 {
624 self.roots.push(current_node);
625 } else {
626 self.roots.push([0u8; 32]);
630 }
631
632 self.inc_next_index()?;
633 self.inc_sequence_number()?;
634
635 self.set_rightmost_leaf(leaf);
636 }
637
638 if self.canopy_depth > 0 {
639 self.update_canopy(first_changelog_index, leaves.len());
640 }
641
642 Ok((first_changelog_index, first_sequence_number))
643 }
644
645 fn update_canopy(&mut self, first_changelog_index: usize, num_leaves: usize) {
646 for i in 0..num_leaves {
647 let changelog_index = (first_changelog_index + i) % self.changelog.capacity();
648 for (i, path_node) in self.changelog[changelog_index]
649 .path
650 .iter()
651 .rev()
652 .take(self.canopy_depth)
653 .enumerate()
654 {
655 if let Some(path_node) = path_node {
656 let level = self.height - i - 1;
657 let index = (1 << (self.height - level))
658 + (self.changelog[changelog_index].index >> level);
659 self.canopy[(index - 2) as usize] = *path_node;
661 }
662 }
663 }
664 }
665}
666
667impl<H, const HEIGHT: usize> Drop for ConcurrentMerkleTree<H, HEIGHT>
668where
669 H: Hasher,
670{
671 fn drop(&mut self) {
672 let layout = Layout::new::<usize>();
673 unsafe { alloc::dealloc(self.next_index as *mut u8, layout) };
674
675 let layout = Layout::new::<usize>();
676 unsafe { alloc::dealloc(self.sequence_number as *mut u8, layout) };
677
678 let layout = Layout::new::<[u8; 32]>();
679 unsafe { alloc::dealloc(self.rightmost_leaf as *mut u8, layout) };
680 }
681}
682
683impl<H, const HEIGHT: usize> PartialEq for ConcurrentMerkleTree<H, HEIGHT>
684where
685 H: Hasher,
686{
687 fn eq(&self, other: &Self) -> bool {
688 self.height.eq(&other.height)
689 && self.canopy_depth.eq(&other.canopy_depth)
690 && self.next_index().eq(&other.next_index())
691 && self.sequence_number().eq(&other.sequence_number())
692 && self.rightmost_leaf().eq(&other.rightmost_leaf())
693 && self
694 .filled_subtrees
695 .as_slice()
696 .eq(other.filled_subtrees.as_slice())
697 && self.changelog.iter().eq(other.changelog.iter())
698 && self.roots.iter().eq(other.roots.iter())
699 && self.canopy.as_slice().eq(other.canopy.as_slice())
700 }
701}