1use crate::mmr::{
4 hasher::Hasher,
5 iterator::{nodes_needing_parents, nodes_to_pin, PathIterator, PeakIterator},
6 proof,
7 Error::{self, *},
8 Location, Position, Proof,
9};
10use alloc::{
11 collections::{BTreeMap, BTreeSet, VecDeque},
12 vec::Vec,
13};
14use commonware_cryptography::Digest;
15use core::{mem, ops::Range};
16cfg_if::cfg_if! {
17 if #[cfg(feature = "std")] {
18 use commonware_runtime::ThreadPool;
19 use rayon::prelude::*;
20 } else {
21 struct ThreadPool;
22 }
23}
24
25#[cfg(feature = "std")]
27const MIN_TO_PARALLELIZE: usize = 20;
28
29pub type DirtyMmr<D> = Mmr<D, Dirty>;
31
32pub type CleanMmr<D> = Mmr<D, Clean<D>>;
34
35mod private {
37 pub trait Sealed {}
38}
39
40pub trait State<D: Digest>: private::Sealed + Sized {
42 fn add_leaf_digest<H: Hasher<D>>(mmr: &mut Mmr<D, Self>, hasher: &mut H, digest: D)
44 -> Position;
45}
46
47#[derive(Clone, Copy, Debug)]
49pub struct Clean<D: Digest> {
50 pub root: D,
52}
53
54impl<D: Digest> private::Sealed for Clean<D> {}
55impl<D: Digest> State<D> for Clean<D> {
56 fn add_leaf_digest<H: Hasher<D>>(mmr: &mut CleanMmr<D>, hasher: &mut H, digest: D) -> Position {
57 mmr.add_leaf_digest(hasher, digest)
58 }
59}
60
61#[derive(Clone, Debug, Default)]
63pub struct Dirty {
64 dirty_nodes: BTreeSet<(Position, u32)>,
68}
69
70impl private::Sealed for Dirty {}
71impl<D: Digest> State<D> for Dirty {
72 fn add_leaf_digest<H: Hasher<D>>(mmr: &mut DirtyMmr<D>, hasher: &mut H, digest: D) -> Position {
73 mmr.add_leaf_digest(hasher, digest)
74 }
75}
76
77pub struct Config<D: Digest> {
79 pub nodes: Vec<D>,
81
82 pub pruned_to_pos: Position,
85
86 pub pinned_nodes: Vec<D>,
88}
89
90#[derive(Clone, Debug)]
114pub struct Mmr<D: Digest, S: State<D> = Dirty> {
115 nodes: VecDeque<D>,
118
119 pruned_to_pos: Position,
122
123 pinned_nodes: BTreeMap<Position, D>,
125
126 state: S,
128}
129
130impl<D: Digest> Default for DirtyMmr<D> {
131 fn default() -> Self {
132 Self::new()
133 }
134}
135
136impl<D: Digest> From<CleanMmr<D>> for DirtyMmr<D> {
137 fn from(clean: CleanMmr<D>) -> Self {
138 DirtyMmr {
139 nodes: clean.nodes,
140 pruned_to_pos: clean.pruned_to_pos,
141 pinned_nodes: clean.pinned_nodes,
142 state: Dirty {
143 dirty_nodes: BTreeSet::new(),
144 },
145 }
146 }
147}
148
149impl<D: Digest, S: State<D>> Mmr<D, S> {
150 pub fn size(&self) -> Position {
153 Position::new(self.nodes.len() as u64 + *self.pruned_to_pos)
154 }
155
156 pub fn leaves(&self) -> Location {
158 Location::try_from(self.size()).expect("invalid mmr size")
159 }
160
161 pub fn last_leaf_pos(&self) -> Option<Position> {
163 if self.size() == 0 {
164 return None;
165 }
166
167 Some(PeakIterator::last_leaf_pos(self.size()))
168 }
169
170 pub const fn pruned_to_pos(&self) -> Position {
173 self.pruned_to_pos
174 }
175
176 pub fn oldest_retained_pos(&self) -> Option<Position> {
179 if self.pruned_to_pos == self.size() {
180 return None;
181 }
182
183 Some(self.pruned_to_pos)
184 }
185
186 pub fn peak_iterator(&self) -> PeakIterator {
188 PeakIterator::new(self.size())
189 }
190
191 fn index_to_pos(&self, index: usize) -> Position {
193 self.pruned_to_pos + (index as u64)
194 }
195
196 pub(crate) fn get_node_unchecked(&self, pos: Position) -> &D {
204 if pos < self.pruned_to_pos {
205 return self
206 .pinned_nodes
207 .get(&pos)
208 .expect("requested node is pruned and not pinned");
209 }
210
211 &self.nodes[self.pos_to_index(pos)]
212 }
213
214 fn pos_to_index(&self, pos: Position) -> usize {
220 assert!(
221 pos >= self.pruned_to_pos,
222 "pos precedes oldest retained position"
223 );
224
225 *pos.checked_sub(*self.pruned_to_pos).unwrap() as usize
226 }
227
228 #[cfg(any(feature = "std", test))]
231 pub(crate) fn add_pinned_nodes(&mut self, pinned_nodes: BTreeMap<Position, D>) {
232 for (pos, node) in pinned_nodes.into_iter() {
233 self.pinned_nodes.insert(pos, node);
234 }
235 }
236
237 pub fn add<H: Hasher<D>>(&mut self, hasher: &mut H, element: &[u8]) -> Position {
240 let digest = hasher.leaf_digest(self.size(), element);
241 S::add_leaf_digest(self, hasher, digest)
242 }
243}
244
245impl<D: Digest> CleanMmr<D> {
247 pub fn init(config: Config<D>, hasher: &mut impl Hasher<D>) -> Result<Self, Error> {
256 let Some(size) = config.pruned_to_pos.checked_add(config.nodes.len() as u64) else {
258 return Err(Error::InvalidSize(u64::MAX));
259 };
260 if !size.is_mmr_size() {
261 return Err(Error::InvalidSize(*size));
262 }
263
264 let mut pinned_nodes = BTreeMap::new();
266 let mut expected_pinned_nodes = 0;
267 for (i, pos) in nodes_to_pin(config.pruned_to_pos).enumerate() {
268 expected_pinned_nodes += 1;
269 if i >= config.pinned_nodes.len() {
270 return Err(Error::InvalidPinnedNodes);
271 }
272 pinned_nodes.insert(pos, config.pinned_nodes[i]);
273 }
274
275 if config.pinned_nodes.len() != expected_pinned_nodes {
277 return Err(Error::InvalidPinnedNodes);
278 }
279
280 let mmr = Mmr {
281 nodes: VecDeque::from(config.nodes),
282 pruned_to_pos: config.pruned_to_pos,
283 pinned_nodes,
284 state: Dirty::default(),
285 };
286 Ok(mmr.merkleize(hasher, None))
287 }
288
289 pub fn new(hasher: &mut impl Hasher<D>) -> Self {
291 let mmr: DirtyMmr<D> = Default::default();
292 mmr.merkleize(hasher, None)
293 }
294
295 pub fn from_components(
297 hasher: &mut impl Hasher<D>,
298 nodes: Vec<D>,
299 pruned_to_pos: Position,
300 pinned_nodes: Vec<D>,
301 ) -> Self {
302 DirtyMmr::from_components(nodes, pruned_to_pos, pinned_nodes).merkleize(hasher, None)
303 }
304
305 pub fn get_node(&self, pos: Position) -> Option<D> {
307 if pos < self.pruned_to_pos {
308 return self.pinned_nodes.get(&pos).copied();
309 }
310
311 self.nodes.get(self.pos_to_index(pos)).copied()
312 }
313
314 pub(super) fn add_leaf_digest(&mut self, hasher: &mut impl Hasher<D>, digest: D) -> Position {
317 let mut dirty_mmr = mem::replace(self, Self::new(hasher)).into_dirty();
318 let leaf_pos = dirty_mmr.add_leaf_digest(hasher, digest);
319 *self = dirty_mmr.merkleize(hasher, None);
320 leaf_pos
321 }
322
323 pub fn pop(&mut self, hasher: &mut impl Hasher<D>) -> Result<Position, Error> {
326 let mut dirty_mmr = mem::replace(self, Self::new(hasher)).into_dirty();
327 let result = dirty_mmr.pop();
328 *self = dirty_mmr.merkleize(hasher, None);
329 result
330 }
331
332 pub(crate) fn nodes_to_pin(&self, prune_pos: Position) -> BTreeMap<Position, D> {
335 nodes_to_pin(prune_pos)
336 .map(|pos| (pos, *self.get_node_unchecked(pos)))
337 .collect()
338 }
339
340 pub fn prune_to_pos(&mut self, pos: Position) {
343 self.pinned_nodes = self.nodes_to_pin(pos);
345 let retained_nodes = self.pos_to_index(pos);
346 self.nodes.drain(0..retained_nodes);
347 self.pruned_to_pos = pos;
348 }
349
350 pub fn prune_all(&mut self) {
353 if !self.nodes.is_empty() {
354 let pos = self.index_to_pos(self.nodes.len());
355 self.prune_to_pos(pos);
356 }
357 }
358
359 pub fn update_leaf(
373 &mut self,
374 hasher: &mut impl Hasher<D>,
375 loc: Location,
376 element: &[u8],
377 ) -> Result<(), Error> {
378 let mut dirty_mmr = mem::replace(self, Self::new(hasher)).into_dirty();
379 let result = dirty_mmr.update_leaf(hasher, loc, element);
380 *self = dirty_mmr.merkleize(hasher, None);
381 result
382 }
383
384 pub fn into_dirty(self) -> DirtyMmr<D> {
386 self.into()
387 }
388
389 pub const fn root(&self) -> &D {
391 &self.state.root
392 }
393
394 pub fn empty_mmr_root(hasher: &mut impl commonware_cryptography::Hasher<Digest = D>) -> D {
396 hasher.update(&0u64.to_be_bytes());
397 hasher.finalize()
398 }
399
400 pub fn proof(&self, loc: Location) -> Result<Proof<D>, Error> {
411 if !loc.is_valid() {
412 return Err(Error::LocationOverflow(loc));
413 }
414 self.range_proof(loc..loc + 1)
416 }
417
418 pub fn range_proof(&self, range: Range<Location>) -> Result<Proof<D>, Error> {
430 let leaves = self.leaves();
431 assert!(
432 range.start < leaves,
433 "range start {} >= leaf count {}",
434 range.start,
435 leaves
436 );
437 assert!(
438 range.end <= leaves,
439 "range end {} > leaf count {}",
440 range.end,
441 leaves
442 );
443
444 let size = self.size();
445 let positions = proof::nodes_required_for_range_proof(size, range)?;
446 let digests = positions
447 .into_iter()
448 .map(|pos| self.get_node(pos).ok_or(Error::ElementPruned(pos)))
449 .collect::<Result<Vec<_>, _>>()?;
450
451 Ok(Proof { size, digests })
452 }
453
454 #[cfg(test)]
457 pub(crate) fn node_digests_to_pin(&self, start_pos: Position) -> Vec<D> {
458 nodes_to_pin(start_pos)
459 .map(|pos| *self.get_node_unchecked(pos))
460 .collect()
461 }
462
463 #[cfg(test)]
466 pub(super) fn pinned_nodes(&self) -> BTreeMap<Position, D> {
467 self.pinned_nodes.clone()
468 }
469}
470
471impl<D: Digest> DirtyMmr<D> {
473 pub fn new() -> Self {
475 Self {
476 nodes: VecDeque::new(),
477 pruned_to_pos: Position::new(0),
478 pinned_nodes: BTreeMap::new(),
479 state: Dirty::default(),
480 }
481 }
482
483 pub fn from_components(nodes: Vec<D>, pruned_to_pos: Position, pinned_nodes: Vec<D>) -> Self {
485 Self {
486 nodes: VecDeque::from(nodes),
487 pruned_to_pos,
488 pinned_nodes: nodes_to_pin(pruned_to_pos)
489 .enumerate()
490 .map(|(i, pos)| (pos, pinned_nodes[i]))
491 .collect(),
492 state: Dirty::default(),
493 }
494 }
495
496 pub(super) fn add_leaf_digest<H: Hasher<D>>(&mut self, _hasher: &mut H, digest: D) -> Position {
499 let nodes_needing_parents = nodes_needing_parents(self.peak_iterator())
501 .into_iter()
502 .rev();
503 let leaf_pos = self.size();
504 self.nodes.push_back(digest);
505
506 let mut height = 1;
507 for _ in nodes_needing_parents {
508 let new_node_pos = self.size();
509 self.nodes
510 .push_back(<H::Inner as commonware_cryptography::Hasher>::EMPTY);
511 self.state.dirty_nodes.insert((new_node_pos, height));
512 height += 1;
513 }
514
515 leaf_pos
516 }
517
518 pub fn pop(&mut self) -> Result<Position, Error> {
521 if self.size() == 0 {
522 return Err(Empty);
523 }
524
525 let mut new_size = self.size() - 1;
526 loop {
527 if new_size < self.pruned_to_pos {
528 return Err(ElementPruned(new_size));
529 }
530 if new_size.is_mmr_size() {
531 break;
532 }
533 new_size -= 1;
534 }
535 let num_to_drain = *(self.size() - new_size) as usize;
536 self.nodes.drain(self.nodes.len() - num_to_drain..);
537
538 let cutoff = (self.size(), 0);
540 self.state.dirty_nodes.split_off(&cutoff);
541
542 Ok(self.size())
543 }
544
545 pub fn merkleize(
548 mut self,
549 hasher: &mut impl Hasher<D>,
550 pool: Option<ThreadPool>,
551 ) -> CleanMmr<D> {
552 #[cfg(feature = "std")]
553 match (pool, self.state.dirty_nodes.len() >= MIN_TO_PARALLELIZE) {
554 (Some(pool), true) => self.merkleize_parallel(hasher, pool, MIN_TO_PARALLELIZE),
555 _ => self.merkleize_serial(hasher),
556 }
557
558 #[cfg(not(feature = "std"))]
559 self.merkleize_serial(hasher);
560
561 let peaks = self
563 .peak_iterator()
564 .map(|(peak_pos, _)| self.get_node_unchecked(peak_pos));
565 let size = self.size();
566 let digest = hasher.root(size, peaks);
567
568 CleanMmr {
569 nodes: self.nodes,
570 pruned_to_pos: self.pruned_to_pos,
571 pinned_nodes: self.pinned_nodes,
572 state: Clean { root: digest },
573 }
574 }
575
576 fn merkleize_serial(&mut self, hasher: &mut impl Hasher<D>) {
577 let mut nodes: Vec<(Position, u32)> = self.state.dirty_nodes.iter().copied().collect();
578 self.state.dirty_nodes.clear();
579 nodes.sort_by(|a, b| a.1.cmp(&b.1));
580
581 for (pos, height) in nodes {
582 let left = pos - (1 << height);
583 let right = pos - 1;
584 let digest = hasher.node_digest(
585 pos,
586 self.get_node_unchecked(left),
587 self.get_node_unchecked(right),
588 );
589 let index = self.pos_to_index(pos);
590 self.nodes[index] = digest;
591 }
592 }
593
594 #[cfg(feature = "std")]
602 fn merkleize_parallel(
603 &mut self,
604 hasher: &mut impl Hasher<D>,
605 pool: ThreadPool,
606 min_to_parallelize: usize,
607 ) {
608 let mut nodes: Vec<(Position, u32)> = self.state.dirty_nodes.iter().copied().collect();
609 self.state.dirty_nodes.clear();
610 nodes.sort_by(|a, b| a.1.cmp(&b.1));
612
613 let mut same_height = Vec::new();
614 let mut current_height = 1;
615 for (i, (pos, height)) in nodes.iter().enumerate() {
616 if *height == current_height {
617 same_height.push(*pos);
618 continue;
619 }
620 if same_height.len() < min_to_parallelize {
621 self.state.dirty_nodes = nodes[i - same_height.len()..].iter().copied().collect();
622 self.merkleize_serial(hasher);
623 return;
624 }
625 self.update_node_digests(hasher, pool.clone(), &same_height, current_height);
626 same_height.clear();
627 current_height += 1;
628 same_height.push(*pos);
629 }
630
631 if same_height.len() < min_to_parallelize {
632 self.state.dirty_nodes = nodes[nodes.len() - same_height.len()..]
633 .iter()
634 .copied()
635 .collect();
636 self.merkleize_serial(hasher);
637 return;
638 }
639
640 self.update_node_digests(hasher, pool, &same_height, current_height);
641 }
642
643 #[cfg(feature = "std")]
646 fn update_node_digests(
647 &mut self,
648 hasher: &mut impl Hasher<D>,
649 pool: ThreadPool,
650 same_height: &[Position],
651 height: u32,
652 ) {
653 let two_h = 1 << height;
654 pool.install(|| {
655 let computed_digests: Vec<(usize, D)> = same_height
656 .par_iter()
657 .map_init(
658 || hasher.fork(),
659 |hasher, &pos| {
660 let left = pos - two_h;
661 let right = pos - 1;
662 let digest = hasher.node_digest(
663 pos,
664 self.get_node_unchecked(left),
665 self.get_node_unchecked(right),
666 );
667 let index = self.pos_to_index(pos);
668 (index, digest)
669 },
670 )
671 .collect();
672
673 for (index, digest) in computed_digests {
674 self.nodes[index] = digest;
675 }
676 });
677 }
678
679 fn mark_dirty(&mut self, pos: Position) {
682 for (peak_pos, mut height) in self.peak_iterator() {
683 if peak_pos < pos {
684 continue;
685 }
686
687 let path = PathIterator::new(pos, peak_pos, height)
690 .collect::<Vec<_>>()
691 .into_iter()
692 .rev();
693 height = 1;
694 for (parent_pos, _) in path {
695 if !self.state.dirty_nodes.insert((parent_pos, height)) {
696 break;
697 }
698 height += 1;
699 }
700 return;
701 }
702
703 panic!("invalid pos {pos}:{}", self.size());
704 }
705
706 pub fn update_leaf(
708 &mut self,
709 hasher: &mut impl Hasher<D>,
710 loc: Location,
711 element: &[u8],
712 ) -> Result<(), Error> {
713 self.update_leaf_batched(hasher, None, &[(loc, element)])
714 }
715
716 pub fn update_leaf_batched<T: AsRef<[u8]> + Sync>(
724 &mut self,
725 hasher: &mut impl Hasher<D>,
726 pool: Option<ThreadPool>,
727 updates: &[(Location, T)],
728 ) -> Result<(), Error> {
729 if updates.is_empty() {
730 return Ok(());
731 }
732
733 let leaves = self.leaves();
734 let mut positions = Vec::with_capacity(updates.len());
735 for (loc, _) in updates {
736 if *loc >= leaves {
737 return Err(Error::LeafOutOfBounds(*loc));
738 }
739 let pos = Position::try_from(*loc)?;
740 if pos < self.pruned_to_pos {
741 return Err(Error::ElementPruned(pos));
742 }
743 positions.push(pos);
744 }
745
746 #[cfg(feature = "std")]
747 if let Some(pool) = pool {
748 if updates.len() >= MIN_TO_PARALLELIZE {
749 self.update_leaf_parallel(hasher, pool, updates, &positions);
750 return Ok(());
751 }
752 }
753
754 for ((_, element), pos) in updates.iter().zip(positions.iter()) {
755 let digest = hasher.leaf_digest(*pos, element.as_ref());
757 let index = self.pos_to_index(*pos);
758 self.nodes[index] = digest;
759 self.mark_dirty(*pos);
760 }
761
762 Ok(())
763 }
764
765 #[cfg(feature = "std")]
767 fn update_leaf_parallel<T: AsRef<[u8]> + Sync>(
768 &mut self,
769 hasher: &mut impl Hasher<D>,
770 pool: ThreadPool,
771 updates: &[(Location, T)],
772 positions: &[Position],
773 ) {
774 pool.install(|| {
775 let digests: Vec<(Position, D)> = updates
776 .par_iter()
777 .zip(positions.par_iter())
778 .map_init(
779 || hasher.fork(),
780 |hasher, ((_, elem), pos)| {
781 let digest = hasher.leaf_digest(*pos, elem.as_ref());
782 (*pos, digest)
783 },
784 )
785 .collect();
786
787 for (pos, digest) in digests {
788 let index = self.pos_to_index(pos);
789 self.nodes[index] = digest;
790 self.mark_dirty(pos);
791 }
792 });
793 }
794}
795
796#[cfg(test)]
797mod tests {
798 use super::*;
799 use crate::mmr::{
800 hasher::{Hasher as _, Standard},
801 stability::ROOTS,
802 };
803 use commonware_cryptography::{sha256, Hasher, Sha256};
804 use commonware_runtime::{create_pool, deterministic, tokio, Runner};
805 use commonware_utils::hex;
806
807 fn build_and_check_test_roots_mmr(mmr: &mut CleanMmr<sha256::Digest>) {
809 let mut hasher: Standard<Sha256> = Standard::new();
810 for i in 0u64..199 {
811 hasher.inner().update(&i.to_be_bytes());
812 let element = hasher.inner().finalize();
813 let root = *mmr.root();
814 let expected_root = ROOTS[i as usize];
815 assert_eq!(hex(&root), expected_root, "at: {i}");
816 mmr.add(&mut hasher, &element);
817 }
818 assert_eq!(hex(mmr.root()), ROOTS[199], "Root after 200 elements");
819 }
820
821 pub fn build_batched_and_check_test_roots(
823 mut mmr: DirtyMmr<sha256::Digest>,
824 pool: Option<ThreadPool>,
825 ) {
826 let mut hasher: Standard<Sha256> = Standard::new();
827 for i in 0u64..199 {
828 hasher.inner().update(&i.to_be_bytes());
829 let element = hasher.inner().finalize();
830 mmr.add(&mut hasher, &element);
831 }
832 let mmr = mmr.merkleize(&mut hasher, pool);
833 assert_eq!(hex(mmr.root()), ROOTS[199], "Root after 200 elements");
834 }
835
836 #[test]
838 fn test_mem_mmr_empty() {
839 let executor = deterministic::Runner::default();
840 executor.start(|_| async move {
841 let mut hasher: Standard<Sha256> = Standard::new();
842 let mut mmr = CleanMmr::new(&mut hasher);
843 assert_eq!(
844 mmr.peak_iterator().next(),
845 None,
846 "empty iterator should have no peaks"
847 );
848 assert_eq!(mmr.size(), 0);
849 assert_eq!(mmr.leaves(), Location::new_unchecked(0));
850 assert_eq!(mmr.last_leaf_pos(), None);
851 assert_eq!(mmr.oldest_retained_pos(), None);
852 assert_eq!(mmr.get_node(Position::new(0)), None);
853 assert_eq!(*mmr.root(), Mmr::empty_mmr_root(hasher.inner()));
854 assert!(matches!(mmr.pop(&mut hasher), Err(Empty)));
855 mmr.prune_all();
856 assert_eq!(mmr.size(), 0, "prune_all on empty MMR should do nothing");
857
858 assert_eq!(*mmr.root(), hasher.root(Position::new(0), [].iter()));
859 });
860 }
861
862 #[test]
866 fn test_mem_mmr_add_eleven_values() {
867 let executor = deterministic::Runner::default();
868 executor.start(|_| async move {
869 let mut hasher: Standard<Sha256> = Standard::new();
870 let mut mmr = CleanMmr::new(&mut hasher);
871 let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
872 let mut leaves: Vec<Position> = Vec::new();
873 for _ in 0..11 {
874 leaves.push(mmr.add(&mut hasher, &element));
875 let peaks: Vec<(Position, u32)> = mmr.peak_iterator().collect();
876 assert_ne!(peaks.len(), 0);
877 assert!(peaks.len() as u64 <= mmr.size());
878 let nodes_needing_parents = nodes_needing_parents(mmr.peak_iterator());
879 assert!(nodes_needing_parents.len() <= peaks.len());
880 }
881 assert_eq!(mmr.oldest_retained_pos().unwrap(), Position::new(0));
882 assert_eq!(mmr.size(), 19, "mmr not of expected size");
883 assert_eq!(
884 leaves,
885 vec![0, 1, 3, 4, 7, 8, 10, 11, 15, 16, 18]
886 .into_iter()
887 .map(Position::new)
888 .collect::<Vec<_>>(),
889 "mmr leaf positions not as expected"
890 );
891 let peaks: Vec<(Position, u32)> = mmr.peak_iterator().collect();
892 assert_eq!(
893 peaks,
894 vec![
895 (Position::new(14), 3),
896 (Position::new(17), 1),
897 (Position::new(18), 0)
898 ],
899 "mmr peaks not as expected"
900 );
901
902 let peaks_needing_parents = nodes_needing_parents(mmr.peak_iterator());
905 assert_eq!(
906 peaks_needing_parents,
907 vec![Position::new(17), Position::new(18)],
908 "mmr nodes needing parents not as expected"
909 );
910
911 for leaf in leaves.iter().by_ref() {
913 let digest = hasher.leaf_digest(*leaf, &element);
914 assert_eq!(mmr.get_node(*leaf).unwrap(), digest);
915 }
916
917 let digest2 = hasher.node_digest(Position::new(2), &mmr.nodes[0], &mmr.nodes[1]);
919 assert_eq!(mmr.nodes[2], digest2);
920 let digest5 = hasher.node_digest(Position::new(5), &mmr.nodes[3], &mmr.nodes[4]);
921 assert_eq!(mmr.nodes[5], digest5);
922 let digest9 = hasher.node_digest(Position::new(9), &mmr.nodes[7], &mmr.nodes[8]);
923 assert_eq!(mmr.nodes[9], digest9);
924 let digest12 = hasher.node_digest(Position::new(12), &mmr.nodes[10], &mmr.nodes[11]);
925 assert_eq!(mmr.nodes[12], digest12);
926 let digest17 = hasher.node_digest(Position::new(17), &mmr.nodes[15], &mmr.nodes[16]);
927 assert_eq!(mmr.nodes[17], digest17);
928
929 let digest6 = hasher.node_digest(Position::new(6), &mmr.nodes[2], &mmr.nodes[5]);
931 assert_eq!(mmr.nodes[6], digest6);
932 let digest13 = hasher.node_digest(Position::new(13), &mmr.nodes[9], &mmr.nodes[12]);
933 assert_eq!(mmr.nodes[13], digest13);
934 let digest17 = hasher.node_digest(Position::new(17), &mmr.nodes[15], &mmr.nodes[16]);
935 assert_eq!(mmr.nodes[17], digest17);
936
937 let digest14 = hasher.node_digest(Position::new(14), &mmr.nodes[6], &mmr.nodes[13]);
939 assert_eq!(mmr.nodes[14], digest14);
940
941 let root = *mmr.root();
943 let peak_digests = [digest14, digest17, mmr.nodes[18]];
944 let expected_root = hasher.root(Position::new(19), peak_digests.iter());
945 assert_eq!(root, expected_root, "incorrect root");
946
947 mmr.prune_to_pos(Position::new(14)); assert_eq!(mmr.oldest_retained_pos().unwrap(), Position::new(14));
950
951 assert!(matches!(
957 mmr.proof(Location::new_unchecked(0)),
958 Err(ElementPruned(_))
959 ));
960 assert!(matches!(
961 mmr.proof(Location::new_unchecked(6)),
962 Err(ElementPruned(_))
963 ));
964
965 assert!(mmr.proof(Location::new_unchecked(8)).is_ok());
968 assert!(mmr.proof(Location::new_unchecked(10)).is_ok());
969
970 let root_after_prune = *mmr.root();
971 assert_eq!(root, root_after_prune, "root changed after pruning");
972
973 assert!(
974 mmr.range_proof(Location::new_unchecked(5)..Location::new_unchecked(9))
975 .is_err(),
976 "attempts to range_prove elements at or before the oldest retained should fail"
977 );
978 assert!(
979 mmr.range_proof(Location::new_unchecked(8)..mmr.leaves()).is_ok(),
980 "attempts to range_prove over all elements following oldest retained should succeed"
981 );
982
983 let oldest_pos = mmr.oldest_retained_pos().unwrap();
985 let digests = mmr.node_digests_to_pin(oldest_pos);
986 let mmr_copy = Mmr::init(
987 Config {
988 nodes: mmr.nodes.iter().copied().collect(),
989 pruned_to_pos: oldest_pos,
990 pinned_nodes: digests,
991 },
992 &mut hasher,
993 )
994 .unwrap();
995 assert_eq!(mmr_copy.size(), 19);
996 assert_eq!(mmr_copy.leaves(), mmr.leaves());
997 assert_eq!(mmr_copy.last_leaf_pos(), mmr.last_leaf_pos());
998 assert_eq!(mmr_copy.oldest_retained_pos(), mmr.oldest_retained_pos());
999 assert_eq!(*mmr_copy.root(), root);
1000 });
1001 }
1002
1003 #[test]
1005 fn test_mem_mmr_prune_all() {
1006 let executor = deterministic::Runner::default();
1007 executor.start(|_| async move {
1008 let mut hasher: Standard<Sha256> = Standard::new();
1009 let mut mmr = CleanMmr::new(&mut hasher);
1010 let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1011 for _ in 0..1000 {
1012 mmr.prune_all();
1013 mmr.add(&mut hasher, &element);
1014 }
1015 });
1016 }
1017
1018 #[test]
1020 fn test_mem_mmr_validity() {
1021 let executor = deterministic::Runner::default();
1022 executor.start(|_| async move {
1023 let mut hasher: Standard<Sha256> = Standard::new();
1024 let mut mmr = CleanMmr::new(&mut hasher);
1025 let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1026 for _ in 0..1001 {
1027 assert!(
1028 mmr.size().is_mmr_size(),
1029 "mmr of size {} should be valid",
1030 mmr.size()
1031 );
1032 let old_size = mmr.size();
1033 mmr.add(&mut hasher, &element);
1034 for size in *old_size + 1..*mmr.size() {
1035 assert!(
1036 !Position::new(size).is_mmr_size(),
1037 "mmr of size {size} should be invalid",
1038 );
1039 }
1040 }
1041 });
1042 }
1043
1044 #[test]
1047 fn test_mem_mmr_root_stability() {
1048 let executor = deterministic::Runner::default();
1049 executor.start(|_| async move {
1050 let mut hasher: Standard<Sha256> = Standard::new();
1052 let mut mmr = CleanMmr::new(&mut hasher);
1053 build_and_check_test_roots_mmr(&mut mmr);
1054
1055 let mut hasher: Standard<Sha256> = Standard::new();
1056 let mmr = CleanMmr::new(&mut hasher);
1057 build_batched_and_check_test_roots(mmr.into_dirty(), None);
1058 });
1059 }
1060
1061 #[test]
1064 fn test_mem_mmr_root_stability_parallel() {
1065 let executor = tokio::Runner::default();
1066 executor.start(|context| async move {
1067 let pool = commonware_runtime::create_pool(context, 4).unwrap();
1068 let mut hasher: Standard<Sha256> = Standard::new();
1069
1070 let mmr = Mmr::init(
1071 Config {
1072 nodes: vec![],
1073 pruned_to_pos: Position::new(0),
1074 pinned_nodes: vec![],
1075 },
1076 &mut hasher,
1077 )
1078 .unwrap();
1079 build_batched_and_check_test_roots(mmr.into_dirty(), Some(pool));
1080 });
1081 }
1082
1083 #[test]
1086 fn test_mem_mmr_root_stability_while_pruning() {
1087 let executor = deterministic::Runner::default();
1088 executor.start(|_| async move {
1089 let mut hasher: Standard<Sha256> = Standard::new();
1090 let mut mmr = CleanMmr::new(&mut hasher);
1091 for i in 0u64..199 {
1092 let root = *mmr.root();
1093 let expected_root = ROOTS[i as usize];
1094 assert_eq!(hex(&root), expected_root, "at: {i}");
1095 hasher.inner().update(&i.to_be_bytes());
1096 let element = hasher.inner().finalize();
1097 mmr.add(&mut hasher, &element);
1098 mmr.prune_all();
1099 }
1100 });
1101 }
1102
1103 fn compute_big_mmr(
1104 hasher: &mut Standard<Sha256>,
1105 mut mmr: DirtyMmr<sha256::Digest>,
1106 pool: Option<ThreadPool>,
1107 ) -> (CleanMmr<sha256::Digest>, Vec<Position>) {
1108 let mut leaves = Vec::new();
1109 let mut c_hasher = Sha256::default();
1110 for i in 0u64..199 {
1111 c_hasher.update(&i.to_be_bytes());
1112 let element = c_hasher.finalize();
1113 let leaf_pos = mmr.size();
1114 mmr.add(hasher, &element);
1115 leaves.push(leaf_pos);
1116 }
1117
1118 (mmr.merkleize(hasher, pool), leaves)
1119 }
1120
1121 #[test]
1122 fn test_mem_mmr_pop() {
1123 let executor = deterministic::Runner::default();
1124 executor.start(|_| async move {
1125 let mut hasher: Standard<Sha256> = Standard::new();
1126 let (mut mmr, _) = compute_big_mmr(&mut hasher, Mmr::default(), None);
1127 let root = *mmr.root();
1128 let expected_root = ROOTS[199];
1129 assert_eq!(hex(&root), expected_root);
1130
1131 for i in (0..199u64).rev() {
1133 assert!(mmr.pop(&mut hasher).is_ok());
1134 let root = *mmr.root();
1135 let expected_root = ROOTS[i as usize];
1136 assert_eq!(hex(&root), expected_root);
1137 }
1138
1139 assert!(
1140 matches!(mmr.pop(&mut hasher).unwrap_err(), Empty),
1141 "pop on empty MMR should fail"
1142 );
1143
1144 for i in 0u64..199 {
1146 hasher.inner().update(&i.to_be_bytes());
1147 let element = hasher.inner().finalize();
1148 mmr.add(&mut hasher, &element);
1149 }
1150
1151 let leaf_pos = Position::try_from(Location::new_unchecked(100)).unwrap();
1152 mmr.prune_to_pos(leaf_pos);
1153 while mmr.size() > leaf_pos {
1154 mmr.pop(&mut hasher).unwrap();
1155 }
1156 assert_eq!(hex(mmr.root()), ROOTS[100]);
1157 let result = mmr.pop(&mut hasher);
1158 assert!(matches!(result, Err(ElementPruned(_))));
1159 assert_eq!(mmr.oldest_retained_pos(), None);
1160 });
1161 }
1162
1163 #[test]
1164 fn test_mem_mmr_update_leaf() {
1165 let mut hasher: Standard<Sha256> = Standard::new();
1166 let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1167 let executor = deterministic::Runner::default();
1168 executor.start(|_| async move {
1169 let (mut mmr, leaves) = compute_big_mmr(&mut hasher, Mmr::default(), None);
1170 let root = *mmr.root();
1171
1172 for leaf in [0usize, 1, 10, 50, 100, 150, 197, 198] {
1175 let leaf_loc =
1177 Location::try_from(leaves[leaf]).expect("leaf position should map to location");
1178 mmr.update_leaf(&mut hasher, leaf_loc, &element).unwrap();
1179 let updated_root = *mmr.root();
1180 assert!(root != updated_root);
1181
1182 hasher.inner().update(&leaf.to_be_bytes());
1184 let element = hasher.inner().finalize();
1185 mmr.update_leaf(&mut hasher, leaf_loc, &element).unwrap();
1186 let restored_root = *mmr.root();
1187 assert_eq!(root, restored_root);
1188 }
1189
1190 mmr.prune_to_pos(leaves[150]);
1192 for &leaf_pos in &leaves[150..=190] {
1193 mmr.prune_to_pos(leaf_pos);
1194 let leaf_loc =
1195 Location::try_from(leaf_pos).expect("leaf position should map to location");
1196 mmr.update_leaf(&mut hasher, leaf_loc, &element).unwrap();
1197 }
1198 });
1199 }
1200
1201 #[test]
1202 fn test_mem_mmr_update_leaf_error_out_of_bounds() {
1203 let mut hasher: Standard<Sha256> = Standard::new();
1204 let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1205
1206 let executor = deterministic::Runner::default();
1207 executor.start(|_| async move {
1208 let (mut mmr, _) = compute_big_mmr(&mut hasher, Mmr::default(), None);
1209 let invalid_loc = mmr.leaves();
1210 let result = mmr.update_leaf(&mut hasher, invalid_loc, &element);
1211 assert!(matches!(result, Err(Error::LeafOutOfBounds(_))));
1212 });
1213 }
1214
1215 #[test]
1216 fn test_mem_mmr_update_leaf_error_pruned() {
1217 let mut hasher: Standard<Sha256> = Standard::new();
1218 let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1219
1220 let executor = deterministic::Runner::default();
1221 executor.start(|_| async move {
1222 let (mut mmr, _) = compute_big_mmr(&mut hasher, Mmr::default(), None);
1223 mmr.prune_all();
1224 let result = mmr.update_leaf(&mut hasher, Location::new_unchecked(0), &element);
1225 assert!(matches!(result, Err(Error::ElementPruned(_))));
1226 });
1227 }
1228
1229 #[test]
1230 fn test_mem_mmr_batch_update_leaf() {
1231 let mut hasher: Standard<Sha256> = Standard::new();
1232 let executor = deterministic::Runner::default();
1233 executor.start(|_| async move {
1234 let (mmr, leaves) = compute_big_mmr(&mut hasher, Mmr::default(), None);
1235 do_batch_update(&mut hasher, mmr, &leaves);
1236 });
1237 }
1238
1239 #[test]
1240 fn test_mem_mmr_batch_parallel_update_leaf() {
1243 let mut hasher: Standard<Sha256> = Standard::new();
1244 let executor = tokio::Runner::default();
1245 executor.start(|ctx| async move {
1246 let pool = create_pool(ctx, 4).unwrap();
1247 let mmr = Mmr::init(
1248 Config {
1249 nodes: Vec::new(),
1250 pruned_to_pos: Position::new(0),
1251 pinned_nodes: Vec::new(),
1252 },
1253 &mut hasher,
1254 )
1255 .unwrap();
1256 let (mmr, leaves) = compute_big_mmr(&mut hasher, mmr.into_dirty(), Some(pool));
1257 do_batch_update(&mut hasher, mmr, &leaves);
1258 });
1259 }
1260
1261 fn do_batch_update(
1262 hasher: &mut Standard<Sha256>,
1263 mmr: CleanMmr<sha256::Digest>,
1264 leaves: &[Position],
1265 ) {
1266 let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1267 let root = *mmr.root();
1268
1269 let mut updates = Vec::new();
1271 for leaf in [0usize, 1, 10, 50, 100, 150, 197, 198] {
1272 let leaf_loc =
1273 Location::try_from(leaves[leaf]).expect("leaf position should map to location");
1274 updates.push((leaf_loc, &element));
1275 }
1276 let mut dirty_mmr = mmr.into_dirty();
1277 dirty_mmr
1278 .update_leaf_batched(hasher, None, &updates)
1279 .unwrap();
1280
1281 let mmr = dirty_mmr.merkleize(hasher, None);
1282 let updated_root = *mmr.root();
1283 assert_eq!(
1284 "af3acad6aad59c1a880de643b1200a0962a95d06c087ebf677f29eb93fc359a4",
1285 hex(&updated_root)
1286 );
1287
1288 let mut updates = Vec::new();
1290 for leaf in [0usize, 1, 10, 50, 100, 150, 197, 198] {
1291 hasher.inner().update(&leaf.to_be_bytes());
1292 let element = hasher.inner().finalize();
1293 let leaf_loc =
1294 Location::try_from(leaves[leaf]).expect("leaf position should map to location");
1295 updates.push((leaf_loc, element));
1296 }
1297 let mut dirty_mmr = mmr.into_dirty();
1298 dirty_mmr
1299 .update_leaf_batched(hasher, None, &updates)
1300 .unwrap();
1301
1302 let mmr = dirty_mmr.merkleize(hasher, None);
1303 let restored_root = *mmr.root();
1304 assert_eq!(root, restored_root);
1305 }
1306
1307 #[test]
1308 fn test_init_pinned_nodes_validation() {
1309 let executor = deterministic::Runner::default();
1310 executor.start(|_| async move {
1311 let mut hasher: Standard<Sha256> = Standard::new();
1312 let config = Config::<sha256::Digest> {
1314 nodes: vec![],
1315 pruned_to_pos: Position::new(0),
1316 pinned_nodes: vec![],
1317 };
1318 assert!(Mmr::init(config, &mut hasher).is_ok());
1319
1320 let config = Config::<sha256::Digest> {
1323 nodes: vec![],
1324 pruned_to_pos: Position::new(127),
1325 pinned_nodes: vec![], };
1327 assert!(matches!(
1328 Mmr::init(config, &mut hasher),
1329 Err(Error::InvalidPinnedNodes)
1330 ));
1331
1332 let config = Config {
1334 nodes: vec![],
1335 pruned_to_pos: Position::new(0),
1336 pinned_nodes: vec![Sha256::hash(b"dummy")],
1337 };
1338 assert!(matches!(
1339 Mmr::init(config, &mut hasher),
1340 Err(Error::InvalidPinnedNodes)
1341 ));
1342
1343 let mut mmr = CleanMmr::new(&mut hasher);
1346 for i in 0u64..50 {
1347 mmr.add(&mut hasher, &i.to_be_bytes());
1348 }
1349 let pinned_nodes = mmr.node_digests_to_pin(Position::new(50));
1350 let config = Config {
1351 nodes: vec![],
1352 pruned_to_pos: Position::new(50),
1353 pinned_nodes,
1354 };
1355 assert!(Mmr::init(config, &mut hasher).is_ok());
1356 });
1357 }
1358
1359 #[test]
1360 fn test_init_size_validation() {
1361 let executor = deterministic::Runner::default();
1362 executor.start(|_| async move {
1363 let mut hasher: Standard<Sha256> = Standard::new();
1364 let config = Config::<sha256::Digest> {
1366 nodes: vec![],
1367 pruned_to_pos: Position::new(0),
1368 pinned_nodes: vec![],
1369 };
1370 assert!(Mmr::init(config, &mut hasher).is_ok());
1371
1372 let config = Config {
1375 nodes: vec![Sha256::hash(b"node1"), Sha256::hash(b"node2")],
1376 pruned_to_pos: Position::new(0),
1377 pinned_nodes: vec![],
1378 };
1379 assert!(matches!(
1380 Mmr::init(config, &mut hasher),
1381 Err(Error::InvalidSize(_))
1382 ));
1383
1384 let config = Config {
1386 nodes: vec![
1387 Sha256::hash(b"leaf1"),
1388 Sha256::hash(b"leaf2"),
1389 Sha256::hash(b"parent"),
1390 ],
1391 pruned_to_pos: Position::new(0),
1392 pinned_nodes: vec![],
1393 };
1394 assert!(Mmr::init(config, &mut hasher).is_ok());
1395
1396 let mut mmr = CleanMmr::new(&mut hasher);
1399 for i in 0u64..64 {
1400 mmr.add(&mut hasher, &i.to_be_bytes());
1401 }
1402 assert_eq!(mmr.size(), 127); let nodes: Vec<_> = (0..127)
1404 .map(|i| *mmr.get_node_unchecked(Position::new(i)))
1405 .collect();
1406
1407 let config = Config {
1408 nodes,
1409 pruned_to_pos: Position::new(0),
1410 pinned_nodes: vec![],
1411 };
1412 assert!(Mmr::init(config, &mut hasher).is_ok());
1413
1414 let mut mmr = CleanMmr::new(&mut hasher);
1417 for i in 0u64..11 {
1418 mmr.add(&mut hasher, &i.to_be_bytes());
1419 }
1420 assert_eq!(mmr.size(), 19); mmr.prune_to_pos(Position::new(7));
1424 let nodes: Vec<_> = (7..*mmr.size())
1425 .map(|i| *mmr.get_node_unchecked(Position::new(i)))
1426 .collect();
1427 let pinned_nodes = mmr.node_digests_to_pin(Position::new(7));
1428
1429 let config = Config {
1430 nodes: nodes.clone(),
1431 pruned_to_pos: Position::new(7),
1432 pinned_nodes: pinned_nodes.clone(),
1433 };
1434 assert!(Mmr::init(config, &mut hasher).is_ok());
1435
1436 let config = Config {
1439 nodes: nodes.clone(),
1440 pruned_to_pos: Position::new(8),
1441 pinned_nodes: pinned_nodes.clone(),
1442 };
1443 assert!(matches!(
1444 Mmr::init(config, &mut hasher),
1445 Err(Error::InvalidSize(_))
1446 ));
1447
1448 let config = Config {
1451 nodes,
1452 pruned_to_pos: Position::new(9),
1453 pinned_nodes,
1454 };
1455 assert!(matches!(
1456 Mmr::init(config, &mut hasher),
1457 Err(Error::InvalidSize(_))
1458 ));
1459 });
1460 }
1461}