1use std::{
2 alloc::{self, handle_alloc_error, Layout},
3 iter::Skip,
4 marker::PhantomData,
5 mem,
6};
7
8use changelog::ChangelogPath;
9use light_bounded_vec::{
10 BoundedVec, BoundedVecMetadata, CyclicBoundedVec, CyclicBoundedVecIterator,
11 CyclicBoundedVecMetadata,
12};
13pub use light_hasher;
14use light_hasher::Hasher;
15
16pub mod changelog;
17pub mod copy;
18pub mod errors;
19pub mod event;
20pub mod hash;
21pub mod offset;
22pub mod zero_copy;
23
24use crate::{
25 changelog::ChangelogEntry,
26 errors::ConcurrentMerkleTreeError,
27 hash::{compute_parent_node, compute_root},
28};
29
30#[repr(C)]
42#[derive(Debug)]
43pub struct ConcurrentMerkleTree<H, const HEIGHT: usize>
48where
49 H: Hasher,
50{
51 pub height: usize,
52 pub canopy_depth: usize,
53
54 next_index: *mut usize,
55 sequence_number: *mut usize,
56 rightmost_leaf: *mut [u8; 32],
57
58 pub filled_subtrees: BoundedVec<[u8; 32]>,
60 pub changelog: CyclicBoundedVec<ChangelogEntry<HEIGHT>>,
62 pub roots: CyclicBoundedVec<[u8; 32]>,
64 pub canopy: BoundedVec<[u8; 32]>,
66
67 pub _hasher: PhantomData<H>,
68}
69
70pub type ConcurrentMerkleTree26<H> = ConcurrentMerkleTree<H, 26>;
71
72impl<H, const HEIGHT: usize> ConcurrentMerkleTree<H, HEIGHT>
73where
74 H: Hasher,
75{
76 #[inline(always)]
78 pub fn canopy_size(canopy_depth: usize) -> usize {
79 (1 << (canopy_depth + 1)) - 2
80 }
81
82 pub fn non_dyn_fields_size() -> usize {
85 mem::size_of::<usize>()
87 + mem::size_of::<usize>()
89 + mem::size_of::<usize>()
91 + mem::size_of::<usize>()
93 + mem::size_of::<[u8; 32]>()
95 + mem::size_of::<BoundedVecMetadata>()
97 + mem::size_of::<CyclicBoundedVecMetadata>()
99 + mem::size_of::<CyclicBoundedVecMetadata>()
101 + mem::size_of::<BoundedVecMetadata>()
103 }
104
105 pub fn size_in_account(
107 height: usize,
108 changelog_size: usize,
109 roots_size: usize,
110 canopy_depth: usize,
111 ) -> usize {
112 Self::non_dyn_fields_size()
114 + mem::size_of::<[u8; 32]>() * height
116 + mem::size_of::<ChangelogEntry<HEIGHT>>() * changelog_size
118 + mem::size_of::<[u8; 32]>() * roots_size
120 + mem::size_of::<[u8; 32]>() * Self::canopy_size(canopy_depth)
122 }
123
124 fn check_size_constraints_new(
125 height: usize,
126 changelog_size: usize,
127 roots_size: usize,
128 canopy_depth: usize,
129 ) -> Result<(), ConcurrentMerkleTreeError> {
130 if height == 0 || HEIGHT == 0 {
131 return Err(ConcurrentMerkleTreeError::HeightZero);
132 }
133 if height != HEIGHT {
134 return Err(ConcurrentMerkleTreeError::InvalidHeight(HEIGHT));
135 }
136 if canopy_depth > height {
137 return Err(ConcurrentMerkleTreeError::CanopyGeThanHeight);
138 }
139 if changelog_size == 0 {
142 return Err(ConcurrentMerkleTreeError::ChangelogZero);
143 }
144 if roots_size == 0 {
145 return Err(ConcurrentMerkleTreeError::RootsZero);
146 }
147 Ok(())
148 }
149
150 fn check_size_constraints(&self) -> Result<(), ConcurrentMerkleTreeError> {
151 Self::check_size_constraints_new(
152 self.height,
153 self.changelog.capacity(),
154 self.roots.capacity(),
155 self.canopy_depth,
156 )
157 }
158
159 pub fn new(
160 height: usize,
161 changelog_size: usize,
162 roots_size: usize,
163 canopy_depth: usize,
164 ) -> Result<Self, ConcurrentMerkleTreeError> {
165 Self::check_size_constraints_new(height, changelog_size, roots_size, canopy_depth)?;
166
167 let layout = Layout::new::<usize>();
168 let next_index = unsafe { alloc::alloc(layout) as *mut usize };
169 if next_index.is_null() {
170 handle_alloc_error(layout);
171 }
172 unsafe { *next_index = 0 };
173
174 let layout = Layout::new::<usize>();
175 let sequence_number = unsafe { alloc::alloc(layout) as *mut usize };
176 if sequence_number.is_null() {
177 handle_alloc_error(layout);
178 }
179 unsafe { *sequence_number = 0 };
180
181 let layout = Layout::new::<[u8; 32]>();
182 let rightmost_leaf = unsafe { alloc::alloc(layout) as *mut [u8; 32] };
183 if rightmost_leaf.is_null() {
184 handle_alloc_error(layout);
185 }
186 unsafe { *rightmost_leaf = [0u8; 32] };
187
188 Ok(Self {
189 height,
190 canopy_depth,
191
192 next_index,
193 sequence_number,
194 rightmost_leaf,
195
196 filled_subtrees: BoundedVec::with_capacity(height),
197 changelog: CyclicBoundedVec::with_capacity(changelog_size),
198 roots: CyclicBoundedVec::with_capacity(roots_size),
199 canopy: BoundedVec::with_capacity(Self::canopy_size(canopy_depth)),
200
201 _hasher: PhantomData,
202 })
203 }
204
205 pub fn init(&mut self) -> Result<(), ConcurrentMerkleTreeError> {
207 self.check_size_constraints()?;
208
209 let root = H::zero_bytes()[self.height];
211 self.roots.push(root);
212
213 let path = ChangelogPath::from_fn(|i| Some(H::zero_bytes()[i]));
215 let changelog_entry = ChangelogEntry { path, index: 0 };
216 self.changelog.push(changelog_entry);
217
218 for i in 0..self.height {
220 self.filled_subtrees.push(H::zero_bytes()[i]).unwrap();
221 }
222
223 for level_i in 0..self.canopy_depth {
225 let level_nodes = 1 << (level_i + 1);
226 for _ in 0..level_nodes {
227 let node = H::zero_bytes()[self.height - level_i - 1];
228 self.canopy.push(node)?;
229 }
230 }
231
232 Ok(())
233 }
234
235 pub fn changelog_index(&self) -> usize {
237 self.changelog.last_index()
238 }
239
240 pub fn root_index(&self) -> usize {
242 self.roots.last_index()
243 }
244
245 pub fn root(&self) -> [u8; 32] {
247 self.roots[self.root_index()]
250 }
251
252 pub fn current_index(&self) -> usize {
253 let next_index = self.next_index();
254 if next_index > 0 {
255 next_index - 1
256 } else {
257 next_index
258 }
259 }
260
261 pub fn next_index(&self) -> usize {
262 unsafe { *self.next_index }
263 }
264
265 pub fn inc_next_index(&mut self) -> Result<(), ConcurrentMerkleTreeError> {
266 unsafe {
267 *self.next_index = self
268 .next_index()
269 .checked_add(1)
270 .ok_or(ConcurrentMerkleTreeError::IntegerOverflow)?;
271 }
272 Ok(())
273 }
274
275 pub fn sequence_number(&self) -> usize {
276 unsafe { *self.sequence_number }
277 }
278
279 pub fn inc_sequence_number(&mut self) -> Result<(), ConcurrentMerkleTreeError> {
280 unsafe {
281 *self.sequence_number = self
282 .sequence_number()
283 .checked_add(1)
284 .ok_or(ConcurrentMerkleTreeError::IntegerOverflow)?;
285 }
286 Ok(())
287 }
288
289 pub fn rightmost_leaf(&self) -> [u8; 32] {
290 unsafe { *self.rightmost_leaf }
291 }
292
293 fn set_rightmost_leaf(&mut self, leaf: &[u8; 32]) {
294 unsafe { *self.rightmost_leaf = *leaf };
295 }
296
297 pub fn update_proof_from_canopy(
298 &self,
299 leaf_index: usize,
300 proof: &mut BoundedVec<[u8; 32]>,
301 ) -> Result<(), ConcurrentMerkleTreeError> {
302 let mut node_index = ((1 << self.height) + leaf_index) >> (self.height - self.canopy_depth);
303 while node_index > 1 {
304 let canopy_index = node_index - 2;
306 #[allow(clippy::manual_is_multiple_of)]
307 let canopy_index = if canopy_index % 2 == 0 {
308 canopy_index + 1
309 } else {
310 canopy_index - 1
311 };
312 proof.push(self.canopy[canopy_index])?;
313 node_index >>= 1;
314 }
315
316 Ok(())
317 }
318
319 pub fn changelog_entries(
322 &self,
323 changelog_index: usize,
324 ) -> Result<Skip<CyclicBoundedVecIterator<'_, ChangelogEntry<HEIGHT>>>, ConcurrentMerkleTreeError>
325 {
326 Ok(self.changelog.iter_from(changelog_index)?.skip(1))
338 }
339
340 pub fn update_proof_from_changelog(
358 &self,
359 changelog_index: usize,
360 leaf_index: usize,
361 proof: &mut BoundedVec<[u8; 32]>,
362 ) -> Result<(), ConcurrentMerkleTreeError> {
363 for changelog_entry in self.changelog_entries(changelog_index)? {
369 changelog_entry.update_proof(leaf_index, proof)?;
370 }
371
372 Ok(())
373 }
374
375 pub fn validate_proof(
379 &self,
380 leaf: &[u8; 32],
381 leaf_index: usize,
382 proof: &BoundedVec<[u8; 32]>,
383 ) -> Result<(), ConcurrentMerkleTreeError> {
384 let expected_root = self.root();
385 let computed_root = compute_root::<H>(leaf, leaf_index, proof)?;
386 if computed_root == expected_root {
387 Ok(())
388 } else {
389 Err(ConcurrentMerkleTreeError::InvalidProof(
390 expected_root,
391 computed_root,
392 ))
393 }
394 }
395
396 fn update_leaf_in_tree(
412 &mut self,
413 new_leaf: &[u8; 32],
414 leaf_index: usize,
415 proof: &BoundedVec<[u8; 32]>,
416 ) -> Result<(usize, usize), ConcurrentMerkleTreeError> {
417 let mut changelog_entry = ChangelogEntry::default_with_index(leaf_index);
418 let mut current_node = *new_leaf;
419 for (level, sibling) in proof.iter().enumerate() {
420 changelog_entry.path[level] = Some(current_node);
421 current_node = compute_parent_node::<H>(¤t_node, sibling, leaf_index, level)?;
422 }
423
424 self.inc_sequence_number()?;
425
426 self.roots.push(current_node);
427
428 if self.next_index() < (1 << self.height) {
430 changelog_entry.update_proof(self.next_index(), &mut self.filled_subtrees)?;
431 if leaf_index >= self.current_index() {
433 self.set_rightmost_leaf(new_leaf);
434 }
435 }
436 self.changelog.push(changelog_entry);
437
438 if self.canopy_depth > 0 {
439 self.update_canopy(self.changelog.last_index(), 1);
440 }
441
442 Ok((self.changelog.last_index(), self.sequence_number()))
443 }
444
445 #[inline(never)]
449 pub fn update(
450 &mut self,
451 changelog_index: usize,
452 old_leaf: &[u8; 32],
453 new_leaf: &[u8; 32],
454 leaf_index: usize,
455 proof: &mut BoundedVec<[u8; 32]>,
456 ) -> Result<(usize, usize), ConcurrentMerkleTreeError> {
457 let expected_proof_len = self.height - self.canopy_depth;
458 if proof.len() != expected_proof_len {
459 return Err(ConcurrentMerkleTreeError::InvalidProofLength(
460 expected_proof_len,
461 proof.len(),
462 ));
463 }
464 if leaf_index >= self.next_index() {
465 return Err(ConcurrentMerkleTreeError::CannotUpdateEmpty);
466 }
467
468 if self.canopy_depth > 0 {
469 self.update_proof_from_canopy(leaf_index, proof)?;
470 }
471 if changelog_index != self.changelog_index() {
472 self.update_proof_from_changelog(changelog_index, leaf_index, proof)?;
473 }
474 self.validate_proof(old_leaf, leaf_index, proof)?;
475 self.update_leaf_in_tree(new_leaf, leaf_index, proof)
476 }
477
478 pub fn append(&mut self, leaf: &[u8; 32]) -> Result<(usize, usize), ConcurrentMerkleTreeError> {
480 self.append_batch(&[leaf])
481 }
482
483 pub fn append_with_proof(
486 &mut self,
487 leaf: &[u8; 32],
488 proof: &mut BoundedVec<[u8; 32]>,
489 ) -> Result<(usize, usize), ConcurrentMerkleTreeError> {
490 self.append_batch_with_proofs(&[leaf], &mut [proof])
491 }
492
493 pub fn append_batch(
495 &mut self,
496 leaves: &[&[u8; 32]],
497 ) -> Result<(usize, usize), ConcurrentMerkleTreeError> {
498 self.append_batch_common::<false>(leaves, None)
499 }
500
501 pub fn append_batch_with_proofs(
504 &mut self,
505 leaves: &[&[u8; 32]],
506 proofs: &mut [&mut BoundedVec<[u8; 32]>],
507 ) -> Result<(usize, usize), ConcurrentMerkleTreeError> {
508 self.append_batch_common::<true>(leaves, Some(proofs))
509 }
510
511 fn append_batch_common<
517 const WITH_PROOFS: bool,
530 >(
531 &mut self,
532 leaves: &[&[u8; 32]],
533 mut proofs: Option<&mut [&mut BoundedVec<[u8; 32]>]>,
537 ) -> Result<(usize, usize), ConcurrentMerkleTreeError> {
538 if leaves.is_empty() {
539 return Err(ConcurrentMerkleTreeError::EmptyLeaves);
540 }
541 if (self.next_index() + leaves.len() - 1) >= 1 << self.height {
542 return Err(ConcurrentMerkleTreeError::TreeIsFull);
543 }
544 if leaves.len() > self.changelog.capacity() {
545 return Err(ConcurrentMerkleTreeError::BatchGreaterThanChangelog(
546 leaves.len(),
547 self.changelog.capacity(),
548 ));
549 }
550
551 let first_changelog_index = (self.changelog.last_index() + 1) % self.changelog.capacity();
552 let first_sequence_number = self.sequence_number() + 1;
553
554 for (leaf_i, leaf) in leaves.iter().enumerate() {
555 let mut current_index = self.next_index();
556
557 self.changelog
558 .push(ChangelogEntry::<HEIGHT>::default_with_index(current_index));
559 let changelog_index = self.changelog_index();
560
561 let mut current_node = **leaf;
562
563 self.changelog[changelog_index].path[0] = Some(**leaf);
564
565 for i in 0..self.height {
566 #[allow(clippy::manual_is_multiple_of)]
567 let is_left = current_index % 2 == 0;
568
569 if is_left {
570 let empty_node = H::zero_bytes()[i];
582
583 if WITH_PROOFS {
584 proofs.as_mut().unwrap()[leaf_i].push(empty_node)?;
586 }
587
588 self.filled_subtrees[i] = current_node;
589
590 if leaf_i < leaves.len() - 1 {
595 break;
596 }
597
598 current_node = H::hashv(&[¤t_node, &empty_node])?;
599 } else {
600 if WITH_PROOFS {
612 proofs.as_mut().unwrap()[leaf_i].push(self.filled_subtrees[i])?;
614 }
615
616 current_node = H::hashv(&[&self.filled_subtrees[i], ¤t_node])?;
617 }
618
619 if i < self.height - 1 {
620 self.changelog[changelog_index].path[i + 1] = Some(current_node);
621 }
622
623 current_index /= 2;
624 }
625
626 if leaf_i == leaves.len() - 1 {
627 self.roots.push(current_node);
628 } else {
629 self.roots.push([0u8; 32]);
633 }
634
635 self.inc_next_index()?;
636 self.inc_sequence_number()?;
637
638 self.set_rightmost_leaf(leaf);
639 }
640
641 if self.canopy_depth > 0 {
642 self.update_canopy(first_changelog_index, leaves.len());
643 }
644
645 Ok((first_changelog_index, first_sequence_number))
646 }
647
648 fn update_canopy(&mut self, first_changelog_index: usize, num_leaves: usize) {
649 for i in 0..num_leaves {
650 let changelog_index = (first_changelog_index + i) % self.changelog.capacity();
651 for (i, path_node) in self.changelog[changelog_index]
652 .path
653 .iter()
654 .rev()
655 .take(self.canopy_depth)
656 .enumerate()
657 {
658 if let Some(path_node) = path_node {
659 let level = self.height - i - 1;
660 let index = (1 << (self.height - level))
661 + (self.changelog[changelog_index].index >> level);
662 self.canopy[(index - 2) as usize] = *path_node;
664 }
665 }
666 }
667 }
668}
669
670impl<H, const HEIGHT: usize> Drop for ConcurrentMerkleTree<H, HEIGHT>
671where
672 H: Hasher,
673{
674 fn drop(&mut self) {
675 let layout = Layout::new::<usize>();
676 unsafe { alloc::dealloc(self.next_index as *mut u8, layout) };
677
678 let layout = Layout::new::<usize>();
679 unsafe { alloc::dealloc(self.sequence_number as *mut u8, layout) };
680
681 let layout = Layout::new::<[u8; 32]>();
682 unsafe { alloc::dealloc(self.rightmost_leaf as *mut u8, layout) };
683 }
684}
685
686impl<H, const HEIGHT: usize> PartialEq for ConcurrentMerkleTree<H, HEIGHT>
687where
688 H: Hasher,
689{
690 fn eq(&self, other: &Self) -> bool {
691 self.height.eq(&other.height)
692 && self.canopy_depth.eq(&other.canopy_depth)
693 && self.next_index().eq(&other.next_index())
694 && self.sequence_number().eq(&other.sequence_number())
695 && self.rightmost_leaf().eq(&other.rightmost_leaf())
696 && self
697 .filled_subtrees
698 .as_slice()
699 .eq(other.filled_subtrees.as_slice())
700 && self.changelog.iter().eq(other.changelog.iter())
701 && self.roots.iter().eq(other.roots.iter())
702 && self.canopy.as_slice().eq(other.canopy.as_slice())
703 }
704}