1use crate::mmr::{
4 hasher::Hasher,
5 iterator::{nodes_to_pin, PeakIterator},
6 proof,
7 read::{BatchChainInfo, Readable},
8 Error, Location, Position, Proof,
9};
10use alloc::{
11 collections::{BTreeMap, BTreeSet, VecDeque},
12 vec::Vec,
13};
14use commonware_cryptography::Digest;
15use core::ops::Range;
16
17#[cfg(feature = "std")]
19pub(crate) const MIN_TO_PARALLELIZE: usize = 20;
20
21mod private {
23 pub trait Sealed {}
24}
25
26pub trait State<D: Digest>: private::Sealed + Sized + Send + Sync {}
28
29#[derive(Clone, Copy, Debug)]
31pub struct Clean<D: Digest> {
32 pub root: D,
34}
35
36impl<D: Digest> private::Sealed for Clean<D> {}
37impl<D: Digest> State<D> for Clean<D> {}
38
39#[derive(Clone, Debug, Default)]
41pub struct Dirty {
42 dirty_nodes: BTreeSet<(Position, u32)>,
46}
47
48impl private::Sealed for Dirty {}
49impl<D: Digest> State<D> for Dirty {}
50
51impl Dirty {
52 pub(crate) fn insert(&mut self, pos: Position, height: u32) -> bool {
54 self.dirty_nodes.insert((pos, height))
55 }
56
57 pub(crate) fn take_sorted_by_height(&mut self) -> Vec<(Position, u32)> {
59 let mut v: Vec<_> = core::mem::take(&mut self.dirty_nodes).into_iter().collect();
60 v.sort_by_key(|a| a.1);
61 v
62 }
63}
64
65pub struct Config<D: Digest> {
67 pub nodes: Vec<D>,
69
70 pub pruned_to: Location,
73
74 pub pinned_nodes: Vec<D>,
76}
77
78#[derive(Clone, Debug)]
100pub struct Mmr<D: Digest> {
101 nodes: VecDeque<D>,
104
105 pruned_to_pos: Position,
112
113 pinned_nodes: BTreeMap<Position, D>,
115
116 root: D,
118}
119
120impl<D: Digest> Mmr<D> {
121 pub fn new(hasher: &mut impl Hasher<Digest = D>) -> Self {
123 let root = hasher.root(Location::new(0), core::iter::empty::<&D>());
124 Self {
125 nodes: VecDeque::new(),
126 pruned_to_pos: Position::new(0),
127 pinned_nodes: BTreeMap::new(),
128 root,
129 }
130 }
131
132 pub fn init(config: Config<D>, hasher: &mut impl Hasher<Digest = D>) -> Result<Self, Error> {
141 let pruned_to_pos = Position::try_from(config.pruned_to)?;
142
143 let Some(size) = pruned_to_pos.checked_add(config.nodes.len() as u64) else {
145 return Err(Error::InvalidSize(u64::MAX));
146 };
147 if !size.is_mmr_size() {
148 return Err(Error::InvalidSize(*size));
149 }
150
151 let mut pinned_nodes = BTreeMap::new();
153 let mut expected_pinned_nodes = 0;
154 for (i, pos) in nodes_to_pin(pruned_to_pos).enumerate() {
155 expected_pinned_nodes += 1;
156 if i >= config.pinned_nodes.len() {
157 return Err(Error::InvalidPinnedNodes);
158 }
159 pinned_nodes.insert(pos, config.pinned_nodes[i]);
160 }
161
162 if config.pinned_nodes.len() != expected_pinned_nodes {
164 return Err(Error::InvalidPinnedNodes);
165 }
166
167 let nodes = VecDeque::from(config.nodes);
168 let root = Self::compute_root(hasher, &nodes, &pinned_nodes, pruned_to_pos);
169 Ok(Self {
170 nodes,
171 pruned_to_pos,
172 pinned_nodes,
173 root,
174 })
175 }
176
177 pub fn from_components(
186 hasher: &mut impl Hasher<Digest = D>,
187 nodes: Vec<D>,
188 pruned_to: Location,
189 pinned_nodes_vec: Vec<D>,
190 ) -> Result<Self, Error> {
191 let pruned_to_pos = Position::try_from(pruned_to)?;
192 let expected_count = nodes_to_pin(pruned_to_pos).count();
193 if pinned_nodes_vec.len() != expected_count {
194 return Err(Error::InvalidPinnedNodes);
195 }
196 let pinned_nodes: BTreeMap<Position, D> = nodes_to_pin(pruned_to_pos)
197 .enumerate()
198 .map(|(i, pos)| (pos, pinned_nodes_vec[i]))
199 .collect();
200 let nodes = VecDeque::from(nodes);
201 let root = Self::compute_root(hasher, &nodes, &pinned_nodes, pruned_to_pos);
202 Ok(Self {
203 nodes,
204 pruned_to_pos,
205 pinned_nodes,
206 root,
207 })
208 }
209
210 fn compute_root(
212 hasher: &mut impl Hasher<Digest = D>,
213 nodes: &VecDeque<D>,
214 pinned_nodes: &BTreeMap<Position, D>,
215 pruned_to_pos: Position,
216 ) -> D {
217 let size = Position::new(nodes.len() as u64 + *pruned_to_pos);
218 let leaves = Location::try_from(size).expect("invalid mmr size");
219 let get_node = |pos: Position| -> &D {
220 if pos < pruned_to_pos {
221 return pinned_nodes
222 .get(&pos)
223 .expect("requested node is pruned and not pinned");
224 }
225 let index = (*pos - *pruned_to_pos) as usize;
226 &nodes[index]
227 };
228 let peaks = PeakIterator::new(size).map(|(peak_pos, _)| get_node(peak_pos));
229 hasher.root(leaves, peaks)
230 }
231
232 pub fn empty_mmr_root(hasher: &mut impl commonware_cryptography::Hasher<Digest = D>) -> D {
234 hasher.update(&0u64.to_be_bytes());
235 hasher.finalize()
236 }
237
238 pub fn size(&self) -> Position {
241 Position::new(self.nodes.len() as u64 + *self.pruned_to_pos)
242 }
243
244 pub fn leaves(&self) -> Location {
246 Location::try_from(self.size()).expect("invalid mmr size")
247 }
248
249 pub fn bounds(&self) -> Range<Location> {
252 Location::try_from(self.pruned_to_pos).expect("valid pruned_to_pos")..self.leaves()
253 }
254
255 pub fn peak_iterator(&self) -> PeakIterator {
257 PeakIterator::new(self.size())
258 }
259
260 fn index_to_pos(&self, index: usize) -> Position {
262 self.pruned_to_pos + (index as u64)
263 }
264
265 pub(crate) fn get_node_unchecked(&self, pos: Position) -> &D {
273 if pos < self.pruned_to_pos {
274 return self
275 .pinned_nodes
276 .get(&pos)
277 .expect("requested node is pruned and not pinned");
278 }
279
280 &self.nodes[self.pos_to_index(pos)]
281 }
282
283 fn pos_to_index(&self, pos: Position) -> usize {
289 assert!(
290 pos >= self.pruned_to_pos,
291 "pos precedes oldest retained position"
292 );
293
294 *pos.checked_sub(*self.pruned_to_pos).unwrap() as usize
295 }
296
297 #[cfg(any(feature = "std", test))]
300 pub(crate) fn add_pinned_nodes(&mut self, pinned_nodes: BTreeMap<Position, D>) {
301 for (pos, node) in pinned_nodes.into_iter() {
302 self.pinned_nodes.insert(pos, node);
303 }
304 }
305
306 pub fn get_node(&self, pos: Position) -> Option<D> {
308 if pos < self.pruned_to_pos {
309 return self.pinned_nodes.get(&pos).copied();
310 }
311
312 self.nodes.get(self.pos_to_index(pos)).copied()
313 }
314
315 pub(crate) fn nodes_to_pin(&self, prune_pos: Position) -> BTreeMap<Position, D> {
318 nodes_to_pin(prune_pos)
319 .map(|pos| (pos, *self.get_node_unchecked(pos)))
320 .collect()
321 }
322
323 pub fn prune(&mut self, loc: Location) -> Result<(), Error> {
331 if loc > self.leaves() {
332 return Err(Error::LeafOutOfBounds(loc));
333 }
334 let pos = Position::try_from(loc)?;
335 if pos <= self.pruned_to_pos {
336 return Ok(());
337 }
338 self.prune_to_pos(pos);
339 Ok(())
340 }
341
342 pub fn prune_all(&mut self) {
345 if !self.nodes.is_empty() {
346 let pos = self.index_to_pos(self.nodes.len());
347 self.prune_to_pos(pos);
348 }
349 }
350
351 fn prune_to_pos(&mut self, pos: Position) {
354 self.pinned_nodes = self.nodes_to_pin(pos);
355 let retained_nodes = self.pos_to_index(pos);
356 self.nodes.drain(0..retained_nodes);
357 self.pruned_to_pos = pos;
358 }
359
360 #[cfg(feature = "std")]
366 pub(crate) fn truncate(&mut self, new_size: Position, hasher: &mut impl Hasher<Digest = D>) {
367 debug_assert!(new_size.is_mmr_size());
368 debug_assert!(new_size >= self.pruned_to_pos);
369 let keep = (*new_size - *self.pruned_to_pos) as usize;
370 self.nodes.truncate(keep);
371 self.root = Self::compute_root(hasher, &self.nodes, &self.pinned_nodes, self.pruned_to_pos);
372 }
373
374 pub fn update_leaf(
388 &mut self,
389 hasher: &mut impl Hasher<Digest = D>,
390 loc: Location,
391 element: &[u8],
392 ) -> Result<(), Error> {
393 let changeset = {
394 let mut batch = self.new_batch();
395 batch.update_leaf(hasher, loc, element)?;
396 batch.merkleize(hasher).finalize()
397 };
398 self.apply(changeset)
399 .expect("db unmodified since batch creation");
400 Ok(())
401 }
402
403 pub const fn root(&self) -> &D {
405 &self.root
406 }
407
408 pub fn proof(&self, loc: Location) -> Result<Proof<D>, Error> {
416 if !loc.is_valid() {
417 return Err(Error::LocationOverflow(loc));
418 }
419 self.range_proof(loc..loc + 1).map_err(|e| match e {
421 Error::RangeOutOfBounds(loc) => Error::LeafOutOfBounds(loc),
422 _ => e,
423 })
424 }
425
426 pub fn range_proof(&self, range: Range<Location>) -> Result<Proof<D>, Error> {
435 let leaves = self.leaves();
436 let positions = proof::nodes_required_for_range_proof(leaves, range)?;
437 let digests = positions
438 .into_iter()
439 .map(|pos| self.get_node(pos).ok_or(Error::ElementPruned(pos)))
440 .collect::<Result<Vec<_>, _>>()?;
441
442 Ok(Proof { leaves, digests })
443 }
444
445 #[cfg(test)]
448 pub(crate) fn node_digests_to_pin(&self, start_pos: Position) -> Vec<D> {
449 nodes_to_pin(start_pos)
450 .map(|pos| *self.get_node_unchecked(pos))
451 .collect()
452 }
453
454 #[cfg(test)]
457 pub(super) fn pinned_nodes(&self) -> BTreeMap<Position, D> {
458 self.pinned_nodes.clone()
459 }
460
461 pub fn new_batch(&self) -> super::batch::UnmerkleizedBatch<'_, D, Self> {
463 super::batch::UnmerkleizedBatch::new(self)
464 }
465
466 pub fn apply(&mut self, changeset: super::batch::Changeset<D>) -> Result<(), super::Error> {
473 if changeset.base_size != self.size() {
474 return Err(super::Error::StaleChangeset {
475 expected: changeset.base_size,
476 actual: self.size(),
477 });
478 }
479
480 for (pos, digest) in changeset.overwrites {
482 let index = self.pos_to_index(pos);
483 self.nodes[index] = digest;
484 }
485
486 for digest in changeset.appended {
488 self.nodes.push_back(digest);
489 }
490
491 self.root = changeset.root;
493 Ok(())
494 }
495}
496
497impl<D: Digest> Readable<D> for Mmr<D> {
498 fn size(&self) -> Position {
499 self.size()
500 }
501
502 fn get_node(&self, pos: Position) -> Option<D> {
503 self.get_node(pos)
504 }
505
506 fn root(&self) -> D {
507 *self.root()
508 }
509
510 fn pruned_to_pos(&self) -> Position {
511 self.pruned_to_pos
512 }
513}
514
515impl<D: Digest> BatchChainInfo<D> for Mmr<D> {
516 fn base_size(&self) -> Position {
517 self.size()
518 }
519
520 fn collect_overwrites(&self, _into: &mut BTreeMap<Position, D>) {}
521}
522
523#[cfg(test)]
524mod tests {
525 use super::*;
526 use crate::mmr::{
527 conformance::build_test_mmr,
528 hasher::{Hasher as _, Standard},
529 iterator::nodes_needing_parents,
530 };
531 use commonware_cryptography::{sha256, Hasher, Sha256};
532 use commonware_parallel::ThreadPool;
533 use commonware_runtime::{deterministic, tokio, Runner, ThreadPooler};
534 use commonware_utils::NZUsize;
535
536 #[test]
538 fn test_mem_mmr_empty() {
539 let executor = deterministic::Runner::default();
540 executor.start(|_| async move {
541 let mut hasher: Standard<Sha256> = Standard::new();
542 let mmr = Mmr::new(&mut hasher);
543 assert_eq!(
544 mmr.peak_iterator().next(),
545 None,
546 "empty iterator should have no peaks"
547 );
548 assert_eq!(mmr.size(), 0);
549 assert_eq!(mmr.leaves(), Location::new(0));
550 assert!(mmr.bounds().is_empty());
551 assert_eq!(mmr.get_node(Position::new(0)), None);
552 assert_eq!(*mmr.root(), Mmr::empty_mmr_root(hasher.inner()));
553 let mut mmr2 = Mmr::new(&mut hasher);
554 mmr2.prune_all();
555 assert_eq!(mmr2.size(), 0, "prune_all on empty MMR should do nothing");
556
557 assert_eq!(*mmr.root(), hasher.root(Location::new(0), [].iter()));
558 });
559 }
560
561 #[test]
565 fn test_mem_mmr_add_eleven_values() {
566 let executor = deterministic::Runner::default();
567 executor.start(|_| async move {
568 let mut hasher: Standard<Sha256> = Standard::new();
569 let mut mmr = Mmr::new(&mut hasher);
570 let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
571 let mut leaves: Vec<Position> = Vec::new();
572 for _ in 0..11 {
573 let changeset = {
574 let mut batch = mmr.new_batch();
575 leaves.push(batch.add(&mut hasher, &element));
576 batch.merkleize(&mut hasher).finalize()
577 };
578 mmr.apply(changeset).unwrap();
579 let peaks: Vec<(Position, u32)> = mmr.peak_iterator().collect();
580 assert_ne!(peaks.len(), 0);
581 assert!(peaks.len() as u64 <= mmr.size());
582 }
583 assert_eq!(mmr.bounds().start, Location::new(0));
584 assert_eq!(mmr.size(), 19, "mmr not of expected size");
585 assert_eq!(
586 leaves,
587 vec![0, 1, 3, 4, 7, 8, 10, 11, 15, 16, 18]
588 .into_iter()
589 .map(Position::new)
590 .collect::<Vec<_>>(),
591 "mmr leaf positions not as expected"
592 );
593 let peaks: Vec<(Position, u32)> = mmr.peak_iterator().collect();
594 assert_eq!(
595 peaks,
596 vec![
597 (Position::new(14), 3),
598 (Position::new(17), 1),
599 (Position::new(18), 0)
600 ],
601 "mmr peaks not as expected"
602 );
603
604 let peaks_needing_parents = nodes_needing_parents(mmr.peak_iterator());
607 assert_eq!(
608 peaks_needing_parents,
609 vec![Position::new(17), Position::new(18)],
610 "mmr nodes needing parents not as expected"
611 );
612
613 for leaf in leaves.iter().by_ref() {
615 let digest = hasher.leaf_digest(*leaf, &element);
616 assert_eq!(mmr.get_node(*leaf).unwrap(), digest);
617 }
618
619 let digest2 = hasher.node_digest(Position::new(2), &mmr.nodes[0], &mmr.nodes[1]);
621 assert_eq!(mmr.nodes[2], digest2);
622 let digest5 = hasher.node_digest(Position::new(5), &mmr.nodes[3], &mmr.nodes[4]);
623 assert_eq!(mmr.nodes[5], digest5);
624 let digest9 = hasher.node_digest(Position::new(9), &mmr.nodes[7], &mmr.nodes[8]);
625 assert_eq!(mmr.nodes[9], digest9);
626 let digest12 = hasher.node_digest(Position::new(12), &mmr.nodes[10], &mmr.nodes[11]);
627 assert_eq!(mmr.nodes[12], digest12);
628 let digest17 = hasher.node_digest(Position::new(17), &mmr.nodes[15], &mmr.nodes[16]);
629 assert_eq!(mmr.nodes[17], digest17);
630
631 let digest6 = hasher.node_digest(Position::new(6), &mmr.nodes[2], &mmr.nodes[5]);
633 assert_eq!(mmr.nodes[6], digest6);
634 let digest13 = hasher.node_digest(Position::new(13), &mmr.nodes[9], &mmr.nodes[12]);
635 assert_eq!(mmr.nodes[13], digest13);
636 let digest17 = hasher.node_digest(Position::new(17), &mmr.nodes[15], &mmr.nodes[16]);
637 assert_eq!(mmr.nodes[17], digest17);
638
639 let digest14 = hasher.node_digest(Position::new(14), &mmr.nodes[6], &mmr.nodes[13]);
641 assert_eq!(mmr.nodes[14], digest14);
642
643 let root = *mmr.root();
645 let peak_digests = [digest14, digest17, mmr.nodes[18]];
646 let expected_root = hasher.root(Location::new(11), peak_digests.iter());
647 assert_eq!(root, expected_root, "incorrect root");
648
649 mmr.prune(Location::new(8)).unwrap(); assert_eq!(mmr.bounds().start, Location::new(8));
652
653 assert!(matches!(
659 mmr.proof(Location::new(0)),
660 Err(Error::ElementPruned(_))
661 ));
662 assert!(matches!(
663 mmr.proof(Location::new(6)),
664 Err(Error::ElementPruned(_))
665 ));
666
667 assert!(mmr.proof(Location::new(8)).is_ok());
670 assert!(mmr.proof(Location::new(10)).is_ok());
671
672 let root_after_prune = *mmr.root();
673 assert_eq!(root, root_after_prune, "root changed after pruning");
674
675 assert!(
676 mmr.range_proof(Location::new(5)..Location::new(9)).is_err(),
677 "attempts to range_prove elements at or before the oldest retained should fail"
678 );
679 assert!(
680 mmr.range_proof(Location::new(8)..mmr.leaves()).is_ok(),
681 "attempts to range_prove over all elements following oldest retained should succeed"
682 );
683
684 let oldest_loc = mmr.bounds().start;
686 let oldest_pos = Position::try_from(oldest_loc).unwrap();
687 let digests = mmr.node_digests_to_pin(oldest_pos);
688 let mmr_copy = Mmr::init(
689 Config {
690 nodes: mmr.nodes.iter().copied().collect(),
691 pruned_to: oldest_loc,
692 pinned_nodes: digests,
693 },
694 &mut hasher,
695 )
696 .unwrap();
697 assert_eq!(mmr_copy.size(), 19);
698 assert_eq!(mmr_copy.leaves(), mmr.leaves());
699 assert_eq!(mmr_copy.bounds().start, mmr.bounds().start);
700 assert_eq!(*mmr_copy.root(), root);
701 });
702 }
703
704 #[test]
706 fn test_mem_mmr_prune_all() {
707 let executor = deterministic::Runner::default();
708 executor.start(|_| async move {
709 let mut hasher: Standard<Sha256> = Standard::new();
710 let mut mmr = Mmr::new(&mut hasher);
711 let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
712 for _ in 0..1000 {
713 mmr.prune_all();
714 let changeset = {
715 let mut batch = mmr.new_batch();
716 batch.add(&mut hasher, &element);
717 batch.merkleize(&mut hasher).finalize()
718 };
719 mmr.apply(changeset).unwrap();
720 }
721 });
722 }
723
724 #[test]
726 fn test_mem_mmr_validity() {
727 let executor = deterministic::Runner::default();
728 executor.start(|_| async move {
729 let mut hasher: Standard<Sha256> = Standard::new();
730 let mut mmr = Mmr::new(&mut hasher);
731 let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
732 for _ in 0..1001 {
733 assert!(
734 mmr.size().is_mmr_size(),
735 "mmr of size {} should be valid",
736 mmr.size()
737 );
738 let old_size = mmr.size();
739 let changeset = {
740 let mut batch = mmr.new_batch();
741 batch.add(&mut hasher, &element);
742 batch.merkleize(&mut hasher).finalize()
743 };
744 mmr.apply(changeset).unwrap();
745 for size in *old_size + 1..*mmr.size() {
746 assert!(
747 !Position::new(size).is_mmr_size(),
748 "mmr of size {size} should be invalid",
749 );
750 }
751 }
752 });
753 }
754
755 #[test]
758 fn test_mem_mmr_batched_root() {
759 let executor = deterministic::Runner::default();
760 executor.start(|_| async move {
761 let mut hasher: Standard<Sha256> = Standard::new();
762 const NUM_ELEMENTS: u64 = 199;
763 let mut test_mmr = Mmr::new(&mut hasher);
764 test_mmr = build_test_mmr(&mut hasher, test_mmr, NUM_ELEMENTS);
765 let expected_root = test_mmr.root();
766
767 let mut batched_mmr = Mmr::new(&mut hasher);
768
769 let changeset = {
771 let mut batch = batched_mmr.new_batch();
772 for i in 0..NUM_ELEMENTS {
773 hasher.inner().update(&i.to_be_bytes());
774 let element = hasher.inner().finalize();
775 batch.add(&mut hasher, &element);
776 }
777 batch.merkleize(&mut hasher).finalize()
778 };
779 batched_mmr.apply(changeset).unwrap();
780
781 assert_eq!(
782 batched_mmr.root(),
783 expected_root,
784 "Batched MMR root should match reference"
785 );
786 });
787 }
788
789 #[test]
792 fn test_mem_mmr_batched_root_parallel() {
793 let executor = tokio::Runner::default();
794 executor.start(|context| async move {
795 let mut hasher: Standard<Sha256> = Standard::new();
796 const NUM_ELEMENTS: u64 = 199;
797 let test_mmr = Mmr::new(&mut hasher);
798 let test_mmr = build_test_mmr(&mut hasher, test_mmr, NUM_ELEMENTS);
799 let expected_root = test_mmr.root();
800
801 let pool = context.create_thread_pool(NZUsize!(4)).unwrap();
802 let mut hasher: Standard<Sha256> = Standard::new();
803
804 let mut mmr = Mmr::init(
805 Config {
806 nodes: vec![],
807 pruned_to: Location::new(0),
808 pinned_nodes: vec![],
809 },
810 &mut hasher,
811 )
812 .unwrap();
813
814 let changeset = {
815 let mut batch = mmr.new_batch().with_pool(Some(pool));
816 for i in 0u64..NUM_ELEMENTS {
817 hasher.inner().update(&i.to_be_bytes());
818 let element = hasher.inner().finalize();
819 batch.add(&mut hasher, &element);
820 }
821 batch.merkleize(&mut hasher).finalize()
822 };
823 mmr.apply(changeset).unwrap();
824 assert_eq!(
825 mmr.root(),
826 expected_root,
827 "Batched MMR root should match reference"
828 );
829 });
830 }
831
832 #[test]
834 fn test_mem_mmr_root_with_pruning() {
835 let executor = deterministic::Runner::default();
836 executor.start(|_| async move {
837 let mut hasher: Standard<Sha256> = Standard::new();
838 let mut reference_mmr = Mmr::new(&mut hasher);
839 let mut mmr = Mmr::new(&mut hasher);
840 for i in 0u64..200 {
841 hasher.inner().update(&i.to_be_bytes());
842 let element = hasher.inner().finalize();
843
844 let cs = {
846 let mut batch = reference_mmr.new_batch();
847 batch.add(&mut hasher, &element);
848 batch.merkleize(&mut hasher).finalize()
849 };
850 reference_mmr.apply(cs).unwrap();
851
852 let cs = {
854 let mut batch = mmr.new_batch();
855 batch.add(&mut hasher, &element);
856 batch.merkleize(&mut hasher).finalize()
857 };
858 mmr.apply(cs).unwrap();
859
860 mmr.prune_all();
861 assert_eq!(mmr.root(), reference_mmr.root());
862 }
863 });
864 }
865
866 #[test]
867 fn test_mem_mmr_update_leaf() {
868 let mut hasher: Standard<Sha256> = Standard::new();
869 let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
870 let executor = deterministic::Runner::default();
871 executor.start(|_| async move {
872 const NUM_ELEMENTS: u64 = 200;
873 let mmr = Mmr::new(&mut hasher);
874 let mut mmr = build_test_mmr(&mut hasher, mmr, NUM_ELEMENTS);
875 let root = *mmr.root();
876
877 for leaf in [0usize, 1, 10, 50, 100, 150, 197, 198] {
880 let leaf_loc = Location::new(leaf as u64);
882 mmr.update_leaf(&mut hasher, leaf_loc, &element).unwrap();
883 let updated_root = *mmr.root();
884 assert!(root != updated_root);
885
886 hasher.inner().update(&leaf.to_be_bytes());
888 let element = hasher.inner().finalize();
889 mmr.update_leaf(&mut hasher, leaf_loc, &element).unwrap();
890 let restored_root = *mmr.root();
891 assert_eq!(root, restored_root);
892 }
893
894 mmr.prune(Location::new(100)).unwrap();
896 for leaf in 100u64..=190 {
897 mmr.prune(Location::new(leaf)).unwrap();
898 let leaf_loc = Location::new(leaf);
899 mmr.update_leaf(&mut hasher, leaf_loc, &element).unwrap();
900 }
901 });
902 }
903
904 #[test]
905 fn test_mem_mmr_update_leaf_error_out_of_bounds() {
906 let mut hasher: Standard<Sha256> = Standard::new();
907 let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
908
909 let executor = deterministic::Runner::default();
910 executor.start(|_| async move {
911 let mmr = Mmr::new(&mut hasher);
912 let mut mmr = build_test_mmr(&mut hasher, mmr, 200);
913 let invalid_loc = mmr.leaves();
914 let result = mmr.update_leaf(&mut hasher, invalid_loc, &element);
915 assert!(matches!(result, Err(Error::LeafOutOfBounds(_))));
916 });
917 }
918
919 #[test]
920 fn test_mem_mmr_update_leaf_error_pruned() {
921 let mut hasher: Standard<Sha256> = Standard::new();
922 let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
923
924 let executor = deterministic::Runner::default();
925 executor.start(|_| async move {
926 let mmr = Mmr::new(&mut hasher);
927 let mut mmr = build_test_mmr(&mut hasher, mmr, 100);
928 mmr.prune_all();
929 let result = mmr.update_leaf(&mut hasher, Location::new(0), &element);
930 assert!(matches!(result, Err(Error::ElementPruned(_))));
931 });
932 }
933
934 #[test]
935 fn test_mem_mmr_batch_update_leaf() {
936 let mut hasher: Standard<Sha256> = Standard::new();
937 let executor = deterministic::Runner::default();
938 executor.start(|_| async move {
939 let mmr = Mmr::new(&mut hasher);
940 let mmr = build_test_mmr(&mut hasher, mmr, 200);
941 do_batch_update(&mut hasher, mmr, None);
942 });
943 }
944
945 #[test]
948 fn test_mem_mmr_batch_parallel_update_leaf() {
949 let mut hasher: Standard<Sha256> = Standard::new();
950 let executor = tokio::Runner::default();
951 executor.start(|ctx| async move {
952 let mmr = Mmr::init(
953 Config {
954 nodes: Vec::new(),
955 pruned_to: Location::new(0),
956 pinned_nodes: Vec::new(),
957 },
958 &mut hasher,
959 )
960 .unwrap();
961 let mmr = build_test_mmr(&mut hasher, mmr, 200);
962 let pool = ctx.create_thread_pool(NZUsize!(4)).unwrap();
963 do_batch_update(&mut hasher, mmr, Some(pool));
964 });
965 }
966
967 #[test]
968 fn test_update_leaf_digest() {
969 let mut hasher: Standard<Sha256> = Standard::new();
970 let executor = deterministic::Runner::default();
971 executor.start(|_| async move {
972 const NUM_ELEMENTS: u64 = 200;
973 let mmr = Mmr::new(&mut hasher);
974 let mut mmr = build_test_mmr(&mut hasher, mmr, NUM_ELEMENTS);
975 let root = *mmr.root();
976
977 let updated_digest = Sha256::fill(0xFF);
978
979 let loc = Location::new(5);
981 let leaf_pos = Position::try_from(loc).unwrap();
982 let original_digest = mmr.get_node(leaf_pos).unwrap();
983
984 let changeset = {
986 let mut batch = mmr.new_batch();
987 batch.update_leaf_digest(loc, updated_digest).unwrap();
988 batch.merkleize(&mut hasher).finalize()
989 };
990 mmr.apply(changeset).unwrap();
991 assert_ne!(*mmr.root(), root);
992
993 let changeset = {
995 let mut batch = mmr.new_batch();
996 batch.update_leaf_digest(loc, original_digest).unwrap();
997 batch.merkleize(&mut hasher).finalize()
998 };
999 mmr.apply(changeset).unwrap();
1000 assert_eq!(*mmr.root(), root);
1001
1002 let changeset = {
1004 let mut batch = mmr.new_batch();
1005 for i in [0u64, 1, 50, 100, 199] {
1006 batch
1007 .update_leaf_digest(Location::new(i), updated_digest)
1008 .unwrap();
1009 }
1010 batch.merkleize(&mut hasher).finalize()
1011 };
1012 mmr.apply(changeset).unwrap();
1013 assert_ne!(*mmr.root(), root);
1014 });
1015 }
1016
1017 #[test]
1018 fn test_update_leaf_digest_errors() {
1019 let mut hasher: Standard<Sha256> = Standard::new();
1020 let executor = deterministic::Runner::default();
1021 executor.start(|_| async move {
1022 {
1023 let mmr = Mmr::new(&mut hasher);
1025 let mmr = build_test_mmr(&mut hasher, mmr, 100);
1026 let mut batch = mmr.new_batch();
1027 let result = batch.update_leaf_digest(Location::new(100), Sha256::fill(0));
1028 assert!(matches!(result, Err(Error::InvalidPosition(_))));
1029 }
1030
1031 {
1032 let mmr = Mmr::new(&mut hasher);
1034 let mut mmr = build_test_mmr(&mut hasher, mmr, 100);
1035 mmr.prune(Location::new(27)).unwrap();
1036 let mut batch = mmr.new_batch();
1037 let result = batch.update_leaf_digest(Location::new(0), Sha256::fill(0));
1038 assert!(matches!(result, Err(Error::ElementPruned(_))));
1039 }
1040 });
1041 }
1042
1043 fn do_batch_update(
1044 hasher: &mut Standard<Sha256>,
1045 mut mmr: Mmr<sha256::Digest>,
1046 pool: Option<ThreadPool>,
1047 ) {
1048 let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1049 let root = *mmr.root();
1050
1051 let changeset = {
1053 let mut batch = mmr.new_batch();
1054 if let Some(ref pool) = pool {
1055 batch = batch.with_pool(Some(pool.clone()));
1056 }
1057 for leaf in [0u64, 1, 10, 50, 100, 150, 197, 198] {
1058 let leaf_loc = Location::new(leaf);
1059 batch.update_leaf(hasher, leaf_loc, &element).unwrap();
1060 }
1061 batch.merkleize(hasher).finalize()
1062 };
1063 mmr.apply(changeset).unwrap();
1064 let updated_root = *mmr.root();
1065 assert_ne!(updated_root, root);
1066
1067 let changeset = {
1069 let mut batch = mmr.new_batch();
1070 for leaf in [0u64, 1, 10, 50, 100, 150, 197, 198] {
1071 hasher.inner().update(&leaf.to_be_bytes());
1072 let element = hasher.inner().finalize();
1073 let leaf_loc = Location::new(leaf);
1074 batch.update_leaf(hasher, leaf_loc, &element).unwrap();
1075 }
1076 batch.merkleize(hasher).finalize()
1077 };
1078 mmr.apply(changeset).unwrap();
1079 let restored_root = *mmr.root();
1080 assert_eq!(root, restored_root);
1081 }
1082
1083 #[test]
1084 fn test_init_pinned_nodes_validation() {
1085 let executor = deterministic::Runner::default();
1086 executor.start(|_| async move {
1087 let mut hasher: Standard<Sha256> = Standard::new();
1088 let config = Config::<sha256::Digest> {
1090 nodes: vec![],
1091 pruned_to: Location::new(0),
1092 pinned_nodes: vec![],
1093 };
1094 assert!(Mmr::init(config, &mut hasher).is_ok());
1095
1096 let config = Config::<sha256::Digest> {
1099 nodes: vec![],
1100 pruned_to: Location::new(64),
1101 pinned_nodes: vec![],
1102 };
1103 assert!(matches!(
1104 Mmr::init(config, &mut hasher),
1105 Err(Error::InvalidPinnedNodes)
1106 ));
1107
1108 let config = Config {
1110 nodes: vec![],
1111 pruned_to: Location::new(0),
1112 pinned_nodes: vec![Sha256::hash(b"dummy")],
1113 };
1114 assert!(matches!(
1115 Mmr::init(config, &mut hasher),
1116 Err(Error::InvalidPinnedNodes)
1117 ));
1118
1119 let mut mmr = Mmr::new(&mut hasher);
1122 let changeset = {
1123 let mut batch = mmr.new_batch();
1124 for i in 0u64..50 {
1125 batch.add(&mut hasher, &i.to_be_bytes());
1126 }
1127 batch.merkleize(&mut hasher).finalize()
1128 };
1129 mmr.apply(changeset).unwrap();
1130 let pinned_nodes = mmr.node_digests_to_pin(Position::new(50));
1131 let config = Config {
1132 nodes: vec![],
1133 pruned_to: Location::new(27),
1134 pinned_nodes,
1135 };
1136 assert!(Mmr::init(config, &mut hasher).is_ok());
1137 });
1138 }
1139
1140 #[test]
1141 fn test_init_size_validation() {
1142 let executor = deterministic::Runner::default();
1143 executor.start(|_| async move {
1144 let mut hasher: Standard<Sha256> = Standard::new();
1145 let config = Config::<sha256::Digest> {
1147 nodes: vec![],
1148 pruned_to: Location::new(0),
1149 pinned_nodes: vec![],
1150 };
1151 assert!(Mmr::init(config, &mut hasher).is_ok());
1152
1153 let config = Config {
1156 nodes: vec![Sha256::hash(b"node1"), Sha256::hash(b"node2")],
1157 pruned_to: Location::new(0),
1158 pinned_nodes: vec![],
1159 };
1160 assert!(matches!(
1161 Mmr::init(config, &mut hasher),
1162 Err(Error::InvalidSize(_))
1163 ));
1164
1165 let config = Config {
1167 nodes: vec![
1168 Sha256::hash(b"leaf1"),
1169 Sha256::hash(b"leaf2"),
1170 Sha256::hash(b"parent"),
1171 ],
1172 pruned_to: Location::new(0),
1173 pinned_nodes: vec![],
1174 };
1175 assert!(Mmr::init(config, &mut hasher).is_ok());
1176
1177 let mut mmr = Mmr::new(&mut hasher);
1180 let changeset = {
1181 let mut batch = mmr.new_batch();
1182 for i in 0u64..64 {
1183 batch.add(&mut hasher, &i.to_be_bytes());
1184 }
1185 batch.merkleize(&mut hasher).finalize()
1186 };
1187 mmr.apply(changeset).unwrap();
1188 assert_eq!(mmr.size(), 127); let nodes: Vec<_> = (0..127)
1190 .map(|i| *mmr.get_node_unchecked(Position::new(i)))
1191 .collect();
1192
1193 let config = Config {
1194 nodes,
1195 pruned_to: Location::new(0),
1196 pinned_nodes: vec![],
1197 };
1198 assert!(Mmr::init(config, &mut hasher).is_ok());
1199
1200 let mut mmr = Mmr::new(&mut hasher);
1203 let changeset = {
1204 let mut batch = mmr.new_batch();
1205 for i in 0u64..11 {
1206 batch.add(&mut hasher, &i.to_be_bytes());
1207 }
1208 batch.merkleize(&mut hasher).finalize()
1209 };
1210 mmr.apply(changeset).unwrap();
1211 assert_eq!(mmr.size(), 19); mmr.prune(Location::new(4)).unwrap();
1215 let nodes: Vec<_> = (7..*mmr.size())
1216 .map(|i| *mmr.get_node_unchecked(Position::new(i)))
1217 .collect();
1218 let pinned_nodes = mmr.node_digests_to_pin(Position::new(7));
1219
1220 let config = Config {
1221 nodes: nodes.clone(),
1222 pruned_to: Location::new(4),
1223 pinned_nodes: pinned_nodes.clone(),
1224 };
1225 assert!(Mmr::init(config, &mut hasher).is_ok());
1226
1227 let config = Config {
1230 nodes: nodes.clone(),
1231 pruned_to: Location::new(5),
1232 pinned_nodes: pinned_nodes.clone(),
1233 };
1234 assert!(matches!(
1235 Mmr::init(config, &mut hasher),
1236 Err(Error::InvalidSize(_))
1237 ));
1238
1239 let config = Config {
1242 nodes,
1243 pruned_to: Location::new(1),
1244 pinned_nodes,
1245 };
1246 assert!(matches!(
1247 Mmr::init(config, &mut hasher),
1248 Err(Error::InvalidSize(_))
1249 ));
1250 });
1251 }
1252
1253 #[test]
1254 fn test_mem_mmr_range_proof_out_of_bounds() {
1255 let mut hasher: Standard<Sha256> = Standard::new();
1256
1257 let executor = deterministic::Runner::default();
1258 executor.start(|_| async move {
1259 let mmr = Mmr::new(&mut hasher);
1261 assert_eq!(mmr.leaves(), Location::new(0));
1262 let result = mmr.range_proof(Location::new(0)..Location::new(1));
1263 assert!(matches!(result, Err(Error::RangeOutOfBounds(_))));
1264
1265 let mmr = build_test_mmr(&mut hasher, mmr, 10);
1267 assert_eq!(mmr.leaves(), Location::new(10));
1268 let result = mmr.range_proof(Location::new(5)..Location::new(11));
1269 assert!(matches!(result, Err(Error::RangeOutOfBounds(_))));
1270
1271 let result = mmr.range_proof(Location::new(5)..Location::new(10));
1273 assert!(result.is_ok());
1274 });
1275 }
1276
1277 #[test]
1278 fn test_mem_mmr_proof_out_of_bounds() {
1279 let mut hasher: Standard<Sha256> = Standard::new();
1280
1281 let executor = deterministic::Runner::default();
1282 executor.start(|_| async move {
1283 let mmr = Mmr::new(&mut hasher);
1285 let result = mmr.proof(Location::new(0));
1286 assert!(
1287 matches!(result, Err(Error::LeafOutOfBounds(_))),
1288 "expected LeafOutOfBounds, got {:?}",
1289 result
1290 );
1291
1292 let mmr = build_test_mmr(&mut hasher, mmr, 10);
1294 let result = mmr.proof(Location::new(10));
1295 assert!(
1296 matches!(result, Err(Error::LeafOutOfBounds(_))),
1297 "expected LeafOutOfBounds, got {:?}",
1298 result
1299 );
1300
1301 let result = mmr.proof(Location::new(9));
1303 assert!(result.is_ok(), "expected Ok, got {:?}", result);
1304 });
1305 }
1306
1307 #[test]
1308 fn test_stale_changeset_sibling() {
1309 let mut hasher: Standard<Sha256> = Standard::new();
1310
1311 let executor = deterministic::Runner::default();
1312 executor.start(|_| async move {
1313 let mut mmr = Mmr::new(&mut hasher);
1314
1315 let changeset_a = {
1317 let mut batch = mmr.new_batch();
1318 batch.add(&mut hasher, b"leaf-a");
1319 batch.merkleize(&mut hasher).finalize()
1320 };
1321 let changeset_b = {
1322 let mut batch = mmr.new_batch();
1323 batch.add(&mut hasher, b"leaf-b");
1324 batch.merkleize(&mut hasher).finalize()
1325 };
1326
1327 mmr.apply(changeset_a).unwrap();
1329
1330 let result = mmr.apply(changeset_b);
1332 assert!(
1333 matches!(result, Err(Error::StaleChangeset { .. })),
1334 "expected StaleChangeset, got {result:?}"
1335 );
1336 });
1337 }
1338
1339 #[test]
1340 fn test_stale_changeset_chained() {
1341 let mut hasher: Standard<Sha256> = Standard::new();
1342
1343 let executor = deterministic::Runner::default();
1344 executor.start(|_| async move {
1345 let mut mmr = Mmr::new(&mut hasher);
1346
1347 let changeset = {
1349 let mut batch = mmr.new_batch();
1350 batch.add(&mut hasher, b"leaf-0");
1351 batch.merkleize(&mut hasher).finalize()
1352 };
1353 mmr.apply(changeset).unwrap();
1354
1355 let parent = {
1357 let mut batch = mmr.new_batch();
1358 batch.add(&mut hasher, b"leaf-1");
1359 batch.merkleize(&mut hasher)
1360 };
1361 let child_a = {
1362 let mut batch = parent.new_batch();
1363 batch.add(&mut hasher, b"leaf-2a");
1364 batch.merkleize(&mut hasher).finalize()
1365 };
1366 let child_b = {
1367 let mut batch = parent.new_batch();
1368 batch.add(&mut hasher, b"leaf-2b");
1369 batch.merkleize(&mut hasher).finalize()
1370 };
1371
1372 mmr.apply(child_a).unwrap();
1374 let result = mmr.apply(child_b);
1375 assert!(
1376 matches!(result, Err(Error::StaleChangeset { .. })),
1377 "expected StaleChangeset for sibling, got {result:?}"
1378 );
1379 });
1380 }
1381
1382 #[test]
1383 fn test_stale_changeset_parent_before_child() {
1384 let mut hasher: Standard<Sha256> = Standard::new();
1385
1386 let executor = deterministic::Runner::default();
1387 executor.start(|_| async move {
1388 let mut mmr = Mmr::new(&mut hasher);
1389
1390 let parent = {
1392 let mut batch = mmr.new_batch();
1393 batch.add(&mut hasher, b"leaf-0");
1394 batch.merkleize(&mut hasher)
1395 };
1396 let child = {
1397 let mut batch = parent.new_batch();
1398 batch.add(&mut hasher, b"leaf-1");
1399 batch.merkleize(&mut hasher).finalize()
1400 };
1401 let parent = parent.finalize();
1402
1403 mmr.apply(parent).unwrap();
1405 let result = mmr.apply(child);
1406 assert!(
1407 matches!(result, Err(Error::StaleChangeset { .. })),
1408 "expected StaleChangeset for child after parent applied, got {result:?}"
1409 );
1410 });
1411 }
1412
1413 #[test]
1414 fn test_stale_changeset_child_before_parent() {
1415 let mut hasher: Standard<Sha256> = Standard::new();
1416
1417 let executor = deterministic::Runner::default();
1418 executor.start(|_| async move {
1419 let mut mmr = Mmr::new(&mut hasher);
1420
1421 let parent = {
1423 let mut batch = mmr.new_batch();
1424 batch.add(&mut hasher, b"leaf-0");
1425 batch.merkleize(&mut hasher)
1426 };
1427 let child = {
1428 let mut batch = parent.new_batch();
1429 batch.add(&mut hasher, b"leaf-1");
1430 batch.merkleize(&mut hasher).finalize()
1431 };
1432 let parent = parent.finalize();
1433
1434 mmr.apply(child).unwrap();
1436 let result = mmr.apply(parent);
1437 assert!(
1438 matches!(result, Err(Error::StaleChangeset { .. })),
1439 "expected StaleChangeset for parent after child applied, got {result:?}"
1440 );
1441 });
1442 }
1443}