Skip to main content

commonware_storage/mmr/
mem.rs

1//! A basic, no_std compatible MMR where all nodes are stored in-memory.
2
3use crate::mmr::{
4    hasher::Hasher,
5    iterator::{nodes_needing_parents, nodes_to_pin, PathIterator, PeakIterator},
6    proof,
7    Error::{self, *},
8    Location, Position, Proof,
9};
10use alloc::{
11    collections::{BTreeMap, BTreeSet, VecDeque},
12    vec::Vec,
13};
14use commonware_cryptography::Digest;
15use core::{mem, ops::Range};
16cfg_if::cfg_if! {
17    if #[cfg(feature = "std")] {
18        use commonware_parallel::ThreadPool;
19        use rayon::prelude::*;
20    } else {
21        /// Placeholder for no_std builds where parallelism is unavailable.
22        // TODO(#3001): Migrate to commonware-parallel
23        pub struct ThreadPool;
24    }
25}
26
27/// Minimum number of digest computations required during batch updates to trigger parallelization.
28#[cfg(feature = "std")]
29const MIN_TO_PARALLELIZE: usize = 20;
30
31/// An MMR whose root digest has not been computed.
32pub type DirtyMmr<D> = Mmr<D, Dirty>;
33
34/// An MMR whose root digest has been computed.
35pub type CleanMmr<D> = Mmr<D, Clean<D>>;
36
37/// Sealed trait for MMR state types.
38mod private {
39    pub trait Sealed {}
40}
41
42/// Trait for valid MMR state types.
43pub trait State<D: Digest>: private::Sealed + Sized + Send + Sync {}
44
45/// Marker type for a MMR whose root digest has been computed.
46#[derive(Clone, Copy, Debug)]
47pub struct Clean<D: Digest> {
48    /// The root digest of the MMR.
49    pub root: D,
50}
51
52impl<D: Digest> private::Sealed for Clean<D> {}
53impl<D: Digest> State<D> for Clean<D> {}
54
55/// Marker type for a dirty MMR (root digest not computed).
56#[derive(Clone, Debug, Default)]
57pub struct Dirty {
58    /// Non-leaf nodes that need to have their digests recomputed due to a batched update operation.
59    ///
60    /// This is a set of tuples of the form (node_pos, height).
61    dirty_nodes: BTreeSet<(Position, u32)>,
62}
63
64impl private::Sealed for Dirty {}
65impl<D: Digest> State<D> for Dirty {}
66
67/// Configuration for initializing an [Mmr].
68pub struct Config<D: Digest> {
69    /// The retained nodes of the MMR.
70    pub nodes: Vec<D>,
71
72    /// The highest position for which this MMR has been pruned, or 0 if this MMR has never been
73    /// pruned.
74    pub pruned_to_pos: Position,
75
76    /// The pinned nodes of the MMR, in the order expected by `nodes_to_pin`.
77    pub pinned_nodes: Vec<D>,
78}
79
80/// A basic MMR where all nodes are stored in-memory.
81///
82/// # Terminology
83///
84/// Nodes in this structure are either _retained_, _pruned_, or _pinned_. Retained nodes are nodes
85/// that have not yet been pruned, and have digests stored explicitly within the tree structure.
86/// Pruned nodes are those whose positions precede that of the _oldest retained_ node, for which no
87/// digests are maintained. Pinned nodes are nodes that would otherwise be pruned based on their
88/// position, but whose digests remain required for proof generation. The digests for pinned nodes
89/// are stored in an auxiliary map, and are at most O(log2(n)) in number.
90///
91/// # Max Capacity
92///
93/// The maximum number of elements that can be stored is usize::MAX (u32::MAX on 32-bit
94/// architectures).
95///
96/// # Type States
97///
98/// The MMR uses the type-state pattern to enforce at compile-time whether the MMR has pending
99/// updates that must be merkleized before computing proofs. [CleanMmr] represents a clean
100/// MMR whose root digest has been computed. [DirtyMmr] represents a dirty MMR whose root
101/// digest needs to be computed. A dirty MMR can be converted into a clean MMR by calling
102/// [DirtyMmr::merkleize].
103#[derive(Clone, Debug)]
104pub struct Mmr<D: Digest, S: State<D> = Dirty> {
105    /// The nodes of the MMR, laid out according to a post-order traversal of the MMR trees,
106    /// starting from the from tallest tree to shortest.
107    nodes: VecDeque<D>,
108
109    /// The highest position for which this MMR has been pruned, or 0 if this MMR has never been
110    /// pruned.
111    pruned_to_pos: Position,
112
113    /// The auxiliary map from node position to the digest of any pinned node.
114    pinned_nodes: BTreeMap<Position, D>,
115
116    /// Type-state for the MMR.
117    state: S,
118}
119
120impl<D: Digest> Default for DirtyMmr<D> {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126impl<D: Digest> From<CleanMmr<D>> for DirtyMmr<D> {
127    fn from(clean: CleanMmr<D>) -> Self {
128        DirtyMmr {
129            nodes: clean.nodes,
130            pruned_to_pos: clean.pruned_to_pos,
131            pinned_nodes: clean.pinned_nodes,
132            state: Dirty {
133                dirty_nodes: BTreeSet::new(),
134            },
135        }
136    }
137}
138
139impl<D: Digest, S: State<D>> Mmr<D, S> {
140    /// Return the total number of nodes in the MMR, irrespective of any pruning. The next added
141    /// element's position will have this value.
142    pub fn size(&self) -> Position {
143        Position::new(self.nodes.len() as u64 + *self.pruned_to_pos)
144    }
145
146    /// Return the total number of leaves in the MMR.
147    pub fn leaves(&self) -> Location {
148        Location::try_from(self.size()).expect("invalid mmr size")
149    }
150
151    /// Return the position of the last leaf in this MMR, or None if the MMR is empty.
152    pub fn last_leaf_pos(&self) -> Option<Position> {
153        if self.size() == 0 {
154            return None;
155        }
156
157        Some(PeakIterator::last_leaf_pos(self.size()))
158    }
159
160    /// Returns [start, end) where `start` and `end - 1` are the positions of the oldest and newest
161    /// retained nodes respectively.
162    pub fn bounds(&self) -> Range<Position> {
163        self.pruned_to_pos..self.size()
164    }
165
166    /// Return a new iterator over the peaks of the MMR.
167    pub fn peak_iterator(&self) -> PeakIterator {
168        PeakIterator::new(self.size())
169    }
170
171    /// Return the position of the element given its index in the current nodes vector.
172    fn index_to_pos(&self, index: usize) -> Position {
173        self.pruned_to_pos + (index as u64)
174    }
175
176    /// Return the requested node if it is either retained or present in the pinned_nodes map, and
177    /// panic otherwise. Use `get_node` instead if you require a non-panicking getter.
178    ///
179    /// # Warning
180    ///
181    /// If the requested digest is for an unmerkleized node (only possible in the Dirty state) a
182    /// dummy digest will be returned.
183    ///
184    /// # Panics
185    ///
186    /// Panics if the requested node does not exist for any reason such as the node is pruned or
187    /// `pos` is out of bounds.
188    pub(crate) fn get_node_unchecked(&self, pos: Position) -> &D {
189        if pos < self.pruned_to_pos {
190            return self
191                .pinned_nodes
192                .get(&pos)
193                .expect("requested node is pruned and not pinned");
194        }
195
196        &self.nodes[self.pos_to_index(pos)]
197    }
198
199    /// Return the index of the element in the current nodes vector given its position in the MMR.
200    ///
201    /// # Panics
202    ///
203    /// Panics if `pos` precedes the oldest retained position.
204    fn pos_to_index(&self, pos: Position) -> usize {
205        assert!(
206            pos >= self.pruned_to_pos,
207            "pos precedes oldest retained position"
208        );
209
210        *pos.checked_sub(*self.pruned_to_pos).unwrap() as usize
211    }
212
213    /// Utility used by stores that build on the mem MMR to pin extra nodes if needed. It's up to
214    /// the caller to ensure that this set of pinned nodes is valid for their use case.
215    #[cfg(any(feature = "std", test))]
216    pub(crate) fn add_pinned_nodes(&mut self, pinned_nodes: BTreeMap<Position, D>) {
217        for (pos, node) in pinned_nodes.into_iter() {
218            self.pinned_nodes.insert(pos, node);
219        }
220    }
221}
222
223/// Implementation for Clean MMR state.
224impl<D: Digest> CleanMmr<D> {
225    /// Return an [Mmr] initialized with the given `config`.
226    ///
227    /// # Errors
228    ///
229    /// Returns [Error::InvalidPinnedNodes] if the number of pinned nodes doesn't match the expected
230    /// count for `config.pruned_to_pos`.
231    ///
232    /// Returns [Error::InvalidSize] if the MMR size is invalid.
233    pub fn init(config: Config<D>, hasher: &mut impl Hasher<Digest = D>) -> Result<Self, Error> {
234        // Validate that the total size is valid
235        let Some(size) = config.pruned_to_pos.checked_add(config.nodes.len() as u64) else {
236            return Err(Error::InvalidSize(u64::MAX));
237        };
238        if !size.is_mmr_size() {
239            return Err(Error::InvalidSize(*size));
240        }
241
242        // Validate and populate pinned nodes
243        let mut pinned_nodes = BTreeMap::new();
244        let mut expected_pinned_nodes = 0;
245        for (i, pos) in nodes_to_pin(config.pruned_to_pos).enumerate() {
246            expected_pinned_nodes += 1;
247            if i >= config.pinned_nodes.len() {
248                return Err(Error::InvalidPinnedNodes);
249            }
250            pinned_nodes.insert(pos, config.pinned_nodes[i]);
251        }
252
253        // Check for too many pinned nodes
254        if config.pinned_nodes.len() != expected_pinned_nodes {
255            return Err(Error::InvalidPinnedNodes);
256        }
257
258        let mmr = Mmr {
259            nodes: VecDeque::from(config.nodes),
260            pruned_to_pos: config.pruned_to_pos,
261            pinned_nodes,
262            state: Dirty::default(),
263        };
264        Ok(mmr.merkleize(hasher, None))
265    }
266
267    /// Create a new, empty MMR in the Clean state.
268    pub fn new(hasher: &mut impl Hasher<Digest = D>) -> Self {
269        let mmr: DirtyMmr<D> = Default::default();
270        mmr.merkleize(hasher, None)
271    }
272
273    /// Re-initialize the MMR with the given nodes, pruned_to_pos, and pinned_nodes.
274    pub fn from_components(
275        hasher: &mut impl Hasher<Digest = D>,
276        nodes: Vec<D>,
277        pruned_to_pos: Position,
278        pinned_nodes: Vec<D>,
279    ) -> Self {
280        DirtyMmr::from_components(nodes, pruned_to_pos, pinned_nodes).merkleize(hasher, None)
281    }
282
283    /// Return the requested node or None if it is not stored in the MMR.
284    pub fn get_node(&self, pos: Position) -> Option<D> {
285        if pos < self.pruned_to_pos {
286            return self.pinned_nodes.get(&pos).copied();
287        }
288
289        self.nodes.get(self.pos_to_index(pos)).copied()
290    }
291
292    /// Get the nodes (position + digest) that need to be pinned (those required for proof
293    /// generation) in this MMR when pruned to position `prune_pos`.
294    pub(crate) fn nodes_to_pin(&self, prune_pos: Position) -> BTreeMap<Position, D> {
295        nodes_to_pin(prune_pos)
296            .map(|pos| (pos, *self.get_node_unchecked(pos)))
297            .collect()
298    }
299
300    /// Prune all nodes up to but not including the given position, and pin the O(log2(n)) number of
301    /// them required for proof generation.
302    pub fn prune_to_pos(&mut self, pos: Position) {
303        // Recompute the set of older nodes to retain.
304        self.pinned_nodes = self.nodes_to_pin(pos);
305        let retained_nodes = self.pos_to_index(pos);
306        self.nodes.drain(0..retained_nodes);
307        self.pruned_to_pos = pos;
308    }
309
310    /// Prune all nodes and pin the O(log2(n)) number of them required for proof generation going
311    /// forward.
312    pub fn prune_all(&mut self) {
313        if !self.nodes.is_empty() {
314            let pos = self.index_to_pos(self.nodes.len());
315            self.prune_to_pos(pos);
316        }
317    }
318
319    /// Change the digest of any retained leaf. This is useful if you want to use the MMR
320    /// implementation as an updatable binary Merkle tree, and otherwise should be avoided.
321    ///
322    /// # Errors
323    ///
324    /// Returns [Error::ElementPruned] if the leaf has been pruned.
325    /// Returns [Error::LeafOutOfBounds] if `loc` is not an existing leaf.
326    /// Returns [Error::LocationOverflow] if `loc` > [crate::mmr::MAX_LOCATION].
327    ///
328    /// # Warning
329    ///
330    /// This method will change the root and invalidate any previous inclusion proofs.
331    /// Use of this method will prevent using this structure as a base mmr for grafting.
332    pub fn update_leaf(
333        &mut self,
334        hasher: &mut impl Hasher<Digest = D>,
335        loc: Location,
336        element: &[u8],
337    ) -> Result<(), Error> {
338        let mut dirty_mmr = mem::replace(self, Self::new(hasher)).into_dirty();
339        let result = dirty_mmr.update_leaf(hasher, loc, element);
340        *self = dirty_mmr.merkleize(hasher, None);
341        result
342    }
343
344    /// Convert this Clean MMR into a Dirty MMR without making any changes to it.
345    pub fn into_dirty(self) -> DirtyMmr<D> {
346        self.into()
347    }
348
349    /// Get the root digest of the MMR.
350    pub const fn root(&self) -> &D {
351        &self.state.root
352    }
353
354    /// Returns the root that would be produced by calling `root` on an empty MMR.
355    pub fn empty_mmr_root(hasher: &mut impl commonware_cryptography::Hasher<Digest = D>) -> D {
356        hasher.update(&0u64.to_be_bytes());
357        hasher.finalize()
358    }
359
360    /// Return an inclusion proof for the element at location `loc`.
361    ///
362    /// # Errors
363    ///
364    /// Returns [Error::LocationOverflow] if `loc` > [crate::mmr::MAX_LOCATION].
365    /// Returns [Error::ElementPruned] if some element needed to generate the proof has been pruned.
366    /// Returns [Error::LeafOutOfBounds] if `loc` >= [Self::leaves()].
367    pub fn proof(&self, loc: Location) -> Result<Proof<D>, Error> {
368        if !loc.is_valid() {
369            return Err(Error::LocationOverflow(loc));
370        }
371        // loc is valid so it won't overflow from + 1
372        self.range_proof(loc..loc + 1).map_err(|e| match e {
373            Error::RangeOutOfBounds(loc) => Error::LeafOutOfBounds(loc),
374            _ => e,
375        })
376    }
377
378    /// Return an inclusion proof for all elements within the provided `range` of locations.
379    ///
380    /// # Errors
381    ///
382    /// Returns [Error::Empty] if the range is empty.
383    /// Returns [Error::LocationOverflow] if any location in `range` exceeds [crate::mmr::MAX_LOCATION].
384    /// Returns [Error::RangeOutOfBounds] if `range.end` > [Self::leaves()].
385    /// Returns [Error::ElementPruned] if some element needed to generate the proof has been pruned.
386    pub fn range_proof(&self, range: Range<Location>) -> Result<Proof<D>, Error> {
387        let leaves = self.leaves();
388        let positions = proof::nodes_required_for_range_proof(leaves, range)?;
389        let digests = positions
390            .into_iter()
391            .map(|pos| self.get_node(pos).ok_or(Error::ElementPruned(pos)))
392            .collect::<Result<Vec<_>, _>>()?;
393
394        Ok(Proof { leaves, digests })
395    }
396
397    /// Get the digests of nodes that need to be pinned (those required for proof generation) in
398    /// this MMR when pruned to position `prune_pos`.
399    #[cfg(test)]
400    pub(crate) fn node_digests_to_pin(&self, start_pos: Position) -> Vec<D> {
401        nodes_to_pin(start_pos)
402            .map(|pos| *self.get_node_unchecked(pos))
403            .collect()
404    }
405
406    /// Return the nodes this MMR currently has pinned. Pinned nodes are nodes that would otherwise
407    /// be pruned, but whose digests remain required for proof generation.
408    #[cfg(test)]
409    pub(super) fn pinned_nodes(&self) -> BTreeMap<Position, D> {
410        self.pinned_nodes.clone()
411    }
412}
413
414/// Implementation for Dirty MMR state.
415impl<D: Digest> DirtyMmr<D> {
416    /// Return a new (empty) `Mmr`.
417    pub fn new() -> Self {
418        Self {
419            nodes: VecDeque::new(),
420            pruned_to_pos: Position::new(0),
421            pinned_nodes: BTreeMap::new(),
422            state: Dirty::default(),
423        }
424    }
425
426    /// Re-initialize the MMR with the given nodes, pruned_to_pos, and pinned_nodes.
427    pub fn from_components(nodes: Vec<D>, pruned_to_pos: Position, pinned_nodes: Vec<D>) -> Self {
428        Self {
429            nodes: VecDeque::from(nodes),
430            pruned_to_pos,
431            pinned_nodes: nodes_to_pin(pruned_to_pos)
432                .enumerate()
433                .map(|(i, pos)| (pos, pinned_nodes[i]))
434                .collect(),
435            state: Dirty::default(),
436        }
437    }
438
439    /// Add `digest` as a new leaf in the MMR, returning its position.
440    pub(super) fn add_leaf_digest(&mut self, digest: D) -> Position {
441        // Compute the new parent nodes, if any.
442        let nodes_needing_parents = nodes_needing_parents(self.peak_iterator())
443            .into_iter()
444            .rev();
445        let leaf_pos = self.size();
446        self.nodes.push_back(digest);
447
448        let mut height = 1;
449        for _ in nodes_needing_parents {
450            let new_node_pos = self.size();
451            self.nodes.push_back(D::EMPTY);
452            self.state.dirty_nodes.insert((new_node_pos, height));
453            height += 1;
454        }
455
456        leaf_pos
457    }
458
459    /// Add `element` to the MMR and return its position.
460    /// The element can be an arbitrary byte slice, and need not be converted to a digest first.
461    pub fn add<H: Hasher<Digest = D>>(&mut self, hasher: &mut H, element: &[u8]) -> Position {
462        let digest = hasher.leaf_digest(self.size(), element);
463        self.add_leaf_digest(digest)
464    }
465
466    /// Pop the most recent leaf element out of the MMR if it exists, returning Empty or
467    /// ElementPruned errors otherwise.
468    pub fn pop(&mut self) -> Result<Position, Error> {
469        if self.size() == 0 {
470            return Err(Empty);
471        }
472
473        let mut new_size = self.size() - 1;
474        loop {
475            if new_size < self.pruned_to_pos {
476                return Err(ElementPruned(new_size));
477            }
478            if new_size.is_mmr_size() {
479                break;
480            }
481            new_size -= 1;
482        }
483        let num_to_drain = *(self.size() - new_size) as usize;
484        self.nodes.drain(self.nodes.len() - num_to_drain..);
485
486        // Remove dirty nodes that are now out of bounds.
487        let cutoff = (self.size(), 0);
488        self.state.dirty_nodes.split_off(&cutoff);
489
490        Ok(self.size())
491    }
492
493    /// Compute updated digests for dirty nodes and compute the root, converting this MMR into a
494    /// [CleanMmr].
495    pub fn merkleize(
496        mut self,
497        hasher: &mut impl Hasher<Digest = D>,
498        #[cfg_attr(not(feature = "std"), allow(unused_variables))] pool: Option<ThreadPool>,
499    ) -> CleanMmr<D> {
500        #[cfg(feature = "std")]
501        match (pool, self.state.dirty_nodes.len() >= MIN_TO_PARALLELIZE) {
502            (Some(pool), true) => self.merkleize_parallel(hasher, pool, MIN_TO_PARALLELIZE),
503            _ => self.merkleize_serial(hasher),
504        }
505
506        #[cfg(not(feature = "std"))]
507        self.merkleize_serial(hasher);
508
509        // Compute root
510        let peaks = self
511            .peak_iterator()
512            .map(|(peak_pos, _)| self.get_node_unchecked(peak_pos));
513        let digest = hasher.root(self.leaves(), peaks);
514
515        CleanMmr {
516            nodes: self.nodes,
517            pruned_to_pos: self.pruned_to_pos,
518            pinned_nodes: self.pinned_nodes,
519            state: Clean { root: digest },
520        }
521    }
522
523    fn merkleize_serial(&mut self, hasher: &mut impl Hasher<Digest = D>) {
524        let mut nodes: Vec<(Position, u32)> = self.state.dirty_nodes.iter().copied().collect();
525        self.state.dirty_nodes.clear();
526        nodes.sort_by_key(|a| a.1);
527
528        for (pos, height) in nodes {
529            let left = pos - (1 << height);
530            let right = pos - 1;
531            let digest = hasher.node_digest(
532                pos,
533                self.get_node_unchecked(left),
534                self.get_node_unchecked(right),
535            );
536            let index = self.pos_to_index(pos);
537            self.nodes[index] = digest;
538        }
539    }
540
541    /// Process any pending batched updates, using parallel hash workers as long as the number of
542    /// computations that can be parallelized exceeds `min_to_parallelize`.
543    ///
544    /// This implementation parallelizes the computation of digests across all nodes at the same
545    /// height, starting from the bottom and working up to the peaks. If ever the number of
546    /// remaining digest computations is less than the `min_to_parallelize`, it switches to the
547    /// serial implementation.
548    #[cfg(feature = "std")]
549    fn merkleize_parallel(
550        &mut self,
551        hasher: &mut impl Hasher<Digest = D>,
552        pool: ThreadPool,
553        min_to_parallelize: usize,
554    ) {
555        let mut nodes: Vec<(Position, u32)> = self.state.dirty_nodes.iter().copied().collect();
556        self.state.dirty_nodes.clear();
557        // Sort by increasing height.
558        nodes.sort_by_key(|a| a.1);
559
560        let mut same_height = Vec::new();
561        let mut current_height = 1;
562        for (i, (pos, height)) in nodes.iter().enumerate() {
563            if *height == current_height {
564                same_height.push(*pos);
565                continue;
566            }
567            if same_height.len() < min_to_parallelize {
568                self.state.dirty_nodes = nodes[i - same_height.len()..].iter().copied().collect();
569                self.merkleize_serial(hasher);
570                return;
571            }
572            self.update_node_digests(hasher, pool.clone(), &same_height, current_height);
573            same_height.clear();
574            current_height += 1;
575            same_height.push(*pos);
576        }
577
578        if same_height.len() < min_to_parallelize {
579            self.state.dirty_nodes = nodes[nodes.len() - same_height.len()..]
580                .iter()
581                .copied()
582                .collect();
583            self.merkleize_serial(hasher);
584            return;
585        }
586
587        self.update_node_digests(hasher, pool, &same_height, current_height);
588    }
589
590    /// Update digests of the given set of nodes of equal height in the MMR. Since they are all at
591    /// the same height, this can be done in parallel without synchronization.
592    #[cfg(feature = "std")]
593    fn update_node_digests(
594        &mut self,
595        hasher: &mut impl Hasher<Digest = D>,
596        pool: ThreadPool,
597        same_height: &[Position],
598        height: u32,
599    ) {
600        let two_h = 1 << height;
601        pool.install(|| {
602            let computed_digests: Vec<(usize, D)> = same_height
603                .par_iter()
604                .map_init(
605                    || hasher.fork(),
606                    |hasher, &pos| {
607                        let left = pos - two_h;
608                        let right = pos - 1;
609                        let digest = hasher.node_digest(
610                            pos,
611                            self.get_node_unchecked(left),
612                            self.get_node_unchecked(right),
613                        );
614                        let index = self.pos_to_index(pos);
615                        (index, digest)
616                    },
617                )
618                .collect();
619
620            for (index, digest) in computed_digests {
621                self.nodes[index] = digest;
622            }
623        });
624    }
625
626    /// Mark the non-leaf nodes in the path from the given position to the root as dirty, so that
627    /// their digests are appropriately recomputed during the next `merkleize`.
628    fn mark_dirty(&mut self, pos: Position) {
629        for (peak_pos, mut height) in self.peak_iterator() {
630            if peak_pos < pos {
631                continue;
632            }
633
634            // We have found the mountain containing the path we are looking for. Traverse it from
635            // leaf to root, that way we can exit early if we hit a node that is already dirty.
636            let path = PathIterator::new(pos, peak_pos, height)
637                .collect::<Vec<_>>()
638                .into_iter()
639                .rev();
640            height = 1;
641            for (parent_pos, _) in path {
642                if !self.state.dirty_nodes.insert((parent_pos, height)) {
643                    break;
644                }
645                height += 1;
646            }
647            return;
648        }
649
650        panic!("invalid pos {pos}:{}", self.size());
651    }
652
653    /// Update the leaf at `loc` to `element`.
654    pub fn update_leaf(
655        &mut self,
656        hasher: &mut impl Hasher<Digest = D>,
657        loc: Location,
658        element: &[u8],
659    ) -> Result<(), Error> {
660        self.update_leaf_batched(hasher, None, &[(loc, element)])
661    }
662
663    /// Batch update the digests of multiple retained leaves.
664    ///
665    /// # Errors
666    ///
667    /// Returns [Error::LeafOutOfBounds] if any location is not an existing leaf.
668    /// Returns [Error::LocationOverflow] if any location exceeds [crate::mmr::MAX_LOCATION].
669    /// Returns [Error::ElementPruned] if any of the leaves has been pruned.
670    pub fn update_leaf_batched<T: AsRef<[u8]> + Sync>(
671        &mut self,
672        hasher: &mut impl Hasher<Digest = D>,
673        #[cfg_attr(not(feature = "std"), allow(unused_variables))] pool: Option<ThreadPool>,
674        updates: &[(Location, T)],
675    ) -> Result<(), Error> {
676        if updates.is_empty() {
677            return Ok(());
678        }
679
680        let leaves = self.leaves();
681        let mut positions = Vec::with_capacity(updates.len());
682        for (loc, _) in updates {
683            if *loc >= leaves {
684                return Err(Error::LeafOutOfBounds(*loc));
685            }
686            let pos = Position::try_from(*loc)?;
687            if pos < self.pruned_to_pos {
688                return Err(Error::ElementPruned(pos));
689            }
690            positions.push(pos);
691        }
692
693        #[cfg(feature = "std")]
694        if let Some(pool) = pool {
695            if updates.len() >= MIN_TO_PARALLELIZE {
696                self.update_leaf_parallel(hasher, pool, updates, &positions);
697                return Ok(());
698            }
699        }
700
701        for ((_, element), pos) in updates.iter().zip(positions.iter()) {
702            // Update the digest of the leaf node and mark its ancestors as dirty.
703            let digest = hasher.leaf_digest(*pos, element.as_ref());
704            let index = self.pos_to_index(*pos);
705            self.nodes[index] = digest;
706            self.mark_dirty(*pos);
707        }
708
709        Ok(())
710    }
711
712    /// Batch update the digests of multiple retained leaves using multiple threads.
713    #[cfg(feature = "std")]
714    fn update_leaf_parallel<T: AsRef<[u8]> + Sync>(
715        &mut self,
716        hasher: &mut impl Hasher<Digest = D>,
717        pool: ThreadPool,
718        updates: &[(Location, T)],
719        positions: &[Position],
720    ) {
721        pool.install(|| {
722            let digests: Vec<(Position, D)> = updates
723                .par_iter()
724                .zip(positions.par_iter())
725                .map_init(
726                    || hasher.fork(),
727                    |hasher, ((_, elem), pos)| {
728                        let digest = hasher.leaf_digest(*pos, elem.as_ref());
729                        (*pos, digest)
730                    },
731                )
732                .collect();
733
734            for (pos, digest) in digests {
735                let index = self.pos_to_index(pos);
736                self.nodes[index] = digest;
737                self.mark_dirty(pos);
738            }
739        });
740    }
741}
742
743#[cfg(test)]
744mod tests {
745    use super::*;
746    use crate::mmr::{
747        conformance::build_test_mmr,
748        hasher::{Hasher as _, Standard},
749    };
750    use commonware_cryptography::{sha256, Hasher, Sha256};
751    use commonware_runtime::{deterministic, tokio, Runner, ThreadPooler};
752    use commonware_utils::NZUsize;
753
754    /// Test empty MMR behavior.
755    #[test]
756    fn test_mem_mmr_empty() {
757        let executor = deterministic::Runner::default();
758        executor.start(|_| async move {
759            let mut hasher: Standard<Sha256> = Standard::new();
760            let mmr = CleanMmr::new(&mut hasher);
761            assert_eq!(
762                mmr.peak_iterator().next(),
763                None,
764                "empty iterator should have no peaks"
765            );
766            assert_eq!(mmr.size(), 0);
767            assert_eq!(mmr.leaves(), Location::new_unchecked(0));
768            assert_eq!(mmr.last_leaf_pos(), None);
769            assert!(mmr.bounds().is_empty());
770            assert_eq!(mmr.get_node(Position::new(0)), None);
771            assert_eq!(*mmr.root(), Mmr::empty_mmr_root(hasher.inner()));
772            let mut mmr = mmr.into_dirty();
773            assert!(matches!(mmr.pop(), Err(Empty)));
774            let mut mmr = mmr.merkleize(&mut hasher, None);
775            mmr.prune_all();
776            assert_eq!(mmr.size(), 0, "prune_all on empty MMR should do nothing");
777
778            assert_eq!(
779                *mmr.root(),
780                hasher.root(Location::new_unchecked(0), [].iter())
781            );
782        });
783    }
784
785    /// Test MMR building by consecutively adding 11 equal elements to a new MMR, producing the
786    /// structure in the example documented at the top of the mmr crate's mod.rs file with 19 nodes
787    /// and 3 peaks.
788    #[test]
789    fn test_mem_mmr_add_eleven_values() {
790        let executor = deterministic::Runner::default();
791        executor.start(|_| async move {
792            let mut hasher: Standard<Sha256> = Standard::new();
793            let mut mmr = DirtyMmr::new();
794            let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
795            let mut leaves: Vec<Position> = Vec::new();
796            for _ in 0..11 {
797                leaves.push(mmr.add(&mut hasher, &element));
798                let peaks: Vec<(Position, u32)> = mmr.peak_iterator().collect();
799                assert_ne!(peaks.len(), 0);
800                assert!(peaks.len() as u64 <= mmr.size());
801                let nodes_needing_parents = nodes_needing_parents(mmr.peak_iterator());
802                assert!(nodes_needing_parents.len() <= peaks.len());
803            }
804            let mut mmr = mmr.merkleize(&mut hasher, None);
805            assert_eq!(mmr.bounds().start, Position::new(0));
806            assert_eq!(mmr.size(), 19, "mmr not of expected size");
807            assert_eq!(
808                leaves,
809                vec![0, 1, 3, 4, 7, 8, 10, 11, 15, 16, 18]
810                    .into_iter()
811                    .map(Position::new)
812                    .collect::<Vec<_>>(),
813                "mmr leaf positions not as expected"
814            );
815            let peaks: Vec<(Position, u32)> = mmr.peak_iterator().collect();
816            assert_eq!(
817                peaks,
818                vec![
819                    (Position::new(14), 3),
820                    (Position::new(17), 1),
821                    (Position::new(18), 0)
822                ],
823                "mmr peaks not as expected"
824            );
825
826            // Test nodes_needing_parents on the final MMR. Since there's a height gap between the
827            // highest peak (14) and the next, only the lower two peaks (17, 18) should be returned.
828            let peaks_needing_parents = nodes_needing_parents(mmr.peak_iterator());
829            assert_eq!(
830                peaks_needing_parents,
831                vec![Position::new(17), Position::new(18)],
832                "mmr nodes needing parents not as expected"
833            );
834
835            // verify leaf digests
836            for leaf in leaves.iter().by_ref() {
837                let digest = hasher.leaf_digest(*leaf, &element);
838                assert_eq!(mmr.get_node(*leaf).unwrap(), digest);
839            }
840
841            // verify height=1 node digests
842            let digest2 = hasher.node_digest(Position::new(2), &mmr.nodes[0], &mmr.nodes[1]);
843            assert_eq!(mmr.nodes[2], digest2);
844            let digest5 = hasher.node_digest(Position::new(5), &mmr.nodes[3], &mmr.nodes[4]);
845            assert_eq!(mmr.nodes[5], digest5);
846            let digest9 = hasher.node_digest(Position::new(9), &mmr.nodes[7], &mmr.nodes[8]);
847            assert_eq!(mmr.nodes[9], digest9);
848            let digest12 = hasher.node_digest(Position::new(12), &mmr.nodes[10], &mmr.nodes[11]);
849            assert_eq!(mmr.nodes[12], digest12);
850            let digest17 = hasher.node_digest(Position::new(17), &mmr.nodes[15], &mmr.nodes[16]);
851            assert_eq!(mmr.nodes[17], digest17);
852
853            // verify height=2 node digests
854            let digest6 = hasher.node_digest(Position::new(6), &mmr.nodes[2], &mmr.nodes[5]);
855            assert_eq!(mmr.nodes[6], digest6);
856            let digest13 = hasher.node_digest(Position::new(13), &mmr.nodes[9], &mmr.nodes[12]);
857            assert_eq!(mmr.nodes[13], digest13);
858            let digest17 = hasher.node_digest(Position::new(17), &mmr.nodes[15], &mmr.nodes[16]);
859            assert_eq!(mmr.nodes[17], digest17);
860
861            // verify topmost digest
862            let digest14 = hasher.node_digest(Position::new(14), &mmr.nodes[6], &mmr.nodes[13]);
863            assert_eq!(mmr.nodes[14], digest14);
864
865            // verify root
866            let root = *mmr.root();
867            let peak_digests = [digest14, digest17, mmr.nodes[18]];
868            let expected_root = hasher.root(Location::new_unchecked(11), peak_digests.iter());
869            assert_eq!(root, expected_root, "incorrect root");
870
871            // pruning tests
872            mmr.prune_to_pos(Position::new(14)); // prune up to the tallest peak
873            assert_eq!(mmr.bounds().start, Position::new(14));
874
875            // After pruning, we shouldn't be able to generate a proof for any elements before the
876            // pruning boundary. (To be precise, due to the maintenance of pinned nodes, we may in
877            // fact still be able to generate them for some, but it's not guaranteed. For example,
878            // in this case, we actually can still generate a proof for the node with location 7
879            // even though it's pruned.)
880            assert!(matches!(
881                mmr.proof(Location::new_unchecked(0)),
882                Err(ElementPruned(_))
883            ));
884            assert!(matches!(
885                mmr.proof(Location::new_unchecked(6)),
886                Err(ElementPruned(_))
887            ));
888
889            // We should still be able to generate a proof for any leaf following the pruning
890            // boundary, the first of which is at location 8 and the last location 10.
891            assert!(mmr.proof(Location::new_unchecked(8)).is_ok());
892            assert!(mmr.proof(Location::new_unchecked(10)).is_ok());
893
894            let root_after_prune = *mmr.root();
895            assert_eq!(root, root_after_prune, "root changed after pruning");
896
897            assert!(
898                mmr.range_proof(Location::new_unchecked(5)..Location::new_unchecked(9))
899                    .is_err(),
900                "attempts to range_prove elements at or before the oldest retained should fail"
901            );
902            assert!(
903                mmr.range_proof(Location::new_unchecked(8)..mmr.leaves()).is_ok(),
904                "attempts to range_prove over all elements following oldest retained should succeed"
905            );
906
907            // Test that we can initialize a new MMR from another's elements.
908            let oldest_pos = mmr.bounds().start;
909            let digests = mmr.node_digests_to_pin(oldest_pos);
910            let mmr_copy = Mmr::init(
911                Config {
912                    nodes: mmr.nodes.iter().copied().collect(),
913                    pruned_to_pos: oldest_pos,
914                    pinned_nodes: digests,
915                },
916                &mut hasher,
917            )
918            .unwrap();
919            assert_eq!(mmr_copy.size(), 19);
920            assert_eq!(mmr_copy.leaves(), mmr.leaves());
921            assert_eq!(mmr_copy.last_leaf_pos(), mmr.last_leaf_pos());
922            assert_eq!(mmr_copy.bounds().start, mmr.bounds().start);
923            assert_eq!(*mmr_copy.root(), root);
924        });
925    }
926
927    /// Test that pruning all nodes never breaks adding new nodes.
928    #[test]
929    fn test_mem_mmr_prune_all() {
930        let executor = deterministic::Runner::default();
931        executor.start(|_| async move {
932            let mut hasher: Standard<Sha256> = Standard::new();
933            let mut mmr = CleanMmr::new(&mut hasher);
934            let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
935            for _ in 0..1000 {
936                mmr.prune_all();
937                let mut dirty = mmr.into_dirty();
938                dirty.add(&mut hasher, &element);
939                mmr = dirty.merkleize(&mut hasher, None);
940            }
941        });
942    }
943
944    /// Test that the MMR validity check works as expected.
945    #[test]
946    fn test_mem_mmr_validity() {
947        let executor = deterministic::Runner::default();
948        executor.start(|_| async move {
949            let mut hasher: Standard<Sha256> = Standard::new();
950            let mut mmr = DirtyMmr::new();
951            let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
952            for _ in 0..1001 {
953                assert!(
954                    mmr.size().is_mmr_size(),
955                    "mmr of size {} should be valid",
956                    mmr.size()
957                );
958                let old_size = mmr.size();
959                mmr.add(&mut hasher, &element);
960                for size in *old_size + 1..*mmr.size() {
961                    assert!(
962                        !Position::new(size).is_mmr_size(),
963                        "mmr of size {size} should be invalid",
964                    );
965                }
966            }
967        });
968    }
969
970    /// Test that batched MMR building produces the same root as the reference implementation.
971    /// Root stability for the reference implementation is verified by the conformance test.
972    #[test]
973    fn test_mem_mmr_batched_root() {
974        let executor = deterministic::Runner::default();
975        executor.start(|_| async move {
976            let mut hasher: Standard<Sha256> = Standard::new();
977            const NUM_ELEMENTS: u64 = 199;
978            let mut test_mmr = CleanMmr::new(&mut hasher);
979            test_mmr = build_test_mmr(&mut hasher, test_mmr, NUM_ELEMENTS);
980            let expected_root = test_mmr.root();
981
982            let batched_mmr = CleanMmr::new(&mut hasher);
983
984            // First element transitions Clean -> Dirty explicitly
985            let mut dirty_mmr = batched_mmr.into_dirty();
986            hasher.inner().update(&0u64.to_be_bytes());
987            let element = hasher.inner().finalize();
988            dirty_mmr.add(&mut hasher, &element);
989
990            // Subsequent elements keep it Dirty
991            for i in 1..NUM_ELEMENTS {
992                hasher.inner().update(&i.to_be_bytes());
993                let element = hasher.inner().finalize();
994                dirty_mmr.add(&mut hasher, &element);
995            }
996
997            let batched_mmr = dirty_mmr.merkleize(&mut hasher, None);
998
999            assert_eq!(
1000                batched_mmr.root(),
1001                expected_root,
1002                "Batched MMR root should match reference"
1003            );
1004        });
1005    }
1006
1007    /// Test that parallel batched MMR building produces the same root as the reference.
1008    /// This requires the tokio runtime since the deterministic runtime is single-threaded.
1009    #[test]
1010    fn test_mem_mmr_batched_root_parallel() {
1011        let executor = tokio::Runner::default();
1012        executor.start(|context| async move {
1013            let mut hasher: Standard<Sha256> = Standard::new();
1014            const NUM_ELEMENTS: u64 = 199;
1015            let test_mmr = CleanMmr::new(&mut hasher);
1016            let test_mmr = build_test_mmr(&mut hasher, test_mmr, NUM_ELEMENTS);
1017            let expected_root = test_mmr.root();
1018
1019            let pool = context.create_thread_pool(NZUsize!(4)).unwrap();
1020            let mut hasher: Standard<Sha256> = Standard::new();
1021
1022            let mut mmr = Mmr::init(
1023                Config {
1024                    nodes: vec![],
1025                    pruned_to_pos: Position::new(0),
1026                    pinned_nodes: vec![],
1027                },
1028                &mut hasher,
1029            )
1030            .unwrap()
1031            .into_dirty();
1032
1033            let mut hasher: Standard<Sha256> = Standard::new();
1034            for i in 0u64..NUM_ELEMENTS {
1035                hasher.inner().update(&i.to_be_bytes());
1036                let element = hasher.inner().finalize();
1037                mmr.add(&mut hasher, &element);
1038            }
1039            let mmr = mmr.merkleize(&mut hasher, Some(pool));
1040            assert_eq!(
1041                mmr.root(),
1042                expected_root,
1043                "Batched MMR root should match reference"
1044            );
1045        });
1046    }
1047
1048    /// Test that pruning after each add does not affect root computation.
1049    #[test]
1050    fn test_mem_mmr_root_with_pruning() {
1051        let executor = deterministic::Runner::default();
1052        executor.start(|_| async move {
1053            let mut hasher: Standard<Sha256> = Standard::new();
1054            let mut reference_mmr = DirtyMmr::new();
1055            let mut mmr = DirtyMmr::new();
1056            for i in 0u64..200 {
1057                hasher.inner().update(&i.to_be_bytes());
1058                let element = hasher.inner().finalize();
1059                reference_mmr.add(&mut hasher, &element);
1060                mmr.add(&mut hasher, &element);
1061
1062                // Merkleize both to compare roots
1063                let reference_mmr_clean = reference_mmr.merkleize(&mut hasher, None);
1064                let mut mmr_clean = mmr.merkleize(&mut hasher, None);
1065                mmr_clean.prune_all();
1066                assert_eq!(mmr_clean.root(), reference_mmr_clean.root());
1067
1068                reference_mmr = reference_mmr_clean.into_dirty();
1069                mmr = mmr_clean.into_dirty();
1070            }
1071        });
1072    }
1073
1074    #[test]
1075    fn test_mem_mmr_pop() {
1076        let executor = deterministic::Runner::default();
1077        executor.start(|_| async move {
1078            const NUM_ELEMENTS: u64 = 100;
1079
1080            let mut hasher: Standard<Sha256> = Standard::new();
1081            let mmr = CleanMmr::new(&mut hasher);
1082            let mut mmr = build_test_mmr(&mut hasher, mmr, NUM_ELEMENTS);
1083
1084            // Pop off one node at a time until empty, confirming the root matches reference.
1085            for i in (0..NUM_ELEMENTS).rev() {
1086                let mut dirty_mmr = mmr.into_dirty();
1087                assert!(dirty_mmr.pop().is_ok());
1088                mmr = dirty_mmr.merkleize(&mut hasher, None);
1089                let root = *mmr.root();
1090                let reference_mmr = CleanMmr::new(&mut hasher);
1091                let reference_mmr = build_test_mmr(&mut hasher, reference_mmr, i);
1092                assert_eq!(
1093                    root,
1094                    *reference_mmr.root(),
1095                    "root mismatch after pop at {i}"
1096                );
1097            }
1098            let mut mmr = mmr.into_dirty();
1099            assert!(
1100                matches!(mmr.pop().unwrap_err(), Empty),
1101                "pop on empty MMR should fail"
1102            );
1103
1104            // Test that we can pop all elements up to and including the oldest retained leaf.
1105            for i in 0u64..NUM_ELEMENTS {
1106                hasher.inner().update(&i.to_be_bytes());
1107                let element = hasher.inner().finalize();
1108                mmr.add(&mut hasher, &element);
1109            }
1110            let mut mmr = mmr.merkleize(&mut hasher, None);
1111
1112            let leaf_pos = Position::try_from(Location::new_unchecked(100)).unwrap();
1113            mmr.prune_to_pos(leaf_pos);
1114            let mut mmr = mmr.into_dirty();
1115            while mmr.size() > leaf_pos {
1116                mmr.pop().unwrap();
1117            }
1118            let mmr = mmr.merkleize(&mut hasher, None);
1119            let reference_mmr = CleanMmr::new(&mut hasher);
1120            let reference_mmr = build_test_mmr(&mut hasher, reference_mmr, 100);
1121            assert_eq!(*mmr.root(), *reference_mmr.root());
1122            let mut mmr = mmr.into_dirty();
1123            let result = mmr.pop();
1124            assert!(matches!(result, Err(ElementPruned(_))));
1125            assert!(mmr.bounds().is_empty());
1126        });
1127    }
1128
1129    #[test]
1130    fn test_mem_mmr_update_leaf() {
1131        let mut hasher: Standard<Sha256> = Standard::new();
1132        let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1133        let executor = deterministic::Runner::default();
1134        executor.start(|_| async move {
1135            const NUM_ELEMENTS: u64 = 200;
1136            let mmr = CleanMmr::new(&mut hasher);
1137            let mut mmr = build_test_mmr(&mut hasher, mmr, NUM_ELEMENTS);
1138            let root = *mmr.root();
1139
1140            // For a few leaves, update the leaf and ensure the root changes, and the root reverts
1141            // to its previous state then we update the leaf to its original value.
1142            for leaf in [0usize, 1, 10, 50, 100, 150, 197, 198] {
1143                // Change the leaf.
1144                let leaf_loc = Location::new_unchecked(leaf as u64);
1145                mmr.update_leaf(&mut hasher, leaf_loc, &element).unwrap();
1146                let updated_root = *mmr.root();
1147                assert!(root != updated_root);
1148
1149                // Restore the leaf to its original value, ensure the root is as before.
1150                hasher.inner().update(&leaf.to_be_bytes());
1151                let element = hasher.inner().finalize();
1152                mmr.update_leaf(&mut hasher, leaf_loc, &element).unwrap();
1153                let restored_root = *mmr.root();
1154                assert_eq!(root, restored_root);
1155            }
1156
1157            // Confirm the tree has all the hashes necessary to update any element after pruning.
1158            mmr.prune_to_pos(Position::new(150));
1159            for leaf_pos in 150u64..=190 {
1160                mmr.prune_to_pos(Position::new(leaf_pos));
1161                let leaf_loc = Location::new_unchecked(leaf_pos);
1162                mmr.update_leaf(&mut hasher, leaf_loc, &element).unwrap();
1163            }
1164        });
1165    }
1166
1167    #[test]
1168    fn test_mem_mmr_update_leaf_error_out_of_bounds() {
1169        let mut hasher: Standard<Sha256> = Standard::new();
1170        let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1171
1172        let executor = deterministic::Runner::default();
1173        executor.start(|_| async move {
1174            let mmr = CleanMmr::new(&mut hasher);
1175            let mut mmr = build_test_mmr(&mut hasher, mmr, 200);
1176            let invalid_loc = mmr.leaves();
1177            let result = mmr.update_leaf(&mut hasher, invalid_loc, &element);
1178            assert!(matches!(result, Err(Error::LeafOutOfBounds(_))));
1179        });
1180    }
1181
1182    #[test]
1183    fn test_mem_mmr_update_leaf_error_pruned() {
1184        let mut hasher: Standard<Sha256> = Standard::new();
1185        let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1186
1187        let executor = deterministic::Runner::default();
1188        executor.start(|_| async move {
1189            let mmr = CleanMmr::new(&mut hasher);
1190            let mut mmr = build_test_mmr(&mut hasher, mmr, 100);
1191            mmr.prune_all();
1192            let result = mmr.update_leaf(&mut hasher, Location::new_unchecked(0), &element);
1193            assert!(matches!(result, Err(Error::ElementPruned(_))));
1194        });
1195    }
1196
1197    #[test]
1198    fn test_mem_mmr_batch_update_leaf() {
1199        let mut hasher: Standard<Sha256> = Standard::new();
1200        let executor = deterministic::Runner::default();
1201        executor.start(|_| async move {
1202            let mmr = CleanMmr::new(&mut hasher);
1203            let mmr = build_test_mmr(&mut hasher, mmr, 200);
1204            do_batch_update(&mut hasher, mmr, None);
1205        });
1206    }
1207
1208    /// Same test as above only using a thread pool to trigger parallelization. This requires we use
1209    /// tokio runtime instead of the deterministic one.
1210    #[test]
1211    fn test_mem_mmr_batch_parallel_update_leaf() {
1212        let mut hasher: Standard<Sha256> = Standard::new();
1213        let executor = tokio::Runner::default();
1214        executor.start(|ctx| async move {
1215            let mmr = Mmr::init(
1216                Config {
1217                    nodes: Vec::new(),
1218                    pruned_to_pos: Position::new(0),
1219                    pinned_nodes: Vec::new(),
1220                },
1221                &mut hasher,
1222            )
1223            .unwrap();
1224            let mmr = build_test_mmr(&mut hasher, mmr, 200);
1225            let pool = ctx.create_thread_pool(NZUsize!(4)).unwrap();
1226            do_batch_update(&mut hasher, mmr, Some(pool));
1227        });
1228    }
1229
1230    fn do_batch_update(
1231        hasher: &mut Standard<Sha256>,
1232        mmr: CleanMmr<sha256::Digest>,
1233        pool: Option<ThreadPool>,
1234    ) {
1235        let element = <Sha256 as Hasher>::Digest::from(*b"01234567012345670123456701234567");
1236        let root = *mmr.root();
1237
1238        // Change a handful of leaves using a batch update.
1239        let mut updates = Vec::new();
1240        for leaf in [0u64, 1, 10, 50, 100, 150, 197, 198] {
1241            let leaf_loc = Location::new_unchecked(leaf);
1242            updates.push((leaf_loc, &element));
1243        }
1244        let mut mmr = mmr.into_dirty();
1245        mmr.update_leaf_batched(hasher, pool, &updates).unwrap();
1246
1247        let mmr = mmr.merkleize(hasher, None);
1248        let updated_root = *mmr.root();
1249        assert_ne!(updated_root, root);
1250
1251        // Batch-restore the changed leaves to their original values.
1252        let mut updates = Vec::new();
1253        for leaf in [0u64, 1, 10, 50, 100, 150, 197, 198] {
1254            hasher.inner().update(&leaf.to_be_bytes());
1255            let element = hasher.inner().finalize();
1256            let leaf_loc = Location::new_unchecked(leaf);
1257            updates.push((leaf_loc, element));
1258        }
1259        let mut mmr = mmr.into_dirty();
1260        mmr.update_leaf_batched(hasher, None, &updates).unwrap();
1261
1262        let mmr = mmr.merkleize(hasher, None);
1263        let restored_root = *mmr.root();
1264        assert_eq!(root, restored_root);
1265    }
1266
1267    #[test]
1268    fn test_init_pinned_nodes_validation() {
1269        let executor = deterministic::Runner::default();
1270        executor.start(|_| async move {
1271            let mut hasher: Standard<Sha256> = Standard::new();
1272            // Test with empty config - should succeed
1273            let config = Config::<sha256::Digest> {
1274                nodes: vec![],
1275                pruned_to_pos: Position::new(0),
1276                pinned_nodes: vec![],
1277            };
1278            assert!(Mmr::init(config, &mut hasher).is_ok());
1279
1280            // Test with too few pinned nodes - should fail
1281            // Use a valid MMR size (127 is valid: 2^7 - 1 makes a complete tree)
1282            let config = Config::<sha256::Digest> {
1283                nodes: vec![],
1284                pruned_to_pos: Position::new(127),
1285                pinned_nodes: vec![], // Should have nodes for position 127
1286            };
1287            assert!(matches!(
1288                Mmr::init(config, &mut hasher),
1289                Err(Error::InvalidPinnedNodes)
1290            ));
1291
1292            // Test with too many pinned nodes - should fail
1293            let config = Config {
1294                nodes: vec![],
1295                pruned_to_pos: Position::new(0),
1296                pinned_nodes: vec![Sha256::hash(b"dummy")],
1297            };
1298            assert!(matches!(
1299                Mmr::init(config, &mut hasher),
1300                Err(Error::InvalidPinnedNodes)
1301            ));
1302
1303            // Test with correct number of pinned nodes - should succeed
1304            // Build a small MMR to get valid pinned nodes
1305            let mut mmr = DirtyMmr::new();
1306            for i in 0u64..50 {
1307                mmr.add(&mut hasher, &i.to_be_bytes());
1308            }
1309            let mmr = mmr.merkleize(&mut hasher, None);
1310            let pinned_nodes = mmr.node_digests_to_pin(Position::new(50));
1311            let config = Config {
1312                nodes: vec![],
1313                pruned_to_pos: Position::new(50),
1314                pinned_nodes,
1315            };
1316            assert!(Mmr::init(config, &mut hasher).is_ok());
1317        });
1318    }
1319
1320    #[test]
1321    fn test_init_size_validation() {
1322        let executor = deterministic::Runner::default();
1323        executor.start(|_| async move {
1324            let mut hasher: Standard<Sha256> = Standard::new();
1325            // Test with valid size 0 - should succeed
1326            let config = Config::<sha256::Digest> {
1327                nodes: vec![],
1328                pruned_to_pos: Position::new(0),
1329                pinned_nodes: vec![],
1330            };
1331            assert!(Mmr::init(config, &mut hasher).is_ok());
1332
1333            // Test with invalid size 2 - should fail
1334            // Size 2 is invalid (can't have just one parent node + one leaf)
1335            let config = Config {
1336                nodes: vec![Sha256::hash(b"node1"), Sha256::hash(b"node2")],
1337                pruned_to_pos: Position::new(0),
1338                pinned_nodes: vec![],
1339            };
1340            assert!(matches!(
1341                Mmr::init(config, &mut hasher),
1342                Err(Error::InvalidSize(_))
1343            ));
1344
1345            // Test with valid size 3 (one full tree with 2 leaves) - should succeed
1346            let config = Config {
1347                nodes: vec![
1348                    Sha256::hash(b"leaf1"),
1349                    Sha256::hash(b"leaf2"),
1350                    Sha256::hash(b"parent"),
1351                ],
1352                pruned_to_pos: Position::new(0),
1353                pinned_nodes: vec![],
1354            };
1355            assert!(Mmr::init(config, &mut hasher).is_ok());
1356
1357            // Test with large valid size (127 = 2^7 - 1, a complete tree) - should succeed
1358            // Build a real MMR to get the correct structure
1359            let mut mmr = DirtyMmr::new();
1360            for i in 0u64..64 {
1361                mmr.add(&mut hasher, &i.to_be_bytes());
1362            }
1363            let mmr = mmr.merkleize(&mut hasher, None);
1364            assert_eq!(mmr.size(), 127); // Verify we have the expected size
1365            let nodes: Vec<_> = (0..127)
1366                .map(|i| *mmr.get_node_unchecked(Position::new(i)))
1367                .collect();
1368
1369            let config = Config {
1370                nodes,
1371                pruned_to_pos: Position::new(0),
1372                pinned_nodes: vec![],
1373            };
1374            assert!(Mmr::init(config, &mut hasher).is_ok());
1375
1376            // Test with non-zero pruned_to_pos - should succeed
1377            // Build a small MMR (11 leaves -> 19 nodes), prune it, then init from that state
1378            let mut mmr = DirtyMmr::new();
1379            for i in 0u64..11 {
1380                mmr.add(&mut hasher, &i.to_be_bytes());
1381            }
1382            let mut mmr = mmr.merkleize(&mut hasher, None);
1383            assert_eq!(mmr.size(), 19); // 11 leaves = 19 total nodes
1384
1385            // Prune to position 7
1386            mmr.prune_to_pos(Position::new(7));
1387            let nodes: Vec<_> = (7..*mmr.size())
1388                .map(|i| *mmr.get_node_unchecked(Position::new(i)))
1389                .collect();
1390            let pinned_nodes = mmr.node_digests_to_pin(Position::new(7));
1391
1392            let config = Config {
1393                nodes: nodes.clone(),
1394                pruned_to_pos: Position::new(7),
1395                pinned_nodes: pinned_nodes.clone(),
1396            };
1397            assert!(Mmr::init(config, &mut hasher).is_ok());
1398
1399            // Same nodes but wrong pruned_to_pos - should fail
1400            // pruned_to_pos=8 + 12 nodes = size 20 (invalid)
1401            let config = Config {
1402                nodes: nodes.clone(),
1403                pruned_to_pos: Position::new(8),
1404                pinned_nodes: pinned_nodes.clone(),
1405            };
1406            assert!(matches!(
1407                Mmr::init(config, &mut hasher),
1408                Err(Error::InvalidSize(_))
1409            ));
1410
1411            // Same nodes but different wrong pruned_to_pos - should fail
1412            // pruned_to_pos=9 + 12 nodes = size 21 (invalid)
1413            let config = Config {
1414                nodes,
1415                pruned_to_pos: Position::new(9),
1416                pinned_nodes,
1417            };
1418            assert!(matches!(
1419                Mmr::init(config, &mut hasher),
1420                Err(Error::InvalidSize(_))
1421            ));
1422        });
1423    }
1424
1425    #[test]
1426    fn test_mem_mmr_range_proof_out_of_bounds() {
1427        let mut hasher: Standard<Sha256> = Standard::new();
1428
1429        let executor = deterministic::Runner::default();
1430        executor.start(|_| async move {
1431            // Range end > leaves errors on empty MMR
1432            let mmr = CleanMmr::new(&mut hasher);
1433            assert_eq!(mmr.leaves(), Location::new_unchecked(0));
1434            let result = mmr.range_proof(Location::new_unchecked(0)..Location::new_unchecked(1));
1435            assert!(matches!(result, Err(Error::RangeOutOfBounds(_))));
1436
1437            // Range end > leaves errors on non-empty MMR
1438            let mmr = build_test_mmr(&mut hasher, mmr, 10);
1439            assert_eq!(mmr.leaves(), Location::new_unchecked(10));
1440            let result = mmr.range_proof(Location::new_unchecked(5)..Location::new_unchecked(11));
1441            assert!(matches!(result, Err(Error::RangeOutOfBounds(_))));
1442
1443            // Range end == leaves succeeds
1444            let result = mmr.range_proof(Location::new_unchecked(5)..Location::new_unchecked(10));
1445            assert!(result.is_ok());
1446        });
1447    }
1448
1449    #[test]
1450    fn test_mem_mmr_proof_out_of_bounds() {
1451        let mut hasher: Standard<Sha256> = Standard::new();
1452
1453        let executor = deterministic::Runner::default();
1454        executor.start(|_| async move {
1455            // Test on empty MMR - should return error, not panic
1456            let mmr = CleanMmr::new(&mut hasher);
1457            let result = mmr.proof(Location::new_unchecked(0));
1458            assert!(
1459                matches!(result, Err(Error::LeafOutOfBounds(_))),
1460                "expected LeafOutOfBounds, got {:?}",
1461                result
1462            );
1463
1464            // Test on non-empty MMR with location >= leaves
1465            let mmr = build_test_mmr(&mut hasher, mmr, 10);
1466            let result = mmr.proof(Location::new_unchecked(10));
1467            assert!(
1468                matches!(result, Err(Error::LeafOutOfBounds(_))),
1469                "expected LeafOutOfBounds, got {:?}",
1470                result
1471            );
1472
1473            // location < leaves should succeed
1474            let result = mmr.proof(Location::new_unchecked(9));
1475            assert!(result.is_ok(), "expected Ok, got {:?}", result);
1476        });
1477    }
1478}