Skip to main content

commonware_storage/merkle/
batch.rs

1//! A lightweight batch layer over a merkleized structure.
2//!
3//! # Overview
4//!
5//! [`UnmerkleizedBatch`] accumulates mutations (appends and overwrites) against a parent
6//! [`MerkleizedBatch`]. Calling [`UnmerkleizedBatch::merkleize`] computes the root and
7//! produces a new [`MerkleizedBatch`]. Batches can be stacked to arbitrary depth
8//! via `Arc`-backed parent pointers, so multiple forks can coexist on the same parent.
9//!
10//! # Lifecycle
11//!
12//! ```text
13//!                          Mem
14//!                           |
15//!              MerkleizedBatch::from_mem()      (root batch, no data)
16//!                           |
17//!                      new_batch()
18//!                           |
19//!                           v
20//!                    UnmerkleizedBatch          (accumulate mutations)
21//!                           |
22//!                  merkleize(&mem, hasher)
23//!                           |
24//!                           v
25//!                 Arc<MerkleizedBatch>           (immutable, has root)
26//!                           |
27//!                  mem.apply_batch(&batch)
28//!                           |
29//!                           v
30//!                          Mem                   (committed)
31//! ```
32//!
33//! # Parent chain and memory
34//!
35//! Each [`MerkleizedBatch`] stores its own local data (appended nodes and overwrites)
36//! plus `Arc` refs to each ancestor's data, collected during
37//! [`UnmerkleizedBatch::merkleize`]. These ancestor batches' data are used by
38//! [`Mem::apply_batch`] to replay uncommitted ancestors without requiring the
39//! ancestor batches to still be alive.
40//!
41//! A `Weak` pointer to the parent is kept for [`MerkleizedBatch::get_node`] lookups
42//! (used during a child's merkleize) and for walking the chain to collect ancestor
43//! batch data. Committed-and-dropped ancestors truncate the `Weak` walk, but their
44//! data is already captured in `ancestor_appended` / `ancestor_overwrites`.
45//!
46//! During [`UnmerkleizedBatch::merkleize`], the parent is held as a strong `Arc`
47//! (keeping it alive for the walk), and the `Weak` chain is walked to collect
48//! ancestor data. After merkleize, the parent is downgraded to `Weak`.
49//!
50//! In a pipelining pattern (build next batch from prev, apply prev, repeat), each batch
51//! holds at most one ancestor batch (its immediate parent's data, as an `Arc` ref).
52//! When that batch is applied and dropped, the ancestor data is freed. Memory per
53//! batch is O(batch size), never growing with chain depth.
54//!
55//! [`MerkleizedBatch::get_node`] resolves positions stored in the batch chain only.
56//! For positions in the committed structure, callers fall through to [`Mem::get_node`]
57//! (or an adapter that layers a batch over a `Mem`).
58//!
59//! # Example (MMR)
60//!
61//! ```ignore
62//! let hasher = StandardHasher::<Sha256>::new();
63//! let mut mmr = Mmr::new(&hasher);
64//!
65//! // Fork two independent speculative chains from the same base.
66//! let a1 = mmr.new_batch()
67//!     .add(&hasher, b"a1")
68//!     .merkleize(&mmr, &hasher);
69//! let b1 = mmr.new_batch()
70//!     .add(&hasher, b"b1")
71//!     .merkleize(&mmr, &hasher);
72//!
73//! // Commit A1.
74//! mmr.apply_batch(&a1).unwrap();
75//! ```
76
77use crate::merkle::{
78    hasher::Hasher, mem::Mem, path, proof::Proof, Error, Family, Location, Position, Readable,
79};
80use alloc::{
81    collections::{BTreeMap, BTreeSet},
82    sync::{Arc, Weak},
83    vec::Vec,
84};
85use commonware_cryptography::Digest;
86use core::ops::Range;
87cfg_if::cfg_if! {
88    if #[cfg(feature = "std")] {
89        use commonware_parallel::ThreadPool;
90        use rayon::prelude::*;
91    }
92}
93
94/// Minimum number of digest computations required to trigger parallelization.
95#[cfg(feature = "std")]
96pub(crate) const MIN_TO_PARALLELIZE: usize = 20;
97
98// ---------------------------------------------------------------------------
99// UnmerkleizedBatch
100// ---------------------------------------------------------------------------
101
102/// A speculative batch whose root digest has not yet been computed,
103/// in contrast to [`MerkleizedBatch`].
104pub struct UnmerkleizedBatch<F: Family, D: Digest> {
105    parent: Arc<MerkleizedBatch<F, D>>,
106    appended: Vec<D>,
107    overwrites: BTreeMap<Position<F>, D>,
108    dirty_nodes: BTreeSet<(u32, Position<F>)>,
109    #[cfg(feature = "std")]
110    pool: Option<ThreadPool>,
111}
112
113impl<F: Family, D: Digest> UnmerkleizedBatch<F, D> {
114    /// Create a new batch from `parent`.
115    pub const fn new(parent: Arc<MerkleizedBatch<F, D>>) -> Self {
116        Self {
117            parent,
118            appended: Vec::new(),
119            overwrites: BTreeMap::new(),
120            dirty_nodes: BTreeSet::new(),
121            #[cfg(feature = "std")]
122            pool: None,
123        }
124    }
125
126    /// Set a thread pool for parallel merkleization.
127    #[cfg(feature = "std")]
128    pub fn with_pool(mut self, pool: Option<ThreadPool>) -> Self {
129        self.pool = pool;
130        self
131    }
132
133    /// Return a reference to the thread pool, if any.
134    #[cfg(feature = "std")]
135    pub const fn pool(&self) -> Option<&ThreadPool> {
136        self.pool.as_ref()
137    }
138
139    /// The total number of nodes visible through this batch.
140    pub(crate) fn size(&self) -> Position<F> {
141        Position::new(*self.parent.size() + self.appended.len() as u64)
142    }
143
144    /// The number of leaves visible through this batch.
145    pub fn leaves(&self) -> Location<F> {
146        Location::try_from(self.size()).expect("invalid size")
147    }
148
149    /// Resolve a node: own data -> parent chain -> `base` fallback.
150    fn get_node(&self, base: &Mem<F, D>, pos: Position<F>) -> Option<D> {
151        if pos >= self.size() {
152            return None;
153        }
154        if let Some(d) = self.overwrites.get(&pos) {
155            return Some(*d);
156        }
157        let parent_size = self.parent.size();
158        if pos >= parent_size {
159            let index = (*pos - *parent_size) as usize;
160            return self.appended.get(index).copied();
161        }
162        if let Some(d) = self.parent.get_node(pos) {
163            return Some(d);
164        }
165        base.get_node(pos)
166    }
167
168    /// Store a digest at the given position.
169    fn store_node(&mut self, pos: Position<F>, digest: D) {
170        let parent_size = self.parent.size();
171        if pos >= parent_size {
172            let index = (*pos - *parent_size) as usize;
173            self.appended[index] = digest;
174        } else {
175            self.overwrites.insert(pos, digest);
176        }
177    }
178
179    /// Mark ancestors of the leaf at `loc` as dirty up to its peak.
180    ///
181    /// Walks from peak to leaf (top-down) using [`path::Iterator`], then inserts dirty markers
182    /// bottom-up so that an early exit is possible when hitting a node that was already
183    /// dirtied by a prior `update_leaf`.
184    fn mark_dirty(&mut self, loc: Location<F>) {
185        let mut first_leaf = Location::new(0);
186        for (peak_pos, height) in F::peaks(self.size()) {
187            let leaves_in_peak = 1u64 << height;
188            if loc >= first_leaf + leaves_in_peak {
189                first_leaf += leaves_in_peak;
190                continue;
191            }
192
193            let mut buf = [(Position::new(0), Position::new(0), 0u32); path::MAX_PATH_LEN];
194            let mut len = 0;
195            for item in path::Iterator::new(peak_pos, height, first_leaf, loc) {
196                buf[len] = item;
197                len += 1;
198            }
199            for &(parent_pos, _, h) in buf[..len].iter().rev() {
200                if !self.dirty_nodes.insert((h, parent_pos)) {
201                    break;
202                }
203            }
204            return;
205        }
206
207        panic!("leaf {loc} not found (size: {})", self.size());
208    }
209
210    /// Add a pre-computed leaf digest.
211    pub fn add_leaf_digest(mut self, digest: D) -> Self {
212        let heights = F::parent_heights(self.leaves());
213        self.appended.push(digest);
214
215        for height in heights {
216            let pos = self.size();
217            self.appended.push(D::EMPTY);
218            self.dirty_nodes.insert((height, pos));
219        }
220
221        self
222    }
223
224    /// Hash `element` and add it as a leaf.
225    pub fn add(self, hasher: &impl Hasher<F, Digest = D>, element: &[u8]) -> Self {
226        let digest = hasher.leaf_digest(self.size(), element);
227        self.add_leaf_digest(digest)
228    }
229
230    /// Update the leaf at `loc` to `element`.
231    ///
232    /// # Errors
233    ///
234    /// Returns [`Error::LeafOutOfBounds`] if `loc` is not an existing leaf.
235    /// Returns [`Error::ElementPruned`] if the leaf has been pruned.
236    pub fn update_leaf(
237        mut self,
238        hasher: &impl Hasher<F, Digest = D>,
239        loc: Location<F>,
240        element: &[u8],
241    ) -> Result<Self, Error<F>> {
242        let leaves = self.leaves();
243        if loc >= leaves {
244            return Err(Error::LeafOutOfBounds(loc));
245        }
246        if loc < self.parent.pruning_boundary() {
247            return Err(Error::ElementPruned(Position::try_from(loc)?));
248        }
249        let pos = Position::try_from(loc)?;
250        let digest = hasher.leaf_digest(pos, element);
251        self.store_node(pos, digest);
252        self.mark_dirty(loc);
253        Ok(self)
254    }
255
256    /// Overwrite the digest of an existing leaf and mark ancestors dirty.
257    #[cfg(any(feature = "std", test))]
258    pub fn update_leaf_digest(mut self, loc: Location<F>, digest: D) -> Result<Self, Error<F>> {
259        let leaves = self.leaves();
260        if loc >= leaves {
261            return Err(Error::LeafOutOfBounds(loc));
262        }
263        if loc < self.parent.pruning_boundary() {
264            return Err(Error::ElementPruned(Position::try_from(loc)?));
265        }
266        let pos = Position::try_from(loc)?;
267        if F::position_to_location(pos).is_none() {
268            return Err(Error::NonLeaf(pos));
269        }
270        self.store_node(pos, digest);
271        self.mark_dirty(loc);
272        Ok(self)
273    }
274
275    /// Batch update multiple leaf digests.
276    #[cfg(any(feature = "std", test))]
277    pub fn update_leaf_batched(mut self, updates: &[(Location<F>, D)]) -> Result<Self, Error<F>> {
278        let leaves = self.leaves();
279        let prune_boundary = self.parent.pruning_boundary();
280        for (loc, _) in updates {
281            if *loc >= leaves {
282                return Err(Error::LeafOutOfBounds(*loc));
283            }
284            if *loc < prune_boundary {
285                return Err(Error::ElementPruned(Position::try_from(*loc)?));
286            }
287        }
288        for (loc, digest) in updates {
289            let pos = Position::try_from(*loc).unwrap();
290            self.store_node(pos, *digest);
291            self.mark_dirty(*loc);
292        }
293        Ok(self)
294    }
295
296    /// Consume this batch and produce an immutable [`MerkleizedBatch`] with computed root.
297    /// `base` provides committed node data as fallback during hash computation.
298    pub fn merkleize(
299        mut self,
300        base: &Mem<F, D>,
301        hasher: &impl Hasher<F, Digest = D>,
302    ) -> Arc<MerkleizedBatch<F, D>> {
303        let dirty: Vec<_> = core::mem::take(&mut self.dirty_nodes).into_iter().collect();
304
305        #[cfg(feature = "std")]
306        if let Some(pool) = self.pool.take() {
307            if dirty.len() >= MIN_TO_PARALLELIZE {
308                self.merkleize_parallel(base, hasher, &pool, &dirty);
309            } else {
310                self.merkleize_serial(base, hasher, &dirty);
311            }
312            self.pool = Some(pool);
313        } else {
314            self.merkleize_serial(base, hasher, &dirty);
315        }
316
317        #[cfg(not(feature = "std"))]
318        self.merkleize_serial(base, hasher, &dirty);
319
320        // Compute root from peaks.
321        let leaves = self.leaves();
322        let peaks: Vec<D> = F::peaks(self.size())
323            .map(|(peak_pos, _)| self.get_node(base, peak_pos).expect("peak missing"))
324            .collect();
325        let root = hasher.root(leaves, peaks.iter());
326
327        // Collect ancestor data by walking the parent chain (strong Arc + Weak walk).
328        let (ancestor_appended, ancestor_overwrites) = collect_ancestor_batches(&self.parent);
329
330        let parent_size = self.parent.size();
331        Arc::new(MerkleizedBatch {
332            parent: Some(Arc::downgrade(&self.parent)),
333            appended: Arc::new(self.appended),
334            overwrites: Arc::new(self.overwrites),
335            root,
336            parent_size,
337            base_size: self.parent.base_size,
338            pruning_boundary: self.parent.pruning_boundary(),
339            ancestor_appended,
340            ancestor_overwrites,
341            #[cfg(feature = "std")]
342            pool: self.pool,
343        })
344    }
345
346    /// Compute digests for dirty internal nodes, bottom-up by height.
347    fn merkleize_serial(
348        &mut self,
349        base: &Mem<F, D>,
350        hasher: &impl Hasher<F, Digest = D>,
351        dirty: &[(u32, Position<F>)],
352    ) {
353        for &(height, pos) in dirty {
354            let (left, right) = F::children(pos, height);
355            let left_d = self.get_node(base, left).expect("left child missing");
356            let right_d = self.get_node(base, right).expect("right child missing");
357            let digest = hasher.node_digest(pos, &left_d, &right_d);
358            self.store_node(pos, digest);
359        }
360    }
361
362    /// Process dirty nodes in parallel, grouping by height. Falls back to serial
363    /// when the remaining count drops below the threshold.
364    #[cfg(feature = "std")]
365    fn merkleize_parallel(
366        &mut self,
367        base: &Mem<F, D>,
368        hasher: &impl Hasher<F, Digest = D>,
369        pool: &ThreadPool,
370        dirty: &[(u32, Position<F>)],
371    ) {
372        let mut same_height = Vec::new();
373        let mut current_height = dirty.first().map_or(1, |&(h, _)| h);
374        for (i, &(height, pos)) in dirty.iter().enumerate() {
375            if height == current_height {
376                same_height.push(pos);
377                continue;
378            }
379            if same_height.len() < MIN_TO_PARALLELIZE {
380                self.merkleize_serial(base, hasher, &dirty[i - same_height.len()..]);
381                return;
382            }
383            self.compute_height_parallel(base, hasher, pool, &same_height, current_height);
384            same_height.clear();
385            current_height = height;
386            same_height.push(pos);
387        }
388
389        if same_height.len() < MIN_TO_PARALLELIZE {
390            self.merkleize_serial(base, hasher, &dirty[dirty.len() - same_height.len()..]);
391            return;
392        }
393
394        self.compute_height_parallel(base, hasher, pool, &same_height, current_height);
395    }
396
397    /// Compute digests for nodes at the same height in parallel, then store sequentially.
398    #[cfg(feature = "std")]
399    fn compute_height_parallel(
400        &mut self,
401        base: &Mem<F, D>,
402        hasher: &impl Hasher<F, Digest = D>,
403        pool: &ThreadPool,
404        same_height: &[Position<F>],
405        height: u32,
406    ) {
407        let computed: Vec<(Position<F>, D)> = pool.install(|| {
408            same_height
409                .par_iter()
410                .map_init(
411                    || hasher.clone(),
412                    |hasher, &pos| {
413                        let (left, right) = F::children(pos, height);
414                        let left_d = self.get_node(base, left).expect("left child missing");
415                        let right_d = self.get_node(base, right).expect("right child missing");
416                        let digest = hasher.node_digest(pos, &left_d, &right_d);
417                        (pos, digest)
418                    },
419                )
420                .collect()
421        });
422        for (pos, digest) in computed {
423            self.store_node(pos, digest);
424        }
425    }
426}
427
428/// Collect ancestor batch data by walking the parent + its Weak chain.
429/// Returns (appended, overwrites) in root-to-tip order. Skips empty batches
430/// (e.g. root batches from `from_mem`).
431#[allow(clippy::type_complexity)]
432fn collect_ancestor_batches<F: Family, D: Digest>(
433    parent: &Arc<MerkleizedBatch<F, D>>,
434) -> (Vec<Arc<Vec<D>>>, Vec<Arc<BTreeMap<Position<F>, D>>>) {
435    let mut appended = Vec::new();
436    let mut overwrites = Vec::new();
437
438    // Parent is alive (strong Arc held by UnmerkleizedBatch).
439    if !parent.appended.is_empty() || !parent.overwrites.is_empty() {
440        appended.push(Arc::clone(&parent.appended));
441        overwrites.push(Arc::clone(&parent.overwrites));
442    }
443
444    // Walk Weak chain for grandparents+.
445    let mut current = parent.parent.as_ref().and_then(Weak::upgrade);
446    while let Some(batch) = current {
447        if !batch.appended.is_empty() || !batch.overwrites.is_empty() {
448            appended.push(Arc::clone(&batch.appended));
449            overwrites.push(Arc::clone(&batch.overwrites));
450        }
451        current = batch.parent.as_ref().and_then(Weak::upgrade);
452    }
453
454    appended.reverse();
455    overwrites.reverse();
456    (appended, overwrites)
457}
458
459// ---------------------------------------------------------------------------
460// MerkleizedBatch
461// ---------------------------------------------------------------------------
462
463/// A speculative batch whose root digest has been computed,
464/// in contrast to [`UnmerkleizedBatch`].
465#[derive(Debug)]
466pub struct MerkleizedBatch<F: Family, D: Digest> {
467    /// The parent batch in the chain, if any.
468    parent: Option<Weak<Self>>,
469
470    /// This batch's appended nodes only (not accumulated from ancestors).
471    pub(crate) appended: Arc<Vec<D>>,
472
473    /// This batch's overwrites only (not accumulated from ancestors).
474    pub(crate) overwrites: Arc<BTreeMap<Position<F>, D>>,
475
476    /// Root digest after this batch's mutations.
477    root: D,
478
479    /// Number of nodes in the parent batch.
480    pub(crate) parent_size: Position<F>,
481
482    /// Number of committed nodes when the batch chain was forked. Inherited unchanged
483    /// by all descendants. Used by `apply_batch` to detect already-committed ancestors.
484    pub(crate) base_size: Position<F>,
485
486    /// Pruning boundary of the [`Mem`] when the batch chain was forked. Inherited
487    /// unchanged by all descendants, like `base_size`.
488    pruning_boundary: Location<F>,
489
490    /// Arc refs to each ancestor's appended nodes, collected during merkleize while
491    /// ancestors are alive. Root-to-tip order.
492    pub(crate) ancestor_appended: Vec<Arc<Vec<D>>>,
493
494    /// Arc refs to each ancestor's overwrites, collected during merkleize while
495    /// ancestors are alive. Root-to-tip order.
496    pub(crate) ancestor_overwrites: Vec<Arc<BTreeMap<Position<F>, D>>>,
497
498    #[cfg(feature = "std")]
499    pub(crate) pool: Option<ThreadPool>,
500}
501
502impl<F: Family, D: Digest> MerkleizedBatch<F, D> {
503    /// Create a root batch representing the committed state of `mem`.
504    pub fn from_mem(mem: &Mem<F, D>) -> Arc<Self> {
505        Arc::new(Self {
506            parent: None,
507            appended: Arc::new(Vec::new()),
508            overwrites: Arc::new(BTreeMap::new()),
509            root: *mem.root(),
510            parent_size: mem.size(),
511            base_size: mem.size(),
512            pruning_boundary: Readable::pruning_boundary(mem),
513            ancestor_appended: Vec::new(),
514            ancestor_overwrites: Vec::new(),
515            #[cfg(feature = "std")]
516            pool: None,
517        })
518    }
519
520    /// The total number of nodes visible through this batch.
521    pub fn size(&self) -> Position<F> {
522        Position::new(*self.parent_size + self.appended.len() as u64)
523    }
524
525    /// Resolve a node: own data -> Weak parent chain.
526    ///
527    /// Returns `None` for positions that only exist in the committed [`Mem`].
528    /// Callers that need committed data should fall back to [`Mem::get_node`]
529    /// (or use a layered adapter such as the one in `qmdb::current::batch`).
530    pub fn get_node(&self, pos: Position<F>) -> Option<D> {
531        if pos >= self.size() {
532            return None;
533        }
534        if let Some(d) = self.overwrites.get(&pos) {
535            return Some(*d);
536        }
537        if pos >= self.parent_size {
538            let i = (*pos - *self.parent_size) as usize;
539            return self.appended.get(i).copied();
540        }
541        // Walk Weak parent chain.
542        let mut current = self.parent.as_ref().and_then(Weak::upgrade);
543        while let Some(batch) = current {
544            if let Some(d) = batch.overwrites.get(&pos) {
545                return Some(*d);
546            }
547            if pos >= batch.parent_size {
548                let i = (*pos - *batch.parent_size) as usize;
549                return batch.appended.get(i).copied();
550            }
551            current = batch.parent.as_ref().and_then(Weak::upgrade);
552        }
553        None
554    }
555
556    /// Return the root digest after this batch is applied.
557    pub const fn root(&self) -> D {
558        self.root
559    }
560
561    /// Items before this location have been pruned.
562    pub const fn pruning_boundary(&self) -> Location<F> {
563        self.pruning_boundary
564    }
565
566    /// The number of leaves visible through this batch.
567    pub fn leaves(&self) -> Location<F> {
568        Location::try_from(self.size()).expect("invalid size")
569    }
570
571    /// Create a child batch on top of this merkleized batch.
572    ///
573    /// All uncommitted ancestors in the chain must be kept alive until the child (or any
574    /// descendant) is merkleized. Dropping an uncommitted ancestor causes data
575    /// loss detected at `apply_batch` time.
576    pub fn new_batch(self: &Arc<Self>) -> UnmerkleizedBatch<F, D> {
577        let batch = UnmerkleizedBatch::new(Arc::clone(self));
578        #[cfg(feature = "std")]
579        let batch = batch.with_pool(self.pool.clone());
580        batch
581    }
582
583    /// Number of nodes in the committed Mem when the batch chain was forked.
584    pub const fn base_size(&self) -> Position<F> {
585        self.base_size
586    }
587}
588
589impl<F: Family, D: Digest> Readable for MerkleizedBatch<F, D> {
590    type Family = F;
591    type Digest = D;
592    type Error = Error<F>;
593
594    fn size(&self) -> Position<F> {
595        Self::size(self)
596    }
597
598    fn get_node(&self, pos: Position<F>) -> Option<D> {
599        Self::get_node(self, pos)
600    }
601
602    fn root(&self) -> D {
603        Self::root(self)
604    }
605
606    fn pruning_boundary(&self) -> Location<F> {
607        Self::pruning_boundary(self)
608    }
609
610    fn proof(
611        &self,
612        hasher: &impl Hasher<F, Digest = D>,
613        loc: Location<F>,
614    ) -> Result<Proof<F, D>, Error<F>> {
615        if !loc.is_valid_index() {
616            return Err(Error::LocationOverflow(loc));
617        }
618        self.range_proof(hasher, loc..loc + 1).map_err(|e| match e {
619            Error::RangeOutOfBounds(_) => Error::LeafOutOfBounds(loc),
620            _ => e,
621        })
622    }
623
624    fn range_proof(
625        &self,
626        hasher: &impl Hasher<F, Digest = D>,
627        range: Range<Location<F>>,
628    ) -> Result<Proof<F, D>, Error<F>> {
629        crate::merkle::proof::build_range_proof(
630            hasher,
631            self.leaves(),
632            range,
633            |pos| Self::get_node(self, pos),
634            Error::ElementPruned,
635        )
636    }
637}
638
639// ---------------------------------------------------------------------------
640// Tests
641// ---------------------------------------------------------------------------
642
643#[cfg(test)]
644mod tests {
645    use super::*;
646    use crate::merkle::{hasher::Standard, mem::Mem};
647    use commonware_cryptography::{sha256, Sha256};
648    use commonware_runtime::{deterministic, Runner as _};
649
650    type D = sha256::Digest;
651    type H = Standard<Sha256>;
652
653    fn build_reference<F: Family>(hasher: &H, n: u64) -> Mem<F, D> {
654        let mut mem = Mem::new(hasher);
655        let batch = {
656            let mut batch = mem.new_batch();
657            for i in 0u64..n {
658                let element = hasher.digest(&i.to_be_bytes());
659                batch = batch.add(hasher, &element);
660            }
661            batch.merkleize(&mem, hasher)
662        };
663        mem.apply_batch(&batch).unwrap();
664        mem
665    }
666
667    fn consistency_with_reference<F: Family>() {
668        let executor = deterministic::Runner::default();
669        executor.start(|_| async move {
670            let hasher: H = Standard::new();
671            for &n in &[1u64, 2, 10, 100, 199] {
672                let reference = build_reference::<F>(&hasher, n);
673                let base = Mem::<F, D>::new(&hasher);
674                let mut batch = base.new_batch();
675                for i in 0..n {
676                    let element = hasher.digest(&i.to_be_bytes());
677                    batch = batch.add(&hasher, &element);
678                }
679                let merkleized = batch.merkleize(&base, &hasher);
680                let mut result = Mem::<F, D>::new(&hasher);
681                result.apply_batch(&merkleized).unwrap();
682                assert_eq!(result.root(), reference.root(), "root mismatch for n={n}");
683            }
684        });
685    }
686
687    fn lifecycle<F: Family>() {
688        let executor = deterministic::Runner::default();
689        executor.start(|_| async move {
690            let hasher: H = Standard::new();
691            let base = build_reference::<F>(&hasher, 50);
692            let base_root = *base.root();
693            let mut batch = base.new_batch();
694            for i in 50u64..60 {
695                let element = hasher.digest(&i.to_be_bytes());
696                batch = batch.add(&hasher, &element);
697            }
698            let merkleized = batch.merkleize(&base, &hasher);
699            assert_ne!(merkleized.root(), base_root);
700            assert_eq!(*base.root(), base_root);
701            // Apply and verify proof from the resulting Mem.
702            let mut applied = base;
703            applied.apply_batch(&merkleized).unwrap();
704            let loc = Location::<F>::new(55);
705            let element = hasher.digest(&55u64.to_be_bytes());
706            let proof = applied.proof(&hasher, loc).unwrap();
707            assert!(proof.verify_element_inclusion(&hasher, &element, loc, &merkleized.root()));
708        });
709    }
710
711    fn apply_batch<F: Family>() {
712        let executor = deterministic::Runner::default();
713        executor.start(|_| async move {
714            let hasher: H = Standard::new();
715            let mut base = build_reference::<F>(&hasher, 50);
716            let mut batch = base.new_batch();
717            for i in 50u64..75 {
718                let element = hasher.digest(&i.to_be_bytes());
719                batch = batch.add(&hasher, &element);
720            }
721            let merkleized = batch.merkleize(&base, &hasher);
722            let batch_root = merkleized.root();
723            base.apply_batch(&merkleized).unwrap();
724            assert_eq!(*base.root(), batch_root);
725            let reference = build_reference::<F>(&hasher, 75);
726            assert_eq!(base.root(), reference.root());
727        });
728    }
729
730    fn multiple_forks<F: Family>() {
731        let executor = deterministic::Runner::default();
732        executor.start(|_| async move {
733            let hasher: H = Standard::new();
734            let base = build_reference::<F>(&hasher, 50);
735            let base_root = *base.root();
736            let mut ba = base.new_batch();
737            for i in 50u64..60 {
738                let element = hasher.digest(&i.to_be_bytes());
739                ba = ba.add(&hasher, &element);
740            }
741            let ma = ba.merkleize(&base, &hasher);
742            let mut bb = base.new_batch();
743            for i in 100u64..105 {
744                let element = hasher.digest(&i.to_be_bytes());
745                bb = bb.add(&hasher, &element);
746            }
747            let mb = bb.merkleize(&base, &hasher);
748            assert_ne!(ma.root(), mb.root());
749            assert_ne!(ma.root(), base_root);
750            assert_eq!(*base.root(), base_root);
751        });
752    }
753
754    fn fork_of_fork_reads<F: Family>() {
755        let executor = deterministic::Runner::default();
756        executor.start(|_| async move {
757            let hasher: H = Standard::new();
758            let base = build_reference::<F>(&hasher, 50);
759            let mut ba = base.new_batch();
760            for i in 50u64..60 {
761                let element = hasher.digest(&i.to_be_bytes());
762                ba = ba.add(&hasher, &element);
763            }
764            let ma = ba.merkleize(&base, &hasher);
765            let mut bb = ma.new_batch();
766            for i in 60u64..70 {
767                let element = hasher.digest(&i.to_be_bytes());
768                bb = bb.add(&hasher, &element);
769            }
770            let mb = bb.merkleize(&base, &hasher);
771            let reference = build_reference::<F>(&hasher, 70);
772            assert_eq!(mb.root(), *reference.root());
773            // Apply both batches and verify proofs from the resulting Mem.
774            let mut applied = base;
775            applied.apply_batch(&ma).unwrap();
776            applied.apply_batch(&mb).unwrap();
777            for i in [0u64, 25, 55, 65, 69] {
778                let loc = Location::<F>::new(i);
779                let element = hasher.digest(&i.to_be_bytes());
780                let proof = applied.proof(&hasher, loc).unwrap();
781                assert!(proof.verify_element_inclusion(&hasher, &element, loc, &mb.root()));
782            }
783        });
784    }
785
786    fn update_leaf_digest_roundtrip<F: Family>() {
787        let executor = deterministic::Runner::default();
788        executor.start(|_| async move {
789            let hasher: H = Standard::new();
790            let base = build_reference::<F>(&hasher, 100);
791            let base_root = *base.root();
792            let updated = Sha256::fill(0xFF);
793            let m = base
794                .new_batch()
795                .update_leaf_digest(Location::new(5), updated)
796                .unwrap()
797                .merkleize(&base, &hasher);
798            assert_ne!(m.root(), base_root);
799            let pos5 = Position::<F>::try_from(Location::new(5)).unwrap();
800            let original = base.get_node(pos5).unwrap();
801            let m2 = base
802                .new_batch()
803                .update_leaf_digest(Location::new(5), original)
804                .unwrap()
805                .merkleize(&base, &hasher);
806            assert_eq!(m2.root(), base_root);
807        });
808    }
809
810    fn update_and_add<F: Family>() {
811        let executor = deterministic::Runner::default();
812        executor.start(|_| async move {
813            let hasher: H = Standard::new();
814            let base = build_reference::<F>(&hasher, 50);
815            let base_root = *base.root();
816            let updated = Sha256::fill(0xAA);
817            let mut batch = base
818                .new_batch()
819                .update_leaf_digest(Location::new(10), updated)
820                .unwrap();
821            for i in 50u64..55 {
822                let element = hasher.digest(&i.to_be_bytes());
823                batch = batch.add(&hasher, &element);
824            }
825            let m = batch.merkleize(&base, &hasher);
826            assert_ne!(m.root(), base_root);
827            let pos10 = Position::<F>::try_from(Location::new(10)).unwrap();
828            assert_eq!(m.get_node(pos10), Some(updated));
829        });
830    }
831
832    fn update_leaf_batched_roundtrip<F: Family>() {
833        let executor = deterministic::Runner::default();
834        executor.start(|_| async move {
835            let hasher: H = Standard::new();
836            let base = build_reference::<F>(&hasher, 100);
837            let base_root = *base.root();
838            let updated = Sha256::fill(0xBB);
839            let locs = [0u64, 10, 50, 99];
840            let updates: Vec<(Location<F>, D)> =
841                locs.iter().map(|&i| (Location::new(i), updated)).collect();
842            let m = base
843                .new_batch()
844                .update_leaf_batched(&updates)
845                .unwrap()
846                .merkleize(&base, &hasher);
847            assert_ne!(m.root(), base_root);
848            let restore: Vec<(Location<F>, D)> = locs
849                .iter()
850                .map(|&l| {
851                    let pos = Position::<F>::try_from(Location::new(l)).unwrap();
852                    (Location::new(l), base.get_node(pos).unwrap())
853                })
854                .collect();
855            let m2 = base
856                .new_batch()
857                .update_leaf_batched(&restore)
858                .unwrap()
859                .merkleize(&base, &hasher);
860            assert_eq!(m2.root(), base_root);
861        });
862    }
863
864    fn proof_verification<F: Family>() {
865        let executor = deterministic::Runner::default();
866        executor.start(|_| async move {
867            let hasher: H = Standard::new();
868            let base = build_reference::<F>(&hasher, 50);
869            let mut batch = base.new_batch();
870            for i in 50u64..60 {
871                let element = hasher.digest(&i.to_be_bytes());
872                batch = batch.add(&hasher, &element);
873            }
874            let m = batch.merkleize(&base, &hasher);
875            // Apply and verify proofs from the resulting Mem.
876            let mut applied = base;
877            applied.apply_batch(&m).unwrap();
878            let loc = Location::<F>::new(55);
879            let element = hasher.digest(&55u64.to_be_bytes());
880            let proof = applied.proof(&hasher, loc).unwrap();
881            assert!(proof.verify_element_inclusion(&hasher, &element, loc, &m.root()));
882            let range = Location::<F>::new(50)..Location::new(55);
883            let rp = applied.range_proof(&hasher, range.clone()).unwrap();
884            let elements: Vec<D> = (50u64..55)
885                .map(|i| hasher.digest(&i.to_be_bytes()))
886                .collect();
887            assert!(rp.verify_range_inclusion(&hasher, &elements, range.start, &m.root()));
888        });
889    }
890
891    fn empty_batch<F: Family>() {
892        let executor = deterministic::Runner::default();
893        executor.start(|_| async move {
894            let hasher: H = Standard::new();
895            let base = build_reference::<F>(&hasher, 50);
896            let base_root = *base.root();
897            let m = base.new_batch().merkleize(&base, &hasher);
898            assert_eq!(m.root(), base_root);
899        });
900    }
901
902    fn batch_roundtrip<F: Family>() {
903        let executor = deterministic::Runner::default();
904        executor.start(|_| async move {
905            let hasher: H = Standard::new();
906            let base = build_reference::<F>(&hasher, 50);
907            let mut batch = base.new_batch();
908            for i in 50u64..55 {
909                let element = hasher.digest(&i.to_be_bytes());
910                batch = batch.add(&hasher, &element);
911            }
912            let merkleized = batch.merkleize(&base, &hasher);
913            let mut batch_again = merkleized.new_batch();
914            for i in 55u64..60 {
915                let element = hasher.digest(&i.to_be_bytes());
916                batch_again = batch_again.add(&hasher, &element);
917            }
918            let reference = build_reference::<F>(&hasher, 60);
919            assert_eq!(
920                batch_again.merkleize(&base, &hasher).root(),
921                *reference.root()
922            );
923        });
924    }
925
926    fn sequential_apply_batch<F: Family>() {
927        let executor = deterministic::Runner::default();
928        executor.start(|_| async move {
929            let hasher: H = Standard::new();
930            let mut base = build_reference::<F>(&hasher, 50);
931            let mut b1 = base.new_batch();
932            for i in 50u64..60 {
933                let element = hasher.digest(&i.to_be_bytes());
934                b1 = b1.add(&hasher, &element);
935            }
936            let m1 = b1.merkleize(&base, &hasher);
937            base.apply_batch(&m1).unwrap();
938            let mut b2 = base.new_batch();
939            for i in 60u64..70 {
940                let element = hasher.digest(&i.to_be_bytes());
941                b2 = b2.add(&hasher, &element);
942            }
943            let m2 = b2.merkleize(&base, &hasher);
944            base.apply_batch(&m2).unwrap();
945            let reference = build_reference::<F>(&hasher, 70);
946            assert_eq!(base.root(), reference.root());
947        });
948    }
949
950    fn batch_on_pruned_base<F: Family>() {
951        let executor = deterministic::Runner::default();
952        executor.start(|_| async move {
953            let hasher: H = Standard::new();
954            let mut base = build_reference::<F>(&hasher, 100);
955            base.prune(Location::new(27)).unwrap();
956            let mut batch = base.new_batch();
957            for i in 100u64..110 {
958                let element = hasher.digest(&i.to_be_bytes());
959                batch = batch.add(&hasher, &element);
960            }
961            let m = batch.merkleize(&base, &hasher);
962            // Apply and verify proofs from the resulting Mem.
963            let mut applied = base;
964            applied.apply_batch(&m).unwrap();
965            let loc = Location::<F>::new(80);
966            let element = hasher.digest(&80u64.to_be_bytes());
967            let proof = applied.proof(&hasher, loc).unwrap();
968            assert!(proof.verify_element_inclusion(&hasher, &element, loc, &m.root()));
969            assert!(matches!(
970                applied.proof(&hasher, Location::new(0)),
971                Err(Error::ElementPruned(_))
972            ));
973        });
974    }
975
976    fn three_deep_stacking<F: Family>() {
977        let executor = deterministic::Runner::default();
978        executor.start(|_| async move {
979            let hasher: H = Standard::new();
980            let mut base = build_reference::<F>(&hasher, 100);
981            let da = Sha256::fill(0xDD);
982            let db = Sha256::fill(0xEE);
983            let ma = base
984                .new_batch()
985                .update_leaf_digest(Location::new(5), da)
986                .unwrap()
987                .merkleize(&base, &hasher);
988            let mb = ma
989                .new_batch()
990                .update_leaf_digest(Location::new(10), db)
991                .unwrap()
992                .merkleize(&base, &hasher);
993            let mut bc = mb.new_batch();
994            for i in 300u64..310 {
995                let element = hasher.digest(&i.to_be_bytes());
996                bc = bc.add(&hasher, &element);
997            }
998            let mc = bc.merkleize(&base, &hasher);
999            let c_root = mc.root();
1000            base.apply_batch(&mc).unwrap();
1001            assert_eq!(*base.root(), c_root);
1002        });
1003    }
1004
1005    fn overwrite_collision<F: Family>() {
1006        let executor = deterministic::Runner::default();
1007        executor.start(|_| async move {
1008            let hasher: H = Standard::new();
1009            let mut base = build_reference::<F>(&hasher, 100);
1010            let dx = Sha256::fill(0xAA);
1011            let dy = Sha256::fill(0xBB);
1012            let ma = base
1013                .new_batch()
1014                .update_leaf_digest(Location::new(5), dx)
1015                .unwrap()
1016                .merkleize(&base, &hasher);
1017            let mb = ma
1018                .new_batch()
1019                .update_leaf_digest(Location::new(5), dy)
1020                .unwrap()
1021                .merkleize(&base, &hasher);
1022            let b_root = mb.root();
1023            base.apply_batch(&mb).unwrap();
1024            assert_eq!(*base.root(), b_root);
1025            let pos5 = Position::<F>::try_from(Location::new(5)).unwrap();
1026            assert_eq!(base.get_node(pos5), Some(dy));
1027        });
1028    }
1029
1030    fn update_appended_leaf<F: Family>() {
1031        let executor = deterministic::Runner::default();
1032        executor.start(|_| async move {
1033            let hasher: H = Standard::new();
1034            let base = build_reference::<F>(&hasher, 50);
1035            let mut batch = base.new_batch();
1036            for i in 50u64..60 {
1037                let element = hasher.digest(&i.to_be_bytes());
1038                batch = batch.add(&hasher, &element);
1039            }
1040            let updated = Sha256::fill(0xEE);
1041            let m = batch
1042                .update_leaf_digest(Location::new(52), updated)
1043                .unwrap()
1044                .merkleize(&base, &hasher);
1045            let pos52 = Position::<F>::try_from(Location::new(52)).unwrap();
1046            assert_eq!(m.get_node(pos52), Some(updated));
1047            let mut reference = build_reference::<F>(&hasher, 60);
1048            let batch = reference
1049                .new_batch()
1050                .update_leaf_digest(Location::new(52), updated)
1051                .unwrap()
1052                .merkleize(&reference, &hasher);
1053            reference.apply_batch(&batch).unwrap();
1054            assert_eq!(m.root(), *reference.root());
1055        });
1056    }
1057
1058    fn update_leaf_element<F: Family>() {
1059        let executor = deterministic::Runner::default();
1060        executor.start(|_| async move {
1061            let hasher: H = Standard::new();
1062            let base = build_reference::<F>(&hasher, 50);
1063            let base_root = *base.root();
1064            let element = b"updated-element";
1065            let m = base
1066                .new_batch()
1067                .update_leaf(&hasher, Location::new(5), element)
1068                .unwrap()
1069                .merkleize(&base, &hasher);
1070            assert_ne!(m.root(), base_root);
1071            let mut base = base;
1072            let batch = base
1073                .new_batch()
1074                .update_leaf(&hasher, Location::new(5), element)
1075                .unwrap()
1076                .merkleize(&base, &hasher);
1077            base.apply_batch(&batch).unwrap();
1078            assert_eq!(m.root(), *base.root());
1079        });
1080    }
1081
1082    fn update_out_of_bounds<F: Family>() {
1083        let executor = deterministic::Runner::default();
1084        executor.start(|_| async move {
1085            let hasher: H = Standard::new();
1086            let base = build_reference::<F>(&hasher, 50);
1087            let r1 = base
1088                .new_batch()
1089                .update_leaf_digest(Location::new(50), Sha256::fill(0xFF));
1090            assert!(matches!(r1, Err(Error::LeafOutOfBounds(_))));
1091            let updates = [(Location::<F>::new(50), Sha256::fill(0xFF))];
1092            let r2 = base.new_batch().update_leaf_batched(&updates);
1093            assert!(matches!(r2, Err(Error::LeafOutOfBounds(_))));
1094        });
1095    }
1096
1097    // --- MMR tests ---
1098
1099    #[test]
1100    fn mmr_consistency() {
1101        consistency_with_reference::<crate::mmr::Family>();
1102    }
1103    #[test]
1104    fn mmr_lifecycle() {
1105        lifecycle::<crate::mmr::Family>();
1106    }
1107    #[test]
1108    fn mmr_apply_batch() {
1109        apply_batch::<crate::mmr::Family>();
1110    }
1111    #[test]
1112    fn mmr_multiple_forks() {
1113        multiple_forks::<crate::mmr::Family>();
1114    }
1115    #[test]
1116    fn mmr_fork_of_fork_reads() {
1117        fork_of_fork_reads::<crate::mmr::Family>();
1118    }
1119    #[test]
1120    fn mmr_update_leaf_digest() {
1121        update_leaf_digest_roundtrip::<crate::mmr::Family>();
1122    }
1123    #[test]
1124    fn mmr_update_and_add() {
1125        update_and_add::<crate::mmr::Family>();
1126    }
1127    #[test]
1128    fn mmr_update_leaf_batched() {
1129        update_leaf_batched_roundtrip::<crate::mmr::Family>();
1130    }
1131    #[test]
1132    fn mmr_proof_verification() {
1133        proof_verification::<crate::mmr::Family>();
1134    }
1135    #[test]
1136    fn mmr_empty_batch() {
1137        empty_batch::<crate::mmr::Family>();
1138    }
1139    #[test]
1140    fn mmr_batch_roundtrip() {
1141        batch_roundtrip::<crate::mmr::Family>();
1142    }
1143    #[test]
1144    fn mmr_sequential_apply_batch() {
1145        sequential_apply_batch::<crate::mmr::Family>();
1146    }
1147    #[test]
1148    fn mmr_batch_on_pruned_base() {
1149        batch_on_pruned_base::<crate::mmr::Family>();
1150    }
1151    #[test]
1152    fn mmr_three_deep_stacking() {
1153        three_deep_stacking::<crate::mmr::Family>();
1154    }
1155    #[test]
1156    fn mmr_overwrite_collision() {
1157        overwrite_collision::<crate::mmr::Family>();
1158    }
1159    #[test]
1160    fn mmr_update_appended_leaf() {
1161        update_appended_leaf::<crate::mmr::Family>();
1162    }
1163    #[test]
1164    fn mmr_update_leaf_element() {
1165        update_leaf_element::<crate::mmr::Family>();
1166    }
1167    #[test]
1168    fn mmr_update_out_of_bounds() {
1169        update_out_of_bounds::<crate::mmr::Family>();
1170    }
1171
1172    // --- MMB tests ---
1173
1174    #[test]
1175    fn mmb_consistency() {
1176        consistency_with_reference::<crate::mmb::Family>();
1177    }
1178    #[test]
1179    fn mmb_lifecycle() {
1180        lifecycle::<crate::mmb::Family>();
1181    }
1182    #[test]
1183    fn mmb_apply_batch() {
1184        apply_batch::<crate::mmb::Family>();
1185    }
1186    #[test]
1187    fn mmb_multiple_forks() {
1188        multiple_forks::<crate::mmb::Family>();
1189    }
1190    #[test]
1191    fn mmb_fork_of_fork_reads() {
1192        fork_of_fork_reads::<crate::mmb::Family>();
1193    }
1194    #[test]
1195    fn mmb_update_leaf_digest() {
1196        update_leaf_digest_roundtrip::<crate::mmb::Family>();
1197    }
1198    #[test]
1199    fn mmb_update_and_add() {
1200        update_and_add::<crate::mmb::Family>();
1201    }
1202    #[test]
1203    fn mmb_update_leaf_batched() {
1204        update_leaf_batched_roundtrip::<crate::mmb::Family>();
1205    }
1206    #[test]
1207    fn mmb_proof_verification() {
1208        proof_verification::<crate::mmb::Family>();
1209    }
1210    #[test]
1211    fn mmb_empty_batch() {
1212        empty_batch::<crate::mmb::Family>();
1213    }
1214    #[test]
1215    fn mmb_batch_roundtrip() {
1216        batch_roundtrip::<crate::mmb::Family>();
1217    }
1218    #[test]
1219    fn mmb_sequential_apply_batch() {
1220        sequential_apply_batch::<crate::mmb::Family>();
1221    }
1222    #[test]
1223    fn mmb_batch_on_pruned_base() {
1224        batch_on_pruned_base::<crate::mmb::Family>();
1225    }
1226    #[test]
1227    fn mmb_three_deep_stacking() {
1228        three_deep_stacking::<crate::mmb::Family>();
1229    }
1230    #[test]
1231    fn mmb_overwrite_collision() {
1232        overwrite_collision::<crate::mmb::Family>();
1233    }
1234    #[test]
1235    fn mmb_update_appended_leaf() {
1236        update_appended_leaf::<crate::mmb::Family>();
1237    }
1238    #[test]
1239    fn mmb_update_leaf_element() {
1240        update_leaf_element::<crate::mmb::Family>();
1241    }
1242    #[test]
1243    fn mmb_update_out_of_bounds() {
1244        update_out_of_bounds::<crate::mmb::Family>();
1245    }
1246}