commonware_storage/mmr/
mem.rs

1//! A basic, no_std compatible MMR where all nodes are stored in-memory.
2
3use crate::mmr::{
4    hasher::Hasher,
5    iterator::{nodes_needing_parents, nodes_to_pin, PathIterator, PeakIterator},
6    proof,
7    Error::{self, *},
8    Location, Position, Proof,
9};
10use alloc::{
11    collections::{BTreeMap, BTreeSet, VecDeque},
12    vec::Vec,
13};
14use commonware_cryptography::Digest;
15use core::{mem, ops::Range};
16cfg_if::cfg_if! {
17    if #[cfg(feature = "std")] {
18        use commonware_runtime::ThreadPool;
19        use rayon::prelude::*;
20    } else {
21        struct ThreadPool;
22    }
23}
24
25/// Minimum number of digest computations required during batch updates to trigger parallelization.
26#[cfg(feature = "std")]
27const MIN_TO_PARALLELIZE: usize = 20;
28
29/// An MMR whose root digest has not been computed.
30pub type DirtyMmr<D> = Mmr<D, Dirty>;
31
32/// An MMR whose root digest has been computed.
33pub type CleanMmr<D> = Mmr<D, Clean<D>>;
34
35/// Sealed trait for MMR state types.
36mod private {
37    pub trait Sealed {}
38}
39
40/// Trait for valid MMR state types.
41pub trait State<D: Digest>: private::Sealed + Sized {
42    /// Add the given leaf digest to the MMR, returning its position.
43    fn add_leaf_digest<H: Hasher<D>>(mmr: &mut Mmr<D, Self>, hasher: &mut H, digest: D)
44        -> Position;
45}
46
47/// Marker type for a MMR whose root digest has been computed.
48#[derive(Clone, Copy, Debug)]
49pub struct Clean<D: Digest> {
50    /// The root digest of the MMR.
51    pub root: D,
52}
53
54impl<D: Digest> private::Sealed for Clean<D> {}
55impl<D: Digest> State<D> for Clean<D> {
56    fn add_leaf_digest<H: Hasher<D>>(mmr: &mut CleanMmr<D>, hasher: &mut H, digest: D) -> Position {
57        mmr.add_leaf_digest(hasher, digest)
58    }
59}
60
61/// Marker type for a dirty MMR (root digest not computed).
62#[derive(Clone, Debug, Default)]
63pub struct Dirty {
64    /// Non-leaf nodes that need to have their digests recomputed due to a batched update operation.
65    ///
66    /// This is a set of tuples of the form (node_pos, height).
67    dirty_nodes: BTreeSet<(Position, u32)>,
68}
69
70impl private::Sealed for Dirty {}
71impl<D: Digest> State<D> for Dirty {
72    fn add_leaf_digest<H: Hasher<D>>(mmr: &mut DirtyMmr<D>, hasher: &mut H, digest: D) -> Position {
73        mmr.add_leaf_digest(hasher, digest)
74    }
75}
76
77/// Configuration for initializing an [Mmr].
78pub struct Config<D: Digest> {
79    /// The retained nodes of the MMR.
80    pub nodes: Vec<D>,
81
82    /// The highest position for which this MMR has been pruned, or 0 if this MMR has never been
83    /// pruned.
84    pub pruned_to_pos: Position,
85
86    /// The pinned nodes of the MMR, in the order expected by `nodes_to_pin`.
87    pub pinned_nodes: Vec<D>,
88}
89
90/// A basic MMR where all nodes are stored in-memory.
91///
92/// # Terminology
93///
94/// Nodes in this structure are either _retained_, _pruned_, or _pinned_. Retained nodes are nodes
95/// that have not yet been pruned, and have digests stored explicitly within the tree structure.
96/// Pruned nodes are those whose positions precede that of the _oldest retained_ node, for which no
97/// digests are maintained. Pinned nodes are nodes that would otherwise be pruned based on their
98/// position, but whose digests remain required for proof generation. The digests for pinned nodes
99/// are stored in an auxiliary map, and are at most O(log2(n)) in number.
100///
101/// # Max Capacity
102///
103/// The maximum number of elements that can be stored is usize::MAX (u32::MAX on 32-bit
104/// architectures).
105///
106/// # Type States
107///
108/// The MMR uses the type-state pattern to enforce at compile-time whether the MMR has pending
109/// updates that must be merkleized before computing proofs. [CleanMmr] represents a clean
110/// MMR whose root digest has been computed. [DirtyMmr] represents a dirty MMR whose root
111/// digest needs to be computed. A dirty MMR can be converted into a clean MMR by calling
112/// [DirtyMmr::merkleize].
113#[derive(Clone, Debug)]
114pub struct Mmr<D: Digest, S: State<D> = Dirty> {
115    /// The nodes of the MMR, laid out according to a post-order traversal of the MMR trees,
116    /// starting from the from tallest tree to shortest.
117    nodes: VecDeque<D>,
118
119    /// The highest position for which this MMR has been pruned, or 0 if this MMR has never been
120    /// pruned.
121    pruned_to_pos: Position,
122
123    /// The auxiliary map from node position to the digest of any pinned node.
124    pinned_nodes: BTreeMap<Position, D>,
125
126    /// Type-state for the MMR.
127    state: S,
128}
129
130impl<D: Digest> Default for DirtyMmr<D> {
131    fn default() -> Self {
132        Self::new()
133    }
134}
135
136impl<D: Digest> From<CleanMmr<D>> for DirtyMmr<D> {
137    fn from(clean: CleanMmr<D>) -> Self {
138        DirtyMmr {
139            nodes: clean.nodes,
140            pruned_to_pos: clean.pruned_to_pos,
141            pinned_nodes: clean.pinned_nodes,
142            state: Dirty {
143                dirty_nodes: BTreeSet::new(),
144            },
145        }
146    }
147}
148
149impl<D: Digest, S: State<D>> Mmr<D, S> {
150    /// Return the total number of nodes in the MMR, irrespective of any pruning. The next added
151    /// element's position will have this value.
152    pub fn size(&self) -> Position {
153        Position::new(self.nodes.len() as u64 + *self.pruned_to_pos)
154    }
155
156    /// Return the total number of leaves in the MMR.
157    pub fn leaves(&self) -> Location {
158        Location::try_from(self.size()).expect("invalid mmr size")
159    }
160
161    /// Return the position of the last leaf in this MMR, or None if the MMR is empty.
162    pub fn last_leaf_pos(&self) -> Option<Position> {
163        if self.size() == 0 {
164            return None;
165        }
166
167        Some(PeakIterator::last_leaf_pos(self.size()))
168    }
169
170    /// The highest position for which this MMR has been pruned, or 0 if this MMR has never been
171    /// pruned.
172    pub const fn pruned_to_pos(&self) -> Position {
173        self.pruned_to_pos
174    }
175
176    /// Return the position of the oldest retained node in the MMR, not including those cached in
177    /// pinned_nodes.
178    pub fn oldest_retained_pos(&self) -> Option<Position> {
179        if self.pruned_to_pos == self.size() {
180            return None;
181        }
182
183        Some(self.pruned_to_pos)
184    }
185
186    /// Return a new iterator over the peaks of the MMR.
187    pub fn peak_iterator(&self) -> PeakIterator {
188        PeakIterator::new(self.size())
189    }
190
191    /// Return the position of the element given its index in the current nodes vector.
192    fn index_to_pos(&self, index: usize) -> Position {
193        self.pruned_to_pos + (index as u64)
194    }
195
196    /// Return the requested node if it is either retained or present in the pinned_nodes map, and
197    /// panic otherwise. Use `get_node` instead if you require a non-panicking getter.
198    ///
199    /// # Panics
200    ///
201    /// Panics if the requested node does not exist for any reason such as the node is pruned or
202    /// `pos` is out of bounds.
203    pub(crate) fn get_node_unchecked(&self, pos: Position) -> &D {
204        if pos < self.pruned_to_pos {
205            return self
206                .pinned_nodes
207                .get(&pos)
208                .expect("requested node is pruned and not pinned");
209        }
210
211        &self.nodes[self.pos_to_index(pos)]
212    }
213
214    /// Return the index of the element in the current nodes vector given its position in the MMR.
215    ///
216    /// # Panics
217    ///
218    /// Panics if `pos` precedes the oldest retained position.
219    fn pos_to_index(&self, pos: Position) -> usize {
220        assert!(
221            pos >= self.pruned_to_pos,
222            "pos precedes oldest retained position"
223        );
224
225        *pos.checked_sub(*self.pruned_to_pos).unwrap() as usize
226    }
227
228    /// Utility used by stores that build on the mem MMR to pin extra nodes if needed. It's up to
229    /// the caller to ensure that this set of pinned nodes is valid for their use case.
230    #[cfg(any(feature = "std", test))]
231    pub(crate) fn add_pinned_nodes(&mut self, pinned_nodes: BTreeMap<Position, D>) {
232        for (pos, node) in pinned_nodes.into_iter() {
233            self.pinned_nodes.insert(pos, node);
234        }
235    }
236
237    /// Add `element` to the MMR and return its position.
238    /// The element can be an arbitrary byte slice, and need not be converted to a digest first.
239    pub fn add<H: Hasher<D>>(&mut self, hasher: &mut H, element: &[u8]) -> Position {
240        let digest = hasher.leaf_digest(self.size(), element);
241        S::add_leaf_digest(self, hasher, digest)
242    }
243}
244
245/// Implementation for Clean MMR state.
246impl<D: Digest> CleanMmr<D> {
247    /// Return an [Mmr] initialized with the given `config`.
248    ///
249    /// # Errors
250    ///
251    /// Returns [Error::InvalidPinnedNodes] if the number of pinned nodes doesn't match the expected
252    /// count for `config.pruned_to_pos`.
253    ///
254    /// Returns [Error::InvalidSize] if the MMR size is invalid.
255    pub fn init(config: Config<D>, hasher: &mut impl Hasher<D>) -> Result<Self, Error> {
256        // Validate that the total size is valid
257        let Some(size) = config.pruned_to_pos.checked_add(config.nodes.len() as u64) else {
258            return Err(Error::InvalidSize(u64::MAX));
259        };
260        if !size.is_mmr_size() {
261            return Err(Error::InvalidSize(*size));
262        }
263
264        // Validate and populate pinned nodes
265        let mut pinned_nodes = BTreeMap::new();
266        let mut expected_pinned_nodes = 0;
267        for (i, pos) in nodes_to_pin(config.pruned_to_pos).enumerate() {
268            expected_pinned_nodes += 1;
269            if i >= config.pinned_nodes.len() {
270                return Err(Error::InvalidPinnedNodes);
271            }
272            pinned_nodes.insert(pos, config.pinned_nodes[i]);
273        }
274
275        // Check for too many pinned nodes
276        if config.pinned_nodes.len() != expected_pinned_nodes {
277            return Err(Error::InvalidPinnedNodes);
278        }
279
280        let mmr = Mmr {
281            nodes: VecDeque::from(config.nodes),
282            pruned_to_pos: config.pruned_to_pos,
283            pinned_nodes,
284            state: Dirty::default(),
285        };
286        Ok(mmr.merkleize(hasher, None))
287    }
288
289    /// Create a new, empty MMR in the Clean state.
290    pub fn new(hasher: &mut impl Hasher<D>) -> Self {
291        let mmr: DirtyMmr<D> = Default::default();
292        mmr.merkleize(hasher, None)
293    }
294
295    /// Re-initialize the MMR with the given nodes, pruned_to_pos, and pinned_nodes.
296    pub fn from_components(
297        hasher: &mut impl Hasher<D>,
298        nodes: Vec<D>,
299        pruned_to_pos: Position,
300        pinned_nodes: Vec<D>,
301    ) -> Self {
302        DirtyMmr::from_components(nodes, pruned_to_pos, pinned_nodes).merkleize(hasher, None)
303    }
304
305    /// Return the requested node or None if it is not stored in the MMR.
306    pub fn get_node(&self, pos: Position) -> Option<D> {
307        if pos < self.pruned_to_pos {
308            return self.pinned_nodes.get(&pos).copied();
309        }
310
311        self.nodes.get(self.pos_to_index(pos)).copied()
312    }
313
314    /// Add a leaf's `digest` to the MMR, generating the necessary parent nodes to maintain the
315    /// MMR's structure.
316    pub(super) fn add_leaf_digest(&mut self, hasher: &mut impl Hasher<D>, digest: D) -> Position {
317        let mut dirty_mmr = mem::replace(self, Self::new(hasher)).into_dirty();
318        let leaf_pos = dirty_mmr.add_leaf_digest(hasher, digest);
319        *self = dirty_mmr.merkleize(hasher, None);
320        leaf_pos
321    }
322
323    /// Pop the most recent leaf element out of the MMR if it exists, returning Empty or
324    /// ElementPruned errors otherwise.
325    pub fn pop(&mut self, hasher: &mut impl Hasher<D>) -> Result<Position, Error> {
326        let mut dirty_mmr = mem::replace(self, Self::new(hasher)).into_dirty();
327        let result = dirty_mmr.pop();
328        *self = dirty_mmr.merkleize(hasher, None);
329        result
330    }
331
332    /// Get the nodes (position + digest) that need to be pinned (those required for proof
333    /// generation) in this MMR when pruned to position `prune_pos`.
334    pub(crate) fn nodes_to_pin(&self, prune_pos: Position) -> BTreeMap<Position, D> {
335        nodes_to_pin(prune_pos)
336            .map(|pos| (pos, *self.get_node_unchecked(pos)))
337            .collect()
338    }
339
340    /// Prune all nodes up to but not including the given position, and pin the O(log2(n)) number of
341    /// them required for proof generation.
342    pub fn prune_to_pos(&mut self, pos: Position) {
343        // Recompute the set of older nodes to retain.
344        self.pinned_nodes = self.nodes_to_pin(pos);
345        let retained_nodes = self.pos_to_index(pos);
346        self.nodes.drain(0..retained_nodes);
347        self.pruned_to_pos = pos;
348    }
349
350    /// Prune all nodes and pin the O(log2(n)) number of them required for proof generation going
351    /// forward.
352    pub fn prune_all(&mut self) {
353        if !self.nodes.is_empty() {
354            let pos = self.index_to_pos(self.nodes.len());
355            self.prune_to_pos(pos);
356        }
357    }
358
359    /// Change the digest of any retained leaf. This is useful if you want to use the MMR
360    /// implementation as an updatable binary Merkle tree, and otherwise should be avoided.
361    ///
362    /// # Errors
363    ///
364    /// Returns [Error::ElementPruned] if the leaf has been pruned.
365    /// Returns [Error::LeafOutOfBounds] if `loc` is not an existing leaf.
366    /// Returns [Error::LocationOverflow] if `loc` > [crate::mmr::MAX_LOCATION].
367    ///
368    /// # Warning
369    ///
370    /// This method will change the root and invalidate any previous inclusion proofs.
371    /// Use of this method will prevent using this structure as a base mmr for grafting.
372    pub fn update_leaf(
373        &mut self,
374        hasher: &mut impl Hasher<D>,
375        loc: Location,
376        element: &[u8],
377    ) -> Result<(), Error> {
378        let mut dirty_mmr = mem::replace(self, Self::new(hasher)).into_dirty();
379        let result = dirty_mmr.update_leaf(hasher, loc, element);
380        *self = dirty_mmr.merkleize(hasher, None);
381        result
382    }
383
384    /// Convert this Clean MMR into a Dirty MMR without making any changes to it.
385    pub fn into_dirty(self) -> DirtyMmr<D> {
386        self.into()
387    }
388
389    /// Get the root digest of the MMR.
390    pub const fn root(&self) -> &D {
391        &self.state.root
392    }
393
394    /// Returns the root that would be produced by calling `root` on an empty MMR.
395    pub fn empty_mmr_root(hasher: &mut impl commonware_cryptography::Hasher<Digest = D>) -> D {
396        hasher.update(&0u64.to_be_bytes());
397        hasher.finalize()
398    }
399
400    /// Return an inclusion proof for the element at location `loc`.
401    ///
402    /// # Errors
403    ///
404    /// Returns [Error::LocationOverflow] if `loc` > [crate::mmr::MAX_LOCATION].
405    /// Returns [Error::ElementPruned] if some element needed to generate the proof has been pruned.
406    ///
407    /// # Panics
408    ///
409    /// Panics if `loc` is out of bounds.
410    pub fn proof(&self, loc: Location) -> Result<Proof<D>, Error> {
411        if !loc.is_valid() {
412            return Err(Error::LocationOverflow(loc));
413        }
414        // loc is valid so it won't overflow from + 1
415        self.range_proof(loc..loc + 1)
416    }
417
418    /// Return an inclusion proof for all elements within the provided `range` of locations.
419    ///
420    /// # Errors
421    ///
422    /// Returns [Error::Empty] if the range is empty.
423    /// Returns [Error::LocationOverflow] if any location in `range` exceeds [crate::mmr::MAX_LOCATION].
424    /// Returns [Error::ElementPruned] if some element needed to generate the proof has been pruned.
425    ///
426    /// # Panics
427    ///
428    /// Panics if the element range is out of bounds.
429    pub fn range_proof(&self, range: Range<Location>) -> Result<Proof<D>, Error> {
430        let leaves = self.leaves();
431        assert!(
432            range.start < leaves,
433            "range start {} >= leaf count {}",
434            range.start,
435            leaves
436        );
437        assert!(
438            range.end <= leaves,
439            "range end {} > leaf count {}",
440            range.end,
441            leaves
442        );
443
444        let size = self.size();
445        let positions = proof::nodes_required_for_range_proof(size, range)?;
446        let digests = positions
447            .into_iter()
448            .map(|pos| self.get_node(pos).ok_or(Error::ElementPruned(pos)))
449            .collect::<Result<Vec<_>, _>>()?;
450
451        Ok(Proof { size, digests })
452    }
453
454    /// Get the digests of nodes that need to be pinned (those required for proof generation) in
455    /// this MMR when pruned to position `prune_pos`.
456    #[cfg(test)]
457    pub(crate) fn node_digests_to_pin(&self, start_pos: Position) -> Vec<D> {
458        nodes_to_pin(start_pos)
459            .map(|pos| *self.get_node_unchecked(pos))
460            .collect()
461    }
462
463    /// Return the nodes this MMR currently has pinned. Pinned nodes are nodes that would otherwise
464    /// be pruned, but whose digests remain required for proof generation.
465    #[cfg(test)]
466    pub(super) fn pinned_nodes(&self) -> BTreeMap<Position, D> {
467        self.pinned_nodes.clone()
468    }
469}
470
471/// Implementation for Dirty MMR state.
472impl<D: Digest> DirtyMmr<D> {
473    /// Return a new (empty) `Mmr`.
474    pub fn new() -> Self {
475        Self {
476            nodes: VecDeque::new(),
477            pruned_to_pos: Position::new(0),
478            pinned_nodes: BTreeMap::new(),
479            state: Dirty::default(),
480        }
481    }
482
483    /// Re-initialize the MMR with the given nodes, pruned_to_pos, and pinned_nodes.
484    pub fn from_components(nodes: Vec<D>, pruned_to_pos: Position, pinned_nodes: Vec<D>) -> Self {
485        Self {
486            nodes: VecDeque::from(nodes),
487            pruned_to_pos,
488            pinned_nodes: nodes_to_pin(pruned_to_pos)
489                .enumerate()
490                .map(|(i, pos)| (pos, pinned_nodes[i]))
491                .collect(),
492            state: Dirty::default(),
493        }
494    }
495
496    /// Add `digest` as a new leaf in the MMR, returning its position.
497    // TODO(#2318): Remove _hasher which is only used to create dummy digests.
498    pub(super) fn add_leaf_digest<H: Hasher<D>>(&mut self, _hasher: &mut H, digest: D) -> Position {
499        // Compute the new parent nodes, if any.
500        let nodes_needing_parents = nodes_needing_parents(self.peak_iterator())
501            .into_iter()
502            .rev();
503        let leaf_pos = self.size();
504        self.nodes.push_back(digest);
505
506        let mut height = 1;
507        for _ in nodes_needing_parents {
508            let new_node_pos = self.size();
509            self.nodes
510                .push_back(<H::Inner as commonware_cryptography::Hasher>::EMPTY);
511            self.state.dirty_nodes.insert((new_node_pos, height));
512            height += 1;
513        }
514
515        leaf_pos
516    }
517
518    /// Pop the most recent leaf element out of the MMR if it exists, returning Empty or
519    /// ElementPruned errors otherwise.
520    pub fn pop(&mut self) -> Result<Position, Error> {
521        if self.size() == 0 {
522            return Err(Empty);
523        }
524
525        let mut new_size = self.size() - 1;
526        loop {
527            if new_size < self.pruned_to_pos {
528                return Err(ElementPruned(new_size));
529            }
530            if new_size.is_mmr_size() {
531                break;
532            }
533            new_size -= 1;
534        }
535        let num_to_drain = *(self.size() - new_size) as usize;
536        self.nodes.drain(self.nodes.len() - num_to_drain..);
537
538        // Remove dirty nodes that are now out of bounds.
539        let cutoff = (self.size(), 0);
540        self.state.dirty_nodes.split_off(&cutoff);
541
542        Ok(self.size())
543    }
544
545    /// Compute updated digests for dirty nodes and compute the root, converting this MMR into a
546    /// [CleanMmr].
547    pub fn merkleize(
548        mut self,
549        hasher: &mut impl Hasher<D>,
550        pool: Option<ThreadPool>,
551    ) -> CleanMmr<D> {
552        #[cfg(feature = "std")]
553        match (pool, self.state.dirty_nodes.len() >= MIN_TO_PARALLELIZE) {
554            (Some(pool), true) => self.merkleize_parallel(hasher, pool, MIN_TO_PARALLELIZE),
555            _ => self.merkleize_serial(hasher),
556        }
557
558        #[cfg(not(feature = "std"))]
559        self.merkleize_serial(hasher);
560
561        // Compute root
562        let peaks = self
563            .peak_iterator()
564            .map(|(peak_pos, _)| self.get_node_unchecked(peak_pos));
565        let size = self.size();
566        let digest = hasher.root(size, peaks);
567
568        CleanMmr {
569            nodes: self.nodes,
570            pruned_to_pos: self.pruned_to_pos,
571            pinned_nodes: self.pinned_nodes,
572            state: Clean { root: digest },
573        }
574    }
575
576    fn merkleize_serial(&mut self, hasher: &mut impl Hasher<D>) {
577        let mut nodes: Vec<(Position, u32)> = self.state.dirty_nodes.iter().copied().collect();
578        self.state.dirty_nodes.clear();
579        nodes.sort_by(|a, b| a.1.cmp(&b.1));
580
581        for (pos, height) in nodes {
582            let left = pos - (1 << height);
583            let right = pos - 1;
584            let digest = hasher.node_digest(
585                pos,
586                self.get_node_unchecked(left),
587                self.get_node_unchecked(right),
588            );
589            let index = self.pos_to_index(pos);
590            self.nodes[index] = digest;
591        }
592    }
593
594    /// Process any pending batched updates, using parallel hash workers as long as the number of
595    /// computations that can be parallelized exceeds `min_to_parallelize`.
596    ///
597    /// This implementation parallelizes the computation of digests across all nodes at the same
598    /// height, starting from the bottom and working up to the peaks. If ever the number of
599    /// remaining digest computations is less than the `min_to_parallelize`, it switches to the
600    /// serial implementation.
601    #[cfg(feature = "std")]
602    fn merkleize_parallel(
603        &mut self,
604        hasher: &mut impl Hasher<D>,
605        pool: ThreadPool,
606        min_to_parallelize: usize,
607    ) {
608        let mut nodes: Vec<(Position, u32)> = self.state.dirty_nodes.iter().copied().collect();
609        self.state.dirty_nodes.clear();
610        // Sort by increasing height.
611        nodes.sort_by(|a, b| a.1.cmp(&b.1));
612
613        let mut same_height = Vec::new();
614        let mut current_height = 1;
615        for (i, (pos, height)) in nodes.iter().enumerate() {
616            if *height == current_height {
617                same_height.push(*pos);
618                continue;
619            }
620            if same_height.len() < min_to_parallelize {
621                self.state.dirty_nodes = nodes[i - same_height.len()..].iter().copied().collect();
622                self.merkleize_serial(hasher);
623                return;
624            }
625            self.update_node_digests(hasher, pool.clone(), &same_height, current_height);
626            same_height.clear();
627            current_height += 1;
628            same_height.push(*pos);
629        }
630
631        if same_height.len() < min_to_parallelize {
632            self.state.dirty_nodes = nodes[nodes.len() - same_height.len()..]
633                .iter()
634                .copied()
635                .collect();
636            self.merkleize_serial(hasher);
637            return;
638        }
639
640        self.update_node_digests(hasher, pool, &same_height, current_height);
641    }
642
643    /// Update digests of the given set of nodes of equal height in the MMR. Since they are all at
644    /// the same height, this can be done in parallel without synchronization.
645    #[cfg(feature = "std")]
646    fn update_node_digests(
647        &mut self,
648        hasher: &mut impl Hasher<D>,
649        pool: ThreadPool,
650        same_height: &[Position],
651        height: u32,
652    ) {
653        let two_h = 1 << height;
654        pool.install(|| {
655            let computed_digests: Vec<(usize, D)> = same_height
656                .par_iter()
657                .map_init(
658                    || hasher.fork(),
659                    |hasher, &pos| {
660                        let left = pos - two_h;
661                        let right = pos - 1;
662                        let digest = hasher.node_digest(
663                            pos,
664                            self.get_node_unchecked(left),
665                            self.get_node_unchecked(right),
666                        );
667                        let index = self.pos_to_index(pos);
668                        (index, digest)
669                    },
670                )
671                .collect();
672
673            for (index, digest) in computed_digests {
674                self.nodes[index] = digest;
675            }
676        });
677    }
678
679    /// Mark the non-leaf nodes in the path from the given position to the root as dirty, so that
680    /// their digests are appropriately recomputed during the next `merkleize`.
681    fn mark_dirty(&mut self, pos: Position) {
682        for (peak_pos, mut height) in self.peak_iterator() {
683            if peak_pos < pos {
684                continue;
685            }
686
687            // We have found the mountain containing the path we are looking for. Traverse it from
688            // leaf to root, that way we can exit early if we hit a node that is already dirty.
689            let path = PathIterator::new(pos, peak_pos, height)
690                .collect::<Vec<_>>()
691                .into_iter()
692                .rev();
693            height = 1;
694            for (parent_pos, _) in path {
695                if !self.state.dirty_nodes.insert((parent_pos, height)) {
696                    break;
697                }
698                height += 1;
699            }
700            return;
701        }
702
703        panic!("invalid pos {pos}:{}", self.size());
704    }
705
706    /// Update the leaf at `loc` to `element`.
707    pub fn update_leaf(
708        &mut self,
709        hasher: &mut impl Hasher<D>,
710        loc: Location,
711        element: &[u8],
712    ) -> Result<(), Error> {
713        self.update_leaf_batched(hasher, None, &[(loc, element)])
714    }
715
716    /// Batch update the digests of multiple retained leaves.
717    ///
718    /// # Errors
719    ///
720    /// Returns [Error::LeafOutOfBounds] if any location is not an existing leaf.
721    /// Returns [Error::LocationOverflow] if any location exceeds [crate::mmr::MAX_LOCATION].
722    /// Returns [Error::ElementPruned] if any of the leaves has been pruned.
723    pub fn update_leaf_batched<T: AsRef<[u8]> + Sync>(
724        &mut self,
725        hasher: &mut impl Hasher<D>,
726        pool: Option<ThreadPool>,
727        updates: &[(Location, T)],
728    ) -> Result<(), Error> {
729        if updates.is_empty() {
730            return Ok(());
731        }
732
733        let leaves = self.leaves();
734        let mut positions = Vec::with_capacity(updates.len());
735        for (loc, _) in updates {
736            if *loc >= leaves {
737                return Err(Error::LeafOutOfBounds(*loc));
738            }
739            let pos = Position::try_from(*loc)?;
740            if pos < self.pruned_to_pos {
741                return Err(Error::ElementPruned(pos));
742            }
743            positions.push(pos);
744        }
745
746        #[cfg(feature = "std")]
747        if let Some(pool) = pool {
748            if updates.len() >= MIN_TO_PARALLELIZE {
749                self.update_leaf_parallel(hasher, pool, updates, &positions);
750                return Ok(());
751            }
752        }
753
754        for ((_, element), pos) in updates.iter().zip(positions.iter()) {
755            // Update the digest of the leaf node and mark its ancestors as dirty.
756            let digest = hasher.leaf_digest(*pos, element.as_ref());
757            let index = self.pos_to_index(*pos);
758            self.nodes[index] = digest;
759            self.mark_dirty(*pos);
760        }
761
762        Ok(())
763    }
764
765    /// Batch update the digests of multiple retained leaves using multiple threads.
766    #[cfg(feature = "std")]
767    fn update_leaf_parallel<T: AsRef<[u8]> + Sync>(
768        &mut self,
769        hasher: &mut impl Hasher<D>,
770        pool: ThreadPool,
771        updates: &[(Location, T)],
772        positions: &[Position],
773    ) {
774        pool.install(|| {
775            let digests: Vec<(Position, D)> = updates
776                .par_iter()
777                .zip(positions.par_iter())
778                .map_init(
779                    || hasher.fork(),
780                    |hasher, ((_, elem), pos)| {
781                        let digest = hasher.leaf_digest(*pos, elem.as_ref());
782                        (*pos, digest)
783                    },
784                )
785                .collect();
786
787            for (pos, digest) in digests {
788                let index = self.pos_to_index(pos);
789                self.nodes[index] = digest;
790                self.mark_dirty(pos);
791            }
792        });
793    }
794}
795
796#[cfg(test)]
797mod tests {
798    use super::*;
799    use crate::mmr::{
800        hasher::{Hasher as _, Standard},
801        stability::ROOTS,
802    };
803    use commonware_cryptography::{sha256, Hasher, Sha256};
804    use commonware_runtime::{create_pool, deterministic, tokio, Runner};
805    use commonware_utils::hex;
806
807    /// Build the MMR corresponding to the stability test `ROOTS` and confirm the roots match.
808    fn build_and_check_test_roots_mmr(mmr: &mut CleanMmr<sha256::Digest>) {
809        let mut hasher: Standard<Sha256> = Standard::new();
810        for i in 0u64..199 {
811            hasher.inner().update(&i.to_be_bytes());
812            let element = hasher.inner().finalize();
813            let root = *mmr.root();
814            let expected_root = ROOTS[i as usize];
815            assert_eq!(hex(&root), expected_root, "at: {i}");
816            mmr.add(&mut hasher, &element);
817        }
818        assert_eq!(hex(mmr.root()), ROOTS[199], "Root after 200 elements");
819    }
820
821    /// Same as `build_and_check_test_roots` but uses `add` + `merkleize` instead of `add`.
822    pub fn build_batched_and_check_test_roots(
823        mut mmr: DirtyMmr<sha256::Digest>,
824        pool: Option<ThreadPool>,
825    ) {
826        let mut hasher: Standard<Sha256> = Standard::new();
827        for i in 0u64..199 {
828            hasher.inner().update(&i.to_be_bytes());
829            let element = hasher.inner().finalize();
830            mmr.add(&mut hasher, &element);
831        }
832        let mmr = mmr.merkleize(&mut hasher, pool);
833        assert_eq!(hex(mmr.root()), ROOTS[199], "Root after 200 elements");
834    }
835
836    /// Test empty MMR behavior.
837    #[test]
838    fn test_mem_mmr_empty() {
839        let executor = deterministic::Runner::default();
840        executor.start(|_| async move {
841            let mut hasher: Standard<Sha256> = Standard::new();
842            let mut mmr = CleanMmr::new(&mut hasher);
843            assert_eq!(
844                mmr.peak_iterator().next(),
845                None,
846                "empty iterator should have no peaks"
847            );
848            assert_eq!(mmr.size(), 0);
849            assert_eq!(mmr.leaves(), Location::new_unchecked(0));
850            assert_eq!(mmr.last_leaf_pos(), None);
851            assert_eq!(mmr.oldest_retained_pos(), None);
852            assert_eq!(mmr.get_node(Position::new(0)), None);
853            assert_eq!(*mmr.root(), Mmr::empty_mmr_root(hasher.inner()));
854            assert!(matches!(mmr.pop(&mut hasher), Err(Empty)));
855            mmr.prune_all();
856            assert_eq!(mmr.size(), 0, "prune_all on empty MMR should do nothing");
857
858            assert_eq!(*mmr.root(), hasher.root(Position::new(0), [].iter()));
859        });
860    }
861
862    /// Test MMR building by consecutively adding 11 equal elements to a new MMR, producing the
863    /// structure in the example documented at the top of the mmr crate's mod.rs file with 19 nodes
864    /// and 3 peaks.
865    #[test]
866    fn test_mem_mmr_add_eleven_values() {
867        let executor = deterministic::Runner::default();
868        executor.start(|_| async move {
869            let mut hasher: Standard<Sha256> = Standard::new();
870            let mut mmr = CleanMmr::new(&mut hasher);
871            let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
872            let mut leaves: Vec<Position> = Vec::new();
873            for _ in 0..11 {
874                leaves.push(mmr.add(&mut hasher, &element));
875                let peaks: Vec<(Position, u32)> = mmr.peak_iterator().collect();
876                assert_ne!(peaks.len(), 0);
877                assert!(peaks.len() as u64 <= mmr.size());
878                let nodes_needing_parents = nodes_needing_parents(mmr.peak_iterator());
879                assert!(nodes_needing_parents.len() <= peaks.len());
880            }
881            assert_eq!(mmr.oldest_retained_pos().unwrap(), Position::new(0));
882            assert_eq!(mmr.size(), 19, "mmr not of expected size");
883            assert_eq!(
884                leaves,
885                vec![0, 1, 3, 4, 7, 8, 10, 11, 15, 16, 18]
886                    .into_iter()
887                    .map(Position::new)
888                    .collect::<Vec<_>>(),
889                "mmr leaf positions not as expected"
890            );
891            let peaks: Vec<(Position, u32)> = mmr.peak_iterator().collect();
892            assert_eq!(
893                peaks,
894                vec![
895                    (Position::new(14), 3),
896                    (Position::new(17), 1),
897                    (Position::new(18), 0)
898                ],
899                "mmr peaks not as expected"
900            );
901
902            // Test nodes_needing_parents on the final MMR. Since there's a height gap between the
903            // highest peak (14) and the next, only the lower two peaks (17, 18) should be returned.
904            let peaks_needing_parents = nodes_needing_parents(mmr.peak_iterator());
905            assert_eq!(
906                peaks_needing_parents,
907                vec![Position::new(17), Position::new(18)],
908                "mmr nodes needing parents not as expected"
909            );
910
911            // verify leaf digests
912            for leaf in leaves.iter().by_ref() {
913                let digest = hasher.leaf_digest(*leaf, &element);
914                assert_eq!(mmr.get_node(*leaf).unwrap(), digest);
915            }
916
917            // verify height=1 node digests
918            let digest2 = hasher.node_digest(Position::new(2), &mmr.nodes[0], &mmr.nodes[1]);
919            assert_eq!(mmr.nodes[2], digest2);
920            let digest5 = hasher.node_digest(Position::new(5), &mmr.nodes[3], &mmr.nodes[4]);
921            assert_eq!(mmr.nodes[5], digest5);
922            let digest9 = hasher.node_digest(Position::new(9), &mmr.nodes[7], &mmr.nodes[8]);
923            assert_eq!(mmr.nodes[9], digest9);
924            let digest12 = hasher.node_digest(Position::new(12), &mmr.nodes[10], &mmr.nodes[11]);
925            assert_eq!(mmr.nodes[12], digest12);
926            let digest17 = hasher.node_digest(Position::new(17), &mmr.nodes[15], &mmr.nodes[16]);
927            assert_eq!(mmr.nodes[17], digest17);
928
929            // verify height=2 node digests
930            let digest6 = hasher.node_digest(Position::new(6), &mmr.nodes[2], &mmr.nodes[5]);
931            assert_eq!(mmr.nodes[6], digest6);
932            let digest13 = hasher.node_digest(Position::new(13), &mmr.nodes[9], &mmr.nodes[12]);
933            assert_eq!(mmr.nodes[13], digest13);
934            let digest17 = hasher.node_digest(Position::new(17), &mmr.nodes[15], &mmr.nodes[16]);
935            assert_eq!(mmr.nodes[17], digest17);
936
937            // verify topmost digest
938            let digest14 = hasher.node_digest(Position::new(14), &mmr.nodes[6], &mmr.nodes[13]);
939            assert_eq!(mmr.nodes[14], digest14);
940
941            // verify root
942            let root = *mmr.root();
943            let peak_digests = [digest14, digest17, mmr.nodes[18]];
944            let expected_root = hasher.root(Position::new(19), peak_digests.iter());
945            assert_eq!(root, expected_root, "incorrect root");
946
947            // pruning tests
948            mmr.prune_to_pos(Position::new(14)); // prune up to the tallest peak
949            assert_eq!(mmr.oldest_retained_pos().unwrap(), Position::new(14));
950
951            // After pruning, we shouldn't be able to generate a proof for any elements before the
952            // pruning boundary. (To be precise, due to the maintenance of pinned nodes, we may in
953            // fact still be able to generate them for some, but it's not guaranteed. For example,
954            // in this case, we actually can still generate a proof for the node with location 7
955            // even though it's pruned.)
956            assert!(matches!(
957                mmr.proof(Location::new_unchecked(0)),
958                Err(ElementPruned(_))
959            ));
960            assert!(matches!(
961                mmr.proof(Location::new_unchecked(6)),
962                Err(ElementPruned(_))
963            ));
964
965            // We should still be able to generate a proof for any leaf following the pruning
966            // boundary, the first of which is at location 8 and the last location 10.
967            assert!(mmr.proof(Location::new_unchecked(8)).is_ok());
968            assert!(mmr.proof(Location::new_unchecked(10)).is_ok());
969
970            let root_after_prune = *mmr.root();
971            assert_eq!(root, root_after_prune, "root changed after pruning");
972
973            assert!(
974                mmr.range_proof(Location::new_unchecked(5)..Location::new_unchecked(9))
975                    .is_err(),
976                "attempts to range_prove elements at or before the oldest retained should fail"
977            );
978            assert!(
979                mmr.range_proof(Location::new_unchecked(8)..mmr.leaves()).is_ok(),
980                "attempts to range_prove over all elements following oldest retained should succeed"
981            );
982
983            // Test that we can initialize a new MMR from another's elements.
984            let oldest_pos = mmr.oldest_retained_pos().unwrap();
985            let digests = mmr.node_digests_to_pin(oldest_pos);
986            let mmr_copy = Mmr::init(
987                Config {
988                    nodes: mmr.nodes.iter().copied().collect(),
989                    pruned_to_pos: oldest_pos,
990                    pinned_nodes: digests,
991                },
992                &mut hasher,
993            )
994            .unwrap();
995            assert_eq!(mmr_copy.size(), 19);
996            assert_eq!(mmr_copy.leaves(), mmr.leaves());
997            assert_eq!(mmr_copy.last_leaf_pos(), mmr.last_leaf_pos());
998            assert_eq!(mmr_copy.oldest_retained_pos(), mmr.oldest_retained_pos());
999            assert_eq!(*mmr_copy.root(), root);
1000        });
1001    }
1002
1003    /// Test that pruning all nodes never breaks adding new nodes.
1004    #[test]
1005    fn test_mem_mmr_prune_all() {
1006        let executor = deterministic::Runner::default();
1007        executor.start(|_| async move {
1008            let mut hasher: Standard<Sha256> = Standard::new();
1009            let mut mmr = CleanMmr::new(&mut hasher);
1010            let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1011            for _ in 0..1000 {
1012                mmr.prune_all();
1013                mmr.add(&mut hasher, &element);
1014            }
1015        });
1016    }
1017
1018    /// Test that the MMR validity check works as expected.
1019    #[test]
1020    fn test_mem_mmr_validity() {
1021        let executor = deterministic::Runner::default();
1022        executor.start(|_| async move {
1023            let mut hasher: Standard<Sha256> = Standard::new();
1024            let mut mmr = CleanMmr::new(&mut hasher);
1025            let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1026            for _ in 0..1001 {
1027                assert!(
1028                    mmr.size().is_mmr_size(),
1029                    "mmr of size {} should be valid",
1030                    mmr.size()
1031                );
1032                let old_size = mmr.size();
1033                mmr.add(&mut hasher, &element);
1034                for size in *old_size + 1..*mmr.size() {
1035                    assert!(
1036                        !Position::new(size).is_mmr_size(),
1037                        "mmr of size {size} should be invalid",
1038                    );
1039                }
1040            }
1041        });
1042    }
1043
1044    /// Test that the MMR root computation remains stable by comparing against previously computed
1045    /// roots.
1046    #[test]
1047    fn test_mem_mmr_root_stability() {
1048        let executor = deterministic::Runner::default();
1049        executor.start(|_| async move {
1050            // Test root stability under different MMR building methods.
1051            let mut hasher: Standard<Sha256> = Standard::new();
1052            let mut mmr = CleanMmr::new(&mut hasher);
1053            build_and_check_test_roots_mmr(&mut mmr);
1054
1055            let mut hasher: Standard<Sha256> = Standard::new();
1056            let mmr = CleanMmr::new(&mut hasher);
1057            build_batched_and_check_test_roots(mmr.into_dirty(), None);
1058        });
1059    }
1060
1061    /// Test root stability using the parallel builder implementation. This requires we use the
1062    /// tokio runtime since the deterministic runtime would block due to being single-threaded.
1063    #[test]
1064    fn test_mem_mmr_root_stability_parallel() {
1065        let executor = tokio::Runner::default();
1066        executor.start(|context| async move {
1067            let pool = commonware_runtime::create_pool(context, 4).unwrap();
1068            let mut hasher: Standard<Sha256> = Standard::new();
1069
1070            let mmr = Mmr::init(
1071                Config {
1072                    nodes: vec![],
1073                    pruned_to_pos: Position::new(0),
1074                    pinned_nodes: vec![],
1075                },
1076                &mut hasher,
1077            )
1078            .unwrap();
1079            build_batched_and_check_test_roots(mmr.into_dirty(), Some(pool));
1080        });
1081    }
1082
1083    /// Build the MMR corresponding to the stability test while pruning after each add, and confirm
1084    /// the static roots match that from the root computation.
1085    #[test]
1086    fn test_mem_mmr_root_stability_while_pruning() {
1087        let executor = deterministic::Runner::default();
1088        executor.start(|_| async move {
1089            let mut hasher: Standard<Sha256> = Standard::new();
1090            let mut mmr = CleanMmr::new(&mut hasher);
1091            for i in 0u64..199 {
1092                let root = *mmr.root();
1093                let expected_root = ROOTS[i as usize];
1094                assert_eq!(hex(&root), expected_root, "at: {i}");
1095                hasher.inner().update(&i.to_be_bytes());
1096                let element = hasher.inner().finalize();
1097                mmr.add(&mut hasher, &element);
1098                mmr.prune_all();
1099            }
1100        });
1101    }
1102
1103    fn compute_big_mmr(
1104        hasher: &mut Standard<Sha256>,
1105        mut mmr: DirtyMmr<sha256::Digest>,
1106        pool: Option<ThreadPool>,
1107    ) -> (CleanMmr<sha256::Digest>, Vec<Position>) {
1108        let mut leaves = Vec::new();
1109        let mut c_hasher = Sha256::default();
1110        for i in 0u64..199 {
1111            c_hasher.update(&i.to_be_bytes());
1112            let element = c_hasher.finalize();
1113            let leaf_pos = mmr.size();
1114            mmr.add(hasher, &element);
1115            leaves.push(leaf_pos);
1116        }
1117
1118        (mmr.merkleize(hasher, pool), leaves)
1119    }
1120
1121    #[test]
1122    fn test_mem_mmr_pop() {
1123        let executor = deterministic::Runner::default();
1124        executor.start(|_| async move {
1125            let mut hasher: Standard<Sha256> = Standard::new();
1126            let (mut mmr, _) = compute_big_mmr(&mut hasher, Mmr::default(), None);
1127            let root = *mmr.root();
1128            let expected_root = ROOTS[199];
1129            assert_eq!(hex(&root), expected_root);
1130
1131            // Pop off one node at a time until empty, confirming the root is still is as expected.
1132            for i in (0..199u64).rev() {
1133                assert!(mmr.pop(&mut hasher).is_ok());
1134                let root = *mmr.root();
1135                let expected_root = ROOTS[i as usize];
1136                assert_eq!(hex(&root), expected_root);
1137            }
1138
1139            assert!(
1140                matches!(mmr.pop(&mut hasher).unwrap_err(), Empty),
1141                "pop on empty MMR should fail"
1142            );
1143
1144            // Test that we can pop all elements up to and including the oldest retained leaf.
1145            for i in 0u64..199 {
1146                hasher.inner().update(&i.to_be_bytes());
1147                let element = hasher.inner().finalize();
1148                mmr.add(&mut hasher, &element);
1149            }
1150
1151            let leaf_pos = Position::try_from(Location::new_unchecked(100)).unwrap();
1152            mmr.prune_to_pos(leaf_pos);
1153            while mmr.size() > leaf_pos {
1154                mmr.pop(&mut hasher).unwrap();
1155            }
1156            assert_eq!(hex(mmr.root()), ROOTS[100]);
1157            let result = mmr.pop(&mut hasher);
1158            assert!(matches!(result, Err(ElementPruned(_))));
1159            assert_eq!(mmr.oldest_retained_pos(), None);
1160        });
1161    }
1162
1163    #[test]
1164    fn test_mem_mmr_update_leaf() {
1165        let mut hasher: Standard<Sha256> = Standard::new();
1166        let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1167        let executor = deterministic::Runner::default();
1168        executor.start(|_| async move {
1169            let (mut mmr, leaves) = compute_big_mmr(&mut hasher, Mmr::default(), None);
1170            let root = *mmr.root();
1171
1172            // For a few leaves, update the leaf and ensure the root changes, and the root reverts
1173            // to its previous state then we update the leaf to its original value.
1174            for leaf in [0usize, 1, 10, 50, 100, 150, 197, 198] {
1175                // Change the leaf.
1176                let leaf_loc =
1177                    Location::try_from(leaves[leaf]).expect("leaf position should map to location");
1178                mmr.update_leaf(&mut hasher, leaf_loc, &element).unwrap();
1179                let updated_root = *mmr.root();
1180                assert!(root != updated_root);
1181
1182                // Restore the leaf to its original value, ensure the root is as before.
1183                hasher.inner().update(&leaf.to_be_bytes());
1184                let element = hasher.inner().finalize();
1185                mmr.update_leaf(&mut hasher, leaf_loc, &element).unwrap();
1186                let restored_root = *mmr.root();
1187                assert_eq!(root, restored_root);
1188            }
1189
1190            // Confirm the tree has all the hashes necessary to update any element after pruning.
1191            mmr.prune_to_pos(leaves[150]);
1192            for &leaf_pos in &leaves[150..=190] {
1193                mmr.prune_to_pos(leaf_pos);
1194                let leaf_loc =
1195                    Location::try_from(leaf_pos).expect("leaf position should map to location");
1196                mmr.update_leaf(&mut hasher, leaf_loc, &element).unwrap();
1197            }
1198        });
1199    }
1200
1201    #[test]
1202    fn test_mem_mmr_update_leaf_error_out_of_bounds() {
1203        let mut hasher: Standard<Sha256> = Standard::new();
1204        let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1205
1206        let executor = deterministic::Runner::default();
1207        executor.start(|_| async move {
1208            let (mut mmr, _) = compute_big_mmr(&mut hasher, Mmr::default(), None);
1209            let invalid_loc = mmr.leaves();
1210            let result = mmr.update_leaf(&mut hasher, invalid_loc, &element);
1211            assert!(matches!(result, Err(Error::LeafOutOfBounds(_))));
1212        });
1213    }
1214
1215    #[test]
1216    fn test_mem_mmr_update_leaf_error_pruned() {
1217        let mut hasher: Standard<Sha256> = Standard::new();
1218        let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1219
1220        let executor = deterministic::Runner::default();
1221        executor.start(|_| async move {
1222            let (mut mmr, _) = compute_big_mmr(&mut hasher, Mmr::default(), None);
1223            mmr.prune_all();
1224            let result = mmr.update_leaf(&mut hasher, Location::new_unchecked(0), &element);
1225            assert!(matches!(result, Err(Error::ElementPruned(_))));
1226        });
1227    }
1228
1229    #[test]
1230    fn test_mem_mmr_batch_update_leaf() {
1231        let mut hasher: Standard<Sha256> = Standard::new();
1232        let executor = deterministic::Runner::default();
1233        executor.start(|_| async move {
1234            let (mmr, leaves) = compute_big_mmr(&mut hasher, Mmr::default(), None);
1235            do_batch_update(&mut hasher, mmr, &leaves);
1236        });
1237    }
1238
1239    #[test]
1240    /// Same test as above only using a thread pool to trigger parallelization. This requires we use
1241    /// tokio runtime instead of the deterministic one.
1242    fn test_mem_mmr_batch_parallel_update_leaf() {
1243        let mut hasher: Standard<Sha256> = Standard::new();
1244        let executor = tokio::Runner::default();
1245        executor.start(|ctx| async move {
1246            let pool = create_pool(ctx, 4).unwrap();
1247            let mmr = Mmr::init(
1248                Config {
1249                    nodes: Vec::new(),
1250                    pruned_to_pos: Position::new(0),
1251                    pinned_nodes: Vec::new(),
1252                },
1253                &mut hasher,
1254            )
1255            .unwrap();
1256            let (mmr, leaves) = compute_big_mmr(&mut hasher, mmr.into_dirty(), Some(pool));
1257            do_batch_update(&mut hasher, mmr, &leaves);
1258        });
1259    }
1260
1261    fn do_batch_update(
1262        hasher: &mut Standard<Sha256>,
1263        mmr: CleanMmr<sha256::Digest>,
1264        leaves: &[Position],
1265    ) {
1266        let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1267        let root = *mmr.root();
1268
1269        // Change a handful of leaves using a batch update.
1270        let mut updates = Vec::new();
1271        for leaf in [0usize, 1, 10, 50, 100, 150, 197, 198] {
1272            let leaf_loc =
1273                Location::try_from(leaves[leaf]).expect("leaf position should map to location");
1274            updates.push((leaf_loc, &element));
1275        }
1276        let mut dirty_mmr = mmr.into_dirty();
1277        dirty_mmr
1278            .update_leaf_batched(hasher, None, &updates)
1279            .unwrap();
1280
1281        let mmr = dirty_mmr.merkleize(hasher, None);
1282        let updated_root = *mmr.root();
1283        assert_eq!(
1284            "af3acad6aad59c1a880de643b1200a0962a95d06c087ebf677f29eb93fc359a4",
1285            hex(&updated_root)
1286        );
1287
1288        // Batch-restore the changed leaves to their original values.
1289        let mut updates = Vec::new();
1290        for leaf in [0usize, 1, 10, 50, 100, 150, 197, 198] {
1291            hasher.inner().update(&leaf.to_be_bytes());
1292            let element = hasher.inner().finalize();
1293            let leaf_loc =
1294                Location::try_from(leaves[leaf]).expect("leaf position should map to location");
1295            updates.push((leaf_loc, element));
1296        }
1297        let mut dirty_mmr = mmr.into_dirty();
1298        dirty_mmr
1299            .update_leaf_batched(hasher, None, &updates)
1300            .unwrap();
1301
1302        let mmr = dirty_mmr.merkleize(hasher, None);
1303        let restored_root = *mmr.root();
1304        assert_eq!(root, restored_root);
1305    }
1306
1307    #[test]
1308    fn test_init_pinned_nodes_validation() {
1309        let executor = deterministic::Runner::default();
1310        executor.start(|_| async move {
1311            let mut hasher: Standard<Sha256> = Standard::new();
1312            // Test with empty config - should succeed
1313            let config = Config::<sha256::Digest> {
1314                nodes: vec![],
1315                pruned_to_pos: Position::new(0),
1316                pinned_nodes: vec![],
1317            };
1318            assert!(Mmr::init(config, &mut hasher).is_ok());
1319
1320            // Test with too few pinned nodes - should fail
1321            // Use a valid MMR size (127 is valid: 2^7 - 1 makes a complete tree)
1322            let config = Config::<sha256::Digest> {
1323                nodes: vec![],
1324                pruned_to_pos: Position::new(127),
1325                pinned_nodes: vec![], // Should have nodes for position 127
1326            };
1327            assert!(matches!(
1328                Mmr::init(config, &mut hasher),
1329                Err(Error::InvalidPinnedNodes)
1330            ));
1331
1332            // Test with too many pinned nodes - should fail
1333            let config = Config {
1334                nodes: vec![],
1335                pruned_to_pos: Position::new(0),
1336                pinned_nodes: vec![Sha256::hash(b"dummy")],
1337            };
1338            assert!(matches!(
1339                Mmr::init(config, &mut hasher),
1340                Err(Error::InvalidPinnedNodes)
1341            ));
1342
1343            // Test with correct number of pinned nodes - should succeed
1344            // Build a small MMR to get valid pinned nodes
1345            let mut mmr = CleanMmr::new(&mut hasher);
1346            for i in 0u64..50 {
1347                mmr.add(&mut hasher, &i.to_be_bytes());
1348            }
1349            let pinned_nodes = mmr.node_digests_to_pin(Position::new(50));
1350            let config = Config {
1351                nodes: vec![],
1352                pruned_to_pos: Position::new(50),
1353                pinned_nodes,
1354            };
1355            assert!(Mmr::init(config, &mut hasher).is_ok());
1356        });
1357    }
1358
1359    #[test]
1360    fn test_init_size_validation() {
1361        let executor = deterministic::Runner::default();
1362        executor.start(|_| async move {
1363            let mut hasher: Standard<Sha256> = Standard::new();
1364            // Test with valid size 0 - should succeed
1365            let config = Config::<sha256::Digest> {
1366                nodes: vec![],
1367                pruned_to_pos: Position::new(0),
1368                pinned_nodes: vec![],
1369            };
1370            assert!(Mmr::init(config, &mut hasher).is_ok());
1371
1372            // Test with invalid size 2 - should fail
1373            // Size 2 is invalid (can't have just one parent node + one leaf)
1374            let config = Config {
1375                nodes: vec![Sha256::hash(b"node1"), Sha256::hash(b"node2")],
1376                pruned_to_pos: Position::new(0),
1377                pinned_nodes: vec![],
1378            };
1379            assert!(matches!(
1380                Mmr::init(config, &mut hasher),
1381                Err(Error::InvalidSize(_))
1382            ));
1383
1384            // Test with valid size 3 (one full tree with 2 leaves) - should succeed
1385            let config = Config {
1386                nodes: vec![
1387                    Sha256::hash(b"leaf1"),
1388                    Sha256::hash(b"leaf2"),
1389                    Sha256::hash(b"parent"),
1390                ],
1391                pruned_to_pos: Position::new(0),
1392                pinned_nodes: vec![],
1393            };
1394            assert!(Mmr::init(config, &mut hasher).is_ok());
1395
1396            // Test with large valid size (127 = 2^7 - 1, a complete tree) - should succeed
1397            // Build a real MMR to get the correct structure
1398            let mut mmr = CleanMmr::new(&mut hasher);
1399            for i in 0u64..64 {
1400                mmr.add(&mut hasher, &i.to_be_bytes());
1401            }
1402            assert_eq!(mmr.size(), 127); // Verify we have the expected size
1403            let nodes: Vec<_> = (0..127)
1404                .map(|i| *mmr.get_node_unchecked(Position::new(i)))
1405                .collect();
1406
1407            let config = Config {
1408                nodes,
1409                pruned_to_pos: Position::new(0),
1410                pinned_nodes: vec![],
1411            };
1412            assert!(Mmr::init(config, &mut hasher).is_ok());
1413
1414            // Test with non-zero pruned_to_pos - should succeed
1415            // Build a small MMR (11 leaves -> 19 nodes), prune it, then init from that state
1416            let mut mmr = CleanMmr::new(&mut hasher);
1417            for i in 0u64..11 {
1418                mmr.add(&mut hasher, &i.to_be_bytes());
1419            }
1420            assert_eq!(mmr.size(), 19); // 11 leaves = 19 total nodes
1421
1422            // Prune to position 7
1423            mmr.prune_to_pos(Position::new(7));
1424            let nodes: Vec<_> = (7..*mmr.size())
1425                .map(|i| *mmr.get_node_unchecked(Position::new(i)))
1426                .collect();
1427            let pinned_nodes = mmr.node_digests_to_pin(Position::new(7));
1428
1429            let config = Config {
1430                nodes: nodes.clone(),
1431                pruned_to_pos: Position::new(7),
1432                pinned_nodes: pinned_nodes.clone(),
1433            };
1434            assert!(Mmr::init(config, &mut hasher).is_ok());
1435
1436            // Same nodes but wrong pruned_to_pos - should fail
1437            // pruned_to_pos=8 + 12 nodes = size 20 (invalid)
1438            let config = Config {
1439                nodes: nodes.clone(),
1440                pruned_to_pos: Position::new(8),
1441                pinned_nodes: pinned_nodes.clone(),
1442            };
1443            assert!(matches!(
1444                Mmr::init(config, &mut hasher),
1445                Err(Error::InvalidSize(_))
1446            ));
1447
1448            // Same nodes but different wrong pruned_to_pos - should fail
1449            // pruned_to_pos=9 + 12 nodes = size 21 (invalid)
1450            let config = Config {
1451                nodes,
1452                pruned_to_pos: Position::new(9),
1453                pinned_nodes,
1454            };
1455            assert!(matches!(
1456                Mmr::init(config, &mut hasher),
1457                Err(Error::InvalidSize(_))
1458            ));
1459        });
1460    }
1461}