commonware_storage/mmr/
mem.rs

1//! A basic MMR where all nodes are stored in-memory.
2//!
3//! # Terminology
4//!
5//! Nodes in this structure are either _retained_, _pruned_, or _pinned_. Retained nodes are nodes
6//! that have not yet been pruned, and have digests stored explicitly within the tree structure.
7//! Pruned nodes are those whose positions precede that of the _oldest retained_ node, for which no
8//! digests are maintained. Pinned nodes are nodes that would otherwise be pruned based on their
9//! position, but whose digests remain required for proof generation. The digests for pinned nodes
10//! are stored in an auxiliary map, and are at most O(log2(n)) in number.
11use crate::mmr::{
12    iterator::{leaf_pos_to_num, nodes_needing_parents, PathIterator, PeakIterator},
13    verification::Proof,
14    Builder,
15    Error::{self, ElementPruned, Empty},
16    Hasher,
17};
18use commonware_cryptography::Hasher as CHasher;
19use commonware_runtime::ThreadPool;
20use rayon::prelude::*;
21use std::collections::{HashMap, HashSet, VecDeque};
22
23pub struct Config<H: CHasher> {
24    /// The retained nodes of the MMR.
25    pub nodes: Vec<H::Digest>,
26
27    /// The highest position for which this MMR has been pruned, or 0 if this MMR has never been
28    /// pruned.
29    pub pruned_to_pos: u64,
30
31    /// The pinned nodes of the MMR, in the order expected by [Proof::nodes_to_pin].
32    pub pinned_nodes: Vec<H::Digest>,
33
34    /// Optional thread pool to use for parallelizing batch updates.
35    pub pool: Option<ThreadPool>,
36}
37
38/// Implementation of `Mmr`.
39///
40/// # Max Capacity
41///
42/// The maximum number of elements that can be stored is usize::MAX
43/// (u32::MAX on 32-bit architectures).
44pub struct Mmr<H: CHasher> {
45    /// The nodes of the MMR, laid out according to a post-order traversal of the MMR trees,
46    /// starting from the from tallest tree to shortest.
47    nodes: VecDeque<H::Digest>,
48
49    /// The highest position for which this MMR has been pruned, or 0 if this MMR has never been
50    /// pruned.
51    pruned_to_pos: u64,
52
53    /// The auxiliary map from node position to the digest of any pinned node.
54    pub(super) pinned_nodes: HashMap<u64, H::Digest>,
55
56    /// Non-leaf nodes that need to have their digests recomputed due to a batched update operation.
57    ///
58    /// This is a set of tuples of the form (node_pos, height).
59    dirty_nodes: HashSet<(u64, u32)>,
60
61    /// Dummy digest used as a placeholder for nodes whose digests will be updated with the next
62    /// `sync`.
63    dirty_digest: H::Digest,
64
65    /// Thread pool to use for parallelizing updates.
66    pub(super) thread_pool: Option<ThreadPool>,
67}
68
69impl<H: CHasher> Default for Mmr<H> {
70    fn default() -> Self {
71        Self::new()
72    }
73}
74
75impl<H: CHasher> Builder<H> for Mmr<H> {
76    async fn add(&mut self, hasher: &mut impl Hasher<H>, element: &[u8]) -> Result<u64, Error> {
77        Ok(self.add(hasher, element))
78    }
79
80    fn root(&self, hasher: &mut impl Hasher<H>) -> H::Digest {
81        self.root(hasher)
82    }
83}
84
85/// Minimum number of digest computations required during batch updates to trigger parallelization.
86const MIN_TO_PARALLELIZE: usize = 20;
87
88impl<H: CHasher> Mmr<H> {
89    /// Return a new (empty) `Mmr`.
90    pub fn new() -> Self {
91        Self {
92            nodes: VecDeque::new(),
93            pruned_to_pos: 0,
94            pinned_nodes: HashMap::new(),
95            dirty_nodes: HashSet::new(),
96            dirty_digest: Self::dirty_digest(),
97            thread_pool: None,
98        }
99    }
100
101    // Computes the digest to use as the `self.dirty_digest` placeholder. The specific value is
102    // unimportant so we simply use the empty hash.
103    fn dirty_digest() -> H::Digest {
104        H::empty()
105    }
106
107    /// Return an [Mmr] initialized with the given `config`.
108    pub fn init(config: Config<H>) -> Self {
109        let mut mmr = Self {
110            nodes: VecDeque::from(config.nodes),
111            pruned_to_pos: config.pruned_to_pos,
112            pinned_nodes: HashMap::new(),
113            dirty_nodes: HashSet::new(),
114            dirty_digest: Self::dirty_digest(),
115            thread_pool: config.pool,
116        };
117        if mmr.size() == 0 {
118            return mmr;
119        }
120
121        for (i, pos) in Proof::<H::Digest>::nodes_to_pin(config.pruned_to_pos).enumerate() {
122            mmr.pinned_nodes.insert(pos, config.pinned_nodes[i]);
123        }
124
125        mmr
126    }
127
128    /// Return the total number of nodes in the MMR, irrespective of any pruning. The next added
129    /// element's position will have this value.
130    pub fn size(&self) -> u64 {
131        self.nodes.len() as u64 + self.pruned_to_pos
132    }
133
134    /// Return the total number of leaves in the MMR.
135    pub fn leaves(&self) -> u64 {
136        leaf_pos_to_num(self.size()).expect("invalid mmr size")
137    }
138
139    /// Return the position of the last leaf in this MMR, or None if the MMR is empty.
140    pub fn last_leaf_pos(&self) -> Option<u64> {
141        if self.size() == 0 {
142            return None;
143        }
144
145        Some(PeakIterator::last_leaf_pos(self.size()))
146    }
147
148    // The highest position for which this MMR has been pruned, or 0 if this MMR has never been
149    // pruned.
150    pub fn pruned_to_pos(&self) -> u64 {
151        self.pruned_to_pos
152    }
153
154    /// Return the position of the oldest retained node in the MMR, not including those cached in
155    /// pinned_nodes.
156    pub fn oldest_retained_pos(&self) -> Option<u64> {
157        if self.pruned_to_pos == self.size() {
158            return None;
159        }
160
161        Some(self.pruned_to_pos)
162    }
163
164    /// Return a new iterator over the peaks of the MMR.
165    pub(super) fn peak_iterator(&self) -> PeakIterator {
166        PeakIterator::new(self.size())
167    }
168
169    /// Return the position of the element given its index in the current nodes vector.
170    fn index_to_pos(&self, index: usize) -> u64 {
171        index as u64 + self.pruned_to_pos
172    }
173
174    /// Returns the requested node, assuming it is either retained or known to exist in the
175    /// pinned_nodes map.
176    pub fn get_node_unchecked(&self, pos: u64) -> &H::Digest {
177        if pos < self.pruned_to_pos {
178            return self
179                .pinned_nodes
180                .get(&pos)
181                .expect("requested node is pruned and not pinned");
182        }
183
184        &self.nodes[self.pos_to_index(pos)]
185    }
186
187    /// Returns the requested node or None if it is not stored in the MMR.
188    pub fn get_node(&self, pos: u64) -> Option<H::Digest> {
189        if pos < self.pruned_to_pos {
190            return self.pinned_nodes.get(&pos).copied();
191        }
192
193        self.nodes.get(self.pos_to_index(pos)).copied()
194    }
195
196    /// Return the index of the element in the current nodes vector given its position in the MMR.
197    ///
198    /// Will underflow if `pos` precedes the oldest retained position.
199    fn pos_to_index(&self, pos: u64) -> usize {
200        (pos - self.pruned_to_pos) as usize
201    }
202
203    /// Add `element` to the MMR and return its position in the MMR. The element can be an arbitrary
204    /// byte slice, and need not be converted to a digest first.
205    ///
206    /// # Warning
207    ///
208    /// Panics if there are unprocessed batch updates.
209    pub fn add(&mut self, hasher: &mut impl Hasher<H>, element: &[u8]) -> u64 {
210        let leaf_pos = self.size();
211        let digest = hasher.leaf_digest(leaf_pos, element);
212        self.add_leaf_digest(hasher, digest);
213
214        leaf_pos
215    }
216
217    /// Add `element` to the MMR and return its position in the MMR, but without updating ancestors
218    /// until `sync` is called. The element can be an arbitrary byte slice, and need not be
219    /// converted to a digest first.
220    pub fn add_batched(&mut self, hasher: &mut impl Hasher<H>, element: &[u8]) -> u64 {
221        let leaf_pos = self.size();
222        let digest = hasher.leaf_digest(leaf_pos, element);
223
224        // Compute the new parent nodes if any, and insert them into the MMR
225        // with a dummy digest, and add each to the dirty nodes set.
226        let nodes_needing_parents = nodes_needing_parents(self.peak_iterator())
227            .into_iter()
228            .rev();
229        self.nodes.push_back(digest);
230
231        let mut height = 1;
232        for _ in nodes_needing_parents {
233            let new_node_pos = self.size();
234            // The digest we push here doesn't matter as it will be updated later.
235            self.nodes.push_back(self.dirty_digest);
236            self.dirty_nodes.insert((new_node_pos, height));
237            height += 1;
238        }
239
240        leaf_pos
241    }
242
243    /// Add a leaf's `digest` to the MMR, generating the necessary parent nodes to maintain the
244    /// MMR's structure.
245    ///
246    /// # Warning
247    ///
248    /// Panics if there are unprocessed batch updates.
249    pub(super) fn add_leaf_digest(&mut self, hasher: &mut impl Hasher<H>, mut digest: H::Digest) {
250        assert!(
251            self.dirty_nodes.is_empty(),
252            "dirty nodes must be processed before adding an element w/o batching"
253        );
254        let nodes_needing_parents = nodes_needing_parents(self.peak_iterator())
255            .into_iter()
256            .rev();
257        self.nodes.push_back(digest);
258
259        // Compute the new parent nodes if any, and insert them into the MMR.
260        for sibling_pos in nodes_needing_parents {
261            let new_node_pos = self.size();
262            let sibling_digest = self.get_node_unchecked(sibling_pos);
263            digest = hasher.node_digest(new_node_pos, sibling_digest, &digest);
264            self.nodes.push_back(digest);
265        }
266    }
267
268    /// Pop the most recent leaf element out of the MMR if it exists, returning Empty or
269    /// ElementPruned errors otherwise.
270    ///
271    /// # Warning
272    ///
273    /// Panics if there are unprocessed batch updates.
274    pub fn pop(&mut self) -> Result<u64, Error> {
275        if self.size() == 0 {
276            return Err(Empty);
277        }
278        assert!(
279            self.dirty_nodes.is_empty(),
280            "dirty nodes must be processed before popping elements"
281        );
282
283        let mut new_size = self.size() - 1;
284        loop {
285            if new_size < self.pruned_to_pos {
286                return Err(ElementPruned(new_size));
287            }
288            if PeakIterator::check_validity(new_size) {
289                break;
290            }
291            new_size -= 1;
292        }
293        let num_to_drain = (self.size() - new_size) as usize;
294        self.nodes.drain(self.nodes.len() - num_to_drain..);
295
296        Ok(self.size())
297    }
298
299    /// Change the digest of any retained leaf. This is useful if you want to use the MMR
300    /// implementation as an updatable binary Merkle tree, and otherwise should be avoided.
301    ///
302    /// # Warning
303    ///
304    /// - Panics if `pos` does not correspond to a leaf, or if the leaf has been pruned.
305    ///
306    /// - This method will change the root and invalidate any previous inclusion proofs.
307    ///
308    /// - Use of this method will prevent using this structure as a base mmr for grafting.
309    pub fn update_leaf(&mut self, hasher: &mut impl Hasher<H>, pos: u64, element: &[u8]) {
310        if pos < self.pruned_to_pos {
311            panic!("element pruned: pos={pos}");
312        }
313
314        // Update the digest of the leaf node.
315        let mut digest = hasher.leaf_digest(pos, element);
316        let mut index = self.pos_to_index(pos);
317        self.nodes[index] = digest;
318
319        // Update digests of all its ancestors.
320        for (peak_pos, height) in self.peak_iterator() {
321            if peak_pos < pos {
322                continue;
323            }
324            // We have found the mountain containing the path we need to update.
325            let path: Vec<_> = PathIterator::new(pos, peak_pos, height).collect();
326            for (parent_pos, sibling_pos) in path.into_iter().rev() {
327                if parent_pos == pos {
328                    panic!("pos was not for a leaf");
329                }
330                let sibling_digest = self.get_node_unchecked(sibling_pos);
331                digest = if sibling_pos == parent_pos - 1 {
332                    // The sibling is the right child of the parent.
333                    hasher.node_digest(parent_pos, &digest, sibling_digest)
334                } else {
335                    hasher.node_digest(parent_pos, sibling_digest, &digest)
336                };
337                index = self.pos_to_index(parent_pos);
338                self.nodes[index] = digest;
339            }
340            return;
341        }
342
343        panic!("invalid pos {}:{}", pos, self.size())
344    }
345
346    /// Batch update the digests of multiple retained leaves.
347    ///
348    /// # Warning
349    ///
350    /// Panics if any of the updated leaves has been pruned.
351    pub fn update_leaf_batched<T: AsRef<[u8]> + Sync>(
352        &mut self,
353        hasher: &mut impl Hasher<H>,
354        updates: &[(u64, T)],
355    ) {
356        if updates.len() >= MIN_TO_PARALLELIZE && self.thread_pool.is_some() {
357            self.update_leaf_parallel(hasher, updates);
358            return;
359        }
360
361        for (pos, element) in updates {
362            if *pos < self.pruned_to_pos {
363                panic!("element pruned: pos={pos}");
364            }
365
366            // Update the digest of the leaf node and mark its ancestors as dirty.
367            let digest = hasher.leaf_digest(*pos, element.as_ref());
368            let index = self.pos_to_index(*pos);
369            self.nodes[index] = digest;
370            self.mark_dirty(*pos);
371        }
372    }
373
374    /// Mark the non-leaf nodes in the path from the given position to the root as dirty, so that
375    /// their digests are appropriately recomputed during the next `sync`.
376    fn mark_dirty(&mut self, pos: u64) {
377        for (peak_pos, mut height) in self.peak_iterator() {
378            if peak_pos < pos {
379                continue;
380            }
381
382            // We have found the mountain containing the path we are looking for. Traverse it from
383            // leaf to root, that way we can exit early if we hit a node that is already dirty.
384            let path = PathIterator::new(pos, peak_pos, height)
385                .collect::<Vec<_>>()
386                .into_iter()
387                .rev();
388            height = 1;
389            for (parent_pos, _) in path {
390                if !self.dirty_nodes.insert((parent_pos, height)) {
391                    break;
392                }
393                height += 1;
394            }
395            return;
396        }
397
398        panic!("invalid pos {}:{}", pos, self.size());
399    }
400
401    /// Batch update the digests of multiple retained leaves using multiple threads.
402    ///
403    /// # Warning
404    ///
405    /// Assumes `self.pool` is non-None and panics otherwise.
406    fn update_leaf_parallel<T: AsRef<[u8]> + Sync>(
407        &mut self,
408        hasher: &mut impl Hasher<H>,
409        updates: &[(u64, T)],
410    ) {
411        let pool = self.thread_pool.as_ref().unwrap().clone();
412        pool.install(|| {
413            let digests: Vec<(u64, H::Digest)> = updates
414                .par_iter()
415                .map_init(
416                    || hasher.fork(),
417                    |hasher, (pos, elem)| {
418                        let digest = hasher.leaf_digest(*pos, elem.as_ref());
419                        (*pos, digest)
420                    },
421                )
422                .collect();
423
424            for (pos, digest) in digests {
425                let index = self.pos_to_index(pos);
426                self.nodes[index] = digest;
427                self.mark_dirty(pos);
428            }
429        });
430    }
431
432    /// Returns whether there are pending updates.
433    pub fn is_dirty(&self) -> bool {
434        !self.dirty_nodes.is_empty()
435    }
436
437    /// Process any pending batched updates.
438    pub fn sync(&mut self, hasher: &mut impl Hasher<H>) {
439        if self.dirty_nodes.is_empty() {
440            return;
441        }
442        if self.dirty_nodes.len() >= MIN_TO_PARALLELIZE && self.thread_pool.is_some() {
443            self.sync_parallel(hasher, MIN_TO_PARALLELIZE);
444            return;
445        }
446
447        self.sync_serial(hasher);
448    }
449
450    fn sync_serial(&mut self, hasher: &mut impl Hasher<H>) {
451        let mut nodes: Vec<(u64, u32)> = self.dirty_nodes.iter().copied().collect();
452        self.dirty_nodes.clear();
453        nodes.sort_by(|a, b| a.1.cmp(&b.1));
454
455        for (pos, height) in nodes {
456            let left = pos - (1 << height);
457            let right = pos - 1;
458            let digest = hasher.node_digest(
459                pos,
460                self.get_node_unchecked(left),
461                self.get_node_unchecked(right),
462            );
463            let index = self.pos_to_index(pos);
464            self.nodes[index] = digest;
465        }
466    }
467
468    /// Process any pending batched updates, using parallel hash workers as long as the number of
469    /// computations that can be parallelized exceeds `min_to_parallelize`.
470    ///
471    /// This implementation parallelizes the computation of digests across all nodes at the same
472    /// height, starting from the bottom and working up to the peaks. If ever the number of
473    /// remaining digest computations is less than the `min_to_parallelize`, it switches to the
474    /// serial implementation.
475    ///
476    /// # Warning
477    ///
478    /// Assumes `self.pool` is non-None and panics otherwise.
479    fn sync_parallel(&mut self, hasher: &mut impl Hasher<H>, min_to_parallelize: usize) {
480        let mut nodes: Vec<(u64, u32)> = self.dirty_nodes.iter().copied().collect();
481        self.dirty_nodes.clear();
482        // Sort by increasing height.
483        nodes.sort_by(|a, b| a.1.cmp(&b.1));
484
485        let mut same_height = Vec::new();
486        let mut current_height = 1;
487        for (i, (pos, height)) in nodes.iter().enumerate() {
488            if *height == current_height {
489                same_height.push(*pos);
490                continue;
491            }
492            if same_height.len() < min_to_parallelize {
493                self.dirty_nodes = nodes[i - same_height.len()..].iter().copied().collect();
494                self.sync_serial(hasher);
495                return;
496            }
497            self.update_node_digests(hasher, &same_height, current_height);
498            same_height.clear();
499            current_height += 1;
500            same_height.push(*pos);
501        }
502
503        if same_height.len() < min_to_parallelize {
504            self.dirty_nodes = nodes[nodes.len() - same_height.len()..]
505                .iter()
506                .copied()
507                .collect();
508            self.sync_serial(hasher);
509            return;
510        }
511
512        self.update_node_digests(hasher, &same_height, current_height);
513    }
514
515    /// Update digests of the given set of nodes of equal height in the MMR. Since they are all at
516    /// the same height, this can be done in parallel without synchronization.
517    ///
518    /// # Warning
519    ///
520    /// Assumes `self.pool` is non-None and panics otherwise.
521    fn update_node_digests(
522        &mut self,
523        hasher: &mut impl Hasher<H>,
524        same_height: &[u64],
525        height: u32,
526    ) {
527        let two_h = 1 << height;
528        let pool = self.thread_pool.as_ref().unwrap().clone();
529        pool.install(|| {
530            let computed_digests: Vec<(usize, H::Digest)> = same_height
531                .par_iter()
532                .map_init(
533                    || hasher.fork(),
534                    |hasher, &pos| {
535                        let left = pos - two_h;
536                        let right = pos - 1;
537                        let digest = hasher.node_digest(
538                            pos,
539                            self.get_node_unchecked(left),
540                            self.get_node_unchecked(right),
541                        );
542                        let index = self.pos_to_index(pos);
543                        (index, digest)
544                    },
545                )
546                .collect();
547
548            for (index, digest) in computed_digests {
549                self.nodes[index] = digest;
550            }
551        });
552    }
553
554    /// Computes the root of the MMR.
555    ///
556    /// # Warning
557    ///
558    /// Panics if there are unprocessed batch updates.
559    pub fn root(&self, hasher: &mut impl Hasher<H>) -> H::Digest {
560        assert!(
561            self.dirty_nodes.is_empty(),
562            "dirty nodes must be processed before computing the root"
563        );
564        let peaks = self
565            .peak_iterator()
566            .map(|(peak_pos, _)| self.get_node_unchecked(peak_pos));
567        let size = self.size();
568        hasher.root(size, peaks)
569    }
570
571    /// Return an inclusion proof for the specified element. Returns ElementPruned error if some
572    /// element needed to generate the proof has been pruned.
573    ///
574    /// # Warning
575    ///
576    /// Panics if there are unprocessed batch updates.
577    pub async fn proof(&self, element_pos: u64) -> Result<Proof<H::Digest>, Error> {
578        self.range_proof(element_pos, element_pos).await
579    }
580
581    /// Return an inclusion proof for the specified range of elements, inclusive of both endpoints.
582    /// Returns ElementPruned error if some element needed to generate the proof has been pruned.
583    ///
584    /// # Warning
585    ///
586    /// Panics if there are unprocessed batch updates.
587    pub async fn range_proof(
588        &self,
589        start_element_pos: u64,
590        end_element_pos: u64,
591    ) -> Result<Proof<H::Digest>, Error> {
592        if start_element_pos < self.pruned_to_pos {
593            return Err(ElementPruned(start_element_pos));
594        }
595        assert!(
596            self.dirty_nodes.is_empty(),
597            "dirty nodes must be processed before computing proofs"
598        );
599        Proof::<H::Digest>::range_proof(self, start_element_pos, end_element_pos).await
600    }
601
602    /// Prune all nodes and pin the O(log2(n)) number of them required for proof generation going
603    /// forward.
604    ///
605    /// # Warning
606    ///
607    /// Panics if there are unprocessed batch updates.
608    pub fn prune_all(&mut self) {
609        if !self.nodes.is_empty() {
610            self.prune_to_pos(self.index_to_pos(self.nodes.len()));
611        }
612    }
613
614    /// Prune all nodes up to but not including the given position, and pin the O(log2(n)) number of
615    /// them required for proof generation.
616    ///
617    /// # Warning
618    ///
619    /// Panics if there are unprocessed batch updates.
620    pub fn prune_to_pos(&mut self, pos: u64) {
621        assert!(
622            self.dirty_nodes.is_empty(),
623            "dirty nodes must be processed before pruning"
624        );
625        // Recompute the set of older nodes to retain.
626        self.pinned_nodes = self.nodes_to_pin(pos);
627        let retained_nodes = self.pos_to_index(pos);
628        self.nodes.drain(0..retained_nodes);
629        self.pruned_to_pos = pos;
630    }
631
632    /// Get the nodes (position + digest) that need to be pinned (those required for proof
633    /// generation) in this MMR when pruned to position `prune_pos`.
634    pub(super) fn nodes_to_pin(&self, prune_pos: u64) -> HashMap<u64, H::Digest> {
635        Proof::<H::Digest>::nodes_to_pin(prune_pos)
636            .map(|pos| (pos, *self.get_node_unchecked(pos)))
637            .collect()
638    }
639
640    /// Get the digests of nodes that need to be pinned (those required for proof generation) in
641    /// this MMR when pruned to position `prune_pos`.
642    pub(super) fn node_digests_to_pin(&self, start_pos: u64) -> Vec<H::Digest> {
643        Proof::<H::Digest>::nodes_to_pin(start_pos)
644            .map(|pos| *self.get_node_unchecked(pos))
645            .collect()
646    }
647
648    /// Utility used by stores that build on the mem MMR to pin extra nodes if needed. It's up to
649    /// the caller to ensure that this set of pinned nodes is valid for their use case.
650    pub(super) fn add_pinned_nodes(&mut self, pinned_nodes: HashMap<u64, H::Digest>) {
651        for (pos, node) in pinned_nodes.into_iter() {
652            self.pinned_nodes.insert(pos, node);
653        }
654    }
655
656    /// A lightweight cloning operation that "clones" only the fully pruned state of this MMR. The
657    /// output is exactly the same as the result of mmr.prune_all(), only you get a copy without
658    /// mutating the original, and the thread pool if any is not cloned.
659    ///
660    /// Runtime is Log_2(n) in the number of elements even if the original MMR is never pruned.
661    ///
662    /// # Warning
663    ///
664    /// Panics if there are unprocessed batch updates.
665    pub fn clone_pruned(&self) -> Self {
666        if self.size() == 0 {
667            return Self::new();
668        }
669        assert!(
670            self.dirty_nodes.is_empty(),
671            "dirty nodes must be processed before cloning"
672        );
673
674        // Create the "old_nodes" of the MMR in the fully pruned state.
675        let old_nodes = self.node_digests_to_pin(self.size());
676
677        Self::init(Config {
678            nodes: vec![],
679            pruned_to_pos: self.size(),
680            pinned_nodes: old_nodes,
681            pool: None,
682        })
683    }
684}
685
686#[cfg(test)]
687mod tests {
688    use super::*;
689    use crate::mmr::{
690        hasher::Standard,
691        iterator::leaf_num_to_pos,
692        tests::{build_and_check_test_roots_mmr, build_batched_and_check_test_roots, ROOTS},
693    };
694    use commonware_cryptography::Sha256;
695    use commonware_runtime::{create_pool, deterministic, tokio, Runner};
696    use commonware_utils::hex;
697
698    /// Test empty MMR behavior.
699    #[test]
700    fn test_mem_mmr_empty() {
701        let executor = deterministic::Runner::default();
702        executor.start(|_| async move {
703            let mut hasher: Standard<Sha256> = Standard::new();
704            let mut mmr = Mmr::new();
705            assert_eq!(
706                mmr.peak_iterator().next(),
707                None,
708                "empty iterator should have no peaks"
709            );
710            assert_eq!(mmr.size(), 0);
711            assert_eq!(mmr.leaves(), 0);
712            assert_eq!(mmr.last_leaf_pos(), None);
713            assert_eq!(mmr.oldest_retained_pos(), None);
714            assert_eq!(mmr.get_node(0), None);
715            assert!(matches!(mmr.pop(), Err(Empty)));
716            mmr.prune_all();
717            assert_eq!(mmr.size(), 0, "prune_all on empty MMR should do nothing");
718
719            assert_eq!(mmr.root(&mut hasher), hasher.root(0, [].iter()));
720
721            let clone = mmr.clone_pruned();
722            assert_eq!(clone.size(), 0);
723        });
724    }
725
726    /// Test MMR building by consecutively adding 11 equal elements to a new MMR, producing the
727    /// structure in the example documented at the top of the mmr crate's mod.rs file with 19 nodes
728    /// and 3 peaks.
729    #[test]
730    fn test_mem_mmr_add_eleven_values() {
731        let executor = deterministic::Runner::default();
732        executor.start(|_| async move {
733            let mut mmr = Mmr::new();
734            let element = <Sha256 as CHasher>::Digest::from(*b"01234567012345670123456701234567");
735            let mut leaves: Vec<u64> = Vec::new();
736            let mut hasher: Standard<Sha256> = Standard::new();
737            for _ in 0..11 {
738                leaves.push(mmr.add(&mut hasher, &element));
739                let peaks: Vec<(u64, u32)> = mmr.peak_iterator().collect();
740                assert_ne!(peaks.len(), 0);
741                assert!(peaks.len() <= mmr.size() as usize);
742                let nodes_needing_parents = nodes_needing_parents(mmr.peak_iterator());
743                assert!(nodes_needing_parents.len() <= peaks.len());
744            }
745            assert_eq!(mmr.oldest_retained_pos().unwrap(), 0);
746            assert_eq!(mmr.size(), 19, "mmr not of expected size");
747            assert_eq!(
748                leaves,
749                vec![0, 1, 3, 4, 7, 8, 10, 11, 15, 16, 18],
750                "mmr leaf positions not as expected"
751            );
752            let peaks: Vec<(u64, u32)> = mmr.peak_iterator().collect();
753            assert_eq!(
754                peaks,
755                vec![(14, 3), (17, 1), (18, 0)],
756                "mmr peaks not as expected"
757            );
758
759            // Test nodes_needing_parents on the final MMR. Since there's a height gap between the
760            // highest peak (14) and the next, only the lower two peaks (17, 18) should be returned.
761            let peaks_needing_parents = nodes_needing_parents(mmr.peak_iterator());
762            assert_eq!(
763                peaks_needing_parents,
764                vec![17, 18],
765                "mmr nodes needing parents not as expected"
766            );
767
768            // verify leaf digests
769            for leaf in leaves.iter().by_ref() {
770                let digest = hasher.leaf_digest(*leaf, &element);
771                assert_eq!(mmr.get_node(*leaf).unwrap(), digest);
772            }
773
774            // verify height=1 node digests
775            let digest2 = hasher.node_digest(2, &mmr.nodes[0], &mmr.nodes[1]);
776            assert_eq!(mmr.nodes[2], digest2);
777            let digest5 = hasher.node_digest(5, &mmr.nodes[3], &mmr.nodes[4]);
778            assert_eq!(mmr.nodes[5], digest5);
779            let digest9 = hasher.node_digest(9, &mmr.nodes[7], &mmr.nodes[8]);
780            assert_eq!(mmr.nodes[9], digest9);
781            let digest12 = hasher.node_digest(12, &mmr.nodes[10], &mmr.nodes[11]);
782            assert_eq!(mmr.nodes[12], digest12);
783            let digest17 = hasher.node_digest(17, &mmr.nodes[15], &mmr.nodes[16]);
784            assert_eq!(mmr.nodes[17], digest17);
785
786            // verify height=2 node digests
787            let digest6 = hasher.node_digest(6, &mmr.nodes[2], &mmr.nodes[5]);
788            assert_eq!(mmr.nodes[6], digest6);
789            let digest13 = hasher.node_digest(13, &mmr.nodes[9], &mmr.nodes[12]);
790            assert_eq!(mmr.nodes[13], digest13);
791            let digest17 = hasher.node_digest(17, &mmr.nodes[15], &mmr.nodes[16]);
792            assert_eq!(mmr.nodes[17], digest17);
793
794            // verify topmost digest
795            let digest14 = hasher.node_digest(14, &mmr.nodes[6], &mmr.nodes[13]);
796            assert_eq!(mmr.nodes[14], digest14);
797
798            // verify root
799            let root = mmr.root(&mut hasher);
800            let peak_digests = [digest14, digest17, mmr.nodes[18]];
801            let expected_root = hasher.root(19, peak_digests.iter());
802            assert_eq!(root, expected_root, "incorrect root");
803
804            // pruning tests
805            mmr.prune_to_pos(14); // prune up to the tallest peak
806            assert_eq!(mmr.oldest_retained_pos().unwrap(), 14);
807
808            // After pruning up to a peak, we shouldn't be able to prove any elements before it.
809            assert!(matches!(mmr.proof(0).await, Err(ElementPruned(_))));
810            assert!(matches!(mmr.proof(11).await, Err(ElementPruned(_))));
811            // We should still be able to prove any leaf following this peak, the first of which is
812            // at position 15.
813            assert!(mmr.proof(15).await.is_ok());
814
815            let root_after_prune = mmr.root(&mut hasher);
816            assert_eq!(root, root_after_prune, "root changed after pruning");
817            assert!(
818                mmr.proof(11).await.is_err(),
819                "attempts to prove elements at or before the oldest retained should fail"
820            );
821            assert!(
822                mmr.range_proof(10, 15).await.is_err(),
823                "attempts to range_prove elements at or before the oldest retained should fail"
824            );
825            assert!(
826                mmr.range_proof(15, mmr.last_leaf_pos().unwrap())
827                    .await
828                    .is_ok(),
829                "attempts to range_prove over elements following oldest retained should succeed"
830            );
831
832            // Test that we can initialize a new MMR from another's elements.
833            let oldest_pos = mmr.oldest_retained_pos().unwrap();
834            let digests = mmr.node_digests_to_pin(oldest_pos);
835            let mmr_copy = Mmr::init(Config {
836                nodes: mmr.nodes.iter().copied().collect(),
837                pruned_to_pos: oldest_pos,
838                pinned_nodes: digests,
839                pool: None,
840            });
841            assert_eq!(mmr_copy.size(), 19);
842            assert_eq!(mmr_copy.leaves(), mmr.leaves());
843            assert_eq!(mmr_copy.last_leaf_pos(), mmr.last_leaf_pos());
844            assert_eq!(mmr_copy.oldest_retained_pos(), mmr.oldest_retained_pos());
845            assert_eq!(mmr_copy.root(&mut hasher), root);
846
847            // Test that clone_pruned produces a valid copy of the MMR as if it had been cloned
848            // after being fully pruned.
849            mmr.prune_to_pos(17); // prune up to the second peak
850            let clone = mmr.clone_pruned();
851            assert_eq!(clone.oldest_retained_pos(), None);
852            assert_eq!(clone.pruned_to_pos(), clone.size());
853            mmr.prune_all();
854            assert_eq!(mmr.oldest_retained_pos(), None);
855            assert_eq!(mmr.pruned_to_pos(), mmr.size());
856            assert_eq!(mmr.size(), clone.size());
857            assert_eq!(mmr.root(&mut hasher), clone.root(&mut hasher));
858        });
859    }
860
861    /// Test that pruning all nodes never breaks adding new nodes.
862    #[test]
863    fn test_mem_mmr_prune_all() {
864        let executor = deterministic::Runner::default();
865        executor.start(|_| async move {
866            let mut mmr = Mmr::new();
867            let element = <Sha256 as CHasher>::Digest::from(*b"01234567012345670123456701234567");
868            let mut hasher: Standard<Sha256> = Standard::new();
869            for _ in 0..1000 {
870                mmr.prune_all();
871                mmr.add(&mut hasher, &element);
872            }
873        });
874    }
875
876    /// Test that the MMR validity check works as expected.
877    #[test]
878    fn test_mem_mmr_validity() {
879        let executor = deterministic::Runner::default();
880        executor.start(|_| async move {
881            let mut mmr = Mmr::new();
882            let element = <Sha256 as CHasher>::Digest::from(*b"01234567012345670123456701234567");
883            let mut hasher: Standard<Sha256> = Standard::new();
884            for _ in 0..1001 {
885                assert!(
886                    PeakIterator::check_validity(mmr.size()),
887                    "mmr of size {} should be valid",
888                    mmr.size()
889                );
890                let old_size = mmr.size();
891                mmr.add(&mut hasher, &element);
892                for size in old_size + 1..mmr.size() {
893                    assert!(
894                        !PeakIterator::check_validity(size),
895                        "mmr of size {size} should be invalid",
896                    );
897                }
898            }
899        });
900    }
901
902    /// Test that the MMR root computation remains stable by comparing against previously computed
903    /// roots.
904    #[test]
905    fn test_mem_mmr_root_stability() {
906        let executor = deterministic::Runner::default();
907        executor.start(|_| async move {
908            // Test root stability under different MMR building methods.
909            let mut mmr = Mmr::new();
910            build_and_check_test_roots_mmr(&mut mmr).await;
911
912            let mut mmr = Mmr::new();
913            build_batched_and_check_test_roots(&mut mmr).await;
914        });
915    }
916
917    /// Test root stability using the parallel builder implementation. This requires we use the
918    /// tokio runtime since the deterministic runtime would block due to being single-threaded.
919    #[test]
920    fn test_mem_mmr_root_stability_parallel() {
921        let executor = tokio::Runner::default();
922        executor.start(|context| async move {
923            let pool = commonware_runtime::create_pool(context, 4).unwrap();
924
925            let mut mmr = Mmr::init(Config {
926                nodes: vec![],
927                pruned_to_pos: 0,
928                pinned_nodes: vec![],
929                pool: Some(pool),
930            });
931            build_batched_and_check_test_roots(&mut mmr).await;
932        });
933    }
934
935    /// Build the MMR corresponding to the stability test while pruning after each add, and confirm
936    /// the static roots match that from the root computation.
937    #[test]
938    fn test_mem_mmr_root_stability_while_pruning() {
939        let executor = deterministic::Runner::default();
940        executor.start(|_| async move {
941            let mut hasher: Standard<Sha256> = Standard::new();
942            let mut mmr = Mmr::new();
943            for i in 0u64..199 {
944                let root = mmr.root(&mut hasher);
945                let expected_root = ROOTS[i as usize];
946                assert_eq!(hex(&root), expected_root, "at: {i}");
947                hasher.inner().update(&i.to_be_bytes());
948                let element = hasher.inner().finalize();
949                mmr.add(&mut hasher, &element);
950                mmr.prune_all();
951            }
952        });
953    }
954
955    fn compute_big_mmr(hasher: &mut impl Hasher<Sha256>, mmr: &mut Mmr<Sha256>) -> Vec<u64> {
956        let mut leaves = Vec::new();
957        let mut c_hasher = Sha256::default();
958        for i in 0u64..199 {
959            c_hasher.update(&i.to_be_bytes());
960            let element = c_hasher.finalize();
961            leaves.push(mmr.add(hasher, &element));
962        }
963        mmr.sync(hasher);
964
965        leaves
966    }
967
968    #[test]
969    fn test_mem_mmr_pop() {
970        let executor = deterministic::Runner::default();
971        executor.start(|_| async move {
972            let mut hasher: Standard<Sha256> = Standard::new();
973            let mut mmr = Mmr::new();
974            compute_big_mmr(&mut hasher, &mut mmr);
975            let root = mmr.root(&mut hasher);
976            let expected_root = ROOTS[199];
977            assert_eq!(hex(&root), expected_root);
978
979            // Pop off one node at a time until empty, confirming the root is still is as expected.
980            for i in (0..199u64).rev() {
981                assert!(mmr.pop().is_ok());
982                let root = mmr.root(&mut hasher);
983                let expected_root = ROOTS[i as usize];
984                assert_eq!(hex(&root), expected_root);
985            }
986
987            assert!(
988                matches!(mmr.pop().unwrap_err(), Empty),
989                "pop on empty MMR should fail"
990            );
991
992            // Test that we can pop all elements up to and including the oldest retained leaf.
993            for i in 0u64..199 {
994                hasher.inner().update(&i.to_be_bytes());
995                let element = hasher.inner().finalize();
996                mmr.add(&mut hasher, &element);
997            }
998
999            let leaf_pos = leaf_num_to_pos(100);
1000            mmr.prune_to_pos(leaf_pos);
1001            while mmr.size() > leaf_pos {
1002                assert!(mmr.pop().is_ok());
1003            }
1004            assert_eq!(hex(&mmr.root(&mut hasher)), ROOTS[100]);
1005            assert!(matches!(mmr.pop().unwrap_err(), ElementPruned(_)));
1006            assert_eq!(mmr.oldest_retained_pos(), None);
1007        });
1008    }
1009
1010    #[test]
1011    fn test_mem_mmr_update_leaf() {
1012        let mut hasher: Standard<Sha256> = Standard::new();
1013        let element = <Sha256 as CHasher>::Digest::from(*b"01234567012345670123456701234567");
1014        let executor = deterministic::Runner::default();
1015        executor.start(|_| async move {
1016            let mut mmr = Mmr::new();
1017            compute_big_mmr(&mut hasher, &mut mmr);
1018            let leaves = compute_big_mmr(&mut hasher, &mut mmr);
1019            let root = mmr.root(&mut hasher);
1020
1021            // For a few leaves, update the leaf and ensure the root changes, and the root reverts
1022            // to its previous state then we update the leaf to its original value.
1023            for leaf in [0usize, 1, 10, 50, 100, 150, 197, 198] {
1024                // Change the leaf.
1025                mmr.update_leaf(&mut hasher, leaves[leaf], &element);
1026                let updated_root = mmr.root(&mut hasher);
1027                assert!(root != updated_root);
1028
1029                // Restore the leaf to its original value, ensure the root is as before.
1030                hasher.inner().update(&leaf.to_be_bytes());
1031                let element = hasher.inner().finalize();
1032                mmr.update_leaf(&mut hasher, leaves[leaf], &element);
1033                let restored_root = mmr.root(&mut hasher);
1034                assert_eq!(root, restored_root);
1035            }
1036
1037            // Confirm the tree has all the hashes necessary to update any element after pruning.
1038            mmr.prune_to_pos(leaves[150]);
1039            for &leaf_pos in &leaves[150..=190] {
1040                mmr.prune_to_pos(leaf_pos);
1041                mmr.update_leaf(&mut hasher, leaf_pos, &element);
1042            }
1043        });
1044    }
1045
1046    #[test]
1047    #[should_panic(expected = "pos was not for a leaf")]
1048    fn test_mem_mmr_update_leaf_panic_invalid() {
1049        let mut hasher: Standard<Sha256> = Standard::new();
1050        let element = <Sha256 as CHasher>::Digest::from(*b"01234567012345670123456701234567");
1051
1052        let executor = deterministic::Runner::default();
1053        executor.start(|_| async move {
1054            let mut mmr = Mmr::new();
1055            compute_big_mmr(&mut hasher, &mut mmr);
1056            let not_a_leaf_pos = 2;
1057            mmr.update_leaf(&mut hasher, not_a_leaf_pos, &element);
1058        });
1059    }
1060
1061    #[test]
1062    #[should_panic(expected = "element pruned")]
1063    fn test_mem_mmr_update_leaf_panic_pruned() {
1064        let mut hasher: Standard<Sha256> = Standard::new();
1065        let element = <Sha256 as CHasher>::Digest::from(*b"01234567012345670123456701234567");
1066
1067        let executor = deterministic::Runner::default();
1068        executor.start(|_| async move {
1069            let mut mmr = Mmr::new();
1070            compute_big_mmr(&mut hasher, &mut mmr);
1071            mmr.prune_all();
1072            mmr.update_leaf(&mut hasher, 0, &element);
1073        });
1074    }
1075
1076    #[test]
1077    fn test_mem_mmr_batch_update_leaf() {
1078        let mut hasher: Standard<Sha256> = Standard::new();
1079        let executor = deterministic::Runner::default();
1080        executor.start(|_| async move {
1081            let mut mmr = Mmr::new();
1082            let leaves = compute_big_mmr(&mut hasher, &mut mmr);
1083            do_batch_update(&mut hasher, &mut mmr, &leaves);
1084        });
1085    }
1086
1087    #[test]
1088    /// Same test as above only using a thread pool to trigger parallelization. This requires we use
1089    /// tokio runtime instead of the deterministic one.
1090    fn test_mem_mmr_batch_parallel_update_leaf() {
1091        let mut hasher: Standard<Sha256> = Standard::new();
1092        let executor = tokio::Runner::default();
1093        executor.start(|ctx| async move {
1094            let pool = create_pool(ctx, 4).unwrap();
1095            let mut mmr = Mmr::init(Config {
1096                nodes: Vec::new(),
1097                pruned_to_pos: 0,
1098                pinned_nodes: Vec::new(),
1099                pool: Some(pool),
1100            });
1101            let leaves = compute_big_mmr(&mut hasher, &mut mmr);
1102            do_batch_update(&mut hasher, &mut mmr, &leaves);
1103        });
1104    }
1105
1106    fn do_batch_update(hasher: &mut Standard<Sha256>, mmr: &mut Mmr<Sha256>, leaves: &[u64]) {
1107        let element = <Sha256 as CHasher>::Digest::from(*b"01234567012345670123456701234567");
1108        let root = mmr.root(hasher);
1109
1110        // Change a handful of leaves using a batch update.
1111        let mut updates = Vec::new();
1112        for leaf in [0usize, 1, 10, 50, 100, 150, 197, 198] {
1113            updates.push((leaves[leaf], &element));
1114        }
1115        mmr.update_leaf_batched(hasher, &updates);
1116
1117        mmr.sync(hasher);
1118        let updated_root = mmr.root(hasher);
1119        assert_eq!(
1120            "af3acad6aad59c1a880de643b1200a0962a95d06c087ebf677f29eb93fc359a4",
1121            hex(&updated_root)
1122        );
1123
1124        // Batch-restore the changed leaves to their original values.
1125        let mut updates = Vec::new();
1126        for leaf in [0usize, 1, 10, 50, 100, 150, 197, 198] {
1127            hasher.inner().update(&leaf.to_be_bytes());
1128            let element = hasher.inner().finalize();
1129            updates.push((leaves[leaf], element));
1130        }
1131        mmr.update_leaf_batched(hasher, &updates);
1132
1133        mmr.sync(hasher);
1134        let restored_root = mmr.root(hasher);
1135        assert_eq!(root, restored_root);
1136    }
1137}