Skip to main content

commonware_storage/merkle/
mem.rs

1//! Generic in-memory Merkle structure, parameterized by [`Family`].
2//!
3//! Both MMR and MMB share the same node storage, pruning, root computation, and proof logic.
4//! This module provides the unified [`Mem`] struct; per-family modules re-export it as
5//! `mmr::mem::Mmr` and `mmb::mem::Mmb` via type aliases.
6
7use crate::merkle::{
8    batch, hasher::Hasher, proof as merkle_proof, Error, Family, Location, Position, Proof,
9    Readable,
10};
11use alloc::{
12    collections::{BTreeMap, VecDeque},
13    vec::Vec,
14};
15use commonware_cryptography::Digest;
16use core::ops::Range;
17
18/// Configuration for initializing a [`Mem`].
19pub struct Config<F: Family, D: Digest> {
20    /// The retained nodes.
21    pub nodes: Vec<D>,
22
23    /// The leaf location up to which pruning has been performed, or 0 if never pruned.
24    pub pruning_boundary: Location<F>,
25
26    /// The pinned nodes, in the order expected by [`Family::nodes_to_pin`].
27    pub pinned_nodes: Vec<D>,
28}
29
30/// A basic, `no_std`-compatible Merkle structure where all nodes are stored in-memory.
31///
32/// Nodes are either _retained_, _pruned_, or _pinned_. Retained nodes are stored in the main
33/// deque. Pruned nodes precede `pruning_boundary` and are no longer stored unless they are still
34/// required for root computation or proof generation, in which case they are kept in
35/// `pinned_nodes`.
36///
37/// The structure is always merkleized (its root is always computed). Mutations go through the
38/// batch API: create an [`UnmerkleizedBatch`](batch::UnmerkleizedBatch) via [`Self::new_batch`],
39/// accumulate changes, merkleize, then apply the result via [`Self::apply_batch`].
40#[derive(Clone, Debug)]
41pub struct Mem<F: Family, D: Digest> {
42    /// The retained nodes, starting at `pruning_boundary`.
43    nodes: VecDeque<D>,
44
45    /// The highest position for which pruning has been performed, or 0 if never pruned.
46    ///
47    /// # Invariant
48    ///
49    /// This is always leaf-aligned (the position corresponding to some `Location`).
50    pruning_boundary: Position<F>,
51
52    /// Auxiliary map from node position to the digest of any pinned node.
53    pinned_nodes: BTreeMap<Position<F>, D>,
54
55    /// The root digest.
56    root: D,
57}
58
59impl<F: Family, D: Digest> Mem<F, D> {
60    /// Create a new, empty structure.
61    pub fn new(hasher: &impl Hasher<F, Digest = D>) -> Self {
62        let root = hasher.root(Location::new(0), core::iter::empty::<&D>());
63        Self {
64            nodes: VecDeque::new(),
65            pruning_boundary: Position::new(0),
66            pinned_nodes: BTreeMap::new(),
67            root,
68        }
69    }
70
71    /// Return a [`Mem`] initialized with the given `config`.
72    ///
73    /// # Errors
74    ///
75    /// Returns [`Error::InvalidPinnedNodes`] if the number of pinned nodes doesn't match the
76    /// expected count for `config.pruning_boundary`.
77    ///
78    /// Returns [`Error::InvalidSize`] if the resulting size is invalid.
79    pub fn init(
80        config: Config<F, D>,
81        hasher: &impl Hasher<F, Digest = D>,
82    ) -> Result<Self, Error<F>> {
83        let pruning_boundary = Position::try_from(config.pruning_boundary)?;
84
85        let Some(size) = pruning_boundary.checked_add(config.nodes.len() as u64) else {
86            return Err(Error::InvalidSize(u64::MAX));
87        };
88        if !size.is_valid_size() {
89            return Err(Error::InvalidSize(*size));
90        }
91
92        let expected_pinned_positions: Vec<_> = F::nodes_to_pin(config.pruning_boundary).collect();
93        if config.pinned_nodes.len() != expected_pinned_positions.len() {
94            return Err(Error::InvalidPinnedNodes);
95        }
96
97        let pinned_nodes = expected_pinned_positions
98            .into_iter()
99            .zip(config.pinned_nodes)
100            .collect();
101        let nodes = VecDeque::from(config.nodes);
102        let root = Self::compute_root(hasher, &nodes, &pinned_nodes, pruning_boundary);
103
104        Ok(Self {
105            nodes,
106            pruning_boundary,
107            pinned_nodes,
108            root,
109        })
110    }
111
112    /// Re-initialize with the given nodes, pruning boundary, and pinned nodes.
113    ///
114    /// # Errors
115    ///
116    /// Returns [`Error::InvalidPinnedNodes`] if the provided pinned node count is invalid for the
117    /// given state.
118    ///
119    /// Returns [`Error::LocationOverflow`] if `pruning_boundary` exceeds [`Family::MAX_LEAVES`].
120    pub fn from_components(
121        hasher: &impl Hasher<F, Digest = D>,
122        nodes: Vec<D>,
123        pruning_boundary: Location<F>,
124        pinned_nodes: Vec<D>,
125    ) -> Result<Self, Error<F>> {
126        Self::init(
127            Config {
128                nodes,
129                pruning_boundary,
130                pinned_nodes,
131            },
132            hasher,
133        )
134    }
135
136    /// Build a pruned structure that retains nodes above the prune boundary.
137    ///
138    /// Like `from_components` but also accepts retained nodes (stored in the
139    /// `nodes` deque). Used by the grafted MMR which has no disk fallback.
140    #[cfg(feature = "std")]
141    pub(crate) fn from_pruned_with_retained(
142        root: D,
143        pruning_boundary: Position<F>,
144        pinned_nodes: BTreeMap<Position<F>, D>,
145        retained_nodes: Vec<D>,
146    ) -> Self {
147        Self {
148            nodes: VecDeque::from(retained_nodes),
149            pruning_boundary,
150            pinned_nodes,
151            root,
152        }
153    }
154
155    /// Compute the root digest from the current peaks.
156    pub(crate) fn compute_root(
157        hasher: &impl Hasher<F, Digest = D>,
158        nodes: &VecDeque<D>,
159        pinned_nodes: &BTreeMap<Position<F>, D>,
160        pruning_boundary: Position<F>,
161    ) -> D {
162        let size = Position::new(nodes.len() as u64 + *pruning_boundary);
163        let leaves = Location::try_from(size).expect("invalid merkle size");
164        let get_node = |pos: Position<F>| -> &D {
165            if pos < pruning_boundary {
166                return pinned_nodes
167                    .get(&pos)
168                    .expect("requested node is pruned and not pinned");
169            }
170            let index = (*pos - *pruning_boundary) as usize;
171            &nodes[index]
172        };
173        let peaks = F::peaks(size).map(|(p, _)| get_node(p));
174        hasher.root(leaves, peaks)
175    }
176
177    /// Return the total number of nodes, irrespective of any pruning.
178    pub fn size(&self) -> Position<F> {
179        Position::new(self.nodes.len() as u64 + *self.pruning_boundary)
180    }
181
182    /// Return the total number of leaves.
183    pub fn leaves(&self) -> Location<F> {
184        Location::try_from(self.size()).expect("invalid merkle size")
185    }
186
187    /// Returns `[start, end)` where `start` is the oldest retained leaf and `end` is the total
188    /// leaf count.
189    pub fn bounds(&self) -> Range<Location<F>> {
190        Location::try_from(self.pruning_boundary).expect("valid pruning_boundary")..self.leaves()
191    }
192
193    /// Return a new iterator over the peaks.
194    pub fn peak_iterator(&self) -> impl Iterator<Item = (Position<F>, u32)> {
195        F::peaks(self.size())
196    }
197
198    /// Get the root digest.
199    pub const fn root(&self) -> &D {
200        &self.root
201    }
202
203    /// Return the requested node if it is either retained or present in the pinned_nodes map, and
204    /// panic otherwise.
205    ///
206    /// # Panics
207    ///
208    /// Panics if the requested node does not exist.
209    pub(crate) fn get_node_unchecked(&self, pos: Position<F>) -> &D {
210        if pos < self.pruning_boundary {
211            return self
212                .pinned_nodes
213                .get(&pos)
214                .expect("requested node is pruned and not pinned");
215        }
216
217        &self.nodes[self.pos_to_index(pos)]
218    }
219
220    /// Return the index of the element in the current nodes vector given its position.
221    ///
222    /// # Panics
223    ///
224    /// Panics if `pos` precedes the oldest retained position.
225    fn pos_to_index(&self, pos: Position<F>) -> usize {
226        assert!(
227            pos >= self.pruning_boundary,
228            "pos precedes oldest retained position"
229        );
230        (*pos - *self.pruning_boundary) as usize
231    }
232
233    /// Return the requested node or `None` if it is not stored.
234    pub fn get_node(&self, pos: Position<F>) -> Option<D> {
235        if pos < self.pruning_boundary {
236            return self.pinned_nodes.get(&pos).copied();
237        }
238
239        self.nodes.get(self.pos_to_index(pos)).copied()
240    }
241
242    /// Get the nodes (position + digest) that need to be pinned when pruned to `prune_loc`.
243    pub(crate) fn nodes_to_pin(&self, prune_loc: Location<F>) -> BTreeMap<Position<F>, D> {
244        F::nodes_to_pin(prune_loc)
245            .map(|pos| (pos, *self.get_node_unchecked(pos)))
246            .collect()
247    }
248
249    /// Prune all nodes up to but not including the given leaf location, and pin the nodes still
250    /// required for root computation and proof generation.
251    ///
252    /// # Errors
253    ///
254    /// Returns [`Error::LocationOverflow`] if `loc` exceeds [`Family::MAX_LEAVES`].
255    /// Returns [`Error::LeafOutOfBounds`] if `loc` exceeds the current leaf count.
256    pub fn prune(&mut self, loc: Location<F>) -> Result<(), Error<F>> {
257        if loc > self.leaves() {
258            return Err(Error::LeafOutOfBounds(loc));
259        }
260
261        let pos = Position::try_from(loc)?;
262        if pos <= self.pruning_boundary {
263            return Ok(());
264        }
265
266        self.prune_to_loc(loc);
267        Ok(())
268    }
269
270    /// Prune all retained nodes.
271    pub fn prune_all(&mut self) {
272        if !self.nodes.is_empty() {
273            self.prune_to_loc(self.leaves());
274        }
275    }
276
277    /// Location-based pruning.
278    fn prune_to_loc(&mut self, loc: Location<F>) {
279        let pinned = self.nodes_to_pin(loc);
280        let pos = Position::try_from(loc).expect("valid location");
281        let retained_nodes = self.pos_to_index(pos);
282        self.pinned_nodes = pinned;
283        self.nodes.drain(0..retained_nodes);
284        self.pruning_boundary = pos;
285    }
286
287    /// Return an inclusion proof for the element at location `loc`.
288    ///
289    /// # Errors
290    ///
291    /// Returns [`Error::LocationOverflow`] if `loc` exceeds the valid range.
292    /// Returns [`Error::LeafOutOfBounds`] if `loc` >= [`Self::leaves()`].
293    /// Returns [`Error::ElementPruned`] if a required node is missing.
294    pub fn proof(
295        &self,
296        hasher: &impl Hasher<F, Digest = D>,
297        loc: Location<F>,
298    ) -> Result<Proof<F, D>, Error<F>> {
299        if !loc.is_valid_index() {
300            return Err(Error::LocationOverflow(loc));
301        }
302        self.range_proof(hasher, loc..loc + 1).map_err(|e| match e {
303            Error::RangeOutOfBounds(_) => Error::LeafOutOfBounds(loc),
304            _ => e,
305        })
306    }
307
308    /// Return an inclusion proof for all elements within the provided `range` of locations.
309    ///
310    /// # Errors
311    ///
312    /// Returns [`Error::Empty`] if the range is empty.
313    /// Returns [`Error::LocationOverflow`] if any location exceeds the valid range.
314    /// Returns [`Error::RangeOutOfBounds`] if `range.end` > [`Self::leaves()`].
315    /// Returns [`Error::ElementPruned`] if a required node is missing.
316    pub fn range_proof(
317        &self,
318        hasher: &impl Hasher<F, Digest = D>,
319        range: Range<Location<F>>,
320    ) -> Result<Proof<F, D>, Error<F>> {
321        merkle_proof::build_range_proof(
322            hasher,
323            self.leaves(),
324            range,
325            |pos| self.get_node(pos),
326            Error::ElementPruned,
327        )
328    }
329
330    /// Get the digests of nodes that need to be pinned at the provided pruning boundary.
331    #[cfg(test)]
332    pub(crate) fn node_digests_to_pin(&self, prune_loc: Location<F>) -> Vec<D> {
333        F::nodes_to_pin(prune_loc)
334            .map(|pos| *self.get_node_unchecked(pos))
335            .collect()
336    }
337
338    /// Pin extra nodes. It's up to the caller to ensure this set is valid.
339    #[cfg(any(feature = "std", test))]
340    pub(crate) fn add_pinned_nodes(&mut self, pinned_nodes: BTreeMap<Position<F>, D>) {
341        for (pos, node) in pinned_nodes {
342            self.pinned_nodes.insert(pos, node);
343        }
344    }
345
346    /// Truncate the structure to a smaller valid size, discarding all nodes beyond that size.
347    /// Recomputes the root after truncation.
348    #[cfg(feature = "std")]
349    #[allow(dead_code)]
350    pub(crate) fn truncate(&mut self, new_size: Position<F>, hasher: &impl Hasher<F, Digest = D>) {
351        debug_assert!(new_size.is_valid_size());
352        debug_assert!(new_size >= self.pruning_boundary);
353        let keep = (*new_size - *self.pruning_boundary) as usize;
354        self.nodes.truncate(keep);
355        self.root = Self::compute_root(
356            hasher,
357            &self.nodes,
358            &self.pinned_nodes,
359            self.pruning_boundary,
360        );
361    }
362
363    /// Return the nodes this structure currently has pinned.
364    #[cfg(test)]
365    pub(crate) fn pinned_nodes(&self) -> BTreeMap<Position<F>, D> {
366        self.pinned_nodes.clone()
367    }
368
369    /// Create a new speculative batch with this structure as its parent.
370    pub fn new_batch(&self) -> batch::UnmerkleizedBatch<F, D> {
371        let root = batch::MerkleizedBatch::from_mem(self);
372        root.new_batch()
373    }
374
375    /// Apply a merkleized batch. Already-committed ancestors are skipped automatically.
376    pub fn apply_batch(&mut self, batch: &batch::MerkleizedBatch<F, D>) -> Result<(), Error<F>> {
377        let skip_ancestors = if self.size() == batch.base_size {
378            false
379        } else if self.size() > batch.base_size && self.size() < batch.size() {
380            true
381        } else if self.size() == batch.size() && batch.appended.is_empty() {
382            // All ancestors committed and this batch has overwrites only (no appends).
383            true
384        } else {
385            return Err(Error::StaleBatch {
386                expected: batch.base_size,
387                actual: self.size(),
388            });
389        };
390
391        // Apply ancestor batches in root-to-tip order. Already-committed
392        // batches (whose appended nodes are already in the Mem) are skipped
393        // by tracking a running position through the ancestor chain.
394        let mut batch_pos = *batch.base_size;
395        for (appended, overwrites) in batch
396            .ancestor_appended
397            .iter()
398            .zip(&batch.ancestor_overwrites)
399        {
400            batch_pos += appended.len() as u64;
401            // Overwrite-only ancestors don't advance batch_pos, so they can't be
402            // distinguished from their predecessor by size. Use strict < to
403            // avoid skipping them at the boundary. Re-applying committed
404            // overwrites is harmless (idempotent).
405            let committed = if appended.is_empty() {
406                skip_ancestors && batch_pos < *self.size()
407            } else {
408                skip_ancestors && batch_pos <= *self.size()
409            };
410            if committed {
411                continue;
412            }
413            for (&pos, &digest) in overwrites.iter() {
414                if pos < self.pruning_boundary {
415                    continue;
416                }
417                let index = self.pos_to_index(pos);
418                self.nodes[index] = digest;
419            }
420            for &digest in appended.iter() {
421                self.nodes.push_back(digest);
422            }
423        }
424
425        // Apply this batch's own data.
426        for (&pos, &digest) in batch.overwrites.iter() {
427            if skip_ancestors && pos < self.pruning_boundary {
428                continue;
429            }
430            let index = self.pos_to_index(pos);
431            self.nodes[index] = digest;
432        }
433        for &digest in batch.appended.iter() {
434            self.nodes.push_back(digest);
435        }
436
437        // Detect missing ancestor data. If an uncommitted ancestor was dropped
438        // before this batch was merkleized, its appended nodes are absent and the
439        // Mem ends up smaller than expected. This does not catch dropped
440        // overwrite-only ancestors (they don't change the size).
441        if self.size() != batch.size() {
442            return Err(Error::AncestorDropped {
443                expected: batch.size(),
444                actual: self.size(),
445            });
446        }
447
448        self.root = batch.root();
449        Ok(())
450    }
451}
452
453impl<F: Family, D: Digest> Readable for Mem<F, D> {
454    type Family = F;
455    type Digest = D;
456    type Error = Error<F>;
457
458    fn size(&self) -> Position<F> {
459        self.size()
460    }
461
462    fn get_node(&self, pos: Position<F>) -> Option<D> {
463        self.get_node(pos)
464    }
465
466    fn root(&self) -> D {
467        *self.root()
468    }
469
470    fn pruning_boundary(&self) -> Location<F> {
471        Location::try_from(self.pruning_boundary).expect("valid pruning_boundary")
472    }
473
474    fn proof(
475        &self,
476        hasher: &impl Hasher<F, Digest = D>,
477        loc: Location<F>,
478    ) -> Result<Proof<F, D>, Error<F>> {
479        self.proof(hasher, loc)
480    }
481
482    fn range_proof(
483        &self,
484        hasher: &impl Hasher<F, Digest = D>,
485        range: Range<Location<F>>,
486    ) -> Result<Proof<F, D>, Error<F>> {
487        self.range_proof(hasher, range)
488    }
489}
490
491#[cfg(test)]
492mod tests {
493    use super::*;
494    use crate::merkle::{hasher::Standard, Error, Location, Position};
495    use commonware_cryptography::{sha256, Sha256};
496    use commonware_runtime::{deterministic, Runner as _, ThreadPooler};
497    use commonware_utils::NZUsize;
498
499    type D = sha256::Digest;
500    type H = Standard<Sha256>;
501
502    fn build<F: Family>(hasher: &H, n: u64) -> Mem<F, D> {
503        let mut mem = Mem::new(hasher);
504        let batch = {
505            let mut batch = mem.new_batch();
506            for i in 0u64..n {
507                let element = hasher.digest(&i.to_be_bytes());
508                batch = batch.add(hasher, &element);
509            }
510            batch.merkleize(&mem, hasher)
511        };
512        mem.apply_batch(&batch).unwrap();
513        mem
514    }
515
516    fn build_raw<F: Family>(hasher: &H, n: u64) -> Mem<F, D> {
517        let mut mem = Mem::new(hasher);
518        let batch = {
519            let mut batch = mem.new_batch();
520            for i in 0u64..n {
521                batch = batch.add(hasher, &i.to_be_bytes());
522            }
523            batch.merkleize(&mem, hasher)
524        };
525        mem.apply_batch(&batch).unwrap();
526        mem
527    }
528
529    fn empty<F: Family>() {
530        let hasher: H = Standard::new();
531        let mem = Mem::<F, D>::new(&hasher);
532        assert_eq!(*mem.leaves(), 0);
533        assert_eq!(*mem.size(), 0);
534        assert!(mem.bounds().is_empty());
535    }
536
537    fn validity<F: Family>() {
538        let executor = deterministic::Runner::default();
539        executor.start(|_| async move {
540            let hasher: H = Standard::new();
541            let mut mem = Mem::<F, D>::new(&hasher);
542            for i in 0u64..256 {
543                assert!(
544                    mem.size().is_valid_size(),
545                    "size should be valid at step {i}"
546                );
547                let old_size = mem.size();
548                let batch = mem
549                    .new_batch()
550                    .add(&hasher, &i.to_be_bytes())
551                    .merkleize(&mem, &hasher);
552                mem.apply_batch(&batch).unwrap();
553                for size in *old_size + 1..*mem.size() {
554                    assert!(
555                        !Position::<F>::new(size).is_valid_size(),
556                        "size {size} should not be valid"
557                    );
558                }
559            }
560        });
561    }
562
563    fn prune_all_then_append<F: Family>() {
564        let executor = deterministic::Runner::default();
565        executor.start(|_| async move {
566            let hasher: H = Standard::new();
567            let mut mem = Mem::<F, D>::new(&hasher);
568            for i in 0u64..256 {
569                mem.prune_all();
570                let batch = mem
571                    .new_batch()
572                    .add(&hasher, &i.to_be_bytes())
573                    .merkleize(&mem, &hasher);
574                mem.apply_batch(&batch).unwrap();
575                assert_eq!(*mem.leaves(), i + 1);
576            }
577        });
578    }
579
580    fn range_proof_out_of_bounds<F: Family>() {
581        let executor = deterministic::Runner::default();
582        executor.start(|_| async move {
583            let hasher: H = Standard::new();
584            let mem = Mem::<F, D>::new(&hasher);
585            assert!(matches!(
586                mem.range_proof(&hasher, Location::new(0)..Location::new(1)),
587                Err(Error::RangeOutOfBounds(_))
588            ));
589            let mem = build::<F>(&hasher, 10);
590            assert!(matches!(
591                mem.range_proof(&hasher, Location::new(5)..Location::new(11)),
592                Err(Error::RangeOutOfBounds(_))
593            ));
594            assert!(mem
595                .range_proof(&hasher, Location::new(5)..Location::new(10))
596                .is_ok());
597        });
598    }
599
600    fn proof_out_of_bounds<F: Family>() {
601        let executor = deterministic::Runner::default();
602        executor.start(|_| async move {
603            let hasher: H = Standard::new();
604            let mem = Mem::<F, D>::new(&hasher);
605            assert!(matches!(
606                mem.proof(&hasher, Location::new(0)),
607                Err(Error::LeafOutOfBounds(_))
608            ));
609            let mem = build::<F>(&hasher, 10);
610            assert!(matches!(
611                mem.proof(&hasher, Location::new(10)),
612                Err(Error::LeafOutOfBounds(_))
613            ));
614            assert!(mem.proof(&hasher, Location::new(9)).is_ok());
615        });
616    }
617
618    fn init_pinned_nodes_validation<F: Family>() {
619        let executor = deterministic::Runner::default();
620        executor.start(|_| async move {
621            let hasher: H = Standard::new();
622
623            assert!(Mem::<F, D>::init(
624                Config {
625                    nodes: vec![],
626                    pruning_boundary: Location::new(0),
627                    pinned_nodes: vec![],
628                },
629                &hasher,
630            )
631            .is_ok());
632
633            assert!(matches!(
634                Mem::<F, D>::init(
635                    Config {
636                        nodes: vec![],
637                        pruning_boundary: Location::new(8),
638                        pinned_nodes: vec![],
639                    },
640                    &hasher,
641                ),
642                Err(Error::InvalidPinnedNodes)
643            ));
644
645            assert!(matches!(
646                Mem::<F, D>::init(
647                    Config {
648                        nodes: vec![],
649                        pruning_boundary: Location::new(0),
650                        pinned_nodes: vec![hasher.digest(b"dummy")],
651                    },
652                    &hasher,
653                ),
654                Err(Error::InvalidPinnedNodes)
655            ));
656
657            let mem = build::<F>(&hasher, 50);
658            let prune_loc = Location::<F>::new(25);
659            let pinned_nodes = mem.node_digests_to_pin(prune_loc);
660            assert!(Mem::<F, D>::init(
661                Config {
662                    nodes: vec![],
663                    pruning_boundary: prune_loc,
664                    pinned_nodes,
665                },
666                &hasher,
667            )
668            .is_ok());
669        });
670    }
671
672    fn root_stable_under_pruning<F: Family>() {
673        let executor = deterministic::Runner::default();
674        executor.start(|_| async move {
675            let hasher: H = Standard::new();
676            let mut reference = Mem::<F, D>::new(&hasher);
677            let mut pruned = Mem::<F, D>::new(&hasher);
678            for i in 0u64..200 {
679                let element = hasher.digest(&i.to_be_bytes());
680                let cs = reference
681                    .new_batch()
682                    .add(&hasher, &element)
683                    .merkleize(&reference, &hasher);
684                reference.apply_batch(&cs).unwrap();
685                let cs = pruned
686                    .new_batch()
687                    .add(&hasher, &element)
688                    .merkleize(&pruned, &hasher);
689                pruned.apply_batch(&cs).unwrap();
690                pruned.prune_all();
691                assert_eq!(pruned.root(), reference.root());
692            }
693        });
694    }
695
696    fn do_batch_update<F: Family>(
697        hasher: &H,
698        mut mem: Mem<F, D>,
699        pool: Option<commonware_parallel::ThreadPool>,
700    ) {
701        let element = D::from(*b"01234567012345670123456701234567");
702        let root = *mem.root();
703
704        let batch = {
705            let mut batch = mem.new_batch();
706            if let Some(ref pool) = pool {
707                batch = batch.with_pool(Some(pool.clone()));
708            }
709            for leaf in [0u64, 1, 10, 50, 100, 150, 197, 198] {
710                batch = batch
711                    .update_leaf(hasher, Location::new(leaf), &element)
712                    .unwrap();
713            }
714            batch.merkleize(&mem, hasher)
715        };
716        mem.apply_batch(&batch).unwrap();
717        assert_ne!(*mem.root(), root);
718
719        let batch = {
720            let mut batch = mem.new_batch();
721            for leaf in [0u64, 1, 10, 50, 100, 150, 197, 198] {
722                let element = hasher.digest(&leaf.to_be_bytes());
723                batch = batch
724                    .update_leaf(hasher, Location::new(leaf), &element)
725                    .unwrap();
726            }
727            batch.merkleize(&mem, hasher)
728        };
729        mem.apply_batch(&batch).unwrap();
730        assert_eq!(*mem.root(), root);
731    }
732
733    fn batch_update_leaf<F: Family>() {
734        let executor = deterministic::Runner::default();
735        executor.start(|_| async move {
736            let hasher: H = Standard::new();
737            let mem = build::<F>(&hasher, 200);
738            do_batch_update(&hasher, mem, None);
739        });
740    }
741
742    fn batch_parallel_update_leaf<F: Family>() {
743        let executor = commonware_runtime::tokio::Runner::default();
744        executor.start(|ctx| async move {
745            let hasher: H = Standard::new();
746            let mem = build::<F>(&hasher, 200);
747            let pool = ctx.create_thread_pool(NZUsize!(4)).unwrap();
748            do_batch_update(&hasher, mem, Some(pool));
749        });
750    }
751
752    fn root_changes_with_each_append<F: Family>() {
753        let hasher: H = Standard::new();
754        let mut mem = Mem::<F, D>::new(&hasher);
755        let mut prev_root = *mem.root();
756        for i in 0u64..16 {
757            let batch = {
758                let batch = mem.new_batch();
759                let batch = batch.add(&hasher, &i.to_be_bytes());
760                batch.merkleize(&mem, &hasher)
761            };
762            mem.apply_batch(&batch).unwrap();
763            assert_ne!(
764                *mem.root(),
765                prev_root,
766                "root should change after append {i}"
767            );
768            prev_root = *mem.root();
769        }
770    }
771
772    fn single_element_proof_roundtrip<F: Family>() {
773        let hasher: H = Standard::new();
774        let mem = build_raw::<F>(&hasher, 16);
775        let root = *mem.root();
776        for i in 0u64..16 {
777            let proof = mem
778                .proof(&hasher, Location::new(i))
779                .unwrap_or_else(|e| panic!("loc={i}: {e:?}"));
780            assert!(
781                proof.verify_element_inclusion(&hasher, &i.to_be_bytes(), Location::new(i), &root),
782                "loc={i}: proof should verify"
783            );
784        }
785    }
786
787    fn range_proof_roundtrip_exhaustive<F: Family>() {
788        for n in 1u64..=24 {
789            let hasher: H = Standard::new();
790            let mem = build_raw::<F>(&hasher, n);
791            let root = *mem.root();
792
793            for start in 0..n {
794                for end in start + 1..=n {
795                    let range = Location::new(start)..Location::new(end);
796                    let proof = mem
797                        .range_proof(&hasher, range.clone())
798                        .unwrap_or_else(|e| panic!("n={n}, range={start}..{end}: {e:?}"));
799                    let elements: Vec<_> = (start..end).map(|i| i.to_be_bytes()).collect();
800
801                    assert!(
802                        proof.verify_range_inclusion(&hasher, &elements, range.start, &root),
803                        "n={n}, range={start}..{end}: range proof should verify"
804                    );
805                }
806            }
807        }
808    }
809
810    fn root_with_repeated_pruning<F: Family>() {
811        let hasher: H = Standard::new();
812        let mut mem = build::<F>(&hasher, 32);
813        let root = *mem.root();
814
815        for prune_leaf in 1..*mem.leaves() {
816            let prune_loc = Location::new(prune_leaf);
817            mem.prune(prune_loc).unwrap();
818            assert_eq!(
819                *mem.root(),
820                root,
821                "root changed after pruning to {prune_loc}"
822            );
823            assert_eq!(mem.bounds().start, prune_loc);
824            assert!(
825                mem.proof(&hasher, prune_loc).is_ok(),
826                "boundary leaf {prune_loc} should remain provable"
827            );
828            assert!(
829                mem.proof(&hasher, mem.leaves() - 1).is_ok(),
830                "latest leaf should remain provable after pruning to {prune_loc}"
831            );
832        }
833
834        mem.prune_all();
835        assert_eq!(*mem.root(), root, "root changed after prune_all");
836        assert!(mem.bounds().is_empty(), "prune_all should retain no leaves");
837    }
838
839    fn append_after_partial_prune<F: Family>() {
840        let hasher: H = Standard::new();
841        let mut mem = build_raw::<F>(&hasher, 20);
842        mem.prune(Location::new(7)).unwrap();
843
844        let batch = {
845            let mut batch = mem.new_batch();
846            for i in 20u64..48 {
847                batch = batch.add(&hasher, &i.to_be_bytes());
848            }
849            batch.merkleize(&mem, &hasher)
850        };
851        mem.apply_batch(&batch).unwrap();
852
853        let root = *mem.root();
854        for loc in *mem.bounds().start..*mem.leaves() {
855            let proof = mem
856                .proof(&hasher, Location::new(loc))
857                .unwrap_or_else(|e| panic!("loc={loc}: {e:?}"));
858            assert!(
859                proof.verify_element_inclusion(
860                    &hasher,
861                    &loc.to_be_bytes(),
862                    Location::new(loc),
863                    &root
864                ),
865                "loc={loc}: proof should verify after append on pruned structure"
866            );
867        }
868    }
869
870    fn update_leaf<F: Family>() {
871        let hasher: H = Standard::new();
872        let mut mem = build_raw::<F>(&hasher, 11);
873        let root_before = *mem.root();
874
875        let batch = {
876            let batch = mem.new_batch();
877            let batch = batch
878                .update_leaf(&hasher, Location::new(5), b"updated-5")
879                .unwrap();
880            batch.merkleize(&mem, &hasher)
881        };
882        mem.apply_batch(&batch).unwrap();
883
884        assert_ne!(*mem.root(), root_before, "root should change after update");
885        assert_eq!(*mem.leaves(), 11);
886
887        let proof = mem.proof(&hasher, Location::new(5)).unwrap();
888        assert!(
889            proof.verify_element_inclusion(&hasher, b"updated-5", Location::new(5), mem.root()),
890            "updated leaf should verify with new data"
891        );
892
893        assert!(
894            !proof.verify_element_inclusion(
895                &hasher,
896                &5u64.to_be_bytes(),
897                Location::new(5),
898                mem.root()
899            ),
900            "old data should not verify"
901        );
902
903        for i in [0u64, 3, 7, 10] {
904            let p = mem.proof(&hasher, Location::new(i)).unwrap();
905            assert!(
906                p.verify_element_inclusion(&hasher, &i.to_be_bytes(), Location::new(i), mem.root()),
907                "leaf {i} should still verify with original data"
908            );
909        }
910    }
911
912    fn update_leaf_every_position<F: Family>() {
913        let n = 20u64;
914        let hasher: H = Standard::new();
915        let mut mem = build::<F>(&hasher, n);
916
917        for update_loc in 0..n {
918            let batch = {
919                let batch = mem.new_batch();
920                let batch = batch
921                    .update_leaf(&hasher, Location::new(update_loc), b"new-value")
922                    .unwrap();
923                batch.merkleize(&mem, &hasher)
924            };
925            mem.apply_batch(&batch).unwrap();
926
927            let proof = mem.proof(&hasher, Location::new(update_loc)).unwrap();
928            assert!(
929                proof.verify_element_inclusion(
930                    &hasher,
931                    b"new-value",
932                    Location::new(update_loc),
933                    mem.root()
934                ),
935                "update at {update_loc} should verify"
936            );
937        }
938    }
939
940    fn update_leaf_errors<F: Family>() {
941        let hasher: H = Standard::new();
942        let mut mem = build::<F>(&hasher, 10);
943
944        {
945            let batch = mem.new_batch();
946            assert!(matches!(
947                batch.update_leaf(&hasher, Location::new(10), b"x"),
948                Err(Error::LeafOutOfBounds(_))
949            ));
950        }
951
952        mem.prune(Location::new(5)).unwrap();
953        {
954            let batch = mem.new_batch();
955            assert!(matches!(
956                batch.update_leaf(&hasher, Location::new(3), b"x"),
957                Err(Error::ElementPruned(_))
958            ));
959            let batch = mem.new_batch();
960            assert!(batch.update_leaf(&hasher, Location::new(5), b"x").is_ok());
961        }
962    }
963
964    fn update_leaf_with_append<F: Family>() {
965        let hasher: H = Standard::new();
966        let mut mem = build::<F>(&hasher, 8);
967
968        let batch = {
969            let batch = mem.new_batch();
970            let batch = batch
971                .update_leaf(&hasher, Location::new(3), b"updated-3")
972                .unwrap();
973            let batch = batch.add(&hasher, &100u64.to_be_bytes());
974            let batch = batch.add(&hasher, &101u64.to_be_bytes());
975            batch.merkleize(&mem, &hasher)
976        };
977        mem.apply_batch(&batch).unwrap();
978
979        assert_eq!(*mem.leaves(), 10);
980
981        let proof = mem.proof(&hasher, Location::new(3)).unwrap();
982        assert!(proof.verify_element_inclusion(
983            &hasher,
984            b"updated-3",
985            Location::new(3),
986            mem.root()
987        ));
988
989        let proof = mem.proof(&hasher, Location::new(8)).unwrap();
990        assert!(proof.verify_element_inclusion(
991            &hasher,
992            &100u64.to_be_bytes(),
993            Location::new(8),
994            mem.root()
995        ));
996    }
997
998    fn update_leaf_under_merge_parent<F: Family>() {
999        let hasher: H = Standard::new();
1000        let mut mem = build::<F>(&hasher, 2);
1001        let batch = {
1002            let batch = mem.new_batch();
1003            let batch = batch.add(&hasher, &2u64.to_be_bytes());
1004            let batch = batch
1005                .update_leaf(&hasher, Location::new(0), b"updated-0")
1006                .unwrap();
1007            batch.merkleize(&mem, &hasher)
1008        };
1009        mem.apply_batch(&batch).unwrap();
1010
1011        let ref_hasher: H = Standard::new();
1012        let mut ref_mem = build::<F>(&ref_hasher, 2);
1013        let cs = {
1014            let batch = ref_mem.new_batch();
1015            let batch = batch.add(&ref_hasher, &2u64.to_be_bytes());
1016            batch.merkleize(&ref_mem, &ref_hasher)
1017        };
1018        ref_mem.apply_batch(&cs).unwrap();
1019        let cs = {
1020            let batch = ref_mem.new_batch();
1021            let batch = batch
1022                .update_leaf(&ref_hasher, Location::new(0), b"updated-0")
1023                .unwrap();
1024            batch.merkleize(&ref_mem, &ref_hasher)
1025        };
1026        ref_mem.apply_batch(&cs).unwrap();
1027
1028        assert_eq!(*mem.root(), *ref_mem.root(), "roots must match");
1029
1030        let proof = mem.proof(&hasher, Location::new(0)).unwrap();
1031        assert!(
1032            proof.verify_element_inclusion(&hasher, b"updated-0", Location::new(0), mem.root()),
1033            "updated leaf should verify"
1034        );
1035    }
1036
1037    /// Prune to every valid boundary in structures of size 1..=max_n, then update_leaf +
1038    /// merkleize each retained leaf and verify its inclusion proof. This exercises the pinned
1039    /// nodes produced by `nodes_to_pin` under re-merkleization.
1040    fn update_leaf_after_prune<F: Family>() {
1041        let max_n = 20u64;
1042        let hasher: H = Standard::new();
1043        for n in 1..=max_n {
1044            for prune_to in 1..n {
1045                let mut mem = build_raw::<F>(&hasher, n);
1046                mem.prune(Location::new(prune_to)).unwrap();
1047
1048                for update_loc in prune_to..n {
1049                    // Clone so each update starts from the same pruned state.
1050                    let mut m = mem.clone();
1051                    let batch = {
1052                        let batch = m.new_batch();
1053                        let batch = batch
1054                            .update_leaf(&hasher, Location::new(update_loc), b"new")
1055                            .unwrap();
1056                        batch.merkleize(&m, &hasher)
1057                    };
1058                    m.apply_batch(&batch).unwrap();
1059
1060                    let proof = m.proof(&hasher, Location::new(update_loc)).unwrap();
1061                    assert!(
1062                        proof.verify_element_inclusion(
1063                            &hasher,
1064                            b"new",
1065                            Location::new(update_loc),
1066                            m.root()
1067                        ),
1068                        "n={n} prune={prune_to} update={update_loc}: proof should verify"
1069                    );
1070                }
1071            }
1072        }
1073    }
1074
1075    /// Applying C (child of B, grandchild of A) after only A is applied
1076    /// must apply B's uncommitted data + C's data, skipping only A.
1077    fn apply_batch_skips_only_committed_ancestors<F: Family>() {
1078        let hasher: H = Standard::new();
1079        let mut mem = Mem::<F, D>::new(&hasher);
1080
1081        // Chain: Mem -> A -> B -> C
1082        let a = mem.new_batch().add(&hasher, b"a").merkleize(&mem, &hasher);
1083        let b = a.new_batch().add(&hasher, b"b").merkleize(&mem, &hasher);
1084        let c = b.new_batch().add(&hasher, b"c").merkleize(&mem, &hasher);
1085
1086        // Apply A, then apply C directly (skipping B's apply_batch).
1087        // C's ancestor batches carry [A.data, B.data]. A is already committed
1088        // so only B + C should be applied.
1089        mem.apply_batch(&a).unwrap();
1090        mem.apply_batch(&c).unwrap();
1091
1092        // Verify against a reference that applied all three in order.
1093        let mut reference = Mem::<F, D>::new(&hasher);
1094        let full = {
1095            let mut batch = reference.new_batch();
1096            for leaf in [b"a".as_slice(), b"b", b"c"] {
1097                batch = batch.add(&hasher, leaf);
1098            }
1099            batch.merkleize(&reference, &hasher)
1100        };
1101        reference.apply_batch(&full).unwrap();
1102        assert_eq!(mem.root(), reference.root());
1103    }
1104
1105    /// Dropping an uncommitted ancestor before merkleizing a descendant must
1106    /// be detected at apply time, not silently corrupt data.
1107    fn apply_batch_detects_dropped_ancestor<F: Family>() {
1108        let hasher: H = Standard::new();
1109        let mut mem = Mem::<F, D>::new(&hasher);
1110
1111        let a = mem.new_batch().add(&hasher, b"a").merkleize(&mem, &hasher);
1112        let b = a.new_batch().add(&hasher, b"b").merkleize(&mem, &hasher);
1113        drop(a); // A dropped before C is merkleized — its data is lost
1114        let c = b.new_batch().add(&hasher, b"c").merkleize(&mem, &hasher);
1115
1116        let result = mem.apply_batch(&c);
1117        assert!(
1118            matches!(result, Err(Error::AncestorDropped { .. })),
1119            "expected AncestorDropped, got {result:?}"
1120        );
1121    }
1122
1123    /// Overwrite-only ancestor B must not be skipped when applying C after A.
1124    fn apply_batch_overwrite_only_ancestor<F: Family>() {
1125        let hasher: H = Standard::new();
1126        let mut mem = build_raw::<F>(&hasher, 10);
1127
1128        let pos0 = Position::<F>::try_from(Location::new(0)).unwrap();
1129
1130        // A: add 5 leaves.
1131        let a = {
1132            let mut b = mem.new_batch();
1133            for i in 100u64..105 {
1134                b = b.add(&hasher, &i.to_be_bytes());
1135            }
1136            b.merkleize(&mem, &hasher)
1137        };
1138
1139        // B: overwrite leaf 0, no appends.
1140        let b = a
1141            .new_batch()
1142            .update_leaf(&hasher, Location::new(0), b"updated-0")
1143            .unwrap()
1144            .merkleize(&mem, &hasher);
1145
1146        // C: add 5 more leaves.
1147        let c = {
1148            let mut batch = b.new_batch();
1149            for i in 200u64..205 {
1150                batch = batch.add(&hasher, &i.to_be_bytes());
1151            }
1152            batch.merkleize(&mem, &hasher)
1153        };
1154
1155        // Apply A, then C (skipping B's apply_batch).
1156        mem.apply_batch(&a).unwrap();
1157        mem.apply_batch(&c).unwrap();
1158
1159        // B's overwrite must have been applied.
1160        let updated = hasher.leaf_digest(pos0, b"updated-0");
1161        assert_eq!(
1162            mem.get_node(pos0),
1163            Some(updated),
1164            "overwrite-only ancestor B's overwrites were skipped"
1165        );
1166    }
1167
1168    // --- MMR tests ---
1169
1170    #[test]
1171    fn mmr_empty() {
1172        empty::<crate::mmr::Family>();
1173    }
1174    #[test]
1175    fn mmr_validity() {
1176        validity::<crate::mmr::Family>();
1177    }
1178    #[test]
1179    fn mmr_prune_all_then_append() {
1180        prune_all_then_append::<crate::mmr::Family>();
1181    }
1182    #[test]
1183    fn mmr_range_proof_oob() {
1184        range_proof_out_of_bounds::<crate::mmr::Family>();
1185    }
1186    #[test]
1187    fn mmr_proof_oob() {
1188        proof_out_of_bounds::<crate::mmr::Family>();
1189    }
1190    #[test]
1191    fn mmr_init_pinned_nodes() {
1192        init_pinned_nodes_validation::<crate::mmr::Family>();
1193    }
1194    #[test]
1195    fn mmr_root_stable_under_pruning() {
1196        root_stable_under_pruning::<crate::mmr::Family>();
1197    }
1198    #[test]
1199    fn mmr_batch_update_leaf() {
1200        batch_update_leaf::<crate::mmr::Family>();
1201    }
1202    #[test]
1203    fn mmr_batch_parallel_update_leaf() {
1204        batch_parallel_update_leaf::<crate::mmr::Family>();
1205    }
1206    #[test]
1207    fn mmr_root_changes_with_each_append() {
1208        root_changes_with_each_append::<crate::mmr::Family>();
1209    }
1210    #[test]
1211    fn mmr_single_element_proof_roundtrip() {
1212        single_element_proof_roundtrip::<crate::mmr::Family>();
1213    }
1214    #[test]
1215    fn mmr_range_proof_roundtrip_exhaustive() {
1216        range_proof_roundtrip_exhaustive::<crate::mmr::Family>();
1217    }
1218    #[test]
1219    fn mmr_root_with_repeated_pruning() {
1220        root_with_repeated_pruning::<crate::mmr::Family>();
1221    }
1222    #[test]
1223    fn mmr_append_after_partial_prune() {
1224        append_after_partial_prune::<crate::mmr::Family>();
1225    }
1226    #[test]
1227    fn mmr_update_leaf() {
1228        update_leaf::<crate::mmr::Family>();
1229    }
1230    #[test]
1231    fn mmr_update_leaf_every_position() {
1232        update_leaf_every_position::<crate::mmr::Family>();
1233    }
1234    #[test]
1235    fn mmr_update_leaf_errors() {
1236        update_leaf_errors::<crate::mmr::Family>();
1237    }
1238    #[test]
1239    fn mmr_update_leaf_with_append() {
1240        update_leaf_with_append::<crate::mmr::Family>();
1241    }
1242    #[test]
1243    fn mmr_update_leaf_under_merge_parent() {
1244        update_leaf_under_merge_parent::<crate::mmr::Family>();
1245    }
1246    #[test]
1247    fn mmr_update_leaf_after_prune() {
1248        update_leaf_after_prune::<crate::mmr::Family>();
1249    }
1250    #[test]
1251    fn mmr_apply_batch_skips_only_committed_ancestors() {
1252        apply_batch_skips_only_committed_ancestors::<crate::mmr::Family>();
1253    }
1254    #[test]
1255    fn mmr_apply_batch_detects_dropped_ancestor() {
1256        apply_batch_detects_dropped_ancestor::<crate::mmr::Family>();
1257    }
1258    #[test]
1259    fn mmr_apply_batch_overwrite_only_ancestor() {
1260        apply_batch_overwrite_only_ancestor::<crate::mmr::Family>();
1261    }
1262
1263    // --- MMB tests ---
1264
1265    #[test]
1266    fn mmb_empty() {
1267        empty::<crate::mmb::Family>();
1268    }
1269    #[test]
1270    fn mmb_validity() {
1271        validity::<crate::mmb::Family>();
1272    }
1273    #[test]
1274    fn mmb_prune_all_then_append() {
1275        prune_all_then_append::<crate::mmb::Family>();
1276    }
1277    #[test]
1278    fn mmb_range_proof_oob() {
1279        range_proof_out_of_bounds::<crate::mmb::Family>();
1280    }
1281    #[test]
1282    fn mmb_proof_oob() {
1283        proof_out_of_bounds::<crate::mmb::Family>();
1284    }
1285    #[test]
1286    fn mmb_init_pinned_nodes() {
1287        init_pinned_nodes_validation::<crate::mmb::Family>();
1288    }
1289    #[test]
1290    fn mmb_root_stable_under_pruning() {
1291        root_stable_under_pruning::<crate::mmb::Family>();
1292    }
1293    #[test]
1294    fn mmb_batch_update_leaf() {
1295        batch_update_leaf::<crate::mmb::Family>();
1296    }
1297    #[test]
1298    fn mmb_batch_parallel_update_leaf() {
1299        batch_parallel_update_leaf::<crate::mmb::Family>();
1300    }
1301    #[test]
1302    fn mmb_root_changes_with_each_append() {
1303        root_changes_with_each_append::<crate::mmb::Family>();
1304    }
1305    #[test]
1306    fn mmb_single_element_proof_roundtrip() {
1307        single_element_proof_roundtrip::<crate::mmb::Family>();
1308    }
1309    #[test]
1310    fn mmb_range_proof_roundtrip_exhaustive() {
1311        range_proof_roundtrip_exhaustive::<crate::mmb::Family>();
1312    }
1313    #[test]
1314    fn mmb_root_with_repeated_pruning() {
1315        root_with_repeated_pruning::<crate::mmb::Family>();
1316    }
1317    #[test]
1318    fn mmb_append_after_partial_prune() {
1319        append_after_partial_prune::<crate::mmb::Family>();
1320    }
1321    #[test]
1322    fn mmb_update_leaf() {
1323        update_leaf::<crate::mmb::Family>();
1324    }
1325    #[test]
1326    fn mmb_update_leaf_every_position() {
1327        update_leaf_every_position::<crate::mmb::Family>();
1328    }
1329    #[test]
1330    fn mmb_update_leaf_errors() {
1331        update_leaf_errors::<crate::mmb::Family>();
1332    }
1333    #[test]
1334    fn mmb_update_leaf_with_append() {
1335        update_leaf_with_append::<crate::mmb::Family>();
1336    }
1337    #[test]
1338    fn mmb_update_leaf_under_merge_parent() {
1339        update_leaf_under_merge_parent::<crate::mmb::Family>();
1340    }
1341    #[test]
1342    fn mmb_update_leaf_after_prune() {
1343        update_leaf_after_prune::<crate::mmb::Family>();
1344    }
1345    #[test]
1346    fn mmb_apply_batch_skips_only_committed_ancestors() {
1347        apply_batch_skips_only_committed_ancestors::<crate::mmb::Family>();
1348    }
1349    #[test]
1350    fn mmb_apply_batch_detects_dropped_ancestor() {
1351        apply_batch_detects_dropped_ancestor::<crate::mmb::Family>();
1352    }
1353    #[test]
1354    fn mmb_apply_batch_overwrite_only_ancestor() {
1355        apply_batch_overwrite_only_ancestor::<crate::mmb::Family>();
1356    }
1357}