1use std::{
2 alloc::{self, handle_alloc_error, Layout},
3 iter::Skip,
4 marker::PhantomData,
5 mem,
6};
7
8use changelog::ChangelogPath;
9use light_bounded_vec::{
10 BoundedVec, BoundedVecMetadata, CyclicBoundedVec, CyclicBoundedVecIterator,
11 CyclicBoundedVecMetadata,
12};
13pub use light_hasher;
14use light_hasher::Hasher;
15
16pub mod changelog;
17pub mod copy;
18pub mod errors;
19pub mod event;
20pub mod hash;
21pub mod offset;
22pub mod zero_copy;
23
24use crate::{
25 changelog::ChangelogEntry,
26 errors::ConcurrentMerkleTreeError,
27 hash::{compute_parent_node, compute_root},
28};
29
30#[repr(C)]
42#[derive(Debug)]
43pub struct ConcurrentMerkleTree<H, const HEIGHT: usize>
48where
49 H: Hasher,
50{
51 pub height: usize,
52 pub canopy_depth: usize,
53
54 next_index: *mut usize,
55 sequence_number: *mut usize,
56 rightmost_leaf: *mut [u8; 32],
57
58 pub filled_subtrees: BoundedVec<[u8; 32]>,
60 pub changelog: CyclicBoundedVec<ChangelogEntry<HEIGHT>>,
62 pub roots: CyclicBoundedVec<[u8; 32]>,
64 pub canopy: BoundedVec<[u8; 32]>,
66
67 pub _hasher: PhantomData<H>,
68}
69
70pub type ConcurrentMerkleTree26<H> = ConcurrentMerkleTree<H, 26>;
71
72impl<H, const HEIGHT: usize> ConcurrentMerkleTree<H, HEIGHT>
73where
74 H: Hasher,
75{
76 #[inline(always)]
78 pub fn canopy_size(canopy_depth: usize) -> usize {
79 (1 << (canopy_depth + 1)) - 2
80 }
81
82 pub fn non_dyn_fields_size() -> usize {
85 mem::size_of::<usize>()
87 + mem::size_of::<usize>()
89 + mem::size_of::<usize>()
91 + mem::size_of::<usize>()
93 + mem::size_of::<[u8; 32]>()
95 + mem::size_of::<BoundedVecMetadata>()
97 + mem::size_of::<CyclicBoundedVecMetadata>()
99 + mem::size_of::<CyclicBoundedVecMetadata>()
101 + mem::size_of::<BoundedVecMetadata>()
103 }
104
105 pub fn size_in_account(
107 height: usize,
108 changelog_size: usize,
109 roots_size: usize,
110 canopy_depth: usize,
111 ) -> usize {
112 Self::non_dyn_fields_size()
114 + mem::size_of::<[u8; 32]>() * height
116 + mem::size_of::<ChangelogEntry<HEIGHT>>() * changelog_size
118 + mem::size_of::<[u8; 32]>() * roots_size
120 + mem::size_of::<[u8; 32]>() * Self::canopy_size(canopy_depth)
122 }
123
124 fn check_size_constraints_new(
125 height: usize,
126 changelog_size: usize,
127 roots_size: usize,
128 canopy_depth: usize,
129 ) -> Result<(), ConcurrentMerkleTreeError> {
130 if height == 0 || HEIGHT == 0 {
131 return Err(ConcurrentMerkleTreeError::HeightZero);
132 }
133 if height != HEIGHT {
134 return Err(ConcurrentMerkleTreeError::InvalidHeight(HEIGHT));
135 }
136 if canopy_depth > height {
137 return Err(ConcurrentMerkleTreeError::CanopyGeThanHeight);
138 }
139 if changelog_size == 0 {
142 return Err(ConcurrentMerkleTreeError::ChangelogZero);
143 }
144 if roots_size == 0 {
145 return Err(ConcurrentMerkleTreeError::RootsZero);
146 }
147 Ok(())
148 }
149
150 fn check_size_constraints(&self) -> Result<(), ConcurrentMerkleTreeError> {
151 Self::check_size_constraints_new(
152 self.height,
153 self.changelog.capacity(),
154 self.roots.capacity(),
155 self.canopy_depth,
156 )
157 }
158
159 pub fn new(
160 height: usize,
161 changelog_size: usize,
162 roots_size: usize,
163 canopy_depth: usize,
164 ) -> Result<Self, ConcurrentMerkleTreeError> {
165 Self::check_size_constraints_new(height, changelog_size, roots_size, canopy_depth)?;
166
167 let layout = Layout::new::<usize>();
168 let next_index = unsafe { alloc::alloc(layout) as *mut usize };
169 if next_index.is_null() {
170 handle_alloc_error(layout);
171 }
172 unsafe { *next_index = 0 };
173
174 let layout = Layout::new::<usize>();
175 let sequence_number = unsafe { alloc::alloc(layout) as *mut usize };
176 if sequence_number.is_null() {
177 handle_alloc_error(layout);
178 }
179 unsafe { *sequence_number = 0 };
180
181 let layout = Layout::new::<[u8; 32]>();
182 let rightmost_leaf = unsafe { alloc::alloc(layout) as *mut [u8; 32] };
183 if rightmost_leaf.is_null() {
184 handle_alloc_error(layout);
185 }
186 unsafe { *rightmost_leaf = [0u8; 32] };
187
188 Ok(Self {
189 height,
190 canopy_depth,
191
192 next_index,
193 sequence_number,
194 rightmost_leaf,
195
196 filled_subtrees: BoundedVec::with_capacity(height),
197 changelog: CyclicBoundedVec::with_capacity(changelog_size),
198 roots: CyclicBoundedVec::with_capacity(roots_size),
199 canopy: BoundedVec::with_capacity(Self::canopy_size(canopy_depth)),
200
201 _hasher: PhantomData,
202 })
203 }
204
205 pub fn init(&mut self) -> Result<(), ConcurrentMerkleTreeError> {
207 self.check_size_constraints()?;
208
209 let root = H::zero_bytes()[self.height];
211 self.roots.push(root);
212
213 let path = ChangelogPath::from_fn(|i| Some(H::zero_bytes()[i]));
215 let changelog_entry = ChangelogEntry { path, index: 0 };
216 self.changelog.push(changelog_entry);
217
218 for i in 0..self.height {
220 self.filled_subtrees.push(H::zero_bytes()[i]).unwrap();
221 }
222
223 for level_i in 0..self.canopy_depth {
225 let level_nodes = 1 << (level_i + 1);
226 for _ in 0..level_nodes {
227 let node = H::zero_bytes()[self.height - level_i - 1];
228 self.canopy.push(node)?;
229 }
230 }
231
232 Ok(())
233 }
234
235 pub fn changelog_index(&self) -> usize {
237 self.changelog.last_index()
238 }
239
240 pub fn root_index(&self) -> usize {
242 self.roots.last_index()
243 }
244
245 pub fn root(&self) -> [u8; 32] {
247 self.roots[self.root_index()]
250 }
251
252 pub fn current_index(&self) -> usize {
253 let next_index = self.next_index();
254 if next_index > 0 {
255 next_index - 1
256 } else {
257 next_index
258 }
259 }
260
261 pub fn next_index(&self) -> usize {
262 unsafe { *self.next_index }
263 }
264
265 pub fn inc_next_index(&mut self) -> Result<(), ConcurrentMerkleTreeError> {
266 unsafe {
267 *self.next_index = self
268 .next_index()
269 .checked_add(1)
270 .ok_or(ConcurrentMerkleTreeError::IntegerOverflow)?;
271 }
272 Ok(())
273 }
274
275 pub fn sequence_number(&self) -> usize {
276 unsafe { *self.sequence_number }
277 }
278
279 pub fn inc_sequence_number(&mut self) -> Result<(), ConcurrentMerkleTreeError> {
280 unsafe {
281 *self.sequence_number = self
282 .sequence_number()
283 .checked_add(1)
284 .ok_or(ConcurrentMerkleTreeError::IntegerOverflow)?;
285 }
286 Ok(())
287 }
288
289 pub fn rightmost_leaf(&self) -> [u8; 32] {
290 unsafe { *self.rightmost_leaf }
291 }
292
293 fn set_rightmost_leaf(&mut self, leaf: &[u8; 32]) {
294 unsafe { *self.rightmost_leaf = *leaf };
295 }
296
297 pub fn update_proof_from_canopy(
298 &self,
299 leaf_index: usize,
300 proof: &mut BoundedVec<[u8; 32]>,
301 ) -> Result<(), ConcurrentMerkleTreeError> {
302 let mut node_index = ((1 << self.height) + leaf_index) >> (self.height - self.canopy_depth);
303 while node_index > 1 {
304 let canopy_index = node_index - 2;
306 let canopy_index = if canopy_index % 2 == 0 {
307 canopy_index + 1
308 } else {
309 canopy_index - 1
310 };
311 proof.push(self.canopy[canopy_index])?;
312 node_index >>= 1;
313 }
314
315 Ok(())
316 }
317
318 pub fn changelog_entries(
321 &self,
322 changelog_index: usize,
323 ) -> Result<Skip<CyclicBoundedVecIterator<'_, ChangelogEntry<HEIGHT>>>, ConcurrentMerkleTreeError>
324 {
325 Ok(self.changelog.iter_from(changelog_index)?.skip(1))
337 }
338
339 pub fn update_proof_from_changelog(
357 &self,
358 changelog_index: usize,
359 leaf_index: usize,
360 proof: &mut BoundedVec<[u8; 32]>,
361 ) -> Result<(), ConcurrentMerkleTreeError> {
362 for changelog_entry in self.changelog_entries(changelog_index)? {
368 changelog_entry.update_proof(leaf_index, proof)?;
369 }
370
371 Ok(())
372 }
373
374 pub fn validate_proof(
378 &self,
379 leaf: &[u8; 32],
380 leaf_index: usize,
381 proof: &BoundedVec<[u8; 32]>,
382 ) -> Result<(), ConcurrentMerkleTreeError> {
383 let expected_root = self.root();
384 let computed_root = compute_root::<H>(leaf, leaf_index, proof)?;
385 if computed_root == expected_root {
386 Ok(())
387 } else {
388 Err(ConcurrentMerkleTreeError::InvalidProof(
389 expected_root,
390 computed_root,
391 ))
392 }
393 }
394
395 fn update_leaf_in_tree(
411 &mut self,
412 new_leaf: &[u8; 32],
413 leaf_index: usize,
414 proof: &BoundedVec<[u8; 32]>,
415 ) -> Result<(usize, usize), ConcurrentMerkleTreeError> {
416 let mut changelog_entry = ChangelogEntry::default_with_index(leaf_index);
417 let mut current_node = *new_leaf;
418 for (level, sibling) in proof.iter().enumerate() {
419 changelog_entry.path[level] = Some(current_node);
420 current_node = compute_parent_node::<H>(¤t_node, sibling, leaf_index, level)?;
421 }
422
423 self.inc_sequence_number()?;
424
425 self.roots.push(current_node);
426
427 if self.next_index() < (1 << self.height) {
429 changelog_entry.update_proof(self.next_index(), &mut self.filled_subtrees)?;
430 if leaf_index >= self.current_index() {
432 self.set_rightmost_leaf(new_leaf);
433 }
434 }
435 self.changelog.push(changelog_entry);
436
437 if self.canopy_depth > 0 {
438 self.update_canopy(self.changelog.last_index(), 1);
439 }
440
441 Ok((self.changelog.last_index(), self.sequence_number()))
442 }
443
444 #[inline(never)]
448 pub fn update(
449 &mut self,
450 changelog_index: usize,
451 old_leaf: &[u8; 32],
452 new_leaf: &[u8; 32],
453 leaf_index: usize,
454 proof: &mut BoundedVec<[u8; 32]>,
455 ) -> Result<(usize, usize), ConcurrentMerkleTreeError> {
456 let expected_proof_len = self.height - self.canopy_depth;
457 if proof.len() != expected_proof_len {
458 return Err(ConcurrentMerkleTreeError::InvalidProofLength(
459 expected_proof_len,
460 proof.len(),
461 ));
462 }
463 if leaf_index >= self.next_index() {
464 return Err(ConcurrentMerkleTreeError::CannotUpdateEmpty);
465 }
466
467 if self.canopy_depth > 0 {
468 self.update_proof_from_canopy(leaf_index, proof)?;
469 }
470 if changelog_index != self.changelog_index() {
471 self.update_proof_from_changelog(changelog_index, leaf_index, proof)?;
472 }
473 self.validate_proof(old_leaf, leaf_index, proof)?;
474 self.update_leaf_in_tree(new_leaf, leaf_index, proof)
475 }
476
477 pub fn append(&mut self, leaf: &[u8; 32]) -> Result<(usize, usize), ConcurrentMerkleTreeError> {
479 self.append_batch(&[leaf])
480 }
481
482 pub fn append_with_proof(
485 &mut self,
486 leaf: &[u8; 32],
487 proof: &mut BoundedVec<[u8; 32]>,
488 ) -> Result<(usize, usize), ConcurrentMerkleTreeError> {
489 self.append_batch_with_proofs(&[leaf], &mut [proof])
490 }
491
492 pub fn append_batch(
494 &mut self,
495 leaves: &[&[u8; 32]],
496 ) -> Result<(usize, usize), ConcurrentMerkleTreeError> {
497 self.append_batch_common::<false>(leaves, None)
498 }
499
500 pub fn append_batch_with_proofs(
503 &mut self,
504 leaves: &[&[u8; 32]],
505 proofs: &mut [&mut BoundedVec<[u8; 32]>],
506 ) -> Result<(usize, usize), ConcurrentMerkleTreeError> {
507 self.append_batch_common::<true>(leaves, Some(proofs))
508 }
509
510 fn append_batch_common<
516 const WITH_PROOFS: bool,
529 >(
530 &mut self,
531 leaves: &[&[u8; 32]],
532 mut proofs: Option<&mut [&mut BoundedVec<[u8; 32]>]>,
536 ) -> Result<(usize, usize), ConcurrentMerkleTreeError> {
537 if leaves.is_empty() {
538 return Err(ConcurrentMerkleTreeError::EmptyLeaves);
539 }
540 if (self.next_index() + leaves.len() - 1) >= 1 << self.height {
541 return Err(ConcurrentMerkleTreeError::TreeIsFull);
542 }
543 if leaves.len() > self.changelog.capacity() {
544 return Err(ConcurrentMerkleTreeError::BatchGreaterThanChangelog(
545 leaves.len(),
546 self.changelog.capacity(),
547 ));
548 }
549
550 let first_changelog_index = (self.changelog.last_index() + 1) % self.changelog.capacity();
551 let first_sequence_number = self.sequence_number() + 1;
552
553 for (leaf_i, leaf) in leaves.iter().enumerate() {
554 let mut current_index = self.next_index();
555
556 self.changelog
557 .push(ChangelogEntry::<HEIGHT>::default_with_index(current_index));
558 let changelog_index = self.changelog_index();
559
560 let mut current_node = **leaf;
561
562 self.changelog[changelog_index].path[0] = Some(**leaf);
563
564 for i in 0..self.height {
565 let is_left = current_index % 2 == 0;
566
567 if is_left {
568 let empty_node = H::zero_bytes()[i];
580
581 if WITH_PROOFS {
582 proofs.as_mut().unwrap()[leaf_i].push(empty_node)?;
584 }
585
586 self.filled_subtrees[i] = current_node;
587
588 if leaf_i < leaves.len() - 1 {
593 break;
594 }
595
596 current_node = H::hashv(&[¤t_node, &empty_node])?;
597 } else {
598 if WITH_PROOFS {
610 proofs.as_mut().unwrap()[leaf_i].push(self.filled_subtrees[i])?;
612 }
613
614 current_node = H::hashv(&[&self.filled_subtrees[i], ¤t_node])?;
615 }
616
617 if i < self.height - 1 {
618 self.changelog[changelog_index].path[i + 1] = Some(current_node);
619 }
620
621 current_index /= 2;
622 }
623
624 if leaf_i == leaves.len() - 1 {
625 self.roots.push(current_node);
626 } else {
627 self.roots.push([0u8; 32]);
631 }
632
633 self.inc_next_index()?;
634 self.inc_sequence_number()?;
635
636 self.set_rightmost_leaf(leaf);
637 }
638
639 if self.canopy_depth > 0 {
640 self.update_canopy(first_changelog_index, leaves.len());
641 }
642
643 Ok((first_changelog_index, first_sequence_number))
644 }
645
646 fn update_canopy(&mut self, first_changelog_index: usize, num_leaves: usize) {
647 for i in 0..num_leaves {
648 let changelog_index = (first_changelog_index + i) % self.changelog.capacity();
649 for (i, path_node) in self.changelog[changelog_index]
650 .path
651 .iter()
652 .rev()
653 .take(self.canopy_depth)
654 .enumerate()
655 {
656 if let Some(path_node) = path_node {
657 let level = self.height - i - 1;
658 let index = (1 << (self.height - level))
659 + (self.changelog[changelog_index].index >> level);
660 self.canopy[(index - 2) as usize] = *path_node;
662 }
663 }
664 }
665 }
666}
667
668impl<H, const HEIGHT: usize> Drop for ConcurrentMerkleTree<H, HEIGHT>
669where
670 H: Hasher,
671{
672 fn drop(&mut self) {
673 let layout = Layout::new::<usize>();
674 unsafe { alloc::dealloc(self.next_index as *mut u8, layout) };
675
676 let layout = Layout::new::<usize>();
677 unsafe { alloc::dealloc(self.sequence_number as *mut u8, layout) };
678
679 let layout = Layout::new::<[u8; 32]>();
680 unsafe { alloc::dealloc(self.rightmost_leaf as *mut u8, layout) };
681 }
682}
683
684impl<H, const HEIGHT: usize> PartialEq for ConcurrentMerkleTree<H, HEIGHT>
685where
686 H: Hasher,
687{
688 fn eq(&self, other: &Self) -> bool {
689 self.height.eq(&other.height)
690 && self.canopy_depth.eq(&other.canopy_depth)
691 && self.next_index().eq(&other.next_index())
692 && self.sequence_number().eq(&other.sequence_number())
693 && self.rightmost_leaf().eq(&other.rightmost_leaf())
694 && self
695 .filled_subtrees
696 .as_slice()
697 .eq(other.filled_subtrees.as_slice())
698 && self.changelog.iter().eq(other.changelog.iter())
699 && self.roots.iter().eq(other.roots.iter())
700 && self.canopy.as_slice().eq(other.canopy.as_slice())
701 }
702}