1use crate::mmr::{
4 hasher::Hasher,
5 iterator::{nodes_needing_parents, nodes_to_pin, PathIterator, PeakIterator},
6 proof,
7 Error::{self, *},
8 Location, Position, Proof,
9};
10use alloc::{
11 collections::{BTreeMap, BTreeSet, VecDeque},
12 vec::Vec,
13};
14use commonware_cryptography::Digest;
15use core::{mem, ops::Range};
16cfg_if::cfg_if! {
17 if #[cfg(feature = "std")] {
18 use commonware_parallel::ThreadPool;
19 use rayon::prelude::*;
20 } else {
21 struct ThreadPool;
22 }
23}
24
25#[cfg(feature = "std")]
27const MIN_TO_PARALLELIZE: usize = 20;
28
29pub type DirtyMmr<D> = Mmr<D, Dirty>;
31
32pub type CleanMmr<D> = Mmr<D, Clean<D>>;
34
35mod private {
37 pub trait Sealed {}
38}
39
40pub trait State<D: Digest>: private::Sealed + Sized + Send + Sync {
42 fn add_leaf_digest<H: Hasher<D>>(mmr: &mut Mmr<D, Self>, hasher: &mut H, digest: D)
44 -> Position;
45}
46
47#[derive(Clone, Copy, Debug)]
49pub struct Clean<D: Digest> {
50 pub root: D,
52}
53
54impl<D: Digest> private::Sealed for Clean<D> {}
55impl<D: Digest> State<D> for Clean<D> {
56 fn add_leaf_digest<H: Hasher<D>>(mmr: &mut CleanMmr<D>, hasher: &mut H, digest: D) -> Position {
57 mmr.add_leaf_digest(hasher, digest)
58 }
59}
60
61#[derive(Clone, Debug, Default)]
63pub struct Dirty {
64 dirty_nodes: BTreeSet<(Position, u32)>,
68}
69
70impl private::Sealed for Dirty {}
71impl<D: Digest> State<D> for Dirty {
72 fn add_leaf_digest<H: Hasher<D>>(
73 mmr: &mut DirtyMmr<D>,
74 _hasher: &mut H,
75 digest: D,
76 ) -> Position {
77 mmr.add_leaf_digest(digest)
78 }
79}
80
81pub struct Config<D: Digest> {
83 pub nodes: Vec<D>,
85
86 pub pruned_to_pos: Position,
89
90 pub pinned_nodes: Vec<D>,
92}
93
94#[derive(Clone, Debug)]
118pub struct Mmr<D: Digest, S: State<D> = Dirty> {
119 nodes: VecDeque<D>,
122
123 pruned_to_pos: Position,
126
127 pinned_nodes: BTreeMap<Position, D>,
129
130 state: S,
132}
133
134impl<D: Digest> Default for DirtyMmr<D> {
135 fn default() -> Self {
136 Self::new()
137 }
138}
139
140impl<D: Digest> From<CleanMmr<D>> for DirtyMmr<D> {
141 fn from(clean: CleanMmr<D>) -> Self {
142 DirtyMmr {
143 nodes: clean.nodes,
144 pruned_to_pos: clean.pruned_to_pos,
145 pinned_nodes: clean.pinned_nodes,
146 state: Dirty {
147 dirty_nodes: BTreeSet::new(),
148 },
149 }
150 }
151}
152
153impl<D: Digest, S: State<D>> Mmr<D, S> {
154 pub fn size(&self) -> Position {
157 Position::new(self.nodes.len() as u64 + *self.pruned_to_pos)
158 }
159
160 pub fn leaves(&self) -> Location {
162 Location::try_from(self.size()).expect("invalid mmr size")
163 }
164
165 pub fn last_leaf_pos(&self) -> Option<Position> {
167 if self.size() == 0 {
168 return None;
169 }
170
171 Some(PeakIterator::last_leaf_pos(self.size()))
172 }
173
174 pub const fn pruned_to_pos(&self) -> Position {
177 self.pruned_to_pos
178 }
179
180 pub fn oldest_retained_pos(&self) -> Option<Position> {
183 if self.pruned_to_pos == self.size() {
184 return None;
185 }
186
187 Some(self.pruned_to_pos)
188 }
189
190 pub fn peak_iterator(&self) -> PeakIterator {
192 PeakIterator::new(self.size())
193 }
194
195 fn index_to_pos(&self, index: usize) -> Position {
197 self.pruned_to_pos + (index as u64)
198 }
199
200 pub(crate) fn get_node_unchecked(&self, pos: Position) -> &D {
213 if pos < self.pruned_to_pos {
214 return self
215 .pinned_nodes
216 .get(&pos)
217 .expect("requested node is pruned and not pinned");
218 }
219
220 &self.nodes[self.pos_to_index(pos)]
221 }
222
223 fn pos_to_index(&self, pos: Position) -> usize {
229 assert!(
230 pos >= self.pruned_to_pos,
231 "pos precedes oldest retained position"
232 );
233
234 *pos.checked_sub(*self.pruned_to_pos).unwrap() as usize
235 }
236
237 #[cfg(any(feature = "std", test))]
240 pub(crate) fn add_pinned_nodes(&mut self, pinned_nodes: BTreeMap<Position, D>) {
241 for (pos, node) in pinned_nodes.into_iter() {
242 self.pinned_nodes.insert(pos, node);
243 }
244 }
245
246 pub fn add<H: Hasher<D>>(&mut self, hasher: &mut H, element: &[u8]) -> Position {
249 let digest = hasher.leaf_digest(self.size(), element);
250 S::add_leaf_digest(self, hasher, digest)
251 }
252}
253
254impl<D: Digest> CleanMmr<D> {
256 pub fn init(config: Config<D>, hasher: &mut impl Hasher<D>) -> Result<Self, Error> {
265 let Some(size) = config.pruned_to_pos.checked_add(config.nodes.len() as u64) else {
267 return Err(Error::InvalidSize(u64::MAX));
268 };
269 if !size.is_mmr_size() {
270 return Err(Error::InvalidSize(*size));
271 }
272
273 let mut pinned_nodes = BTreeMap::new();
275 let mut expected_pinned_nodes = 0;
276 for (i, pos) in nodes_to_pin(config.pruned_to_pos).enumerate() {
277 expected_pinned_nodes += 1;
278 if i >= config.pinned_nodes.len() {
279 return Err(Error::InvalidPinnedNodes);
280 }
281 pinned_nodes.insert(pos, config.pinned_nodes[i]);
282 }
283
284 if config.pinned_nodes.len() != expected_pinned_nodes {
286 return Err(Error::InvalidPinnedNodes);
287 }
288
289 let mmr = Mmr {
290 nodes: VecDeque::from(config.nodes),
291 pruned_to_pos: config.pruned_to_pos,
292 pinned_nodes,
293 state: Dirty::default(),
294 };
295 Ok(mmr.merkleize(hasher, None))
296 }
297
298 pub fn new(hasher: &mut impl Hasher<D>) -> Self {
300 let mmr: DirtyMmr<D> = Default::default();
301 mmr.merkleize(hasher, None)
302 }
303
304 pub fn from_components(
306 hasher: &mut impl Hasher<D>,
307 nodes: Vec<D>,
308 pruned_to_pos: Position,
309 pinned_nodes: Vec<D>,
310 ) -> Self {
311 DirtyMmr::from_components(nodes, pruned_to_pos, pinned_nodes).merkleize(hasher, None)
312 }
313
314 pub fn get_node(&self, pos: Position) -> Option<D> {
316 if pos < self.pruned_to_pos {
317 return self.pinned_nodes.get(&pos).copied();
318 }
319
320 self.nodes.get(self.pos_to_index(pos)).copied()
321 }
322
323 pub(super) fn add_leaf_digest(&mut self, hasher: &mut impl Hasher<D>, digest: D) -> Position {
326 let mut dirty_mmr = mem::replace(self, Self::new(hasher)).into_dirty();
327 let leaf_pos = dirty_mmr.add_leaf_digest(digest);
328 *self = dirty_mmr.merkleize(hasher, None);
329 leaf_pos
330 }
331
332 pub fn pop(&mut self, hasher: &mut impl Hasher<D>) -> Result<Position, Error> {
335 let mut dirty_mmr = mem::replace(self, Self::new(hasher)).into_dirty();
336 let result = dirty_mmr.pop();
337 *self = dirty_mmr.merkleize(hasher, None);
338 result
339 }
340
341 pub(crate) fn nodes_to_pin(&self, prune_pos: Position) -> BTreeMap<Position, D> {
344 nodes_to_pin(prune_pos)
345 .map(|pos| (pos, *self.get_node_unchecked(pos)))
346 .collect()
347 }
348
349 pub fn prune_to_pos(&mut self, pos: Position) {
352 self.pinned_nodes = self.nodes_to_pin(pos);
354 let retained_nodes = self.pos_to_index(pos);
355 self.nodes.drain(0..retained_nodes);
356 self.pruned_to_pos = pos;
357 }
358
359 pub fn prune_all(&mut self) {
362 if !self.nodes.is_empty() {
363 let pos = self.index_to_pos(self.nodes.len());
364 self.prune_to_pos(pos);
365 }
366 }
367
368 pub fn update_leaf(
382 &mut self,
383 hasher: &mut impl Hasher<D>,
384 loc: Location,
385 element: &[u8],
386 ) -> Result<(), Error> {
387 let mut dirty_mmr = mem::replace(self, Self::new(hasher)).into_dirty();
388 let result = dirty_mmr.update_leaf(hasher, loc, element);
389 *self = dirty_mmr.merkleize(hasher, None);
390 result
391 }
392
393 pub fn into_dirty(self) -> DirtyMmr<D> {
395 self.into()
396 }
397
398 pub const fn root(&self) -> &D {
400 &self.state.root
401 }
402
403 pub fn empty_mmr_root(hasher: &mut impl commonware_cryptography::Hasher<Digest = D>) -> D {
405 hasher.update(&0u64.to_be_bytes());
406 hasher.finalize()
407 }
408
409 pub fn proof(&self, loc: Location) -> Result<Proof<D>, Error> {
420 if !loc.is_valid() {
421 return Err(Error::LocationOverflow(loc));
422 }
423 self.range_proof(loc..loc + 1)
425 }
426
427 pub fn range_proof(&self, range: Range<Location>) -> Result<Proof<D>, Error> {
439 let leaves = self.leaves();
440 assert!(
441 range.start < leaves,
442 "range start {} >= leaf count {}",
443 range.start,
444 leaves
445 );
446 assert!(
447 range.end <= leaves,
448 "range end {} > leaf count {}",
449 range.end,
450 leaves
451 );
452
453 let size = self.size();
454 let positions = proof::nodes_required_for_range_proof(size, range)?;
455 let digests = positions
456 .into_iter()
457 .map(|pos| self.get_node(pos).ok_or(Error::ElementPruned(pos)))
458 .collect::<Result<Vec<_>, _>>()?;
459
460 Ok(Proof { size, digests })
461 }
462
463 #[cfg(test)]
466 pub(crate) fn node_digests_to_pin(&self, start_pos: Position) -> Vec<D> {
467 nodes_to_pin(start_pos)
468 .map(|pos| *self.get_node_unchecked(pos))
469 .collect()
470 }
471
472 #[cfg(test)]
475 pub(super) fn pinned_nodes(&self) -> BTreeMap<Position, D> {
476 self.pinned_nodes.clone()
477 }
478}
479
480impl<D: Digest> DirtyMmr<D> {
482 pub fn new() -> Self {
484 Self {
485 nodes: VecDeque::new(),
486 pruned_to_pos: Position::new(0),
487 pinned_nodes: BTreeMap::new(),
488 state: Dirty::default(),
489 }
490 }
491
492 pub fn from_components(nodes: Vec<D>, pruned_to_pos: Position, pinned_nodes: Vec<D>) -> Self {
494 Self {
495 nodes: VecDeque::from(nodes),
496 pruned_to_pos,
497 pinned_nodes: nodes_to_pin(pruned_to_pos)
498 .enumerate()
499 .map(|(i, pos)| (pos, pinned_nodes[i]))
500 .collect(),
501 state: Dirty::default(),
502 }
503 }
504
505 pub(super) fn add_leaf_digest(&mut self, digest: D) -> Position {
507 let nodes_needing_parents = nodes_needing_parents(self.peak_iterator())
509 .into_iter()
510 .rev();
511 let leaf_pos = self.size();
512 self.nodes.push_back(digest);
513
514 let mut height = 1;
515 for _ in nodes_needing_parents {
516 let new_node_pos = self.size();
517 self.nodes.push_back(D::EMPTY);
518 self.state.dirty_nodes.insert((new_node_pos, height));
519 height += 1;
520 }
521
522 leaf_pos
523 }
524
525 pub fn pop(&mut self) -> Result<Position, Error> {
528 if self.size() == 0 {
529 return Err(Empty);
530 }
531
532 let mut new_size = self.size() - 1;
533 loop {
534 if new_size < self.pruned_to_pos {
535 return Err(ElementPruned(new_size));
536 }
537 if new_size.is_mmr_size() {
538 break;
539 }
540 new_size -= 1;
541 }
542 let num_to_drain = *(self.size() - new_size) as usize;
543 self.nodes.drain(self.nodes.len() - num_to_drain..);
544
545 let cutoff = (self.size(), 0);
547 self.state.dirty_nodes.split_off(&cutoff);
548
549 Ok(self.size())
550 }
551
552 pub fn merkleize(
555 mut self,
556 hasher: &mut impl Hasher<D>,
557 pool: Option<ThreadPool>,
558 ) -> CleanMmr<D> {
559 #[cfg(feature = "std")]
560 match (pool, self.state.dirty_nodes.len() >= MIN_TO_PARALLELIZE) {
561 (Some(pool), true) => self.merkleize_parallel(hasher, pool, MIN_TO_PARALLELIZE),
562 _ => self.merkleize_serial(hasher),
563 }
564
565 #[cfg(not(feature = "std"))]
566 self.merkleize_serial(hasher);
567
568 let peaks = self
570 .peak_iterator()
571 .map(|(peak_pos, _)| self.get_node_unchecked(peak_pos));
572 let size = self.size();
573 let digest = hasher.root(size, peaks);
574
575 CleanMmr {
576 nodes: self.nodes,
577 pruned_to_pos: self.pruned_to_pos,
578 pinned_nodes: self.pinned_nodes,
579 state: Clean { root: digest },
580 }
581 }
582
583 fn merkleize_serial(&mut self, hasher: &mut impl Hasher<D>) {
584 let mut nodes: Vec<(Position, u32)> = self.state.dirty_nodes.iter().copied().collect();
585 self.state.dirty_nodes.clear();
586 nodes.sort_by(|a, b| a.1.cmp(&b.1));
587
588 for (pos, height) in nodes {
589 let left = pos - (1 << height);
590 let right = pos - 1;
591 let digest = hasher.node_digest(
592 pos,
593 self.get_node_unchecked(left),
594 self.get_node_unchecked(right),
595 );
596 let index = self.pos_to_index(pos);
597 self.nodes[index] = digest;
598 }
599 }
600
601 #[cfg(feature = "std")]
609 fn merkleize_parallel(
610 &mut self,
611 hasher: &mut impl Hasher<D>,
612 pool: ThreadPool,
613 min_to_parallelize: usize,
614 ) {
615 let mut nodes: Vec<(Position, u32)> = self.state.dirty_nodes.iter().copied().collect();
616 self.state.dirty_nodes.clear();
617 nodes.sort_by(|a, b| a.1.cmp(&b.1));
619
620 let mut same_height = Vec::new();
621 let mut current_height = 1;
622 for (i, (pos, height)) in nodes.iter().enumerate() {
623 if *height == current_height {
624 same_height.push(*pos);
625 continue;
626 }
627 if same_height.len() < min_to_parallelize {
628 self.state.dirty_nodes = nodes[i - same_height.len()..].iter().copied().collect();
629 self.merkleize_serial(hasher);
630 return;
631 }
632 self.update_node_digests(hasher, pool.clone(), &same_height, current_height);
633 same_height.clear();
634 current_height += 1;
635 same_height.push(*pos);
636 }
637
638 if same_height.len() < min_to_parallelize {
639 self.state.dirty_nodes = nodes[nodes.len() - same_height.len()..]
640 .iter()
641 .copied()
642 .collect();
643 self.merkleize_serial(hasher);
644 return;
645 }
646
647 self.update_node_digests(hasher, pool, &same_height, current_height);
648 }
649
650 #[cfg(feature = "std")]
653 fn update_node_digests(
654 &mut self,
655 hasher: &mut impl Hasher<D>,
656 pool: ThreadPool,
657 same_height: &[Position],
658 height: u32,
659 ) {
660 let two_h = 1 << height;
661 pool.install(|| {
662 let computed_digests: Vec<(usize, D)> = same_height
663 .par_iter()
664 .map_init(
665 || hasher.fork(),
666 |hasher, &pos| {
667 let left = pos - two_h;
668 let right = pos - 1;
669 let digest = hasher.node_digest(
670 pos,
671 self.get_node_unchecked(left),
672 self.get_node_unchecked(right),
673 );
674 let index = self.pos_to_index(pos);
675 (index, digest)
676 },
677 )
678 .collect();
679
680 for (index, digest) in computed_digests {
681 self.nodes[index] = digest;
682 }
683 });
684 }
685
686 fn mark_dirty(&mut self, pos: Position) {
689 for (peak_pos, mut height) in self.peak_iterator() {
690 if peak_pos < pos {
691 continue;
692 }
693
694 let path = PathIterator::new(pos, peak_pos, height)
697 .collect::<Vec<_>>()
698 .into_iter()
699 .rev();
700 height = 1;
701 for (parent_pos, _) in path {
702 if !self.state.dirty_nodes.insert((parent_pos, height)) {
703 break;
704 }
705 height += 1;
706 }
707 return;
708 }
709
710 panic!("invalid pos {pos}:{}", self.size());
711 }
712
713 pub fn update_leaf(
715 &mut self,
716 hasher: &mut impl Hasher<D>,
717 loc: Location,
718 element: &[u8],
719 ) -> Result<(), Error> {
720 self.update_leaf_batched(hasher, None, &[(loc, element)])
721 }
722
723 pub fn update_leaf_batched<T: AsRef<[u8]> + Sync>(
731 &mut self,
732 hasher: &mut impl Hasher<D>,
733 pool: Option<ThreadPool>,
734 updates: &[(Location, T)],
735 ) -> Result<(), Error> {
736 if updates.is_empty() {
737 return Ok(());
738 }
739
740 let leaves = self.leaves();
741 let mut positions = Vec::with_capacity(updates.len());
742 for (loc, _) in updates {
743 if *loc >= leaves {
744 return Err(Error::LeafOutOfBounds(*loc));
745 }
746 let pos = Position::try_from(*loc)?;
747 if pos < self.pruned_to_pos {
748 return Err(Error::ElementPruned(pos));
749 }
750 positions.push(pos);
751 }
752
753 #[cfg(feature = "std")]
754 if let Some(pool) = pool {
755 if updates.len() >= MIN_TO_PARALLELIZE {
756 self.update_leaf_parallel(hasher, pool, updates, &positions);
757 return Ok(());
758 }
759 }
760
761 for ((_, element), pos) in updates.iter().zip(positions.iter()) {
762 let digest = hasher.leaf_digest(*pos, element.as_ref());
764 let index = self.pos_to_index(*pos);
765 self.nodes[index] = digest;
766 self.mark_dirty(*pos);
767 }
768
769 Ok(())
770 }
771
772 #[cfg(feature = "std")]
774 fn update_leaf_parallel<T: AsRef<[u8]> + Sync>(
775 &mut self,
776 hasher: &mut impl Hasher<D>,
777 pool: ThreadPool,
778 updates: &[(Location, T)],
779 positions: &[Position],
780 ) {
781 pool.install(|| {
782 let digests: Vec<(Position, D)> = updates
783 .par_iter()
784 .zip(positions.par_iter())
785 .map_init(
786 || hasher.fork(),
787 |hasher, ((_, elem), pos)| {
788 let digest = hasher.leaf_digest(*pos, elem.as_ref());
789 (*pos, digest)
790 },
791 )
792 .collect();
793
794 for (pos, digest) in digests {
795 let index = self.pos_to_index(pos);
796 self.nodes[index] = digest;
797 self.mark_dirty(pos);
798 }
799 });
800 }
801}
802
803#[cfg(test)]
804mod tests {
805 use super::*;
806 use crate::mmr::{
807 hasher::{Hasher as _, Standard},
808 stability::ROOTS,
809 };
810 use commonware_cryptography::{sha256, Hasher, Sha256};
811 use commonware_runtime::{deterministic, tokio, RayonPoolSpawner, Runner};
812 use commonware_utils::{hex, NZUsize};
813
814 fn build_and_check_test_roots_mmr(mmr: &mut CleanMmr<sha256::Digest>) {
816 let mut hasher: Standard<Sha256> = Standard::new();
817 for i in 0u64..199 {
818 hasher.inner().update(&i.to_be_bytes());
819 let element = hasher.inner().finalize();
820 let root = *mmr.root();
821 let expected_root = ROOTS[i as usize];
822 assert_eq!(hex(&root), expected_root, "at: {i}");
823 mmr.add(&mut hasher, &element);
824 }
825 assert_eq!(hex(mmr.root()), ROOTS[199], "Root after 200 elements");
826 }
827
828 pub fn build_batched_and_check_test_roots(
830 mut mmr: DirtyMmr<sha256::Digest>,
831 pool: Option<ThreadPool>,
832 ) {
833 let mut hasher: Standard<Sha256> = Standard::new();
834 for i in 0u64..199 {
835 hasher.inner().update(&i.to_be_bytes());
836 let element = hasher.inner().finalize();
837 mmr.add(&mut hasher, &element);
838 }
839 let mmr = mmr.merkleize(&mut hasher, pool);
840 assert_eq!(hex(mmr.root()), ROOTS[199], "Root after 200 elements");
841 }
842
843 #[test]
845 fn test_mem_mmr_empty() {
846 let executor = deterministic::Runner::default();
847 executor.start(|_| async move {
848 let mut hasher: Standard<Sha256> = Standard::new();
849 let mut mmr = CleanMmr::new(&mut hasher);
850 assert_eq!(
851 mmr.peak_iterator().next(),
852 None,
853 "empty iterator should have no peaks"
854 );
855 assert_eq!(mmr.size(), 0);
856 assert_eq!(mmr.leaves(), Location::new_unchecked(0));
857 assert_eq!(mmr.last_leaf_pos(), None);
858 assert_eq!(mmr.oldest_retained_pos(), None);
859 assert_eq!(mmr.get_node(Position::new(0)), None);
860 assert_eq!(*mmr.root(), Mmr::empty_mmr_root(hasher.inner()));
861 assert!(matches!(mmr.pop(&mut hasher), Err(Empty)));
862 mmr.prune_all();
863 assert_eq!(mmr.size(), 0, "prune_all on empty MMR should do nothing");
864
865 assert_eq!(*mmr.root(), hasher.root(Position::new(0), [].iter()));
866 });
867 }
868
869 #[test]
873 fn test_mem_mmr_add_eleven_values() {
874 let executor = deterministic::Runner::default();
875 executor.start(|_| async move {
876 let mut hasher: Standard<Sha256> = Standard::new();
877 let mut mmr = CleanMmr::new(&mut hasher);
878 let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
879 let mut leaves: Vec<Position> = Vec::new();
880 for _ in 0..11 {
881 leaves.push(mmr.add(&mut hasher, &element));
882 let peaks: Vec<(Position, u32)> = mmr.peak_iterator().collect();
883 assert_ne!(peaks.len(), 0);
884 assert!(peaks.len() as u64 <= mmr.size());
885 let nodes_needing_parents = nodes_needing_parents(mmr.peak_iterator());
886 assert!(nodes_needing_parents.len() <= peaks.len());
887 }
888 assert_eq!(mmr.oldest_retained_pos().unwrap(), Position::new(0));
889 assert_eq!(mmr.size(), 19, "mmr not of expected size");
890 assert_eq!(
891 leaves,
892 vec![0, 1, 3, 4, 7, 8, 10, 11, 15, 16, 18]
893 .into_iter()
894 .map(Position::new)
895 .collect::<Vec<_>>(),
896 "mmr leaf positions not as expected"
897 );
898 let peaks: Vec<(Position, u32)> = mmr.peak_iterator().collect();
899 assert_eq!(
900 peaks,
901 vec![
902 (Position::new(14), 3),
903 (Position::new(17), 1),
904 (Position::new(18), 0)
905 ],
906 "mmr peaks not as expected"
907 );
908
909 let peaks_needing_parents = nodes_needing_parents(mmr.peak_iterator());
912 assert_eq!(
913 peaks_needing_parents,
914 vec![Position::new(17), Position::new(18)],
915 "mmr nodes needing parents not as expected"
916 );
917
918 for leaf in leaves.iter().by_ref() {
920 let digest = hasher.leaf_digest(*leaf, &element);
921 assert_eq!(mmr.get_node(*leaf).unwrap(), digest);
922 }
923
924 let digest2 = hasher.node_digest(Position::new(2), &mmr.nodes[0], &mmr.nodes[1]);
926 assert_eq!(mmr.nodes[2], digest2);
927 let digest5 = hasher.node_digest(Position::new(5), &mmr.nodes[3], &mmr.nodes[4]);
928 assert_eq!(mmr.nodes[5], digest5);
929 let digest9 = hasher.node_digest(Position::new(9), &mmr.nodes[7], &mmr.nodes[8]);
930 assert_eq!(mmr.nodes[9], digest9);
931 let digest12 = hasher.node_digest(Position::new(12), &mmr.nodes[10], &mmr.nodes[11]);
932 assert_eq!(mmr.nodes[12], digest12);
933 let digest17 = hasher.node_digest(Position::new(17), &mmr.nodes[15], &mmr.nodes[16]);
934 assert_eq!(mmr.nodes[17], digest17);
935
936 let digest6 = hasher.node_digest(Position::new(6), &mmr.nodes[2], &mmr.nodes[5]);
938 assert_eq!(mmr.nodes[6], digest6);
939 let digest13 = hasher.node_digest(Position::new(13), &mmr.nodes[9], &mmr.nodes[12]);
940 assert_eq!(mmr.nodes[13], digest13);
941 let digest17 = hasher.node_digest(Position::new(17), &mmr.nodes[15], &mmr.nodes[16]);
942 assert_eq!(mmr.nodes[17], digest17);
943
944 let digest14 = hasher.node_digest(Position::new(14), &mmr.nodes[6], &mmr.nodes[13]);
946 assert_eq!(mmr.nodes[14], digest14);
947
948 let root = *mmr.root();
950 let peak_digests = [digest14, digest17, mmr.nodes[18]];
951 let expected_root = hasher.root(Position::new(19), peak_digests.iter());
952 assert_eq!(root, expected_root, "incorrect root");
953
954 mmr.prune_to_pos(Position::new(14)); assert_eq!(mmr.oldest_retained_pos().unwrap(), Position::new(14));
957
958 assert!(matches!(
964 mmr.proof(Location::new_unchecked(0)),
965 Err(ElementPruned(_))
966 ));
967 assert!(matches!(
968 mmr.proof(Location::new_unchecked(6)),
969 Err(ElementPruned(_))
970 ));
971
972 assert!(mmr.proof(Location::new_unchecked(8)).is_ok());
975 assert!(mmr.proof(Location::new_unchecked(10)).is_ok());
976
977 let root_after_prune = *mmr.root();
978 assert_eq!(root, root_after_prune, "root changed after pruning");
979
980 assert!(
981 mmr.range_proof(Location::new_unchecked(5)..Location::new_unchecked(9))
982 .is_err(),
983 "attempts to range_prove elements at or before the oldest retained should fail"
984 );
985 assert!(
986 mmr.range_proof(Location::new_unchecked(8)..mmr.leaves()).is_ok(),
987 "attempts to range_prove over all elements following oldest retained should succeed"
988 );
989
990 let oldest_pos = mmr.oldest_retained_pos().unwrap();
992 let digests = mmr.node_digests_to_pin(oldest_pos);
993 let mmr_copy = Mmr::init(
994 Config {
995 nodes: mmr.nodes.iter().copied().collect(),
996 pruned_to_pos: oldest_pos,
997 pinned_nodes: digests,
998 },
999 &mut hasher,
1000 )
1001 .unwrap();
1002 assert_eq!(mmr_copy.size(), 19);
1003 assert_eq!(mmr_copy.leaves(), mmr.leaves());
1004 assert_eq!(mmr_copy.last_leaf_pos(), mmr.last_leaf_pos());
1005 assert_eq!(mmr_copy.oldest_retained_pos(), mmr.oldest_retained_pos());
1006 assert_eq!(*mmr_copy.root(), root);
1007 });
1008 }
1009
1010 #[test]
1012 fn test_mem_mmr_prune_all() {
1013 let executor = deterministic::Runner::default();
1014 executor.start(|_| async move {
1015 let mut hasher: Standard<Sha256> = Standard::new();
1016 let mut mmr = CleanMmr::new(&mut hasher);
1017 let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1018 for _ in 0..1000 {
1019 mmr.prune_all();
1020 mmr.add(&mut hasher, &element);
1021 }
1022 });
1023 }
1024
1025 #[test]
1027 fn test_mem_mmr_validity() {
1028 let executor = deterministic::Runner::default();
1029 executor.start(|_| async move {
1030 let mut hasher: Standard<Sha256> = Standard::new();
1031 let mut mmr = CleanMmr::new(&mut hasher);
1032 let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1033 for _ in 0..1001 {
1034 assert!(
1035 mmr.size().is_mmr_size(),
1036 "mmr of size {} should be valid",
1037 mmr.size()
1038 );
1039 let old_size = mmr.size();
1040 mmr.add(&mut hasher, &element);
1041 for size in *old_size + 1..*mmr.size() {
1042 assert!(
1043 !Position::new(size).is_mmr_size(),
1044 "mmr of size {size} should be invalid",
1045 );
1046 }
1047 }
1048 });
1049 }
1050
1051 #[test]
1054 fn test_mem_mmr_root_stability() {
1055 let executor = deterministic::Runner::default();
1056 executor.start(|_| async move {
1057 let mut hasher: Standard<Sha256> = Standard::new();
1059 let mut mmr = CleanMmr::new(&mut hasher);
1060 build_and_check_test_roots_mmr(&mut mmr);
1061
1062 let mut hasher: Standard<Sha256> = Standard::new();
1063 let mmr = CleanMmr::new(&mut hasher);
1064 build_batched_and_check_test_roots(mmr.into_dirty(), None);
1065 });
1066 }
1067
1068 #[test]
1071 fn test_mem_mmr_root_stability_parallel() {
1072 let executor = tokio::Runner::default();
1073 executor.start(|context| async move {
1074 let pool = context.create_pool(NZUsize!(4)).unwrap();
1075 let mut hasher: Standard<Sha256> = Standard::new();
1076
1077 let mmr = Mmr::init(
1078 Config {
1079 nodes: vec![],
1080 pruned_to_pos: Position::new(0),
1081 pinned_nodes: vec![],
1082 },
1083 &mut hasher,
1084 )
1085 .unwrap();
1086 build_batched_and_check_test_roots(mmr.into_dirty(), Some(pool));
1087 });
1088 }
1089
1090 #[test]
1093 fn test_mem_mmr_root_stability_while_pruning() {
1094 let executor = deterministic::Runner::default();
1095 executor.start(|_| async move {
1096 let mut hasher: Standard<Sha256> = Standard::new();
1097 let mut mmr = CleanMmr::new(&mut hasher);
1098 for i in 0u64..199 {
1099 let root = *mmr.root();
1100 let expected_root = ROOTS[i as usize];
1101 assert_eq!(hex(&root), expected_root, "at: {i}");
1102 hasher.inner().update(&i.to_be_bytes());
1103 let element = hasher.inner().finalize();
1104 mmr.add(&mut hasher, &element);
1105 mmr.prune_all();
1106 }
1107 });
1108 }
1109
1110 fn compute_big_mmr(
1111 hasher: &mut Standard<Sha256>,
1112 mut mmr: DirtyMmr<sha256::Digest>,
1113 pool: Option<ThreadPool>,
1114 ) -> (CleanMmr<sha256::Digest>, Vec<Position>) {
1115 let mut leaves = Vec::new();
1116 let mut c_hasher = Sha256::default();
1117 for i in 0u64..199 {
1118 c_hasher.update(&i.to_be_bytes());
1119 let element = c_hasher.finalize();
1120 let leaf_pos = mmr.size();
1121 mmr.add(hasher, &element);
1122 leaves.push(leaf_pos);
1123 }
1124
1125 (mmr.merkleize(hasher, pool), leaves)
1126 }
1127
1128 #[test]
1129 fn test_mem_mmr_pop() {
1130 let executor = deterministic::Runner::default();
1131 executor.start(|_| async move {
1132 let mut hasher: Standard<Sha256> = Standard::new();
1133 let (mut mmr, _) = compute_big_mmr(&mut hasher, Mmr::default(), None);
1134 let root = *mmr.root();
1135 let expected_root = ROOTS[199];
1136 assert_eq!(hex(&root), expected_root);
1137
1138 for i in (0..199u64).rev() {
1140 assert!(mmr.pop(&mut hasher).is_ok());
1141 let root = *mmr.root();
1142 let expected_root = ROOTS[i as usize];
1143 assert_eq!(hex(&root), expected_root);
1144 }
1145
1146 assert!(
1147 matches!(mmr.pop(&mut hasher).unwrap_err(), Empty),
1148 "pop on empty MMR should fail"
1149 );
1150
1151 for i in 0u64..199 {
1153 hasher.inner().update(&i.to_be_bytes());
1154 let element = hasher.inner().finalize();
1155 mmr.add(&mut hasher, &element);
1156 }
1157
1158 let leaf_pos = Position::try_from(Location::new_unchecked(100)).unwrap();
1159 mmr.prune_to_pos(leaf_pos);
1160 while mmr.size() > leaf_pos {
1161 mmr.pop(&mut hasher).unwrap();
1162 }
1163 assert_eq!(hex(mmr.root()), ROOTS[100]);
1164 let result = mmr.pop(&mut hasher);
1165 assert!(matches!(result, Err(ElementPruned(_))));
1166 assert_eq!(mmr.oldest_retained_pos(), None);
1167 });
1168 }
1169
1170 #[test]
1171 fn test_mem_mmr_update_leaf() {
1172 let mut hasher: Standard<Sha256> = Standard::new();
1173 let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1174 let executor = deterministic::Runner::default();
1175 executor.start(|_| async move {
1176 let (mut mmr, leaves) = compute_big_mmr(&mut hasher, Mmr::default(), None);
1177 let root = *mmr.root();
1178
1179 for leaf in [0usize, 1, 10, 50, 100, 150, 197, 198] {
1182 let leaf_loc =
1184 Location::try_from(leaves[leaf]).expect("leaf position should map to location");
1185 mmr.update_leaf(&mut hasher, leaf_loc, &element).unwrap();
1186 let updated_root = *mmr.root();
1187 assert!(root != updated_root);
1188
1189 hasher.inner().update(&leaf.to_be_bytes());
1191 let element = hasher.inner().finalize();
1192 mmr.update_leaf(&mut hasher, leaf_loc, &element).unwrap();
1193 let restored_root = *mmr.root();
1194 assert_eq!(root, restored_root);
1195 }
1196
1197 mmr.prune_to_pos(leaves[150]);
1199 for &leaf_pos in &leaves[150..=190] {
1200 mmr.prune_to_pos(leaf_pos);
1201 let leaf_loc =
1202 Location::try_from(leaf_pos).expect("leaf position should map to location");
1203 mmr.update_leaf(&mut hasher, leaf_loc, &element).unwrap();
1204 }
1205 });
1206 }
1207
1208 #[test]
1209 fn test_mem_mmr_update_leaf_error_out_of_bounds() {
1210 let mut hasher: Standard<Sha256> = Standard::new();
1211 let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1212
1213 let executor = deterministic::Runner::default();
1214 executor.start(|_| async move {
1215 let (mut mmr, _) = compute_big_mmr(&mut hasher, Mmr::default(), None);
1216 let invalid_loc = mmr.leaves();
1217 let result = mmr.update_leaf(&mut hasher, invalid_loc, &element);
1218 assert!(matches!(result, Err(Error::LeafOutOfBounds(_))));
1219 });
1220 }
1221
1222 #[test]
1223 fn test_mem_mmr_update_leaf_error_pruned() {
1224 let mut hasher: Standard<Sha256> = Standard::new();
1225 let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1226
1227 let executor = deterministic::Runner::default();
1228 executor.start(|_| async move {
1229 let (mut mmr, _) = compute_big_mmr(&mut hasher, Mmr::default(), None);
1230 mmr.prune_all();
1231 let result = mmr.update_leaf(&mut hasher, Location::new_unchecked(0), &element);
1232 assert!(matches!(result, Err(Error::ElementPruned(_))));
1233 });
1234 }
1235
1236 #[test]
1237 fn test_mem_mmr_batch_update_leaf() {
1238 let mut hasher: Standard<Sha256> = Standard::new();
1239 let executor = deterministic::Runner::default();
1240 executor.start(|_| async move {
1241 let (mmr, leaves) = compute_big_mmr(&mut hasher, Mmr::default(), None);
1242 do_batch_update(&mut hasher, mmr, &leaves);
1243 });
1244 }
1245
1246 #[test]
1249 fn test_mem_mmr_batch_parallel_update_leaf() {
1250 let mut hasher: Standard<Sha256> = Standard::new();
1251 let executor = tokio::Runner::default();
1252 executor.start(|ctx| async move {
1253 let pool = ctx.create_pool(NZUsize!(4)).unwrap();
1254 let mmr = Mmr::init(
1255 Config {
1256 nodes: Vec::new(),
1257 pruned_to_pos: Position::new(0),
1258 pinned_nodes: Vec::new(),
1259 },
1260 &mut hasher,
1261 )
1262 .unwrap();
1263 let (mmr, leaves) = compute_big_mmr(&mut hasher, mmr.into_dirty(), Some(pool));
1264 do_batch_update(&mut hasher, mmr, &leaves);
1265 });
1266 }
1267
1268 fn do_batch_update(
1269 hasher: &mut Standard<Sha256>,
1270 mmr: CleanMmr<sha256::Digest>,
1271 leaves: &[Position],
1272 ) {
1273 let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1274 let root = *mmr.root();
1275
1276 let mut updates = Vec::new();
1278 for leaf in [0usize, 1, 10, 50, 100, 150, 197, 198] {
1279 let leaf_loc =
1280 Location::try_from(leaves[leaf]).expect("leaf position should map to location");
1281 updates.push((leaf_loc, &element));
1282 }
1283 let mut dirty_mmr = mmr.into_dirty();
1284 dirty_mmr
1285 .update_leaf_batched(hasher, None, &updates)
1286 .unwrap();
1287
1288 let mmr = dirty_mmr.merkleize(hasher, None);
1289 let updated_root = *mmr.root();
1290 assert_eq!(
1291 "af3acad6aad59c1a880de643b1200a0962a95d06c087ebf677f29eb93fc359a4",
1292 hex(&updated_root)
1293 );
1294
1295 let mut updates = Vec::new();
1297 for leaf in [0usize, 1, 10, 50, 100, 150, 197, 198] {
1298 hasher.inner().update(&leaf.to_be_bytes());
1299 let element = hasher.inner().finalize();
1300 let leaf_loc =
1301 Location::try_from(leaves[leaf]).expect("leaf position should map to location");
1302 updates.push((leaf_loc, element));
1303 }
1304 let mut dirty_mmr = mmr.into_dirty();
1305 dirty_mmr
1306 .update_leaf_batched(hasher, None, &updates)
1307 .unwrap();
1308
1309 let mmr = dirty_mmr.merkleize(hasher, None);
1310 let restored_root = *mmr.root();
1311 assert_eq!(root, restored_root);
1312 }
1313
1314 #[test]
1315 fn test_init_pinned_nodes_validation() {
1316 let executor = deterministic::Runner::default();
1317 executor.start(|_| async move {
1318 let mut hasher: Standard<Sha256> = Standard::new();
1319 let config = Config::<sha256::Digest> {
1321 nodes: vec![],
1322 pruned_to_pos: Position::new(0),
1323 pinned_nodes: vec![],
1324 };
1325 assert!(Mmr::init(config, &mut hasher).is_ok());
1326
1327 let config = Config::<sha256::Digest> {
1330 nodes: vec![],
1331 pruned_to_pos: Position::new(127),
1332 pinned_nodes: vec![], };
1334 assert!(matches!(
1335 Mmr::init(config, &mut hasher),
1336 Err(Error::InvalidPinnedNodes)
1337 ));
1338
1339 let config = Config {
1341 nodes: vec![],
1342 pruned_to_pos: Position::new(0),
1343 pinned_nodes: vec![Sha256::hash(b"dummy")],
1344 };
1345 assert!(matches!(
1346 Mmr::init(config, &mut hasher),
1347 Err(Error::InvalidPinnedNodes)
1348 ));
1349
1350 let mut mmr = CleanMmr::new(&mut hasher);
1353 for i in 0u64..50 {
1354 mmr.add(&mut hasher, &i.to_be_bytes());
1355 }
1356 let pinned_nodes = mmr.node_digests_to_pin(Position::new(50));
1357 let config = Config {
1358 nodes: vec![],
1359 pruned_to_pos: Position::new(50),
1360 pinned_nodes,
1361 };
1362 assert!(Mmr::init(config, &mut hasher).is_ok());
1363 });
1364 }
1365
1366 #[test]
1367 fn test_init_size_validation() {
1368 let executor = deterministic::Runner::default();
1369 executor.start(|_| async move {
1370 let mut hasher: Standard<Sha256> = Standard::new();
1371 let config = Config::<sha256::Digest> {
1373 nodes: vec![],
1374 pruned_to_pos: Position::new(0),
1375 pinned_nodes: vec![],
1376 };
1377 assert!(Mmr::init(config, &mut hasher).is_ok());
1378
1379 let config = Config {
1382 nodes: vec![Sha256::hash(b"node1"), Sha256::hash(b"node2")],
1383 pruned_to_pos: Position::new(0),
1384 pinned_nodes: vec![],
1385 };
1386 assert!(matches!(
1387 Mmr::init(config, &mut hasher),
1388 Err(Error::InvalidSize(_))
1389 ));
1390
1391 let config = Config {
1393 nodes: vec![
1394 Sha256::hash(b"leaf1"),
1395 Sha256::hash(b"leaf2"),
1396 Sha256::hash(b"parent"),
1397 ],
1398 pruned_to_pos: Position::new(0),
1399 pinned_nodes: vec![],
1400 };
1401 assert!(Mmr::init(config, &mut hasher).is_ok());
1402
1403 let mut mmr = CleanMmr::new(&mut hasher);
1406 for i in 0u64..64 {
1407 mmr.add(&mut hasher, &i.to_be_bytes());
1408 }
1409 assert_eq!(mmr.size(), 127); let nodes: Vec<_> = (0..127)
1411 .map(|i| *mmr.get_node_unchecked(Position::new(i)))
1412 .collect();
1413
1414 let config = Config {
1415 nodes,
1416 pruned_to_pos: Position::new(0),
1417 pinned_nodes: vec![],
1418 };
1419 assert!(Mmr::init(config, &mut hasher).is_ok());
1420
1421 let mut mmr = CleanMmr::new(&mut hasher);
1424 for i in 0u64..11 {
1425 mmr.add(&mut hasher, &i.to_be_bytes());
1426 }
1427 assert_eq!(mmr.size(), 19); mmr.prune_to_pos(Position::new(7));
1431 let nodes: Vec<_> = (7..*mmr.size())
1432 .map(|i| *mmr.get_node_unchecked(Position::new(i)))
1433 .collect();
1434 let pinned_nodes = mmr.node_digests_to_pin(Position::new(7));
1435
1436 let config = Config {
1437 nodes: nodes.clone(),
1438 pruned_to_pos: Position::new(7),
1439 pinned_nodes: pinned_nodes.clone(),
1440 };
1441 assert!(Mmr::init(config, &mut hasher).is_ok());
1442
1443 let config = Config {
1446 nodes: nodes.clone(),
1447 pruned_to_pos: Position::new(8),
1448 pinned_nodes: pinned_nodes.clone(),
1449 };
1450 assert!(matches!(
1451 Mmr::init(config, &mut hasher),
1452 Err(Error::InvalidSize(_))
1453 ));
1454
1455 let config = Config {
1458 nodes,
1459 pruned_to_pos: Position::new(9),
1460 pinned_nodes,
1461 };
1462 assert!(matches!(
1463 Mmr::init(config, &mut hasher),
1464 Err(Error::InvalidSize(_))
1465 ));
1466 });
1467 }
1468}