1use crate::mmr::{
12 iterator::{leaf_pos_to_num, nodes_needing_parents, PathIterator, PeakIterator},
13 verification::Proof,
14 Builder,
15 Error::{self, ElementPruned, Empty},
16 Hasher,
17};
18use commonware_cryptography::Hasher as CHasher;
19use commonware_runtime::ThreadPool;
20use rayon::prelude::*;
21use std::collections::{HashMap, HashSet, VecDeque};
22
23pub struct Config<H: CHasher> {
24 pub nodes: Vec<H::Digest>,
26
27 pub pruned_to_pos: u64,
30
31 pub pinned_nodes: Vec<H::Digest>,
33
34 pub pool: Option<ThreadPool>,
36}
37
38pub struct Mmr<H: CHasher> {
45 nodes: VecDeque<H::Digest>,
48
49 pruned_to_pos: u64,
52
53 pub(super) pinned_nodes: HashMap<u64, H::Digest>,
55
56 dirty_nodes: HashSet<(u64, u32)>,
60
61 dirty_digest: H::Digest,
64
65 pub(super) thread_pool: Option<ThreadPool>,
67}
68
69impl<H: CHasher> Default for Mmr<H> {
70 fn default() -> Self {
71 Self::new()
72 }
73}
74
75impl<H: CHasher> Builder<H> for Mmr<H> {
76 async fn add(&mut self, hasher: &mut impl Hasher<H>, element: &[u8]) -> Result<u64, Error> {
77 Ok(self.add(hasher, element))
78 }
79
80 fn root(&self, hasher: &mut impl Hasher<H>) -> H::Digest {
81 self.root(hasher)
82 }
83}
84
85const MIN_TO_PARALLELIZE: usize = 20;
87
88impl<H: CHasher> Mmr<H> {
89 pub fn new() -> Self {
91 Self {
92 nodes: VecDeque::new(),
93 pruned_to_pos: 0,
94 pinned_nodes: HashMap::new(),
95 dirty_nodes: HashSet::new(),
96 dirty_digest: Self::dirty_digest(),
97 thread_pool: None,
98 }
99 }
100
101 fn dirty_digest() -> H::Digest {
104 H::empty()
105 }
106
107 pub fn init(config: Config<H>) -> Self {
109 let mut mmr = Self {
110 nodes: VecDeque::from(config.nodes),
111 pruned_to_pos: config.pruned_to_pos,
112 pinned_nodes: HashMap::new(),
113 dirty_nodes: HashSet::new(),
114 dirty_digest: Self::dirty_digest(),
115 thread_pool: config.pool,
116 };
117 if mmr.size() == 0 {
118 return mmr;
119 }
120
121 for (i, pos) in Proof::<H::Digest>::nodes_to_pin(config.pruned_to_pos).enumerate() {
122 mmr.pinned_nodes.insert(pos, config.pinned_nodes[i]);
123 }
124
125 mmr
126 }
127
128 pub fn size(&self) -> u64 {
131 self.nodes.len() as u64 + self.pruned_to_pos
132 }
133
134 pub fn leaves(&self) -> u64 {
136 leaf_pos_to_num(self.size()).expect("invalid mmr size")
137 }
138
139 pub fn last_leaf_pos(&self) -> Option<u64> {
141 if self.size() == 0 {
142 return None;
143 }
144
145 Some(PeakIterator::last_leaf_pos(self.size()))
146 }
147
148 pub fn pruned_to_pos(&self) -> u64 {
151 self.pruned_to_pos
152 }
153
154 pub fn oldest_retained_pos(&self) -> Option<u64> {
157 if self.pruned_to_pos == self.size() {
158 return None;
159 }
160
161 Some(self.pruned_to_pos)
162 }
163
164 pub(super) fn peak_iterator(&self) -> PeakIterator {
166 PeakIterator::new(self.size())
167 }
168
169 fn index_to_pos(&self, index: usize) -> u64 {
171 index as u64 + self.pruned_to_pos
172 }
173
174 pub fn get_node_unchecked(&self, pos: u64) -> &H::Digest {
177 if pos < self.pruned_to_pos {
178 return self
179 .pinned_nodes
180 .get(&pos)
181 .expect("requested node is pruned and not pinned");
182 }
183
184 &self.nodes[self.pos_to_index(pos)]
185 }
186
187 pub fn get_node(&self, pos: u64) -> Option<H::Digest> {
189 if pos < self.pruned_to_pos {
190 return self.pinned_nodes.get(&pos).copied();
191 }
192
193 self.nodes.get(self.pos_to_index(pos)).copied()
194 }
195
196 fn pos_to_index(&self, pos: u64) -> usize {
200 (pos - self.pruned_to_pos) as usize
201 }
202
203 pub fn add(&mut self, hasher: &mut impl Hasher<H>, element: &[u8]) -> u64 {
210 let leaf_pos = self.size();
211 let digest = hasher.leaf_digest(leaf_pos, element);
212 self.add_leaf_digest(hasher, digest);
213
214 leaf_pos
215 }
216
217 pub fn add_batched(&mut self, hasher: &mut impl Hasher<H>, element: &[u8]) -> u64 {
221 let leaf_pos = self.size();
222 let digest = hasher.leaf_digest(leaf_pos, element);
223
224 let nodes_needing_parents = nodes_needing_parents(self.peak_iterator())
227 .into_iter()
228 .rev();
229 self.nodes.push_back(digest);
230
231 let mut height = 1;
232 for _ in nodes_needing_parents {
233 let new_node_pos = self.size();
234 self.nodes.push_back(self.dirty_digest);
236 self.dirty_nodes.insert((new_node_pos, height));
237 height += 1;
238 }
239
240 leaf_pos
241 }
242
243 pub(super) fn add_leaf_digest(&mut self, hasher: &mut impl Hasher<H>, mut digest: H::Digest) {
250 assert!(
251 self.dirty_nodes.is_empty(),
252 "dirty nodes must be processed before adding an element w/o batching"
253 );
254 let nodes_needing_parents = nodes_needing_parents(self.peak_iterator())
255 .into_iter()
256 .rev();
257 self.nodes.push_back(digest);
258
259 for sibling_pos in nodes_needing_parents {
261 let new_node_pos = self.size();
262 let sibling_digest = self.get_node_unchecked(sibling_pos);
263 digest = hasher.node_digest(new_node_pos, sibling_digest, &digest);
264 self.nodes.push_back(digest);
265 }
266 }
267
268 pub fn pop(&mut self) -> Result<u64, Error> {
275 if self.size() == 0 {
276 return Err(Empty);
277 }
278 assert!(
279 self.dirty_nodes.is_empty(),
280 "dirty nodes must be processed before popping elements"
281 );
282
283 let mut new_size = self.size() - 1;
284 loop {
285 if new_size < self.pruned_to_pos {
286 return Err(ElementPruned(new_size));
287 }
288 if PeakIterator::check_validity(new_size) {
289 break;
290 }
291 new_size -= 1;
292 }
293 let num_to_drain = (self.size() - new_size) as usize;
294 self.nodes.drain(self.nodes.len() - num_to_drain..);
295
296 Ok(self.size())
297 }
298
299 pub fn update_leaf(&mut self, hasher: &mut impl Hasher<H>, pos: u64, element: &[u8]) {
310 if pos < self.pruned_to_pos {
311 panic!("element pruned: pos={pos}");
312 }
313
314 let mut digest = hasher.leaf_digest(pos, element);
316 let mut index = self.pos_to_index(pos);
317 self.nodes[index] = digest;
318
319 for (peak_pos, height) in self.peak_iterator() {
321 if peak_pos < pos {
322 continue;
323 }
324 let path: Vec<_> = PathIterator::new(pos, peak_pos, height).collect();
326 for (parent_pos, sibling_pos) in path.into_iter().rev() {
327 if parent_pos == pos {
328 panic!("pos was not for a leaf");
329 }
330 let sibling_digest = self.get_node_unchecked(sibling_pos);
331 digest = if sibling_pos == parent_pos - 1 {
332 hasher.node_digest(parent_pos, &digest, sibling_digest)
334 } else {
335 hasher.node_digest(parent_pos, sibling_digest, &digest)
336 };
337 index = self.pos_to_index(parent_pos);
338 self.nodes[index] = digest;
339 }
340 return;
341 }
342
343 panic!("invalid pos {}:{}", pos, self.size())
344 }
345
346 pub fn update_leaf_batched<T: AsRef<[u8]> + Sync>(
352 &mut self,
353 hasher: &mut impl Hasher<H>,
354 updates: &[(u64, T)],
355 ) {
356 if updates.len() >= MIN_TO_PARALLELIZE && self.thread_pool.is_some() {
357 self.update_leaf_parallel(hasher, updates);
358 return;
359 }
360
361 for (pos, element) in updates {
362 if *pos < self.pruned_to_pos {
363 panic!("element pruned: pos={pos}");
364 }
365
366 let digest = hasher.leaf_digest(*pos, element.as_ref());
368 let index = self.pos_to_index(*pos);
369 self.nodes[index] = digest;
370 self.mark_dirty(*pos);
371 }
372 }
373
374 fn mark_dirty(&mut self, pos: u64) {
377 for (peak_pos, mut height) in self.peak_iterator() {
378 if peak_pos < pos {
379 continue;
380 }
381
382 let path = PathIterator::new(pos, peak_pos, height)
385 .collect::<Vec<_>>()
386 .into_iter()
387 .rev();
388 height = 1;
389 for (parent_pos, _) in path {
390 if !self.dirty_nodes.insert((parent_pos, height)) {
391 break;
392 }
393 height += 1;
394 }
395 return;
396 }
397
398 panic!("invalid pos {}:{}", pos, self.size());
399 }
400
401 fn update_leaf_parallel<T: AsRef<[u8]> + Sync>(
407 &mut self,
408 hasher: &mut impl Hasher<H>,
409 updates: &[(u64, T)],
410 ) {
411 let pool = self.thread_pool.as_ref().unwrap().clone();
412 pool.install(|| {
413 let digests: Vec<(u64, H::Digest)> = updates
414 .par_iter()
415 .map_init(
416 || hasher.fork(),
417 |hasher, (pos, elem)| {
418 let digest = hasher.leaf_digest(*pos, elem.as_ref());
419 (*pos, digest)
420 },
421 )
422 .collect();
423
424 for (pos, digest) in digests {
425 let index = self.pos_to_index(pos);
426 self.nodes[index] = digest;
427 self.mark_dirty(pos);
428 }
429 });
430 }
431
432 pub fn is_dirty(&self) -> bool {
434 !self.dirty_nodes.is_empty()
435 }
436
437 pub fn sync(&mut self, hasher: &mut impl Hasher<H>) {
439 if self.dirty_nodes.is_empty() {
440 return;
441 }
442 if self.dirty_nodes.len() >= MIN_TO_PARALLELIZE && self.thread_pool.is_some() {
443 self.sync_parallel(hasher, MIN_TO_PARALLELIZE);
444 return;
445 }
446
447 self.sync_serial(hasher);
448 }
449
450 fn sync_serial(&mut self, hasher: &mut impl Hasher<H>) {
451 let mut nodes: Vec<(u64, u32)> = self.dirty_nodes.iter().copied().collect();
452 self.dirty_nodes.clear();
453 nodes.sort_by(|a, b| a.1.cmp(&b.1));
454
455 for (pos, height) in nodes {
456 let left = pos - (1 << height);
457 let right = pos - 1;
458 let digest = hasher.node_digest(
459 pos,
460 self.get_node_unchecked(left),
461 self.get_node_unchecked(right),
462 );
463 let index = self.pos_to_index(pos);
464 self.nodes[index] = digest;
465 }
466 }
467
468 fn sync_parallel(&mut self, hasher: &mut impl Hasher<H>, min_to_parallelize: usize) {
480 let mut nodes: Vec<(u64, u32)> = self.dirty_nodes.iter().copied().collect();
481 self.dirty_nodes.clear();
482 nodes.sort_by(|a, b| a.1.cmp(&b.1));
484
485 let mut same_height = Vec::new();
486 let mut current_height = 1;
487 for (i, (pos, height)) in nodes.iter().enumerate() {
488 if *height == current_height {
489 same_height.push(*pos);
490 continue;
491 }
492 if same_height.len() < min_to_parallelize {
493 self.dirty_nodes = nodes[i - same_height.len()..].iter().copied().collect();
494 self.sync_serial(hasher);
495 return;
496 }
497 self.update_node_digests(hasher, &same_height, current_height);
498 same_height.clear();
499 current_height += 1;
500 same_height.push(*pos);
501 }
502
503 if same_height.len() < min_to_parallelize {
504 self.dirty_nodes = nodes[nodes.len() - same_height.len()..]
505 .iter()
506 .copied()
507 .collect();
508 self.sync_serial(hasher);
509 return;
510 }
511
512 self.update_node_digests(hasher, &same_height, current_height);
513 }
514
515 fn update_node_digests(
522 &mut self,
523 hasher: &mut impl Hasher<H>,
524 same_height: &[u64],
525 height: u32,
526 ) {
527 let two_h = 1 << height;
528 let pool = self.thread_pool.as_ref().unwrap().clone();
529 pool.install(|| {
530 let computed_digests: Vec<(usize, H::Digest)> = same_height
531 .par_iter()
532 .map_init(
533 || hasher.fork(),
534 |hasher, &pos| {
535 let left = pos - two_h;
536 let right = pos - 1;
537 let digest = hasher.node_digest(
538 pos,
539 self.get_node_unchecked(left),
540 self.get_node_unchecked(right),
541 );
542 let index = self.pos_to_index(pos);
543 (index, digest)
544 },
545 )
546 .collect();
547
548 for (index, digest) in computed_digests {
549 self.nodes[index] = digest;
550 }
551 });
552 }
553
554 pub fn root(&self, hasher: &mut impl Hasher<H>) -> H::Digest {
560 assert!(
561 self.dirty_nodes.is_empty(),
562 "dirty nodes must be processed before computing the root"
563 );
564 let peaks = self
565 .peak_iterator()
566 .map(|(peak_pos, _)| self.get_node_unchecked(peak_pos));
567 let size = self.size();
568 hasher.root(size, peaks)
569 }
570
571 pub async fn proof(&self, element_pos: u64) -> Result<Proof<H::Digest>, Error> {
578 self.range_proof(element_pos, element_pos).await
579 }
580
581 pub async fn range_proof(
588 &self,
589 start_element_pos: u64,
590 end_element_pos: u64,
591 ) -> Result<Proof<H::Digest>, Error> {
592 if start_element_pos < self.pruned_to_pos {
593 return Err(ElementPruned(start_element_pos));
594 }
595 assert!(
596 self.dirty_nodes.is_empty(),
597 "dirty nodes must be processed before computing proofs"
598 );
599 Proof::<H::Digest>::range_proof(self, start_element_pos, end_element_pos).await
600 }
601
602 pub fn prune_all(&mut self) {
609 if !self.nodes.is_empty() {
610 self.prune_to_pos(self.index_to_pos(self.nodes.len()));
611 }
612 }
613
614 pub fn prune_to_pos(&mut self, pos: u64) {
621 assert!(
622 self.dirty_nodes.is_empty(),
623 "dirty nodes must be processed before pruning"
624 );
625 self.pinned_nodes = self.nodes_to_pin(pos);
627 let retained_nodes = self.pos_to_index(pos);
628 self.nodes.drain(0..retained_nodes);
629 self.pruned_to_pos = pos;
630 }
631
632 pub(super) fn nodes_to_pin(&self, prune_pos: u64) -> HashMap<u64, H::Digest> {
635 Proof::<H::Digest>::nodes_to_pin(prune_pos)
636 .map(|pos| (pos, *self.get_node_unchecked(pos)))
637 .collect()
638 }
639
640 pub(super) fn node_digests_to_pin(&self, start_pos: u64) -> Vec<H::Digest> {
643 Proof::<H::Digest>::nodes_to_pin(start_pos)
644 .map(|pos| *self.get_node_unchecked(pos))
645 .collect()
646 }
647
648 pub(super) fn add_pinned_nodes(&mut self, pinned_nodes: HashMap<u64, H::Digest>) {
651 for (pos, node) in pinned_nodes.into_iter() {
652 self.pinned_nodes.insert(pos, node);
653 }
654 }
655
656 pub fn clone_pruned(&self) -> Self {
666 if self.size() == 0 {
667 return Self::new();
668 }
669 assert!(
670 self.dirty_nodes.is_empty(),
671 "dirty nodes must be processed before cloning"
672 );
673
674 let old_nodes = self.node_digests_to_pin(self.size());
676
677 Self::init(Config {
678 nodes: vec![],
679 pruned_to_pos: self.size(),
680 pinned_nodes: old_nodes,
681 pool: None,
682 })
683 }
684}
685
686#[cfg(test)]
687mod tests {
688 use super::*;
689 use crate::mmr::{
690 hasher::Standard,
691 iterator::leaf_num_to_pos,
692 tests::{build_and_check_test_roots_mmr, build_batched_and_check_test_roots, ROOTS},
693 };
694 use commonware_cryptography::Sha256;
695 use commonware_runtime::{create_pool, deterministic, tokio, Runner};
696 use commonware_utils::hex;
697
698 #[test]
700 fn test_mem_mmr_empty() {
701 let executor = deterministic::Runner::default();
702 executor.start(|_| async move {
703 let mut hasher: Standard<Sha256> = Standard::new();
704 let mut mmr = Mmr::new();
705 assert_eq!(
706 mmr.peak_iterator().next(),
707 None,
708 "empty iterator should have no peaks"
709 );
710 assert_eq!(mmr.size(), 0);
711 assert_eq!(mmr.leaves(), 0);
712 assert_eq!(mmr.last_leaf_pos(), None);
713 assert_eq!(mmr.oldest_retained_pos(), None);
714 assert_eq!(mmr.get_node(0), None);
715 assert!(matches!(mmr.pop(), Err(Empty)));
716 mmr.prune_all();
717 assert_eq!(mmr.size(), 0, "prune_all on empty MMR should do nothing");
718
719 assert_eq!(mmr.root(&mut hasher), hasher.root(0, [].iter()));
720
721 let clone = mmr.clone_pruned();
722 assert_eq!(clone.size(), 0);
723 });
724 }
725
726 #[test]
730 fn test_mem_mmr_add_eleven_values() {
731 let executor = deterministic::Runner::default();
732 executor.start(|_| async move {
733 let mut mmr = Mmr::new();
734 let element = <Sha256 as CHasher>::Digest::from(*b"01234567012345670123456701234567");
735 let mut leaves: Vec<u64> = Vec::new();
736 let mut hasher: Standard<Sha256> = Standard::new();
737 for _ in 0..11 {
738 leaves.push(mmr.add(&mut hasher, &element));
739 let peaks: Vec<(u64, u32)> = mmr.peak_iterator().collect();
740 assert_ne!(peaks.len(), 0);
741 assert!(peaks.len() <= mmr.size() as usize);
742 let nodes_needing_parents = nodes_needing_parents(mmr.peak_iterator());
743 assert!(nodes_needing_parents.len() <= peaks.len());
744 }
745 assert_eq!(mmr.oldest_retained_pos().unwrap(), 0);
746 assert_eq!(mmr.size(), 19, "mmr not of expected size");
747 assert_eq!(
748 leaves,
749 vec![0, 1, 3, 4, 7, 8, 10, 11, 15, 16, 18],
750 "mmr leaf positions not as expected"
751 );
752 let peaks: Vec<(u64, u32)> = mmr.peak_iterator().collect();
753 assert_eq!(
754 peaks,
755 vec![(14, 3), (17, 1), (18, 0)],
756 "mmr peaks not as expected"
757 );
758
759 let peaks_needing_parents = nodes_needing_parents(mmr.peak_iterator());
762 assert_eq!(
763 peaks_needing_parents,
764 vec![17, 18],
765 "mmr nodes needing parents not as expected"
766 );
767
768 for leaf in leaves.iter().by_ref() {
770 let digest = hasher.leaf_digest(*leaf, &element);
771 assert_eq!(mmr.get_node(*leaf).unwrap(), digest);
772 }
773
774 let digest2 = hasher.node_digest(2, &mmr.nodes[0], &mmr.nodes[1]);
776 assert_eq!(mmr.nodes[2], digest2);
777 let digest5 = hasher.node_digest(5, &mmr.nodes[3], &mmr.nodes[4]);
778 assert_eq!(mmr.nodes[5], digest5);
779 let digest9 = hasher.node_digest(9, &mmr.nodes[7], &mmr.nodes[8]);
780 assert_eq!(mmr.nodes[9], digest9);
781 let digest12 = hasher.node_digest(12, &mmr.nodes[10], &mmr.nodes[11]);
782 assert_eq!(mmr.nodes[12], digest12);
783 let digest17 = hasher.node_digest(17, &mmr.nodes[15], &mmr.nodes[16]);
784 assert_eq!(mmr.nodes[17], digest17);
785
786 let digest6 = hasher.node_digest(6, &mmr.nodes[2], &mmr.nodes[5]);
788 assert_eq!(mmr.nodes[6], digest6);
789 let digest13 = hasher.node_digest(13, &mmr.nodes[9], &mmr.nodes[12]);
790 assert_eq!(mmr.nodes[13], digest13);
791 let digest17 = hasher.node_digest(17, &mmr.nodes[15], &mmr.nodes[16]);
792 assert_eq!(mmr.nodes[17], digest17);
793
794 let digest14 = hasher.node_digest(14, &mmr.nodes[6], &mmr.nodes[13]);
796 assert_eq!(mmr.nodes[14], digest14);
797
798 let root = mmr.root(&mut hasher);
800 let peak_digests = [digest14, digest17, mmr.nodes[18]];
801 let expected_root = hasher.root(19, peak_digests.iter());
802 assert_eq!(root, expected_root, "incorrect root");
803
804 mmr.prune_to_pos(14); assert_eq!(mmr.oldest_retained_pos().unwrap(), 14);
807
808 assert!(matches!(mmr.proof(0).await, Err(ElementPruned(_))));
810 assert!(matches!(mmr.proof(11).await, Err(ElementPruned(_))));
811 assert!(mmr.proof(15).await.is_ok());
814
815 let root_after_prune = mmr.root(&mut hasher);
816 assert_eq!(root, root_after_prune, "root changed after pruning");
817 assert!(
818 mmr.proof(11).await.is_err(),
819 "attempts to prove elements at or before the oldest retained should fail"
820 );
821 assert!(
822 mmr.range_proof(10, 15).await.is_err(),
823 "attempts to range_prove elements at or before the oldest retained should fail"
824 );
825 assert!(
826 mmr.range_proof(15, mmr.last_leaf_pos().unwrap())
827 .await
828 .is_ok(),
829 "attempts to range_prove over elements following oldest retained should succeed"
830 );
831
832 let oldest_pos = mmr.oldest_retained_pos().unwrap();
834 let digests = mmr.node_digests_to_pin(oldest_pos);
835 let mmr_copy = Mmr::init(Config {
836 nodes: mmr.nodes.iter().copied().collect(),
837 pruned_to_pos: oldest_pos,
838 pinned_nodes: digests,
839 pool: None,
840 });
841 assert_eq!(mmr_copy.size(), 19);
842 assert_eq!(mmr_copy.leaves(), mmr.leaves());
843 assert_eq!(mmr_copy.last_leaf_pos(), mmr.last_leaf_pos());
844 assert_eq!(mmr_copy.oldest_retained_pos(), mmr.oldest_retained_pos());
845 assert_eq!(mmr_copy.root(&mut hasher), root);
846
847 mmr.prune_to_pos(17); let clone = mmr.clone_pruned();
851 assert_eq!(clone.oldest_retained_pos(), None);
852 assert_eq!(clone.pruned_to_pos(), clone.size());
853 mmr.prune_all();
854 assert_eq!(mmr.oldest_retained_pos(), None);
855 assert_eq!(mmr.pruned_to_pos(), mmr.size());
856 assert_eq!(mmr.size(), clone.size());
857 assert_eq!(mmr.root(&mut hasher), clone.root(&mut hasher));
858 });
859 }
860
861 #[test]
863 fn test_mem_mmr_prune_all() {
864 let executor = deterministic::Runner::default();
865 executor.start(|_| async move {
866 let mut mmr = Mmr::new();
867 let element = <Sha256 as CHasher>::Digest::from(*b"01234567012345670123456701234567");
868 let mut hasher: Standard<Sha256> = Standard::new();
869 for _ in 0..1000 {
870 mmr.prune_all();
871 mmr.add(&mut hasher, &element);
872 }
873 });
874 }
875
876 #[test]
878 fn test_mem_mmr_validity() {
879 let executor = deterministic::Runner::default();
880 executor.start(|_| async move {
881 let mut mmr = Mmr::new();
882 let element = <Sha256 as CHasher>::Digest::from(*b"01234567012345670123456701234567");
883 let mut hasher: Standard<Sha256> = Standard::new();
884 for _ in 0..1001 {
885 assert!(
886 PeakIterator::check_validity(mmr.size()),
887 "mmr of size {} should be valid",
888 mmr.size()
889 );
890 let old_size = mmr.size();
891 mmr.add(&mut hasher, &element);
892 for size in old_size + 1..mmr.size() {
893 assert!(
894 !PeakIterator::check_validity(size),
895 "mmr of size {size} should be invalid",
896 );
897 }
898 }
899 });
900 }
901
902 #[test]
905 fn test_mem_mmr_root_stability() {
906 let executor = deterministic::Runner::default();
907 executor.start(|_| async move {
908 let mut mmr = Mmr::new();
910 build_and_check_test_roots_mmr(&mut mmr).await;
911
912 let mut mmr = Mmr::new();
913 build_batched_and_check_test_roots(&mut mmr).await;
914 });
915 }
916
917 #[test]
920 fn test_mem_mmr_root_stability_parallel() {
921 let executor = tokio::Runner::default();
922 executor.start(|context| async move {
923 let pool = commonware_runtime::create_pool(context, 4).unwrap();
924
925 let mut mmr = Mmr::init(Config {
926 nodes: vec![],
927 pruned_to_pos: 0,
928 pinned_nodes: vec![],
929 pool: Some(pool),
930 });
931 build_batched_and_check_test_roots(&mut mmr).await;
932 });
933 }
934
935 #[test]
938 fn test_mem_mmr_root_stability_while_pruning() {
939 let executor = deterministic::Runner::default();
940 executor.start(|_| async move {
941 let mut hasher: Standard<Sha256> = Standard::new();
942 let mut mmr = Mmr::new();
943 for i in 0u64..199 {
944 let root = mmr.root(&mut hasher);
945 let expected_root = ROOTS[i as usize];
946 assert_eq!(hex(&root), expected_root, "at: {i}");
947 hasher.inner().update(&i.to_be_bytes());
948 let element = hasher.inner().finalize();
949 mmr.add(&mut hasher, &element);
950 mmr.prune_all();
951 }
952 });
953 }
954
955 fn compute_big_mmr(hasher: &mut impl Hasher<Sha256>, mmr: &mut Mmr<Sha256>) -> Vec<u64> {
956 let mut leaves = Vec::new();
957 let mut c_hasher = Sha256::default();
958 for i in 0u64..199 {
959 c_hasher.update(&i.to_be_bytes());
960 let element = c_hasher.finalize();
961 leaves.push(mmr.add(hasher, &element));
962 }
963 mmr.sync(hasher);
964
965 leaves
966 }
967
968 #[test]
969 fn test_mem_mmr_pop() {
970 let executor = deterministic::Runner::default();
971 executor.start(|_| async move {
972 let mut hasher: Standard<Sha256> = Standard::new();
973 let mut mmr = Mmr::new();
974 compute_big_mmr(&mut hasher, &mut mmr);
975 let root = mmr.root(&mut hasher);
976 let expected_root = ROOTS[199];
977 assert_eq!(hex(&root), expected_root);
978
979 for i in (0..199u64).rev() {
981 assert!(mmr.pop().is_ok());
982 let root = mmr.root(&mut hasher);
983 let expected_root = ROOTS[i as usize];
984 assert_eq!(hex(&root), expected_root);
985 }
986
987 assert!(
988 matches!(mmr.pop().unwrap_err(), Empty),
989 "pop on empty MMR should fail"
990 );
991
992 for i in 0u64..199 {
994 hasher.inner().update(&i.to_be_bytes());
995 let element = hasher.inner().finalize();
996 mmr.add(&mut hasher, &element);
997 }
998
999 let leaf_pos = leaf_num_to_pos(100);
1000 mmr.prune_to_pos(leaf_pos);
1001 while mmr.size() > leaf_pos {
1002 assert!(mmr.pop().is_ok());
1003 }
1004 assert_eq!(hex(&mmr.root(&mut hasher)), ROOTS[100]);
1005 assert!(matches!(mmr.pop().unwrap_err(), ElementPruned(_)));
1006 assert_eq!(mmr.oldest_retained_pos(), None);
1007 });
1008 }
1009
1010 #[test]
1011 fn test_mem_mmr_update_leaf() {
1012 let mut hasher: Standard<Sha256> = Standard::new();
1013 let element = <Sha256 as CHasher>::Digest::from(*b"01234567012345670123456701234567");
1014 let executor = deterministic::Runner::default();
1015 executor.start(|_| async move {
1016 let mut mmr = Mmr::new();
1017 compute_big_mmr(&mut hasher, &mut mmr);
1018 let leaves = compute_big_mmr(&mut hasher, &mut mmr);
1019 let root = mmr.root(&mut hasher);
1020
1021 for leaf in [0usize, 1, 10, 50, 100, 150, 197, 198] {
1024 mmr.update_leaf(&mut hasher, leaves[leaf], &element);
1026 let updated_root = mmr.root(&mut hasher);
1027 assert!(root != updated_root);
1028
1029 hasher.inner().update(&leaf.to_be_bytes());
1031 let element = hasher.inner().finalize();
1032 mmr.update_leaf(&mut hasher, leaves[leaf], &element);
1033 let restored_root = mmr.root(&mut hasher);
1034 assert_eq!(root, restored_root);
1035 }
1036
1037 mmr.prune_to_pos(leaves[150]);
1039 for &leaf_pos in &leaves[150..=190] {
1040 mmr.prune_to_pos(leaf_pos);
1041 mmr.update_leaf(&mut hasher, leaf_pos, &element);
1042 }
1043 });
1044 }
1045
1046 #[test]
1047 #[should_panic(expected = "pos was not for a leaf")]
1048 fn test_mem_mmr_update_leaf_panic_invalid() {
1049 let mut hasher: Standard<Sha256> = Standard::new();
1050 let element = <Sha256 as CHasher>::Digest::from(*b"01234567012345670123456701234567");
1051
1052 let executor = deterministic::Runner::default();
1053 executor.start(|_| async move {
1054 let mut mmr = Mmr::new();
1055 compute_big_mmr(&mut hasher, &mut mmr);
1056 let not_a_leaf_pos = 2;
1057 mmr.update_leaf(&mut hasher, not_a_leaf_pos, &element);
1058 });
1059 }
1060
1061 #[test]
1062 #[should_panic(expected = "element pruned")]
1063 fn test_mem_mmr_update_leaf_panic_pruned() {
1064 let mut hasher: Standard<Sha256> = Standard::new();
1065 let element = <Sha256 as CHasher>::Digest::from(*b"01234567012345670123456701234567");
1066
1067 let executor = deterministic::Runner::default();
1068 executor.start(|_| async move {
1069 let mut mmr = Mmr::new();
1070 compute_big_mmr(&mut hasher, &mut mmr);
1071 mmr.prune_all();
1072 mmr.update_leaf(&mut hasher, 0, &element);
1073 });
1074 }
1075
1076 #[test]
1077 fn test_mem_mmr_batch_update_leaf() {
1078 let mut hasher: Standard<Sha256> = Standard::new();
1079 let executor = deterministic::Runner::default();
1080 executor.start(|_| async move {
1081 let mut mmr = Mmr::new();
1082 let leaves = compute_big_mmr(&mut hasher, &mut mmr);
1083 do_batch_update(&mut hasher, &mut mmr, &leaves);
1084 });
1085 }
1086
1087 #[test]
1088 fn test_mem_mmr_batch_parallel_update_leaf() {
1091 let mut hasher: Standard<Sha256> = Standard::new();
1092 let executor = tokio::Runner::default();
1093 executor.start(|ctx| async move {
1094 let pool = create_pool(ctx, 4).unwrap();
1095 let mut mmr = Mmr::init(Config {
1096 nodes: Vec::new(),
1097 pruned_to_pos: 0,
1098 pinned_nodes: Vec::new(),
1099 pool: Some(pool),
1100 });
1101 let leaves = compute_big_mmr(&mut hasher, &mut mmr);
1102 do_batch_update(&mut hasher, &mut mmr, &leaves);
1103 });
1104 }
1105
1106 fn do_batch_update(hasher: &mut Standard<Sha256>, mmr: &mut Mmr<Sha256>, leaves: &[u64]) {
1107 let element = <Sha256 as CHasher>::Digest::from(*b"01234567012345670123456701234567");
1108 let root = mmr.root(hasher);
1109
1110 let mut updates = Vec::new();
1112 for leaf in [0usize, 1, 10, 50, 100, 150, 197, 198] {
1113 updates.push((leaves[leaf], &element));
1114 }
1115 mmr.update_leaf_batched(hasher, &updates);
1116
1117 mmr.sync(hasher);
1118 let updated_root = mmr.root(hasher);
1119 assert_eq!(
1120 "af3acad6aad59c1a880de643b1200a0962a95d06c087ebf677f29eb93fc359a4",
1121 hex(&updated_root)
1122 );
1123
1124 let mut updates = Vec::new();
1126 for leaf in [0usize, 1, 10, 50, 100, 150, 197, 198] {
1127 hasher.inner().update(&leaf.to_be_bytes());
1128 let element = hasher.inner().finalize();
1129 updates.push((leaves[leaf], element));
1130 }
1131 mmr.update_leaf_batched(hasher, &updates);
1132
1133 mmr.sync(hasher);
1134 let restored_root = mmr.root(hasher);
1135 assert_eq!(root, restored_root);
1136 }
1137}