commonware_storage/mmr/
mem.rs

1//! A basic, no_std compatible MMR where all nodes are stored in-memory.
2
3use crate::mmr::{
4    hasher::Hasher,
5    iterator::{nodes_needing_parents, nodes_to_pin, PathIterator, PeakIterator},
6    proof,
7    Error::{self, *},
8    Location, Position, Proof,
9};
10use alloc::{
11    collections::{BTreeMap, BTreeSet, VecDeque},
12    vec::Vec,
13};
14use commonware_cryptography::Digest;
15use core::{mem, ops::Range};
16cfg_if::cfg_if! {
17    if #[cfg(feature = "std")] {
18        use commonware_parallel::ThreadPool;
19        use rayon::prelude::*;
20    } else {
21        struct ThreadPool;
22    }
23}
24
25/// Minimum number of digest computations required during batch updates to trigger parallelization.
26#[cfg(feature = "std")]
27const MIN_TO_PARALLELIZE: usize = 20;
28
29/// An MMR whose root digest has not been computed.
30pub type DirtyMmr<D> = Mmr<D, Dirty>;
31
32/// An MMR whose root digest has been computed.
33pub type CleanMmr<D> = Mmr<D, Clean<D>>;
34
35/// Sealed trait for MMR state types.
36mod private {
37    pub trait Sealed {}
38}
39
40/// Trait for valid MMR state types.
41pub trait State<D: Digest>: private::Sealed + Sized + Send + Sync {
42    /// Add the given leaf digest to the MMR, returning its position.
43    fn add_leaf_digest<H: Hasher<D>>(mmr: &mut Mmr<D, Self>, hasher: &mut H, digest: D)
44        -> Position;
45}
46
47/// Marker type for a MMR whose root digest has been computed.
48#[derive(Clone, Copy, Debug)]
49pub struct Clean<D: Digest> {
50    /// The root digest of the MMR.
51    pub root: D,
52}
53
54impl<D: Digest> private::Sealed for Clean<D> {}
55impl<D: Digest> State<D> for Clean<D> {
56    fn add_leaf_digest<H: Hasher<D>>(mmr: &mut CleanMmr<D>, hasher: &mut H, digest: D) -> Position {
57        mmr.add_leaf_digest(hasher, digest)
58    }
59}
60
61/// Marker type for a dirty MMR (root digest not computed).
62#[derive(Clone, Debug, Default)]
63pub struct Dirty {
64    /// Non-leaf nodes that need to have their digests recomputed due to a batched update operation.
65    ///
66    /// This is a set of tuples of the form (node_pos, height).
67    dirty_nodes: BTreeSet<(Position, u32)>,
68}
69
70impl private::Sealed for Dirty {}
71impl<D: Digest> State<D> for Dirty {
72    fn add_leaf_digest<H: Hasher<D>>(
73        mmr: &mut DirtyMmr<D>,
74        _hasher: &mut H,
75        digest: D,
76    ) -> Position {
77        mmr.add_leaf_digest(digest)
78    }
79}
80
81/// Configuration for initializing an [Mmr].
82pub struct Config<D: Digest> {
83    /// The retained nodes of the MMR.
84    pub nodes: Vec<D>,
85
86    /// The highest position for which this MMR has been pruned, or 0 if this MMR has never been
87    /// pruned.
88    pub pruned_to_pos: Position,
89
90    /// The pinned nodes of the MMR, in the order expected by `nodes_to_pin`.
91    pub pinned_nodes: Vec<D>,
92}
93
94/// A basic MMR where all nodes are stored in-memory.
95///
96/// # Terminology
97///
98/// Nodes in this structure are either _retained_, _pruned_, or _pinned_. Retained nodes are nodes
99/// that have not yet been pruned, and have digests stored explicitly within the tree structure.
100/// Pruned nodes are those whose positions precede that of the _oldest retained_ node, for which no
101/// digests are maintained. Pinned nodes are nodes that would otherwise be pruned based on their
102/// position, but whose digests remain required for proof generation. The digests for pinned nodes
103/// are stored in an auxiliary map, and are at most O(log2(n)) in number.
104///
105/// # Max Capacity
106///
107/// The maximum number of elements that can be stored is usize::MAX (u32::MAX on 32-bit
108/// architectures).
109///
110/// # Type States
111///
112/// The MMR uses the type-state pattern to enforce at compile-time whether the MMR has pending
113/// updates that must be merkleized before computing proofs. [CleanMmr] represents a clean
114/// MMR whose root digest has been computed. [DirtyMmr] represents a dirty MMR whose root
115/// digest needs to be computed. A dirty MMR can be converted into a clean MMR by calling
116/// [DirtyMmr::merkleize].
117#[derive(Clone, Debug)]
118pub struct Mmr<D: Digest, S: State<D> = Dirty> {
119    /// The nodes of the MMR, laid out according to a post-order traversal of the MMR trees,
120    /// starting from the from tallest tree to shortest.
121    nodes: VecDeque<D>,
122
123    /// The highest position for which this MMR has been pruned, or 0 if this MMR has never been
124    /// pruned.
125    pruned_to_pos: Position,
126
127    /// The auxiliary map from node position to the digest of any pinned node.
128    pinned_nodes: BTreeMap<Position, D>,
129
130    /// Type-state for the MMR.
131    state: S,
132}
133
134impl<D: Digest> Default for DirtyMmr<D> {
135    fn default() -> Self {
136        Self::new()
137    }
138}
139
140impl<D: Digest> From<CleanMmr<D>> for DirtyMmr<D> {
141    fn from(clean: CleanMmr<D>) -> Self {
142        DirtyMmr {
143            nodes: clean.nodes,
144            pruned_to_pos: clean.pruned_to_pos,
145            pinned_nodes: clean.pinned_nodes,
146            state: Dirty {
147                dirty_nodes: BTreeSet::new(),
148            },
149        }
150    }
151}
152
153impl<D: Digest, S: State<D>> Mmr<D, S> {
154    /// Return the total number of nodes in the MMR, irrespective of any pruning. The next added
155    /// element's position will have this value.
156    pub fn size(&self) -> Position {
157        Position::new(self.nodes.len() as u64 + *self.pruned_to_pos)
158    }
159
160    /// Return the total number of leaves in the MMR.
161    pub fn leaves(&self) -> Location {
162        Location::try_from(self.size()).expect("invalid mmr size")
163    }
164
165    /// Return the position of the last leaf in this MMR, or None if the MMR is empty.
166    pub fn last_leaf_pos(&self) -> Option<Position> {
167        if self.size() == 0 {
168            return None;
169        }
170
171        Some(PeakIterator::last_leaf_pos(self.size()))
172    }
173
174    /// The highest position for which this MMR has been pruned, or 0 if this MMR has never been
175    /// pruned.
176    pub const fn pruned_to_pos(&self) -> Position {
177        self.pruned_to_pos
178    }
179
180    /// Return the position of the oldest retained node in the MMR, not including those cached in
181    /// pinned_nodes.
182    pub fn oldest_retained_pos(&self) -> Option<Position> {
183        if self.pruned_to_pos == self.size() {
184            return None;
185        }
186
187        Some(self.pruned_to_pos)
188    }
189
190    /// Return a new iterator over the peaks of the MMR.
191    pub fn peak_iterator(&self) -> PeakIterator {
192        PeakIterator::new(self.size())
193    }
194
195    /// Return the position of the element given its index in the current nodes vector.
196    fn index_to_pos(&self, index: usize) -> Position {
197        self.pruned_to_pos + (index as u64)
198    }
199
200    /// Return the requested node if it is either retained or present in the pinned_nodes map, and
201    /// panic otherwise. Use `get_node` instead if you require a non-panicking getter.
202    ///
203    /// # Warning
204    ///
205    /// If the requested digest is for an unmerkleized node (only possible in the Dirty state) a
206    /// dummy digest will be returned.
207    ///
208    /// # Panics
209    ///
210    /// Panics if the requested node does not exist for any reason such as the node is pruned or
211    /// `pos` is out of bounds.
212    pub(crate) fn get_node_unchecked(&self, pos: Position) -> &D {
213        if pos < self.pruned_to_pos {
214            return self
215                .pinned_nodes
216                .get(&pos)
217                .expect("requested node is pruned and not pinned");
218        }
219
220        &self.nodes[self.pos_to_index(pos)]
221    }
222
223    /// Return the index of the element in the current nodes vector given its position in the MMR.
224    ///
225    /// # Panics
226    ///
227    /// Panics if `pos` precedes the oldest retained position.
228    fn pos_to_index(&self, pos: Position) -> usize {
229        assert!(
230            pos >= self.pruned_to_pos,
231            "pos precedes oldest retained position"
232        );
233
234        *pos.checked_sub(*self.pruned_to_pos).unwrap() as usize
235    }
236
237    /// Utility used by stores that build on the mem MMR to pin extra nodes if needed. It's up to
238    /// the caller to ensure that this set of pinned nodes is valid for their use case.
239    #[cfg(any(feature = "std", test))]
240    pub(crate) fn add_pinned_nodes(&mut self, pinned_nodes: BTreeMap<Position, D>) {
241        for (pos, node) in pinned_nodes.into_iter() {
242            self.pinned_nodes.insert(pos, node);
243        }
244    }
245
246    /// Add `element` to the MMR and return its position.
247    /// The element can be an arbitrary byte slice, and need not be converted to a digest first.
248    pub fn add<H: Hasher<D>>(&mut self, hasher: &mut H, element: &[u8]) -> Position {
249        let digest = hasher.leaf_digest(self.size(), element);
250        S::add_leaf_digest(self, hasher, digest)
251    }
252}
253
254/// Implementation for Clean MMR state.
255impl<D: Digest> CleanMmr<D> {
256    /// Return an [Mmr] initialized with the given `config`.
257    ///
258    /// # Errors
259    ///
260    /// Returns [Error::InvalidPinnedNodes] if the number of pinned nodes doesn't match the expected
261    /// count for `config.pruned_to_pos`.
262    ///
263    /// Returns [Error::InvalidSize] if the MMR size is invalid.
264    pub fn init(config: Config<D>, hasher: &mut impl Hasher<D>) -> Result<Self, Error> {
265        // Validate that the total size is valid
266        let Some(size) = config.pruned_to_pos.checked_add(config.nodes.len() as u64) else {
267            return Err(Error::InvalidSize(u64::MAX));
268        };
269        if !size.is_mmr_size() {
270            return Err(Error::InvalidSize(*size));
271        }
272
273        // Validate and populate pinned nodes
274        let mut pinned_nodes = BTreeMap::new();
275        let mut expected_pinned_nodes = 0;
276        for (i, pos) in nodes_to_pin(config.pruned_to_pos).enumerate() {
277            expected_pinned_nodes += 1;
278            if i >= config.pinned_nodes.len() {
279                return Err(Error::InvalidPinnedNodes);
280            }
281            pinned_nodes.insert(pos, config.pinned_nodes[i]);
282        }
283
284        // Check for too many pinned nodes
285        if config.pinned_nodes.len() != expected_pinned_nodes {
286            return Err(Error::InvalidPinnedNodes);
287        }
288
289        let mmr = Mmr {
290            nodes: VecDeque::from(config.nodes),
291            pruned_to_pos: config.pruned_to_pos,
292            pinned_nodes,
293            state: Dirty::default(),
294        };
295        Ok(mmr.merkleize(hasher, None))
296    }
297
298    /// Create a new, empty MMR in the Clean state.
299    pub fn new(hasher: &mut impl Hasher<D>) -> Self {
300        let mmr: DirtyMmr<D> = Default::default();
301        mmr.merkleize(hasher, None)
302    }
303
304    /// Re-initialize the MMR with the given nodes, pruned_to_pos, and pinned_nodes.
305    pub fn from_components(
306        hasher: &mut impl Hasher<D>,
307        nodes: Vec<D>,
308        pruned_to_pos: Position,
309        pinned_nodes: Vec<D>,
310    ) -> Self {
311        DirtyMmr::from_components(nodes, pruned_to_pos, pinned_nodes).merkleize(hasher, None)
312    }
313
314    /// Return the requested node or None if it is not stored in the MMR.
315    pub fn get_node(&self, pos: Position) -> Option<D> {
316        if pos < self.pruned_to_pos {
317            return self.pinned_nodes.get(&pos).copied();
318        }
319
320        self.nodes.get(self.pos_to_index(pos)).copied()
321    }
322
323    /// Add a leaf's `digest` to the MMR, generating the necessary parent nodes to maintain the
324    /// MMR's structure.
325    pub(super) fn add_leaf_digest(&mut self, hasher: &mut impl Hasher<D>, digest: D) -> Position {
326        let mut dirty_mmr = mem::replace(self, Self::new(hasher)).into_dirty();
327        let leaf_pos = dirty_mmr.add_leaf_digest(digest);
328        *self = dirty_mmr.merkleize(hasher, None);
329        leaf_pos
330    }
331
332    /// Pop the most recent leaf element out of the MMR if it exists, returning Empty or
333    /// ElementPruned errors otherwise.
334    pub fn pop(&mut self, hasher: &mut impl Hasher<D>) -> Result<Position, Error> {
335        let mut dirty_mmr = mem::replace(self, Self::new(hasher)).into_dirty();
336        let result = dirty_mmr.pop();
337        *self = dirty_mmr.merkleize(hasher, None);
338        result
339    }
340
341    /// Get the nodes (position + digest) that need to be pinned (those required for proof
342    /// generation) in this MMR when pruned to position `prune_pos`.
343    pub(crate) fn nodes_to_pin(&self, prune_pos: Position) -> BTreeMap<Position, D> {
344        nodes_to_pin(prune_pos)
345            .map(|pos| (pos, *self.get_node_unchecked(pos)))
346            .collect()
347    }
348
349    /// Prune all nodes up to but not including the given position, and pin the O(log2(n)) number of
350    /// them required for proof generation.
351    pub fn prune_to_pos(&mut self, pos: Position) {
352        // Recompute the set of older nodes to retain.
353        self.pinned_nodes = self.nodes_to_pin(pos);
354        let retained_nodes = self.pos_to_index(pos);
355        self.nodes.drain(0..retained_nodes);
356        self.pruned_to_pos = pos;
357    }
358
359    /// Prune all nodes and pin the O(log2(n)) number of them required for proof generation going
360    /// forward.
361    pub fn prune_all(&mut self) {
362        if !self.nodes.is_empty() {
363            let pos = self.index_to_pos(self.nodes.len());
364            self.prune_to_pos(pos);
365        }
366    }
367
368    /// Change the digest of any retained leaf. This is useful if you want to use the MMR
369    /// implementation as an updatable binary Merkle tree, and otherwise should be avoided.
370    ///
371    /// # Errors
372    ///
373    /// Returns [Error::ElementPruned] if the leaf has been pruned.
374    /// Returns [Error::LeafOutOfBounds] if `loc` is not an existing leaf.
375    /// Returns [Error::LocationOverflow] if `loc` > [crate::mmr::MAX_LOCATION].
376    ///
377    /// # Warning
378    ///
379    /// This method will change the root and invalidate any previous inclusion proofs.
380    /// Use of this method will prevent using this structure as a base mmr for grafting.
381    pub fn update_leaf(
382        &mut self,
383        hasher: &mut impl Hasher<D>,
384        loc: Location,
385        element: &[u8],
386    ) -> Result<(), Error> {
387        let mut dirty_mmr = mem::replace(self, Self::new(hasher)).into_dirty();
388        let result = dirty_mmr.update_leaf(hasher, loc, element);
389        *self = dirty_mmr.merkleize(hasher, None);
390        result
391    }
392
393    /// Convert this Clean MMR into a Dirty MMR without making any changes to it.
394    pub fn into_dirty(self) -> DirtyMmr<D> {
395        self.into()
396    }
397
398    /// Get the root digest of the MMR.
399    pub const fn root(&self) -> &D {
400        &self.state.root
401    }
402
403    /// Returns the root that would be produced by calling `root` on an empty MMR.
404    pub fn empty_mmr_root(hasher: &mut impl commonware_cryptography::Hasher<Digest = D>) -> D {
405        hasher.update(&0u64.to_be_bytes());
406        hasher.finalize()
407    }
408
409    /// Return an inclusion proof for the element at location `loc`.
410    ///
411    /// # Errors
412    ///
413    /// Returns [Error::LocationOverflow] if `loc` > [crate::mmr::MAX_LOCATION].
414    /// Returns [Error::ElementPruned] if some element needed to generate the proof has been pruned.
415    ///
416    /// # Panics
417    ///
418    /// Panics if `loc` is out of bounds.
419    pub fn proof(&self, loc: Location) -> Result<Proof<D>, Error> {
420        if !loc.is_valid() {
421            return Err(Error::LocationOverflow(loc));
422        }
423        // loc is valid so it won't overflow from + 1
424        self.range_proof(loc..loc + 1)
425    }
426
427    /// Return an inclusion proof for all elements within the provided `range` of locations.
428    ///
429    /// # Errors
430    ///
431    /// Returns [Error::Empty] if the range is empty.
432    /// Returns [Error::LocationOverflow] if any location in `range` exceeds [crate::mmr::MAX_LOCATION].
433    /// Returns [Error::ElementPruned] if some element needed to generate the proof has been pruned.
434    ///
435    /// # Panics
436    ///
437    /// Panics if the element range is out of bounds.
438    pub fn range_proof(&self, range: Range<Location>) -> Result<Proof<D>, Error> {
439        let leaves = self.leaves();
440        assert!(
441            range.start < leaves,
442            "range start {} >= leaf count {}",
443            range.start,
444            leaves
445        );
446        assert!(
447            range.end <= leaves,
448            "range end {} > leaf count {}",
449            range.end,
450            leaves
451        );
452
453        let size = self.size();
454        let positions = proof::nodes_required_for_range_proof(size, range)?;
455        let digests = positions
456            .into_iter()
457            .map(|pos| self.get_node(pos).ok_or(Error::ElementPruned(pos)))
458            .collect::<Result<Vec<_>, _>>()?;
459
460        Ok(Proof { size, digests })
461    }
462
463    /// Get the digests of nodes that need to be pinned (those required for proof generation) in
464    /// this MMR when pruned to position `prune_pos`.
465    #[cfg(test)]
466    pub(crate) fn node_digests_to_pin(&self, start_pos: Position) -> Vec<D> {
467        nodes_to_pin(start_pos)
468            .map(|pos| *self.get_node_unchecked(pos))
469            .collect()
470    }
471
472    /// Return the nodes this MMR currently has pinned. Pinned nodes are nodes that would otherwise
473    /// be pruned, but whose digests remain required for proof generation.
474    #[cfg(test)]
475    pub(super) fn pinned_nodes(&self) -> BTreeMap<Position, D> {
476        self.pinned_nodes.clone()
477    }
478}
479
480/// Implementation for Dirty MMR state.
481impl<D: Digest> DirtyMmr<D> {
482    /// Return a new (empty) `Mmr`.
483    pub fn new() -> Self {
484        Self {
485            nodes: VecDeque::new(),
486            pruned_to_pos: Position::new(0),
487            pinned_nodes: BTreeMap::new(),
488            state: Dirty::default(),
489        }
490    }
491
492    /// Re-initialize the MMR with the given nodes, pruned_to_pos, and pinned_nodes.
493    pub fn from_components(nodes: Vec<D>, pruned_to_pos: Position, pinned_nodes: Vec<D>) -> Self {
494        Self {
495            nodes: VecDeque::from(nodes),
496            pruned_to_pos,
497            pinned_nodes: nodes_to_pin(pruned_to_pos)
498                .enumerate()
499                .map(|(i, pos)| (pos, pinned_nodes[i]))
500                .collect(),
501            state: Dirty::default(),
502        }
503    }
504
505    /// Add `digest` as a new leaf in the MMR, returning its position.
506    pub(super) fn add_leaf_digest(&mut self, digest: D) -> Position {
507        // Compute the new parent nodes, if any.
508        let nodes_needing_parents = nodes_needing_parents(self.peak_iterator())
509            .into_iter()
510            .rev();
511        let leaf_pos = self.size();
512        self.nodes.push_back(digest);
513
514        let mut height = 1;
515        for _ in nodes_needing_parents {
516            let new_node_pos = self.size();
517            self.nodes.push_back(D::EMPTY);
518            self.state.dirty_nodes.insert((new_node_pos, height));
519            height += 1;
520        }
521
522        leaf_pos
523    }
524
525    /// Pop the most recent leaf element out of the MMR if it exists, returning Empty or
526    /// ElementPruned errors otherwise.
527    pub fn pop(&mut self) -> Result<Position, Error> {
528        if self.size() == 0 {
529            return Err(Empty);
530        }
531
532        let mut new_size = self.size() - 1;
533        loop {
534            if new_size < self.pruned_to_pos {
535                return Err(ElementPruned(new_size));
536            }
537            if new_size.is_mmr_size() {
538                break;
539            }
540            new_size -= 1;
541        }
542        let num_to_drain = *(self.size() - new_size) as usize;
543        self.nodes.drain(self.nodes.len() - num_to_drain..);
544
545        // Remove dirty nodes that are now out of bounds.
546        let cutoff = (self.size(), 0);
547        self.state.dirty_nodes.split_off(&cutoff);
548
549        Ok(self.size())
550    }
551
552    /// Compute updated digests for dirty nodes and compute the root, converting this MMR into a
553    /// [CleanMmr].
554    pub fn merkleize(
555        mut self,
556        hasher: &mut impl Hasher<D>,
557        pool: Option<ThreadPool>,
558    ) -> CleanMmr<D> {
559        #[cfg(feature = "std")]
560        match (pool, self.state.dirty_nodes.len() >= MIN_TO_PARALLELIZE) {
561            (Some(pool), true) => self.merkleize_parallel(hasher, pool, MIN_TO_PARALLELIZE),
562            _ => self.merkleize_serial(hasher),
563        }
564
565        #[cfg(not(feature = "std"))]
566        self.merkleize_serial(hasher);
567
568        // Compute root
569        let peaks = self
570            .peak_iterator()
571            .map(|(peak_pos, _)| self.get_node_unchecked(peak_pos));
572        let size = self.size();
573        let digest = hasher.root(size, peaks);
574
575        CleanMmr {
576            nodes: self.nodes,
577            pruned_to_pos: self.pruned_to_pos,
578            pinned_nodes: self.pinned_nodes,
579            state: Clean { root: digest },
580        }
581    }
582
583    fn merkleize_serial(&mut self, hasher: &mut impl Hasher<D>) {
584        let mut nodes: Vec<(Position, u32)> = self.state.dirty_nodes.iter().copied().collect();
585        self.state.dirty_nodes.clear();
586        nodes.sort_by(|a, b| a.1.cmp(&b.1));
587
588        for (pos, height) in nodes {
589            let left = pos - (1 << height);
590            let right = pos - 1;
591            let digest = hasher.node_digest(
592                pos,
593                self.get_node_unchecked(left),
594                self.get_node_unchecked(right),
595            );
596            let index = self.pos_to_index(pos);
597            self.nodes[index] = digest;
598        }
599    }
600
601    /// Process any pending batched updates, using parallel hash workers as long as the number of
602    /// computations that can be parallelized exceeds `min_to_parallelize`.
603    ///
604    /// This implementation parallelizes the computation of digests across all nodes at the same
605    /// height, starting from the bottom and working up to the peaks. If ever the number of
606    /// remaining digest computations is less than the `min_to_parallelize`, it switches to the
607    /// serial implementation.
608    #[cfg(feature = "std")]
609    fn merkleize_parallel(
610        &mut self,
611        hasher: &mut impl Hasher<D>,
612        pool: ThreadPool,
613        min_to_parallelize: usize,
614    ) {
615        let mut nodes: Vec<(Position, u32)> = self.state.dirty_nodes.iter().copied().collect();
616        self.state.dirty_nodes.clear();
617        // Sort by increasing height.
618        nodes.sort_by(|a, b| a.1.cmp(&b.1));
619
620        let mut same_height = Vec::new();
621        let mut current_height = 1;
622        for (i, (pos, height)) in nodes.iter().enumerate() {
623            if *height == current_height {
624                same_height.push(*pos);
625                continue;
626            }
627            if same_height.len() < min_to_parallelize {
628                self.state.dirty_nodes = nodes[i - same_height.len()..].iter().copied().collect();
629                self.merkleize_serial(hasher);
630                return;
631            }
632            self.update_node_digests(hasher, pool.clone(), &same_height, current_height);
633            same_height.clear();
634            current_height += 1;
635            same_height.push(*pos);
636        }
637
638        if same_height.len() < min_to_parallelize {
639            self.state.dirty_nodes = nodes[nodes.len() - same_height.len()..]
640                .iter()
641                .copied()
642                .collect();
643            self.merkleize_serial(hasher);
644            return;
645        }
646
647        self.update_node_digests(hasher, pool, &same_height, current_height);
648    }
649
650    /// Update digests of the given set of nodes of equal height in the MMR. Since they are all at
651    /// the same height, this can be done in parallel without synchronization.
652    #[cfg(feature = "std")]
653    fn update_node_digests(
654        &mut self,
655        hasher: &mut impl Hasher<D>,
656        pool: ThreadPool,
657        same_height: &[Position],
658        height: u32,
659    ) {
660        let two_h = 1 << height;
661        pool.install(|| {
662            let computed_digests: Vec<(usize, D)> = same_height
663                .par_iter()
664                .map_init(
665                    || hasher.fork(),
666                    |hasher, &pos| {
667                        let left = pos - two_h;
668                        let right = pos - 1;
669                        let digest = hasher.node_digest(
670                            pos,
671                            self.get_node_unchecked(left),
672                            self.get_node_unchecked(right),
673                        );
674                        let index = self.pos_to_index(pos);
675                        (index, digest)
676                    },
677                )
678                .collect();
679
680            for (index, digest) in computed_digests {
681                self.nodes[index] = digest;
682            }
683        });
684    }
685
686    /// Mark the non-leaf nodes in the path from the given position to the root as dirty, so that
687    /// their digests are appropriately recomputed during the next `merkleize`.
688    fn mark_dirty(&mut self, pos: Position) {
689        for (peak_pos, mut height) in self.peak_iterator() {
690            if peak_pos < pos {
691                continue;
692            }
693
694            // We have found the mountain containing the path we are looking for. Traverse it from
695            // leaf to root, that way we can exit early if we hit a node that is already dirty.
696            let path = PathIterator::new(pos, peak_pos, height)
697                .collect::<Vec<_>>()
698                .into_iter()
699                .rev();
700            height = 1;
701            for (parent_pos, _) in path {
702                if !self.state.dirty_nodes.insert((parent_pos, height)) {
703                    break;
704                }
705                height += 1;
706            }
707            return;
708        }
709
710        panic!("invalid pos {pos}:{}", self.size());
711    }
712
713    /// Update the leaf at `loc` to `element`.
714    pub fn update_leaf(
715        &mut self,
716        hasher: &mut impl Hasher<D>,
717        loc: Location,
718        element: &[u8],
719    ) -> Result<(), Error> {
720        self.update_leaf_batched(hasher, None, &[(loc, element)])
721    }
722
723    /// Batch update the digests of multiple retained leaves.
724    ///
725    /// # Errors
726    ///
727    /// Returns [Error::LeafOutOfBounds] if any location is not an existing leaf.
728    /// Returns [Error::LocationOverflow] if any location exceeds [crate::mmr::MAX_LOCATION].
729    /// Returns [Error::ElementPruned] if any of the leaves has been pruned.
730    pub fn update_leaf_batched<T: AsRef<[u8]> + Sync>(
731        &mut self,
732        hasher: &mut impl Hasher<D>,
733        pool: Option<ThreadPool>,
734        updates: &[(Location, T)],
735    ) -> Result<(), Error> {
736        if updates.is_empty() {
737            return Ok(());
738        }
739
740        let leaves = self.leaves();
741        let mut positions = Vec::with_capacity(updates.len());
742        for (loc, _) in updates {
743            if *loc >= leaves {
744                return Err(Error::LeafOutOfBounds(*loc));
745            }
746            let pos = Position::try_from(*loc)?;
747            if pos < self.pruned_to_pos {
748                return Err(Error::ElementPruned(pos));
749            }
750            positions.push(pos);
751        }
752
753        #[cfg(feature = "std")]
754        if let Some(pool) = pool {
755            if updates.len() >= MIN_TO_PARALLELIZE {
756                self.update_leaf_parallel(hasher, pool, updates, &positions);
757                return Ok(());
758            }
759        }
760
761        for ((_, element), pos) in updates.iter().zip(positions.iter()) {
762            // Update the digest of the leaf node and mark its ancestors as dirty.
763            let digest = hasher.leaf_digest(*pos, element.as_ref());
764            let index = self.pos_to_index(*pos);
765            self.nodes[index] = digest;
766            self.mark_dirty(*pos);
767        }
768
769        Ok(())
770    }
771
772    /// Batch update the digests of multiple retained leaves using multiple threads.
773    #[cfg(feature = "std")]
774    fn update_leaf_parallel<T: AsRef<[u8]> + Sync>(
775        &mut self,
776        hasher: &mut impl Hasher<D>,
777        pool: ThreadPool,
778        updates: &[(Location, T)],
779        positions: &[Position],
780    ) {
781        pool.install(|| {
782            let digests: Vec<(Position, D)> = updates
783                .par_iter()
784                .zip(positions.par_iter())
785                .map_init(
786                    || hasher.fork(),
787                    |hasher, ((_, elem), pos)| {
788                        let digest = hasher.leaf_digest(*pos, elem.as_ref());
789                        (*pos, digest)
790                    },
791                )
792                .collect();
793
794            for (pos, digest) in digests {
795                let index = self.pos_to_index(pos);
796                self.nodes[index] = digest;
797                self.mark_dirty(pos);
798            }
799        });
800    }
801}
802
803#[cfg(test)]
804mod tests {
805    use super::*;
806    use crate::mmr::{
807        hasher::{Hasher as _, Standard},
808        stability::ROOTS,
809    };
810    use commonware_cryptography::{sha256, Hasher, Sha256};
811    use commonware_runtime::{deterministic, tokio, RayonPoolSpawner, Runner};
812    use commonware_utils::{hex, NZUsize};
813
814    /// Build the MMR corresponding to the stability test `ROOTS` and confirm the roots match.
815    fn build_and_check_test_roots_mmr(mmr: &mut CleanMmr<sha256::Digest>) {
816        let mut hasher: Standard<Sha256> = Standard::new();
817        for i in 0u64..199 {
818            hasher.inner().update(&i.to_be_bytes());
819            let element = hasher.inner().finalize();
820            let root = *mmr.root();
821            let expected_root = ROOTS[i as usize];
822            assert_eq!(hex(&root), expected_root, "at: {i}");
823            mmr.add(&mut hasher, &element);
824        }
825        assert_eq!(hex(mmr.root()), ROOTS[199], "Root after 200 elements");
826    }
827
828    /// Same as `build_and_check_test_roots` but uses `add` + `merkleize` instead of `add`.
829    pub fn build_batched_and_check_test_roots(
830        mut mmr: DirtyMmr<sha256::Digest>,
831        pool: Option<ThreadPool>,
832    ) {
833        let mut hasher: Standard<Sha256> = Standard::new();
834        for i in 0u64..199 {
835            hasher.inner().update(&i.to_be_bytes());
836            let element = hasher.inner().finalize();
837            mmr.add(&mut hasher, &element);
838        }
839        let mmr = mmr.merkleize(&mut hasher, pool);
840        assert_eq!(hex(mmr.root()), ROOTS[199], "Root after 200 elements");
841    }
842
843    /// Test empty MMR behavior.
844    #[test]
845    fn test_mem_mmr_empty() {
846        let executor = deterministic::Runner::default();
847        executor.start(|_| async move {
848            let mut hasher: Standard<Sha256> = Standard::new();
849            let mut mmr = CleanMmr::new(&mut hasher);
850            assert_eq!(
851                mmr.peak_iterator().next(),
852                None,
853                "empty iterator should have no peaks"
854            );
855            assert_eq!(mmr.size(), 0);
856            assert_eq!(mmr.leaves(), Location::new_unchecked(0));
857            assert_eq!(mmr.last_leaf_pos(), None);
858            assert_eq!(mmr.oldest_retained_pos(), None);
859            assert_eq!(mmr.get_node(Position::new(0)), None);
860            assert_eq!(*mmr.root(), Mmr::empty_mmr_root(hasher.inner()));
861            assert!(matches!(mmr.pop(&mut hasher), Err(Empty)));
862            mmr.prune_all();
863            assert_eq!(mmr.size(), 0, "prune_all on empty MMR should do nothing");
864
865            assert_eq!(*mmr.root(), hasher.root(Position::new(0), [].iter()));
866        });
867    }
868
869    /// Test MMR building by consecutively adding 11 equal elements to a new MMR, producing the
870    /// structure in the example documented at the top of the mmr crate's mod.rs file with 19 nodes
871    /// and 3 peaks.
872    #[test]
873    fn test_mem_mmr_add_eleven_values() {
874        let executor = deterministic::Runner::default();
875        executor.start(|_| async move {
876            let mut hasher: Standard<Sha256> = Standard::new();
877            let mut mmr = CleanMmr::new(&mut hasher);
878            let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
879            let mut leaves: Vec<Position> = Vec::new();
880            for _ in 0..11 {
881                leaves.push(mmr.add(&mut hasher, &element));
882                let peaks: Vec<(Position, u32)> = mmr.peak_iterator().collect();
883                assert_ne!(peaks.len(), 0);
884                assert!(peaks.len() as u64 <= mmr.size());
885                let nodes_needing_parents = nodes_needing_parents(mmr.peak_iterator());
886                assert!(nodes_needing_parents.len() <= peaks.len());
887            }
888            assert_eq!(mmr.oldest_retained_pos().unwrap(), Position::new(0));
889            assert_eq!(mmr.size(), 19, "mmr not of expected size");
890            assert_eq!(
891                leaves,
892                vec![0, 1, 3, 4, 7, 8, 10, 11, 15, 16, 18]
893                    .into_iter()
894                    .map(Position::new)
895                    .collect::<Vec<_>>(),
896                "mmr leaf positions not as expected"
897            );
898            let peaks: Vec<(Position, u32)> = mmr.peak_iterator().collect();
899            assert_eq!(
900                peaks,
901                vec![
902                    (Position::new(14), 3),
903                    (Position::new(17), 1),
904                    (Position::new(18), 0)
905                ],
906                "mmr peaks not as expected"
907            );
908
909            // Test nodes_needing_parents on the final MMR. Since there's a height gap between the
910            // highest peak (14) and the next, only the lower two peaks (17, 18) should be returned.
911            let peaks_needing_parents = nodes_needing_parents(mmr.peak_iterator());
912            assert_eq!(
913                peaks_needing_parents,
914                vec![Position::new(17), Position::new(18)],
915                "mmr nodes needing parents not as expected"
916            );
917
918            // verify leaf digests
919            for leaf in leaves.iter().by_ref() {
920                let digest = hasher.leaf_digest(*leaf, &element);
921                assert_eq!(mmr.get_node(*leaf).unwrap(), digest);
922            }
923
924            // verify height=1 node digests
925            let digest2 = hasher.node_digest(Position::new(2), &mmr.nodes[0], &mmr.nodes[1]);
926            assert_eq!(mmr.nodes[2], digest2);
927            let digest5 = hasher.node_digest(Position::new(5), &mmr.nodes[3], &mmr.nodes[4]);
928            assert_eq!(mmr.nodes[5], digest5);
929            let digest9 = hasher.node_digest(Position::new(9), &mmr.nodes[7], &mmr.nodes[8]);
930            assert_eq!(mmr.nodes[9], digest9);
931            let digest12 = hasher.node_digest(Position::new(12), &mmr.nodes[10], &mmr.nodes[11]);
932            assert_eq!(mmr.nodes[12], digest12);
933            let digest17 = hasher.node_digest(Position::new(17), &mmr.nodes[15], &mmr.nodes[16]);
934            assert_eq!(mmr.nodes[17], digest17);
935
936            // verify height=2 node digests
937            let digest6 = hasher.node_digest(Position::new(6), &mmr.nodes[2], &mmr.nodes[5]);
938            assert_eq!(mmr.nodes[6], digest6);
939            let digest13 = hasher.node_digest(Position::new(13), &mmr.nodes[9], &mmr.nodes[12]);
940            assert_eq!(mmr.nodes[13], digest13);
941            let digest17 = hasher.node_digest(Position::new(17), &mmr.nodes[15], &mmr.nodes[16]);
942            assert_eq!(mmr.nodes[17], digest17);
943
944            // verify topmost digest
945            let digest14 = hasher.node_digest(Position::new(14), &mmr.nodes[6], &mmr.nodes[13]);
946            assert_eq!(mmr.nodes[14], digest14);
947
948            // verify root
949            let root = *mmr.root();
950            let peak_digests = [digest14, digest17, mmr.nodes[18]];
951            let expected_root = hasher.root(Position::new(19), peak_digests.iter());
952            assert_eq!(root, expected_root, "incorrect root");
953
954            // pruning tests
955            mmr.prune_to_pos(Position::new(14)); // prune up to the tallest peak
956            assert_eq!(mmr.oldest_retained_pos().unwrap(), Position::new(14));
957
958            // After pruning, we shouldn't be able to generate a proof for any elements before the
959            // pruning boundary. (To be precise, due to the maintenance of pinned nodes, we may in
960            // fact still be able to generate them for some, but it's not guaranteed. For example,
961            // in this case, we actually can still generate a proof for the node with location 7
962            // even though it's pruned.)
963            assert!(matches!(
964                mmr.proof(Location::new_unchecked(0)),
965                Err(ElementPruned(_))
966            ));
967            assert!(matches!(
968                mmr.proof(Location::new_unchecked(6)),
969                Err(ElementPruned(_))
970            ));
971
972            // We should still be able to generate a proof for any leaf following the pruning
973            // boundary, the first of which is at location 8 and the last location 10.
974            assert!(mmr.proof(Location::new_unchecked(8)).is_ok());
975            assert!(mmr.proof(Location::new_unchecked(10)).is_ok());
976
977            let root_after_prune = *mmr.root();
978            assert_eq!(root, root_after_prune, "root changed after pruning");
979
980            assert!(
981                mmr.range_proof(Location::new_unchecked(5)..Location::new_unchecked(9))
982                    .is_err(),
983                "attempts to range_prove elements at or before the oldest retained should fail"
984            );
985            assert!(
986                mmr.range_proof(Location::new_unchecked(8)..mmr.leaves()).is_ok(),
987                "attempts to range_prove over all elements following oldest retained should succeed"
988            );
989
990            // Test that we can initialize a new MMR from another's elements.
991            let oldest_pos = mmr.oldest_retained_pos().unwrap();
992            let digests = mmr.node_digests_to_pin(oldest_pos);
993            let mmr_copy = Mmr::init(
994                Config {
995                    nodes: mmr.nodes.iter().copied().collect(),
996                    pruned_to_pos: oldest_pos,
997                    pinned_nodes: digests,
998                },
999                &mut hasher,
1000            )
1001            .unwrap();
1002            assert_eq!(mmr_copy.size(), 19);
1003            assert_eq!(mmr_copy.leaves(), mmr.leaves());
1004            assert_eq!(mmr_copy.last_leaf_pos(), mmr.last_leaf_pos());
1005            assert_eq!(mmr_copy.oldest_retained_pos(), mmr.oldest_retained_pos());
1006            assert_eq!(*mmr_copy.root(), root);
1007        });
1008    }
1009
1010    /// Test that pruning all nodes never breaks adding new nodes.
1011    #[test]
1012    fn test_mem_mmr_prune_all() {
1013        let executor = deterministic::Runner::default();
1014        executor.start(|_| async move {
1015            let mut hasher: Standard<Sha256> = Standard::new();
1016            let mut mmr = CleanMmr::new(&mut hasher);
1017            let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1018            for _ in 0..1000 {
1019                mmr.prune_all();
1020                mmr.add(&mut hasher, &element);
1021            }
1022        });
1023    }
1024
1025    /// Test that the MMR validity check works as expected.
1026    #[test]
1027    fn test_mem_mmr_validity() {
1028        let executor = deterministic::Runner::default();
1029        executor.start(|_| async move {
1030            let mut hasher: Standard<Sha256> = Standard::new();
1031            let mut mmr = CleanMmr::new(&mut hasher);
1032            let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1033            for _ in 0..1001 {
1034                assert!(
1035                    mmr.size().is_mmr_size(),
1036                    "mmr of size {} should be valid",
1037                    mmr.size()
1038                );
1039                let old_size = mmr.size();
1040                mmr.add(&mut hasher, &element);
1041                for size in *old_size + 1..*mmr.size() {
1042                    assert!(
1043                        !Position::new(size).is_mmr_size(),
1044                        "mmr of size {size} should be invalid",
1045                    );
1046                }
1047            }
1048        });
1049    }
1050
1051    /// Test that the MMR root computation remains stable by comparing against previously computed
1052    /// roots.
1053    #[test]
1054    fn test_mem_mmr_root_stability() {
1055        let executor = deterministic::Runner::default();
1056        executor.start(|_| async move {
1057            // Test root stability under different MMR building methods.
1058            let mut hasher: Standard<Sha256> = Standard::new();
1059            let mut mmr = CleanMmr::new(&mut hasher);
1060            build_and_check_test_roots_mmr(&mut mmr);
1061
1062            let mut hasher: Standard<Sha256> = Standard::new();
1063            let mmr = CleanMmr::new(&mut hasher);
1064            build_batched_and_check_test_roots(mmr.into_dirty(), None);
1065        });
1066    }
1067
1068    /// Test root stability using the parallel builder implementation. This requires we use the
1069    /// tokio runtime since the deterministic runtime would block due to being single-threaded.
1070    #[test]
1071    fn test_mem_mmr_root_stability_parallel() {
1072        let executor = tokio::Runner::default();
1073        executor.start(|context| async move {
1074            let pool = context.create_pool(NZUsize!(4)).unwrap();
1075            let mut hasher: Standard<Sha256> = Standard::new();
1076
1077            let mmr = Mmr::init(
1078                Config {
1079                    nodes: vec![],
1080                    pruned_to_pos: Position::new(0),
1081                    pinned_nodes: vec![],
1082                },
1083                &mut hasher,
1084            )
1085            .unwrap();
1086            build_batched_and_check_test_roots(mmr.into_dirty(), Some(pool));
1087        });
1088    }
1089
1090    /// Build the MMR corresponding to the stability test while pruning after each add, and confirm
1091    /// the static roots match that from the root computation.
1092    #[test]
1093    fn test_mem_mmr_root_stability_while_pruning() {
1094        let executor = deterministic::Runner::default();
1095        executor.start(|_| async move {
1096            let mut hasher: Standard<Sha256> = Standard::new();
1097            let mut mmr = CleanMmr::new(&mut hasher);
1098            for i in 0u64..199 {
1099                let root = *mmr.root();
1100                let expected_root = ROOTS[i as usize];
1101                assert_eq!(hex(&root), expected_root, "at: {i}");
1102                hasher.inner().update(&i.to_be_bytes());
1103                let element = hasher.inner().finalize();
1104                mmr.add(&mut hasher, &element);
1105                mmr.prune_all();
1106            }
1107        });
1108    }
1109
1110    fn compute_big_mmr(
1111        hasher: &mut Standard<Sha256>,
1112        mut mmr: DirtyMmr<sha256::Digest>,
1113        pool: Option<ThreadPool>,
1114    ) -> (CleanMmr<sha256::Digest>, Vec<Position>) {
1115        let mut leaves = Vec::new();
1116        let mut c_hasher = Sha256::default();
1117        for i in 0u64..199 {
1118            c_hasher.update(&i.to_be_bytes());
1119            let element = c_hasher.finalize();
1120            let leaf_pos = mmr.size();
1121            mmr.add(hasher, &element);
1122            leaves.push(leaf_pos);
1123        }
1124
1125        (mmr.merkleize(hasher, pool), leaves)
1126    }
1127
1128    #[test]
1129    fn test_mem_mmr_pop() {
1130        let executor = deterministic::Runner::default();
1131        executor.start(|_| async move {
1132            let mut hasher: Standard<Sha256> = Standard::new();
1133            let (mut mmr, _) = compute_big_mmr(&mut hasher, Mmr::default(), None);
1134            let root = *mmr.root();
1135            let expected_root = ROOTS[199];
1136            assert_eq!(hex(&root), expected_root);
1137
1138            // Pop off one node at a time until empty, confirming the root is still is as expected.
1139            for i in (0..199u64).rev() {
1140                assert!(mmr.pop(&mut hasher).is_ok());
1141                let root = *mmr.root();
1142                let expected_root = ROOTS[i as usize];
1143                assert_eq!(hex(&root), expected_root);
1144            }
1145
1146            assert!(
1147                matches!(mmr.pop(&mut hasher).unwrap_err(), Empty),
1148                "pop on empty MMR should fail"
1149            );
1150
1151            // Test that we can pop all elements up to and including the oldest retained leaf.
1152            for i in 0u64..199 {
1153                hasher.inner().update(&i.to_be_bytes());
1154                let element = hasher.inner().finalize();
1155                mmr.add(&mut hasher, &element);
1156            }
1157
1158            let leaf_pos = Position::try_from(Location::new_unchecked(100)).unwrap();
1159            mmr.prune_to_pos(leaf_pos);
1160            while mmr.size() > leaf_pos {
1161                mmr.pop(&mut hasher).unwrap();
1162            }
1163            assert_eq!(hex(mmr.root()), ROOTS[100]);
1164            let result = mmr.pop(&mut hasher);
1165            assert!(matches!(result, Err(ElementPruned(_))));
1166            assert_eq!(mmr.oldest_retained_pos(), None);
1167        });
1168    }
1169
1170    #[test]
1171    fn test_mem_mmr_update_leaf() {
1172        let mut hasher: Standard<Sha256> = Standard::new();
1173        let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1174        let executor = deterministic::Runner::default();
1175        executor.start(|_| async move {
1176            let (mut mmr, leaves) = compute_big_mmr(&mut hasher, Mmr::default(), None);
1177            let root = *mmr.root();
1178
1179            // For a few leaves, update the leaf and ensure the root changes, and the root reverts
1180            // to its previous state then we update the leaf to its original value.
1181            for leaf in [0usize, 1, 10, 50, 100, 150, 197, 198] {
1182                // Change the leaf.
1183                let leaf_loc =
1184                    Location::try_from(leaves[leaf]).expect("leaf position should map to location");
1185                mmr.update_leaf(&mut hasher, leaf_loc, &element).unwrap();
1186                let updated_root = *mmr.root();
1187                assert!(root != updated_root);
1188
1189                // Restore the leaf to its original value, ensure the root is as before.
1190                hasher.inner().update(&leaf.to_be_bytes());
1191                let element = hasher.inner().finalize();
1192                mmr.update_leaf(&mut hasher, leaf_loc, &element).unwrap();
1193                let restored_root = *mmr.root();
1194                assert_eq!(root, restored_root);
1195            }
1196
1197            // Confirm the tree has all the hashes necessary to update any element after pruning.
1198            mmr.prune_to_pos(leaves[150]);
1199            for &leaf_pos in &leaves[150..=190] {
1200                mmr.prune_to_pos(leaf_pos);
1201                let leaf_loc =
1202                    Location::try_from(leaf_pos).expect("leaf position should map to location");
1203                mmr.update_leaf(&mut hasher, leaf_loc, &element).unwrap();
1204            }
1205        });
1206    }
1207
1208    #[test]
1209    fn test_mem_mmr_update_leaf_error_out_of_bounds() {
1210        let mut hasher: Standard<Sha256> = Standard::new();
1211        let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1212
1213        let executor = deterministic::Runner::default();
1214        executor.start(|_| async move {
1215            let (mut mmr, _) = compute_big_mmr(&mut hasher, Mmr::default(), None);
1216            let invalid_loc = mmr.leaves();
1217            let result = mmr.update_leaf(&mut hasher, invalid_loc, &element);
1218            assert!(matches!(result, Err(Error::LeafOutOfBounds(_))));
1219        });
1220    }
1221
1222    #[test]
1223    fn test_mem_mmr_update_leaf_error_pruned() {
1224        let mut hasher: Standard<Sha256> = Standard::new();
1225        let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1226
1227        let executor = deterministic::Runner::default();
1228        executor.start(|_| async move {
1229            let (mut mmr, _) = compute_big_mmr(&mut hasher, Mmr::default(), None);
1230            mmr.prune_all();
1231            let result = mmr.update_leaf(&mut hasher, Location::new_unchecked(0), &element);
1232            assert!(matches!(result, Err(Error::ElementPruned(_))));
1233        });
1234    }
1235
1236    #[test]
1237    fn test_mem_mmr_batch_update_leaf() {
1238        let mut hasher: Standard<Sha256> = Standard::new();
1239        let executor = deterministic::Runner::default();
1240        executor.start(|_| async move {
1241            let (mmr, leaves) = compute_big_mmr(&mut hasher, Mmr::default(), None);
1242            do_batch_update(&mut hasher, mmr, &leaves);
1243        });
1244    }
1245
1246    /// Same test as above only using a thread pool to trigger parallelization. This requires we use
1247    /// tokio runtime instead of the deterministic one.
1248    #[test]
1249    fn test_mem_mmr_batch_parallel_update_leaf() {
1250        let mut hasher: Standard<Sha256> = Standard::new();
1251        let executor = tokio::Runner::default();
1252        executor.start(|ctx| async move {
1253            let pool = ctx.create_pool(NZUsize!(4)).unwrap();
1254            let mmr = Mmr::init(
1255                Config {
1256                    nodes: Vec::new(),
1257                    pruned_to_pos: Position::new(0),
1258                    pinned_nodes: Vec::new(),
1259                },
1260                &mut hasher,
1261            )
1262            .unwrap();
1263            let (mmr, leaves) = compute_big_mmr(&mut hasher, mmr.into_dirty(), Some(pool));
1264            do_batch_update(&mut hasher, mmr, &leaves);
1265        });
1266    }
1267
1268    fn do_batch_update(
1269        hasher: &mut Standard<Sha256>,
1270        mmr: CleanMmr<sha256::Digest>,
1271        leaves: &[Position],
1272    ) {
1273        let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1274        let root = *mmr.root();
1275
1276        // Change a handful of leaves using a batch update.
1277        let mut updates = Vec::new();
1278        for leaf in [0usize, 1, 10, 50, 100, 150, 197, 198] {
1279            let leaf_loc =
1280                Location::try_from(leaves[leaf]).expect("leaf position should map to location");
1281            updates.push((leaf_loc, &element));
1282        }
1283        let mut dirty_mmr = mmr.into_dirty();
1284        dirty_mmr
1285            .update_leaf_batched(hasher, None, &updates)
1286            .unwrap();
1287
1288        let mmr = dirty_mmr.merkleize(hasher, None);
1289        let updated_root = *mmr.root();
1290        assert_eq!(
1291            "af3acad6aad59c1a880de643b1200a0962a95d06c087ebf677f29eb93fc359a4",
1292            hex(&updated_root)
1293        );
1294
1295        // Batch-restore the changed leaves to their original values.
1296        let mut updates = Vec::new();
1297        for leaf in [0usize, 1, 10, 50, 100, 150, 197, 198] {
1298            hasher.inner().update(&leaf.to_be_bytes());
1299            let element = hasher.inner().finalize();
1300            let leaf_loc =
1301                Location::try_from(leaves[leaf]).expect("leaf position should map to location");
1302            updates.push((leaf_loc, element));
1303        }
1304        let mut dirty_mmr = mmr.into_dirty();
1305        dirty_mmr
1306            .update_leaf_batched(hasher, None, &updates)
1307            .unwrap();
1308
1309        let mmr = dirty_mmr.merkleize(hasher, None);
1310        let restored_root = *mmr.root();
1311        assert_eq!(root, restored_root);
1312    }
1313
1314    #[test]
1315    fn test_init_pinned_nodes_validation() {
1316        let executor = deterministic::Runner::default();
1317        executor.start(|_| async move {
1318            let mut hasher: Standard<Sha256> = Standard::new();
1319            // Test with empty config - should succeed
1320            let config = Config::<sha256::Digest> {
1321                nodes: vec![],
1322                pruned_to_pos: Position::new(0),
1323                pinned_nodes: vec![],
1324            };
1325            assert!(Mmr::init(config, &mut hasher).is_ok());
1326
1327            // Test with too few pinned nodes - should fail
1328            // Use a valid MMR size (127 is valid: 2^7 - 1 makes a complete tree)
1329            let config = Config::<sha256::Digest> {
1330                nodes: vec![],
1331                pruned_to_pos: Position::new(127),
1332                pinned_nodes: vec![], // Should have nodes for position 127
1333            };
1334            assert!(matches!(
1335                Mmr::init(config, &mut hasher),
1336                Err(Error::InvalidPinnedNodes)
1337            ));
1338
1339            // Test with too many pinned nodes - should fail
1340            let config = Config {
1341                nodes: vec![],
1342                pruned_to_pos: Position::new(0),
1343                pinned_nodes: vec![Sha256::hash(b"dummy")],
1344            };
1345            assert!(matches!(
1346                Mmr::init(config, &mut hasher),
1347                Err(Error::InvalidPinnedNodes)
1348            ));
1349
1350            // Test with correct number of pinned nodes - should succeed
1351            // Build a small MMR to get valid pinned nodes
1352            let mut mmr = CleanMmr::new(&mut hasher);
1353            for i in 0u64..50 {
1354                mmr.add(&mut hasher, &i.to_be_bytes());
1355            }
1356            let pinned_nodes = mmr.node_digests_to_pin(Position::new(50));
1357            let config = Config {
1358                nodes: vec![],
1359                pruned_to_pos: Position::new(50),
1360                pinned_nodes,
1361            };
1362            assert!(Mmr::init(config, &mut hasher).is_ok());
1363        });
1364    }
1365
1366    #[test]
1367    fn test_init_size_validation() {
1368        let executor = deterministic::Runner::default();
1369        executor.start(|_| async move {
1370            let mut hasher: Standard<Sha256> = Standard::new();
1371            // Test with valid size 0 - should succeed
1372            let config = Config::<sha256::Digest> {
1373                nodes: vec![],
1374                pruned_to_pos: Position::new(0),
1375                pinned_nodes: vec![],
1376            };
1377            assert!(Mmr::init(config, &mut hasher).is_ok());
1378
1379            // Test with invalid size 2 - should fail
1380            // Size 2 is invalid (can't have just one parent node + one leaf)
1381            let config = Config {
1382                nodes: vec![Sha256::hash(b"node1"), Sha256::hash(b"node2")],
1383                pruned_to_pos: Position::new(0),
1384                pinned_nodes: vec![],
1385            };
1386            assert!(matches!(
1387                Mmr::init(config, &mut hasher),
1388                Err(Error::InvalidSize(_))
1389            ));
1390
1391            // Test with valid size 3 (one full tree with 2 leaves) - should succeed
1392            let config = Config {
1393                nodes: vec![
1394                    Sha256::hash(b"leaf1"),
1395                    Sha256::hash(b"leaf2"),
1396                    Sha256::hash(b"parent"),
1397                ],
1398                pruned_to_pos: Position::new(0),
1399                pinned_nodes: vec![],
1400            };
1401            assert!(Mmr::init(config, &mut hasher).is_ok());
1402
1403            // Test with large valid size (127 = 2^7 - 1, a complete tree) - should succeed
1404            // Build a real MMR to get the correct structure
1405            let mut mmr = CleanMmr::new(&mut hasher);
1406            for i in 0u64..64 {
1407                mmr.add(&mut hasher, &i.to_be_bytes());
1408            }
1409            assert_eq!(mmr.size(), 127); // Verify we have the expected size
1410            let nodes: Vec<_> = (0..127)
1411                .map(|i| *mmr.get_node_unchecked(Position::new(i)))
1412                .collect();
1413
1414            let config = Config {
1415                nodes,
1416                pruned_to_pos: Position::new(0),
1417                pinned_nodes: vec![],
1418            };
1419            assert!(Mmr::init(config, &mut hasher).is_ok());
1420
1421            // Test with non-zero pruned_to_pos - should succeed
1422            // Build a small MMR (11 leaves -> 19 nodes), prune it, then init from that state
1423            let mut mmr = CleanMmr::new(&mut hasher);
1424            for i in 0u64..11 {
1425                mmr.add(&mut hasher, &i.to_be_bytes());
1426            }
1427            assert_eq!(mmr.size(), 19); // 11 leaves = 19 total nodes
1428
1429            // Prune to position 7
1430            mmr.prune_to_pos(Position::new(7));
1431            let nodes: Vec<_> = (7..*mmr.size())
1432                .map(|i| *mmr.get_node_unchecked(Position::new(i)))
1433                .collect();
1434            let pinned_nodes = mmr.node_digests_to_pin(Position::new(7));
1435
1436            let config = Config {
1437                nodes: nodes.clone(),
1438                pruned_to_pos: Position::new(7),
1439                pinned_nodes: pinned_nodes.clone(),
1440            };
1441            assert!(Mmr::init(config, &mut hasher).is_ok());
1442
1443            // Same nodes but wrong pruned_to_pos - should fail
1444            // pruned_to_pos=8 + 12 nodes = size 20 (invalid)
1445            let config = Config {
1446                nodes: nodes.clone(),
1447                pruned_to_pos: Position::new(8),
1448                pinned_nodes: pinned_nodes.clone(),
1449            };
1450            assert!(matches!(
1451                Mmr::init(config, &mut hasher),
1452                Err(Error::InvalidSize(_))
1453            ));
1454
1455            // Same nodes but different wrong pruned_to_pos - should fail
1456            // pruned_to_pos=9 + 12 nodes = size 21 (invalid)
1457            let config = Config {
1458                nodes,
1459                pruned_to_pos: Position::new(9),
1460                pinned_nodes,
1461            };
1462            assert!(matches!(
1463                Mmr::init(config, &mut hasher),
1464                Err(Error::InvalidSize(_))
1465            ));
1466        });
1467    }
1468}