1use crate::mmr::{
4 hasher::Hasher,
5 iterator::{nodes_needing_parents, nodes_to_pin, PathIterator, PeakIterator},
6 proof,
7 Error::{self, *},
8 Location, Position, Proof,
9};
10use alloc::{
11 collections::{BTreeMap, BTreeSet, VecDeque},
12 vec::Vec,
13};
14use commonware_cryptography::Digest;
15use core::{mem, ops::Range};
16cfg_if::cfg_if! {
17 if #[cfg(feature = "std")] {
18 use commonware_parallel::ThreadPool;
19 use rayon::prelude::*;
20 } else {
21 pub struct ThreadPool;
24 }
25}
26
27#[cfg(feature = "std")]
29const MIN_TO_PARALLELIZE: usize = 20;
30
31pub type DirtyMmr<D> = Mmr<D, Dirty>;
33
34pub type CleanMmr<D> = Mmr<D, Clean<D>>;
36
37mod private {
39 pub trait Sealed {}
40}
41
42pub trait State<D: Digest>: private::Sealed + Sized + Send + Sync {}
44
45#[derive(Clone, Copy, Debug)]
47pub struct Clean<D: Digest> {
48 pub root: D,
50}
51
52impl<D: Digest> private::Sealed for Clean<D> {}
53impl<D: Digest> State<D> for Clean<D> {}
54
55#[derive(Clone, Debug, Default)]
57pub struct Dirty {
58 dirty_nodes: BTreeSet<(Position, u32)>,
62}
63
64impl private::Sealed for Dirty {}
65impl<D: Digest> State<D> for Dirty {}
66
67pub struct Config<D: Digest> {
69 pub nodes: Vec<D>,
71
72 pub pruned_to_pos: Position,
75
76 pub pinned_nodes: Vec<D>,
78}
79
80#[derive(Clone, Debug)]
104pub struct Mmr<D: Digest, S: State<D> = Dirty> {
105 nodes: VecDeque<D>,
108
109 pruned_to_pos: Position,
112
113 pinned_nodes: BTreeMap<Position, D>,
115
116 state: S,
118}
119
120impl<D: Digest> Default for DirtyMmr<D> {
121 fn default() -> Self {
122 Self::new()
123 }
124}
125
126impl<D: Digest> From<CleanMmr<D>> for DirtyMmr<D> {
127 fn from(clean: CleanMmr<D>) -> Self {
128 DirtyMmr {
129 nodes: clean.nodes,
130 pruned_to_pos: clean.pruned_to_pos,
131 pinned_nodes: clean.pinned_nodes,
132 state: Dirty {
133 dirty_nodes: BTreeSet::new(),
134 },
135 }
136 }
137}
138
139impl<D: Digest, S: State<D>> Mmr<D, S> {
140 pub fn size(&self) -> Position {
143 Position::new(self.nodes.len() as u64 + *self.pruned_to_pos)
144 }
145
146 pub fn leaves(&self) -> Location {
148 Location::try_from(self.size()).expect("invalid mmr size")
149 }
150
151 pub fn last_leaf_pos(&self) -> Option<Position> {
153 if self.size() == 0 {
154 return None;
155 }
156
157 Some(PeakIterator::last_leaf_pos(self.size()))
158 }
159
160 pub fn bounds(&self) -> Range<Position> {
163 self.pruned_to_pos..self.size()
164 }
165
166 pub fn peak_iterator(&self) -> PeakIterator {
168 PeakIterator::new(self.size())
169 }
170
171 fn index_to_pos(&self, index: usize) -> Position {
173 self.pruned_to_pos + (index as u64)
174 }
175
176 pub(crate) fn get_node_unchecked(&self, pos: Position) -> &D {
189 if pos < self.pruned_to_pos {
190 return self
191 .pinned_nodes
192 .get(&pos)
193 .expect("requested node is pruned and not pinned");
194 }
195
196 &self.nodes[self.pos_to_index(pos)]
197 }
198
199 fn pos_to_index(&self, pos: Position) -> usize {
205 assert!(
206 pos >= self.pruned_to_pos,
207 "pos precedes oldest retained position"
208 );
209
210 *pos.checked_sub(*self.pruned_to_pos).unwrap() as usize
211 }
212
213 #[cfg(any(feature = "std", test))]
216 pub(crate) fn add_pinned_nodes(&mut self, pinned_nodes: BTreeMap<Position, D>) {
217 for (pos, node) in pinned_nodes.into_iter() {
218 self.pinned_nodes.insert(pos, node);
219 }
220 }
221}
222
223impl<D: Digest> CleanMmr<D> {
225 pub fn init(config: Config<D>, hasher: &mut impl Hasher<Digest = D>) -> Result<Self, Error> {
234 let Some(size) = config.pruned_to_pos.checked_add(config.nodes.len() as u64) else {
236 return Err(Error::InvalidSize(u64::MAX));
237 };
238 if !size.is_mmr_size() {
239 return Err(Error::InvalidSize(*size));
240 }
241
242 let mut pinned_nodes = BTreeMap::new();
244 let mut expected_pinned_nodes = 0;
245 for (i, pos) in nodes_to_pin(config.pruned_to_pos).enumerate() {
246 expected_pinned_nodes += 1;
247 if i >= config.pinned_nodes.len() {
248 return Err(Error::InvalidPinnedNodes);
249 }
250 pinned_nodes.insert(pos, config.pinned_nodes[i]);
251 }
252
253 if config.pinned_nodes.len() != expected_pinned_nodes {
255 return Err(Error::InvalidPinnedNodes);
256 }
257
258 let mmr = Mmr {
259 nodes: VecDeque::from(config.nodes),
260 pruned_to_pos: config.pruned_to_pos,
261 pinned_nodes,
262 state: Dirty::default(),
263 };
264 Ok(mmr.merkleize(hasher, None))
265 }
266
267 pub fn new(hasher: &mut impl Hasher<Digest = D>) -> Self {
269 let mmr: DirtyMmr<D> = Default::default();
270 mmr.merkleize(hasher, None)
271 }
272
273 pub fn from_components(
275 hasher: &mut impl Hasher<Digest = D>,
276 nodes: Vec<D>,
277 pruned_to_pos: Position,
278 pinned_nodes: Vec<D>,
279 ) -> Self {
280 DirtyMmr::from_components(nodes, pruned_to_pos, pinned_nodes).merkleize(hasher, None)
281 }
282
283 pub fn get_node(&self, pos: Position) -> Option<D> {
285 if pos < self.pruned_to_pos {
286 return self.pinned_nodes.get(&pos).copied();
287 }
288
289 self.nodes.get(self.pos_to_index(pos)).copied()
290 }
291
292 pub(crate) fn nodes_to_pin(&self, prune_pos: Position) -> BTreeMap<Position, D> {
295 nodes_to_pin(prune_pos)
296 .map(|pos| (pos, *self.get_node_unchecked(pos)))
297 .collect()
298 }
299
300 pub fn prune_to_pos(&mut self, pos: Position) {
303 self.pinned_nodes = self.nodes_to_pin(pos);
305 let retained_nodes = self.pos_to_index(pos);
306 self.nodes.drain(0..retained_nodes);
307 self.pruned_to_pos = pos;
308 }
309
310 pub fn prune_all(&mut self) {
313 if !self.nodes.is_empty() {
314 let pos = self.index_to_pos(self.nodes.len());
315 self.prune_to_pos(pos);
316 }
317 }
318
319 pub fn update_leaf(
333 &mut self,
334 hasher: &mut impl Hasher<Digest = D>,
335 loc: Location,
336 element: &[u8],
337 ) -> Result<(), Error> {
338 let mut dirty_mmr = mem::replace(self, Self::new(hasher)).into_dirty();
339 let result = dirty_mmr.update_leaf(hasher, loc, element);
340 *self = dirty_mmr.merkleize(hasher, None);
341 result
342 }
343
344 pub fn into_dirty(self) -> DirtyMmr<D> {
346 self.into()
347 }
348
349 pub const fn root(&self) -> &D {
351 &self.state.root
352 }
353
354 pub fn empty_mmr_root(hasher: &mut impl commonware_cryptography::Hasher<Digest = D>) -> D {
356 hasher.update(&0u64.to_be_bytes());
357 hasher.finalize()
358 }
359
360 pub fn proof(&self, loc: Location) -> Result<Proof<D>, Error> {
368 if !loc.is_valid() {
369 return Err(Error::LocationOverflow(loc));
370 }
371 self.range_proof(loc..loc + 1).map_err(|e| match e {
373 Error::RangeOutOfBounds(loc) => Error::LeafOutOfBounds(loc),
374 _ => e,
375 })
376 }
377
378 pub fn range_proof(&self, range: Range<Location>) -> Result<Proof<D>, Error> {
387 let leaves = self.leaves();
388 let positions = proof::nodes_required_for_range_proof(leaves, range)?;
389 let digests = positions
390 .into_iter()
391 .map(|pos| self.get_node(pos).ok_or(Error::ElementPruned(pos)))
392 .collect::<Result<Vec<_>, _>>()?;
393
394 Ok(Proof { leaves, digests })
395 }
396
397 #[cfg(test)]
400 pub(crate) fn node_digests_to_pin(&self, start_pos: Position) -> Vec<D> {
401 nodes_to_pin(start_pos)
402 .map(|pos| *self.get_node_unchecked(pos))
403 .collect()
404 }
405
406 #[cfg(test)]
409 pub(super) fn pinned_nodes(&self) -> BTreeMap<Position, D> {
410 self.pinned_nodes.clone()
411 }
412}
413
414impl<D: Digest> DirtyMmr<D> {
416 pub fn new() -> Self {
418 Self {
419 nodes: VecDeque::new(),
420 pruned_to_pos: Position::new(0),
421 pinned_nodes: BTreeMap::new(),
422 state: Dirty::default(),
423 }
424 }
425
426 pub fn from_components(nodes: Vec<D>, pruned_to_pos: Position, pinned_nodes: Vec<D>) -> Self {
428 Self {
429 nodes: VecDeque::from(nodes),
430 pruned_to_pos,
431 pinned_nodes: nodes_to_pin(pruned_to_pos)
432 .enumerate()
433 .map(|(i, pos)| (pos, pinned_nodes[i]))
434 .collect(),
435 state: Dirty::default(),
436 }
437 }
438
439 pub(super) fn add_leaf_digest(&mut self, digest: D) -> Position {
441 let nodes_needing_parents = nodes_needing_parents(self.peak_iterator())
443 .into_iter()
444 .rev();
445 let leaf_pos = self.size();
446 self.nodes.push_back(digest);
447
448 let mut height = 1;
449 for _ in nodes_needing_parents {
450 let new_node_pos = self.size();
451 self.nodes.push_back(D::EMPTY);
452 self.state.dirty_nodes.insert((new_node_pos, height));
453 height += 1;
454 }
455
456 leaf_pos
457 }
458
459 pub fn add<H: Hasher<Digest = D>>(&mut self, hasher: &mut H, element: &[u8]) -> Position {
462 let digest = hasher.leaf_digest(self.size(), element);
463 self.add_leaf_digest(digest)
464 }
465
466 pub fn pop(&mut self) -> Result<Position, Error> {
469 if self.size() == 0 {
470 return Err(Empty);
471 }
472
473 let mut new_size = self.size() - 1;
474 loop {
475 if new_size < self.pruned_to_pos {
476 return Err(ElementPruned(new_size));
477 }
478 if new_size.is_mmr_size() {
479 break;
480 }
481 new_size -= 1;
482 }
483 let num_to_drain = *(self.size() - new_size) as usize;
484 self.nodes.drain(self.nodes.len() - num_to_drain..);
485
486 let cutoff = (self.size(), 0);
488 self.state.dirty_nodes.split_off(&cutoff);
489
490 Ok(self.size())
491 }
492
493 pub fn merkleize(
496 mut self,
497 hasher: &mut impl Hasher<Digest = D>,
498 #[cfg_attr(not(feature = "std"), allow(unused_variables))] pool: Option<ThreadPool>,
499 ) -> CleanMmr<D> {
500 #[cfg(feature = "std")]
501 match (pool, self.state.dirty_nodes.len() >= MIN_TO_PARALLELIZE) {
502 (Some(pool), true) => self.merkleize_parallel(hasher, pool, MIN_TO_PARALLELIZE),
503 _ => self.merkleize_serial(hasher),
504 }
505
506 #[cfg(not(feature = "std"))]
507 self.merkleize_serial(hasher);
508
509 let peaks = self
511 .peak_iterator()
512 .map(|(peak_pos, _)| self.get_node_unchecked(peak_pos));
513 let digest = hasher.root(self.leaves(), peaks);
514
515 CleanMmr {
516 nodes: self.nodes,
517 pruned_to_pos: self.pruned_to_pos,
518 pinned_nodes: self.pinned_nodes,
519 state: Clean { root: digest },
520 }
521 }
522
523 fn merkleize_serial(&mut self, hasher: &mut impl Hasher<Digest = D>) {
524 let mut nodes: Vec<(Position, u32)> = self.state.dirty_nodes.iter().copied().collect();
525 self.state.dirty_nodes.clear();
526 nodes.sort_by_key(|a| a.1);
527
528 for (pos, height) in nodes {
529 let left = pos - (1 << height);
530 let right = pos - 1;
531 let digest = hasher.node_digest(
532 pos,
533 self.get_node_unchecked(left),
534 self.get_node_unchecked(right),
535 );
536 let index = self.pos_to_index(pos);
537 self.nodes[index] = digest;
538 }
539 }
540
541 #[cfg(feature = "std")]
549 fn merkleize_parallel(
550 &mut self,
551 hasher: &mut impl Hasher<Digest = D>,
552 pool: ThreadPool,
553 min_to_parallelize: usize,
554 ) {
555 let mut nodes: Vec<(Position, u32)> = self.state.dirty_nodes.iter().copied().collect();
556 self.state.dirty_nodes.clear();
557 nodes.sort_by_key(|a| a.1);
559
560 let mut same_height = Vec::new();
561 let mut current_height = 1;
562 for (i, (pos, height)) in nodes.iter().enumerate() {
563 if *height == current_height {
564 same_height.push(*pos);
565 continue;
566 }
567 if same_height.len() < min_to_parallelize {
568 self.state.dirty_nodes = nodes[i - same_height.len()..].iter().copied().collect();
569 self.merkleize_serial(hasher);
570 return;
571 }
572 self.update_node_digests(hasher, pool.clone(), &same_height, current_height);
573 same_height.clear();
574 current_height += 1;
575 same_height.push(*pos);
576 }
577
578 if same_height.len() < min_to_parallelize {
579 self.state.dirty_nodes = nodes[nodes.len() - same_height.len()..]
580 .iter()
581 .copied()
582 .collect();
583 self.merkleize_serial(hasher);
584 return;
585 }
586
587 self.update_node_digests(hasher, pool, &same_height, current_height);
588 }
589
590 #[cfg(feature = "std")]
593 fn update_node_digests(
594 &mut self,
595 hasher: &mut impl Hasher<Digest = D>,
596 pool: ThreadPool,
597 same_height: &[Position],
598 height: u32,
599 ) {
600 let two_h = 1 << height;
601 pool.install(|| {
602 let computed_digests: Vec<(usize, D)> = same_height
603 .par_iter()
604 .map_init(
605 || hasher.fork(),
606 |hasher, &pos| {
607 let left = pos - two_h;
608 let right = pos - 1;
609 let digest = hasher.node_digest(
610 pos,
611 self.get_node_unchecked(left),
612 self.get_node_unchecked(right),
613 );
614 let index = self.pos_to_index(pos);
615 (index, digest)
616 },
617 )
618 .collect();
619
620 for (index, digest) in computed_digests {
621 self.nodes[index] = digest;
622 }
623 });
624 }
625
626 fn mark_dirty(&mut self, pos: Position) {
629 for (peak_pos, mut height) in self.peak_iterator() {
630 if peak_pos < pos {
631 continue;
632 }
633
634 let path = PathIterator::new(pos, peak_pos, height)
637 .collect::<Vec<_>>()
638 .into_iter()
639 .rev();
640 height = 1;
641 for (parent_pos, _) in path {
642 if !self.state.dirty_nodes.insert((parent_pos, height)) {
643 break;
644 }
645 height += 1;
646 }
647 return;
648 }
649
650 panic!("invalid pos {pos}:{}", self.size());
651 }
652
653 pub fn update_leaf(
655 &mut self,
656 hasher: &mut impl Hasher<Digest = D>,
657 loc: Location,
658 element: &[u8],
659 ) -> Result<(), Error> {
660 self.update_leaf_batched(hasher, None, &[(loc, element)])
661 }
662
663 pub fn update_leaf_batched<T: AsRef<[u8]> + Sync>(
671 &mut self,
672 hasher: &mut impl Hasher<Digest = D>,
673 #[cfg_attr(not(feature = "std"), allow(unused_variables))] pool: Option<ThreadPool>,
674 updates: &[(Location, T)],
675 ) -> Result<(), Error> {
676 if updates.is_empty() {
677 return Ok(());
678 }
679
680 let leaves = self.leaves();
681 let mut positions = Vec::with_capacity(updates.len());
682 for (loc, _) in updates {
683 if *loc >= leaves {
684 return Err(Error::LeafOutOfBounds(*loc));
685 }
686 let pos = Position::try_from(*loc)?;
687 if pos < self.pruned_to_pos {
688 return Err(Error::ElementPruned(pos));
689 }
690 positions.push(pos);
691 }
692
693 #[cfg(feature = "std")]
694 if let Some(pool) = pool {
695 if updates.len() >= MIN_TO_PARALLELIZE {
696 self.update_leaf_parallel(hasher, pool, updates, &positions);
697 return Ok(());
698 }
699 }
700
701 for ((_, element), pos) in updates.iter().zip(positions.iter()) {
702 let digest = hasher.leaf_digest(*pos, element.as_ref());
704 let index = self.pos_to_index(*pos);
705 self.nodes[index] = digest;
706 self.mark_dirty(*pos);
707 }
708
709 Ok(())
710 }
711
712 #[cfg(feature = "std")]
714 fn update_leaf_parallel<T: AsRef<[u8]> + Sync>(
715 &mut self,
716 hasher: &mut impl Hasher<Digest = D>,
717 pool: ThreadPool,
718 updates: &[(Location, T)],
719 positions: &[Position],
720 ) {
721 pool.install(|| {
722 let digests: Vec<(Position, D)> = updates
723 .par_iter()
724 .zip(positions.par_iter())
725 .map_init(
726 || hasher.fork(),
727 |hasher, ((_, elem), pos)| {
728 let digest = hasher.leaf_digest(*pos, elem.as_ref());
729 (*pos, digest)
730 },
731 )
732 .collect();
733
734 for (pos, digest) in digests {
735 let index = self.pos_to_index(pos);
736 self.nodes[index] = digest;
737 self.mark_dirty(pos);
738 }
739 });
740 }
741}
742
743#[cfg(test)]
744mod tests {
745 use super::*;
746 use crate::mmr::{
747 conformance::build_test_mmr,
748 hasher::{Hasher as _, Standard},
749 };
750 use commonware_cryptography::{sha256, Hasher, Sha256};
751 use commonware_runtime::{deterministic, tokio, Runner, ThreadPooler};
752 use commonware_utils::NZUsize;
753
754 #[test]
756 fn test_mem_mmr_empty() {
757 let executor = deterministic::Runner::default();
758 executor.start(|_| async move {
759 let mut hasher: Standard<Sha256> = Standard::new();
760 let mmr = CleanMmr::new(&mut hasher);
761 assert_eq!(
762 mmr.peak_iterator().next(),
763 None,
764 "empty iterator should have no peaks"
765 );
766 assert_eq!(mmr.size(), 0);
767 assert_eq!(mmr.leaves(), Location::new_unchecked(0));
768 assert_eq!(mmr.last_leaf_pos(), None);
769 assert!(mmr.bounds().is_empty());
770 assert_eq!(mmr.get_node(Position::new(0)), None);
771 assert_eq!(*mmr.root(), Mmr::empty_mmr_root(hasher.inner()));
772 let mut mmr = mmr.into_dirty();
773 assert!(matches!(mmr.pop(), Err(Empty)));
774 let mut mmr = mmr.merkleize(&mut hasher, None);
775 mmr.prune_all();
776 assert_eq!(mmr.size(), 0, "prune_all on empty MMR should do nothing");
777
778 assert_eq!(
779 *mmr.root(),
780 hasher.root(Location::new_unchecked(0), [].iter())
781 );
782 });
783 }
784
785 #[test]
789 fn test_mem_mmr_add_eleven_values() {
790 let executor = deterministic::Runner::default();
791 executor.start(|_| async move {
792 let mut hasher: Standard<Sha256> = Standard::new();
793 let mut mmr = DirtyMmr::new();
794 let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
795 let mut leaves: Vec<Position> = Vec::new();
796 for _ in 0..11 {
797 leaves.push(mmr.add(&mut hasher, &element));
798 let peaks: Vec<(Position, u32)> = mmr.peak_iterator().collect();
799 assert_ne!(peaks.len(), 0);
800 assert!(peaks.len() as u64 <= mmr.size());
801 let nodes_needing_parents = nodes_needing_parents(mmr.peak_iterator());
802 assert!(nodes_needing_parents.len() <= peaks.len());
803 }
804 let mut mmr = mmr.merkleize(&mut hasher, None);
805 assert_eq!(mmr.bounds().start, Position::new(0));
806 assert_eq!(mmr.size(), 19, "mmr not of expected size");
807 assert_eq!(
808 leaves,
809 vec![0, 1, 3, 4, 7, 8, 10, 11, 15, 16, 18]
810 .into_iter()
811 .map(Position::new)
812 .collect::<Vec<_>>(),
813 "mmr leaf positions not as expected"
814 );
815 let peaks: Vec<(Position, u32)> = mmr.peak_iterator().collect();
816 assert_eq!(
817 peaks,
818 vec![
819 (Position::new(14), 3),
820 (Position::new(17), 1),
821 (Position::new(18), 0)
822 ],
823 "mmr peaks not as expected"
824 );
825
826 let peaks_needing_parents = nodes_needing_parents(mmr.peak_iterator());
829 assert_eq!(
830 peaks_needing_parents,
831 vec![Position::new(17), Position::new(18)],
832 "mmr nodes needing parents not as expected"
833 );
834
835 for leaf in leaves.iter().by_ref() {
837 let digest = hasher.leaf_digest(*leaf, &element);
838 assert_eq!(mmr.get_node(*leaf).unwrap(), digest);
839 }
840
841 let digest2 = hasher.node_digest(Position::new(2), &mmr.nodes[0], &mmr.nodes[1]);
843 assert_eq!(mmr.nodes[2], digest2);
844 let digest5 = hasher.node_digest(Position::new(5), &mmr.nodes[3], &mmr.nodes[4]);
845 assert_eq!(mmr.nodes[5], digest5);
846 let digest9 = hasher.node_digest(Position::new(9), &mmr.nodes[7], &mmr.nodes[8]);
847 assert_eq!(mmr.nodes[9], digest9);
848 let digest12 = hasher.node_digest(Position::new(12), &mmr.nodes[10], &mmr.nodes[11]);
849 assert_eq!(mmr.nodes[12], digest12);
850 let digest17 = hasher.node_digest(Position::new(17), &mmr.nodes[15], &mmr.nodes[16]);
851 assert_eq!(mmr.nodes[17], digest17);
852
853 let digest6 = hasher.node_digest(Position::new(6), &mmr.nodes[2], &mmr.nodes[5]);
855 assert_eq!(mmr.nodes[6], digest6);
856 let digest13 = hasher.node_digest(Position::new(13), &mmr.nodes[9], &mmr.nodes[12]);
857 assert_eq!(mmr.nodes[13], digest13);
858 let digest17 = hasher.node_digest(Position::new(17), &mmr.nodes[15], &mmr.nodes[16]);
859 assert_eq!(mmr.nodes[17], digest17);
860
861 let digest14 = hasher.node_digest(Position::new(14), &mmr.nodes[6], &mmr.nodes[13]);
863 assert_eq!(mmr.nodes[14], digest14);
864
865 let root = *mmr.root();
867 let peak_digests = [digest14, digest17, mmr.nodes[18]];
868 let expected_root = hasher.root(Location::new_unchecked(11), peak_digests.iter());
869 assert_eq!(root, expected_root, "incorrect root");
870
871 mmr.prune_to_pos(Position::new(14)); assert_eq!(mmr.bounds().start, Position::new(14));
874
875 assert!(matches!(
881 mmr.proof(Location::new_unchecked(0)),
882 Err(ElementPruned(_))
883 ));
884 assert!(matches!(
885 mmr.proof(Location::new_unchecked(6)),
886 Err(ElementPruned(_))
887 ));
888
889 assert!(mmr.proof(Location::new_unchecked(8)).is_ok());
892 assert!(mmr.proof(Location::new_unchecked(10)).is_ok());
893
894 let root_after_prune = *mmr.root();
895 assert_eq!(root, root_after_prune, "root changed after pruning");
896
897 assert!(
898 mmr.range_proof(Location::new_unchecked(5)..Location::new_unchecked(9))
899 .is_err(),
900 "attempts to range_prove elements at or before the oldest retained should fail"
901 );
902 assert!(
903 mmr.range_proof(Location::new_unchecked(8)..mmr.leaves()).is_ok(),
904 "attempts to range_prove over all elements following oldest retained should succeed"
905 );
906
907 let oldest_pos = mmr.bounds().start;
909 let digests = mmr.node_digests_to_pin(oldest_pos);
910 let mmr_copy = Mmr::init(
911 Config {
912 nodes: mmr.nodes.iter().copied().collect(),
913 pruned_to_pos: oldest_pos,
914 pinned_nodes: digests,
915 },
916 &mut hasher,
917 )
918 .unwrap();
919 assert_eq!(mmr_copy.size(), 19);
920 assert_eq!(mmr_copy.leaves(), mmr.leaves());
921 assert_eq!(mmr_copy.last_leaf_pos(), mmr.last_leaf_pos());
922 assert_eq!(mmr_copy.bounds().start, mmr.bounds().start);
923 assert_eq!(*mmr_copy.root(), root);
924 });
925 }
926
927 #[test]
929 fn test_mem_mmr_prune_all() {
930 let executor = deterministic::Runner::default();
931 executor.start(|_| async move {
932 let mut hasher: Standard<Sha256> = Standard::new();
933 let mut mmr = CleanMmr::new(&mut hasher);
934 let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
935 for _ in 0..1000 {
936 mmr.prune_all();
937 let mut dirty = mmr.into_dirty();
938 dirty.add(&mut hasher, &element);
939 mmr = dirty.merkleize(&mut hasher, None);
940 }
941 });
942 }
943
944 #[test]
946 fn test_mem_mmr_validity() {
947 let executor = deterministic::Runner::default();
948 executor.start(|_| async move {
949 let mut hasher: Standard<Sha256> = Standard::new();
950 let mut mmr = DirtyMmr::new();
951 let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
952 for _ in 0..1001 {
953 assert!(
954 mmr.size().is_mmr_size(),
955 "mmr of size {} should be valid",
956 mmr.size()
957 );
958 let old_size = mmr.size();
959 mmr.add(&mut hasher, &element);
960 for size in *old_size + 1..*mmr.size() {
961 assert!(
962 !Position::new(size).is_mmr_size(),
963 "mmr of size {size} should be invalid",
964 );
965 }
966 }
967 });
968 }
969
970 #[test]
973 fn test_mem_mmr_batched_root() {
974 let executor = deterministic::Runner::default();
975 executor.start(|_| async move {
976 let mut hasher: Standard<Sha256> = Standard::new();
977 const NUM_ELEMENTS: u64 = 199;
978 let mut test_mmr = CleanMmr::new(&mut hasher);
979 test_mmr = build_test_mmr(&mut hasher, test_mmr, NUM_ELEMENTS);
980 let expected_root = test_mmr.root();
981
982 let batched_mmr = CleanMmr::new(&mut hasher);
983
984 let mut dirty_mmr = batched_mmr.into_dirty();
986 hasher.inner().update(&0u64.to_be_bytes());
987 let element = hasher.inner().finalize();
988 dirty_mmr.add(&mut hasher, &element);
989
990 for i in 1..NUM_ELEMENTS {
992 hasher.inner().update(&i.to_be_bytes());
993 let element = hasher.inner().finalize();
994 dirty_mmr.add(&mut hasher, &element);
995 }
996
997 let batched_mmr = dirty_mmr.merkleize(&mut hasher, None);
998
999 assert_eq!(
1000 batched_mmr.root(),
1001 expected_root,
1002 "Batched MMR root should match reference"
1003 );
1004 });
1005 }
1006
1007 #[test]
1010 fn test_mem_mmr_batched_root_parallel() {
1011 let executor = tokio::Runner::default();
1012 executor.start(|context| async move {
1013 let mut hasher: Standard<Sha256> = Standard::new();
1014 const NUM_ELEMENTS: u64 = 199;
1015 let test_mmr = CleanMmr::new(&mut hasher);
1016 let test_mmr = build_test_mmr(&mut hasher, test_mmr, NUM_ELEMENTS);
1017 let expected_root = test_mmr.root();
1018
1019 let pool = context.create_thread_pool(NZUsize!(4)).unwrap();
1020 let mut hasher: Standard<Sha256> = Standard::new();
1021
1022 let mut mmr = Mmr::init(
1023 Config {
1024 nodes: vec![],
1025 pruned_to_pos: Position::new(0),
1026 pinned_nodes: vec![],
1027 },
1028 &mut hasher,
1029 )
1030 .unwrap()
1031 .into_dirty();
1032
1033 let mut hasher: Standard<Sha256> = Standard::new();
1034 for i in 0u64..NUM_ELEMENTS {
1035 hasher.inner().update(&i.to_be_bytes());
1036 let element = hasher.inner().finalize();
1037 mmr.add(&mut hasher, &element);
1038 }
1039 let mmr = mmr.merkleize(&mut hasher, Some(pool));
1040 assert_eq!(
1041 mmr.root(),
1042 expected_root,
1043 "Batched MMR root should match reference"
1044 );
1045 });
1046 }
1047
1048 #[test]
1050 fn test_mem_mmr_root_with_pruning() {
1051 let executor = deterministic::Runner::default();
1052 executor.start(|_| async move {
1053 let mut hasher: Standard<Sha256> = Standard::new();
1054 let mut reference_mmr = DirtyMmr::new();
1055 let mut mmr = DirtyMmr::new();
1056 for i in 0u64..200 {
1057 hasher.inner().update(&i.to_be_bytes());
1058 let element = hasher.inner().finalize();
1059 reference_mmr.add(&mut hasher, &element);
1060 mmr.add(&mut hasher, &element);
1061
1062 let reference_mmr_clean = reference_mmr.merkleize(&mut hasher, None);
1064 let mut mmr_clean = mmr.merkleize(&mut hasher, None);
1065 mmr_clean.prune_all();
1066 assert_eq!(mmr_clean.root(), reference_mmr_clean.root());
1067
1068 reference_mmr = reference_mmr_clean.into_dirty();
1069 mmr = mmr_clean.into_dirty();
1070 }
1071 });
1072 }
1073
1074 #[test]
1075 fn test_mem_mmr_pop() {
1076 let executor = deterministic::Runner::default();
1077 executor.start(|_| async move {
1078 const NUM_ELEMENTS: u64 = 100;
1079
1080 let mut hasher: Standard<Sha256> = Standard::new();
1081 let mmr = CleanMmr::new(&mut hasher);
1082 let mut mmr = build_test_mmr(&mut hasher, mmr, NUM_ELEMENTS);
1083
1084 for i in (0..NUM_ELEMENTS).rev() {
1086 let mut dirty_mmr = mmr.into_dirty();
1087 assert!(dirty_mmr.pop().is_ok());
1088 mmr = dirty_mmr.merkleize(&mut hasher, None);
1089 let root = *mmr.root();
1090 let reference_mmr = CleanMmr::new(&mut hasher);
1091 let reference_mmr = build_test_mmr(&mut hasher, reference_mmr, i);
1092 assert_eq!(
1093 root,
1094 *reference_mmr.root(),
1095 "root mismatch after pop at {i}"
1096 );
1097 }
1098 let mut mmr = mmr.into_dirty();
1099 assert!(
1100 matches!(mmr.pop().unwrap_err(), Empty),
1101 "pop on empty MMR should fail"
1102 );
1103
1104 for i in 0u64..NUM_ELEMENTS {
1106 hasher.inner().update(&i.to_be_bytes());
1107 let element = hasher.inner().finalize();
1108 mmr.add(&mut hasher, &element);
1109 }
1110 let mut mmr = mmr.merkleize(&mut hasher, None);
1111
1112 let leaf_pos = Position::try_from(Location::new_unchecked(100)).unwrap();
1113 mmr.prune_to_pos(leaf_pos);
1114 let mut mmr = mmr.into_dirty();
1115 while mmr.size() > leaf_pos {
1116 mmr.pop().unwrap();
1117 }
1118 let mmr = mmr.merkleize(&mut hasher, None);
1119 let reference_mmr = CleanMmr::new(&mut hasher);
1120 let reference_mmr = build_test_mmr(&mut hasher, reference_mmr, 100);
1121 assert_eq!(*mmr.root(), *reference_mmr.root());
1122 let mut mmr = mmr.into_dirty();
1123 let result = mmr.pop();
1124 assert!(matches!(result, Err(ElementPruned(_))));
1125 assert!(mmr.bounds().is_empty());
1126 });
1127 }
1128
1129 #[test]
1130 fn test_mem_mmr_update_leaf() {
1131 let mut hasher: Standard<Sha256> = Standard::new();
1132 let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1133 let executor = deterministic::Runner::default();
1134 executor.start(|_| async move {
1135 const NUM_ELEMENTS: u64 = 200;
1136 let mmr = CleanMmr::new(&mut hasher);
1137 let mut mmr = build_test_mmr(&mut hasher, mmr, NUM_ELEMENTS);
1138 let root = *mmr.root();
1139
1140 for leaf in [0usize, 1, 10, 50, 100, 150, 197, 198] {
1143 let leaf_loc = Location::new_unchecked(leaf as u64);
1145 mmr.update_leaf(&mut hasher, leaf_loc, &element).unwrap();
1146 let updated_root = *mmr.root();
1147 assert!(root != updated_root);
1148
1149 hasher.inner().update(&leaf.to_be_bytes());
1151 let element = hasher.inner().finalize();
1152 mmr.update_leaf(&mut hasher, leaf_loc, &element).unwrap();
1153 let restored_root = *mmr.root();
1154 assert_eq!(root, restored_root);
1155 }
1156
1157 mmr.prune_to_pos(Position::new(150));
1159 for leaf_pos in 150u64..=190 {
1160 mmr.prune_to_pos(Position::new(leaf_pos));
1161 let leaf_loc = Location::new_unchecked(leaf_pos);
1162 mmr.update_leaf(&mut hasher, leaf_loc, &element).unwrap();
1163 }
1164 });
1165 }
1166
1167 #[test]
1168 fn test_mem_mmr_update_leaf_error_out_of_bounds() {
1169 let mut hasher: Standard<Sha256> = Standard::new();
1170 let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1171
1172 let executor = deterministic::Runner::default();
1173 executor.start(|_| async move {
1174 let mmr = CleanMmr::new(&mut hasher);
1175 let mut mmr = build_test_mmr(&mut hasher, mmr, 200);
1176 let invalid_loc = mmr.leaves();
1177 let result = mmr.update_leaf(&mut hasher, invalid_loc, &element);
1178 assert!(matches!(result, Err(Error::LeafOutOfBounds(_))));
1179 });
1180 }
1181
1182 #[test]
1183 fn test_mem_mmr_update_leaf_error_pruned() {
1184 let mut hasher: Standard<Sha256> = Standard::new();
1185 let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1186
1187 let executor = deterministic::Runner::default();
1188 executor.start(|_| async move {
1189 let mmr = CleanMmr::new(&mut hasher);
1190 let mut mmr = build_test_mmr(&mut hasher, mmr, 100);
1191 mmr.prune_all();
1192 let result = mmr.update_leaf(&mut hasher, Location::new_unchecked(0), &element);
1193 assert!(matches!(result, Err(Error::ElementPruned(_))));
1194 });
1195 }
1196
1197 #[test]
1198 fn test_mem_mmr_batch_update_leaf() {
1199 let mut hasher: Standard<Sha256> = Standard::new();
1200 let executor = deterministic::Runner::default();
1201 executor.start(|_| async move {
1202 let mmr = CleanMmr::new(&mut hasher);
1203 let mmr = build_test_mmr(&mut hasher, mmr, 200);
1204 do_batch_update(&mut hasher, mmr, None);
1205 });
1206 }
1207
1208 #[test]
1211 fn test_mem_mmr_batch_parallel_update_leaf() {
1212 let mut hasher: Standard<Sha256> = Standard::new();
1213 let executor = tokio::Runner::default();
1214 executor.start(|ctx| async move {
1215 let mmr = Mmr::init(
1216 Config {
1217 nodes: Vec::new(),
1218 pruned_to_pos: Position::new(0),
1219 pinned_nodes: Vec::new(),
1220 },
1221 &mut hasher,
1222 )
1223 .unwrap();
1224 let mmr = build_test_mmr(&mut hasher, mmr, 200);
1225 let pool = ctx.create_thread_pool(NZUsize!(4)).unwrap();
1226 do_batch_update(&mut hasher, mmr, Some(pool));
1227 });
1228 }
1229
1230 fn do_batch_update(
1231 hasher: &mut Standard<Sha256>,
1232 mmr: CleanMmr<sha256::Digest>,
1233 pool: Option<ThreadPool>,
1234 ) {
1235 let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1236 let root = *mmr.root();
1237
1238 let mut updates = Vec::new();
1240 for leaf in [0u64, 1, 10, 50, 100, 150, 197, 198] {
1241 let leaf_loc = Location::new_unchecked(leaf);
1242 updates.push((leaf_loc, &element));
1243 }
1244 let mut mmr = mmr.into_dirty();
1245 mmr.update_leaf_batched(hasher, pool, &updates).unwrap();
1246
1247 let mmr = mmr.merkleize(hasher, None);
1248 let updated_root = *mmr.root();
1249 assert_ne!(updated_root, root);
1250
1251 let mut updates = Vec::new();
1253 for leaf in [0u64, 1, 10, 50, 100, 150, 197, 198] {
1254 hasher.inner().update(&leaf.to_be_bytes());
1255 let element = hasher.inner().finalize();
1256 let leaf_loc = Location::new_unchecked(leaf);
1257 updates.push((leaf_loc, element));
1258 }
1259 let mut mmr = mmr.into_dirty();
1260 mmr.update_leaf_batched(hasher, None, &updates).unwrap();
1261
1262 let mmr = mmr.merkleize(hasher, None);
1263 let restored_root = *mmr.root();
1264 assert_eq!(root, restored_root);
1265 }
1266
1267 #[test]
1268 fn test_init_pinned_nodes_validation() {
1269 let executor = deterministic::Runner::default();
1270 executor.start(|_| async move {
1271 let mut hasher: Standard<Sha256> = Standard::new();
1272 let config = Config::<sha256::Digest> {
1274 nodes: vec![],
1275 pruned_to_pos: Position::new(0),
1276 pinned_nodes: vec![],
1277 };
1278 assert!(Mmr::init(config, &mut hasher).is_ok());
1279
1280 let config = Config::<sha256::Digest> {
1283 nodes: vec![],
1284 pruned_to_pos: Position::new(127),
1285 pinned_nodes: vec![], };
1287 assert!(matches!(
1288 Mmr::init(config, &mut hasher),
1289 Err(Error::InvalidPinnedNodes)
1290 ));
1291
1292 let config = Config {
1294 nodes: vec![],
1295 pruned_to_pos: Position::new(0),
1296 pinned_nodes: vec![Sha256::hash(b"dummy")],
1297 };
1298 assert!(matches!(
1299 Mmr::init(config, &mut hasher),
1300 Err(Error::InvalidPinnedNodes)
1301 ));
1302
1303 let mut mmr = DirtyMmr::new();
1306 for i in 0u64..50 {
1307 mmr.add(&mut hasher, &i.to_be_bytes());
1308 }
1309 let mmr = mmr.merkleize(&mut hasher, None);
1310 let pinned_nodes = mmr.node_digests_to_pin(Position::new(50));
1311 let config = Config {
1312 nodes: vec![],
1313 pruned_to_pos: Position::new(50),
1314 pinned_nodes,
1315 };
1316 assert!(Mmr::init(config, &mut hasher).is_ok());
1317 });
1318 }
1319
1320 #[test]
1321 fn test_init_size_validation() {
1322 let executor = deterministic::Runner::default();
1323 executor.start(|_| async move {
1324 let mut hasher: Standard<Sha256> = Standard::new();
1325 let config = Config::<sha256::Digest> {
1327 nodes: vec![],
1328 pruned_to_pos: Position::new(0),
1329 pinned_nodes: vec![],
1330 };
1331 assert!(Mmr::init(config, &mut hasher).is_ok());
1332
1333 let config = Config {
1336 nodes: vec![Sha256::hash(b"node1"), Sha256::hash(b"node2")],
1337 pruned_to_pos: Position::new(0),
1338 pinned_nodes: vec![],
1339 };
1340 assert!(matches!(
1341 Mmr::init(config, &mut hasher),
1342 Err(Error::InvalidSize(_))
1343 ));
1344
1345 let config = Config {
1347 nodes: vec![
1348 Sha256::hash(b"leaf1"),
1349 Sha256::hash(b"leaf2"),
1350 Sha256::hash(b"parent"),
1351 ],
1352 pruned_to_pos: Position::new(0),
1353 pinned_nodes: vec![],
1354 };
1355 assert!(Mmr::init(config, &mut hasher).is_ok());
1356
1357 let mut mmr = DirtyMmr::new();
1360 for i in 0u64..64 {
1361 mmr.add(&mut hasher, &i.to_be_bytes());
1362 }
1363 let mmr = mmr.merkleize(&mut hasher, None);
1364 assert_eq!(mmr.size(), 127); let nodes: Vec<_> = (0..127)
1366 .map(|i| *mmr.get_node_unchecked(Position::new(i)))
1367 .collect();
1368
1369 let config = Config {
1370 nodes,
1371 pruned_to_pos: Position::new(0),
1372 pinned_nodes: vec![],
1373 };
1374 assert!(Mmr::init(config, &mut hasher).is_ok());
1375
1376 let mut mmr = DirtyMmr::new();
1379 for i in 0u64..11 {
1380 mmr.add(&mut hasher, &i.to_be_bytes());
1381 }
1382 let mut mmr = mmr.merkleize(&mut hasher, None);
1383 assert_eq!(mmr.size(), 19); mmr.prune_to_pos(Position::new(7));
1387 let nodes: Vec<_> = (7..*mmr.size())
1388 .map(|i| *mmr.get_node_unchecked(Position::new(i)))
1389 .collect();
1390 let pinned_nodes = mmr.node_digests_to_pin(Position::new(7));
1391
1392 let config = Config {
1393 nodes: nodes.clone(),
1394 pruned_to_pos: Position::new(7),
1395 pinned_nodes: pinned_nodes.clone(),
1396 };
1397 assert!(Mmr::init(config, &mut hasher).is_ok());
1398
1399 let config = Config {
1402 nodes: nodes.clone(),
1403 pruned_to_pos: Position::new(8),
1404 pinned_nodes: pinned_nodes.clone(),
1405 };
1406 assert!(matches!(
1407 Mmr::init(config, &mut hasher),
1408 Err(Error::InvalidSize(_))
1409 ));
1410
1411 let config = Config {
1414 nodes,
1415 pruned_to_pos: Position::new(9),
1416 pinned_nodes,
1417 };
1418 assert!(matches!(
1419 Mmr::init(config, &mut hasher),
1420 Err(Error::InvalidSize(_))
1421 ));
1422 });
1423 }
1424
1425 #[test]
1426 fn test_mem_mmr_range_proof_out_of_bounds() {
1427 let mut hasher: Standard<Sha256> = Standard::new();
1428
1429 let executor = deterministic::Runner::default();
1430 executor.start(|_| async move {
1431 let mmr = CleanMmr::new(&mut hasher);
1433 assert_eq!(mmr.leaves(), Location::new_unchecked(0));
1434 let result = mmr.range_proof(Location::new_unchecked(0)..Location::new_unchecked(1));
1435 assert!(matches!(result, Err(Error::RangeOutOfBounds(_))));
1436
1437 let mmr = build_test_mmr(&mut hasher, mmr, 10);
1439 assert_eq!(mmr.leaves(), Location::new_unchecked(10));
1440 let result = mmr.range_proof(Location::new_unchecked(5)..Location::new_unchecked(11));
1441 assert!(matches!(result, Err(Error::RangeOutOfBounds(_))));
1442
1443 let result = mmr.range_proof(Location::new_unchecked(5)..Location::new_unchecked(10));
1445 assert!(result.is_ok());
1446 });
1447 }
1448
1449 #[test]
1450 fn test_mem_mmr_proof_out_of_bounds() {
1451 let mut hasher: Standard<Sha256> = Standard::new();
1452
1453 let executor = deterministic::Runner::default();
1454 executor.start(|_| async move {
1455 let mmr = CleanMmr::new(&mut hasher);
1457 let result = mmr.proof(Location::new_unchecked(0));
1458 assert!(
1459 matches!(result, Err(Error::LeafOutOfBounds(_))),
1460 "expected LeafOutOfBounds, got {:?}",
1461 result
1462 );
1463
1464 let mmr = build_test_mmr(&mut hasher, mmr, 10);
1466 let result = mmr.proof(Location::new_unchecked(10));
1467 assert!(
1468 matches!(result, Err(Error::LeafOutOfBounds(_))),
1469 "expected LeafOutOfBounds, got {:?}",
1470 result
1471 );
1472
1473 let result = mmr.proof(Location::new_unchecked(9));
1475 assert!(result.is_ok(), "expected Ok, got {:?}", result);
1476 });
1477 }
1478}