1use crate::merkle::{
78 hasher::Hasher, mem::Mem, path, proof::Proof, Error, Family, Location, Position, Readable,
79};
80use alloc::{
81 collections::{BTreeMap, BTreeSet},
82 sync::{Arc, Weak},
83 vec::Vec,
84};
85use commonware_cryptography::Digest;
86use core::ops::Range;
87cfg_if::cfg_if! {
88 if #[cfg(feature = "std")] {
89 use commonware_parallel::ThreadPool;
90 use rayon::prelude::*;
91 }
92}
93
94#[cfg(feature = "std")]
96pub(crate) const MIN_TO_PARALLELIZE: usize = 20;
97
98pub struct UnmerkleizedBatch<F: Family, D: Digest> {
105 parent: Arc<MerkleizedBatch<F, D>>,
106 appended: Vec<D>,
107 overwrites: BTreeMap<Position<F>, D>,
108 dirty_nodes: BTreeSet<(u32, Position<F>)>,
109 #[cfg(feature = "std")]
110 pool: Option<ThreadPool>,
111}
112
113impl<F: Family, D: Digest> UnmerkleizedBatch<F, D> {
114 pub const fn new(parent: Arc<MerkleizedBatch<F, D>>) -> Self {
116 Self {
117 parent,
118 appended: Vec::new(),
119 overwrites: BTreeMap::new(),
120 dirty_nodes: BTreeSet::new(),
121 #[cfg(feature = "std")]
122 pool: None,
123 }
124 }
125
126 #[cfg(feature = "std")]
128 pub fn with_pool(mut self, pool: Option<ThreadPool>) -> Self {
129 self.pool = pool;
130 self
131 }
132
133 #[cfg(feature = "std")]
135 pub const fn pool(&self) -> Option<&ThreadPool> {
136 self.pool.as_ref()
137 }
138
139 pub(crate) fn size(&self) -> Position<F> {
141 Position::new(*self.parent.size() + self.appended.len() as u64)
142 }
143
144 pub fn leaves(&self) -> Location<F> {
146 Location::try_from(self.size()).expect("invalid size")
147 }
148
149 fn get_node(&self, base: &Mem<F, D>, pos: Position<F>) -> Option<D> {
151 if pos >= self.size() {
152 return None;
153 }
154 if let Some(d) = self.overwrites.get(&pos) {
155 return Some(*d);
156 }
157 let parent_size = self.parent.size();
158 if pos >= parent_size {
159 let index = (*pos - *parent_size) as usize;
160 return self.appended.get(index).copied();
161 }
162 if let Some(d) = self.parent.get_node(pos) {
163 return Some(d);
164 }
165 base.get_node(pos)
166 }
167
168 fn store_node(&mut self, pos: Position<F>, digest: D) {
170 let parent_size = self.parent.size();
171 if pos >= parent_size {
172 let index = (*pos - *parent_size) as usize;
173 self.appended[index] = digest;
174 } else {
175 self.overwrites.insert(pos, digest);
176 }
177 }
178
179 fn mark_dirty(&mut self, loc: Location<F>) {
185 let mut first_leaf = Location::new(0);
186 for (peak_pos, height) in F::peaks(self.size()) {
187 let leaves_in_peak = 1u64 << height;
188 if loc >= first_leaf + leaves_in_peak {
189 first_leaf += leaves_in_peak;
190 continue;
191 }
192
193 let mut buf = [(Position::new(0), Position::new(0), 0u32); path::MAX_PATH_LEN];
194 let mut len = 0;
195 for item in path::Iterator::new(peak_pos, height, first_leaf, loc) {
196 buf[len] = item;
197 len += 1;
198 }
199 for &(parent_pos, _, h) in buf[..len].iter().rev() {
200 if !self.dirty_nodes.insert((h, parent_pos)) {
201 break;
202 }
203 }
204 return;
205 }
206
207 panic!("leaf {loc} not found (size: {})", self.size());
208 }
209
210 pub fn add_leaf_digest(mut self, digest: D) -> Self {
212 let heights = F::parent_heights(self.leaves());
213 self.appended.push(digest);
214
215 for height in heights {
216 let pos = self.size();
217 self.appended.push(D::EMPTY);
218 self.dirty_nodes.insert((height, pos));
219 }
220
221 self
222 }
223
224 pub fn add(self, hasher: &impl Hasher<F, Digest = D>, element: &[u8]) -> Self {
226 let digest = hasher.leaf_digest(self.size(), element);
227 self.add_leaf_digest(digest)
228 }
229
230 pub fn update_leaf(
237 mut self,
238 hasher: &impl Hasher<F, Digest = D>,
239 loc: Location<F>,
240 element: &[u8],
241 ) -> Result<Self, Error<F>> {
242 let leaves = self.leaves();
243 if loc >= leaves {
244 return Err(Error::LeafOutOfBounds(loc));
245 }
246 if loc < self.parent.pruning_boundary() {
247 return Err(Error::ElementPruned(Position::try_from(loc)?));
248 }
249 let pos = Position::try_from(loc)?;
250 let digest = hasher.leaf_digest(pos, element);
251 self.store_node(pos, digest);
252 self.mark_dirty(loc);
253 Ok(self)
254 }
255
256 #[cfg(any(feature = "std", test))]
258 pub fn update_leaf_digest(mut self, loc: Location<F>, digest: D) -> Result<Self, Error<F>> {
259 let leaves = self.leaves();
260 if loc >= leaves {
261 return Err(Error::LeafOutOfBounds(loc));
262 }
263 if loc < self.parent.pruning_boundary() {
264 return Err(Error::ElementPruned(Position::try_from(loc)?));
265 }
266 let pos = Position::try_from(loc)?;
267 if F::position_to_location(pos).is_none() {
268 return Err(Error::NonLeaf(pos));
269 }
270 self.store_node(pos, digest);
271 self.mark_dirty(loc);
272 Ok(self)
273 }
274
275 #[cfg(any(feature = "std", test))]
277 pub fn update_leaf_batched(mut self, updates: &[(Location<F>, D)]) -> Result<Self, Error<F>> {
278 let leaves = self.leaves();
279 let prune_boundary = self.parent.pruning_boundary();
280 for (loc, _) in updates {
281 if *loc >= leaves {
282 return Err(Error::LeafOutOfBounds(*loc));
283 }
284 if *loc < prune_boundary {
285 return Err(Error::ElementPruned(Position::try_from(*loc)?));
286 }
287 }
288 for (loc, digest) in updates {
289 let pos = Position::try_from(*loc).unwrap();
290 self.store_node(pos, *digest);
291 self.mark_dirty(*loc);
292 }
293 Ok(self)
294 }
295
296 pub fn merkleize(
299 mut self,
300 base: &Mem<F, D>,
301 hasher: &impl Hasher<F, Digest = D>,
302 ) -> Arc<MerkleizedBatch<F, D>> {
303 let dirty: Vec<_> = core::mem::take(&mut self.dirty_nodes).into_iter().collect();
304
305 #[cfg(feature = "std")]
306 if let Some(pool) = self.pool.take() {
307 if dirty.len() >= MIN_TO_PARALLELIZE {
308 self.merkleize_parallel(base, hasher, &pool, &dirty);
309 } else {
310 self.merkleize_serial(base, hasher, &dirty);
311 }
312 self.pool = Some(pool);
313 } else {
314 self.merkleize_serial(base, hasher, &dirty);
315 }
316
317 #[cfg(not(feature = "std"))]
318 self.merkleize_serial(base, hasher, &dirty);
319
320 let leaves = self.leaves();
322 let peaks: Vec<D> = F::peaks(self.size())
323 .map(|(peak_pos, _)| self.get_node(base, peak_pos).expect("peak missing"))
324 .collect();
325 let root = hasher.root(leaves, peaks.iter());
326
327 let (ancestor_appended, ancestor_overwrites) = collect_ancestor_batches(&self.parent);
329
330 let parent_size = self.parent.size();
331 Arc::new(MerkleizedBatch {
332 parent: Some(Arc::downgrade(&self.parent)),
333 appended: Arc::new(self.appended),
334 overwrites: Arc::new(self.overwrites),
335 root,
336 parent_size,
337 base_size: self.parent.base_size,
338 pruning_boundary: self.parent.pruning_boundary(),
339 ancestor_appended,
340 ancestor_overwrites,
341 #[cfg(feature = "std")]
342 pool: self.pool,
343 })
344 }
345
346 fn merkleize_serial(
348 &mut self,
349 base: &Mem<F, D>,
350 hasher: &impl Hasher<F, Digest = D>,
351 dirty: &[(u32, Position<F>)],
352 ) {
353 for &(height, pos) in dirty {
354 let (left, right) = F::children(pos, height);
355 let left_d = self.get_node(base, left).expect("left child missing");
356 let right_d = self.get_node(base, right).expect("right child missing");
357 let digest = hasher.node_digest(pos, &left_d, &right_d);
358 self.store_node(pos, digest);
359 }
360 }
361
362 #[cfg(feature = "std")]
365 fn merkleize_parallel(
366 &mut self,
367 base: &Mem<F, D>,
368 hasher: &impl Hasher<F, Digest = D>,
369 pool: &ThreadPool,
370 dirty: &[(u32, Position<F>)],
371 ) {
372 let mut same_height = Vec::new();
373 let mut current_height = dirty.first().map_or(1, |&(h, _)| h);
374 for (i, &(height, pos)) in dirty.iter().enumerate() {
375 if height == current_height {
376 same_height.push(pos);
377 continue;
378 }
379 if same_height.len() < MIN_TO_PARALLELIZE {
380 self.merkleize_serial(base, hasher, &dirty[i - same_height.len()..]);
381 return;
382 }
383 self.compute_height_parallel(base, hasher, pool, &same_height, current_height);
384 same_height.clear();
385 current_height = height;
386 same_height.push(pos);
387 }
388
389 if same_height.len() < MIN_TO_PARALLELIZE {
390 self.merkleize_serial(base, hasher, &dirty[dirty.len() - same_height.len()..]);
391 return;
392 }
393
394 self.compute_height_parallel(base, hasher, pool, &same_height, current_height);
395 }
396
397 #[cfg(feature = "std")]
399 fn compute_height_parallel(
400 &mut self,
401 base: &Mem<F, D>,
402 hasher: &impl Hasher<F, Digest = D>,
403 pool: &ThreadPool,
404 same_height: &[Position<F>],
405 height: u32,
406 ) {
407 let computed: Vec<(Position<F>, D)> = pool.install(|| {
408 same_height
409 .par_iter()
410 .map_init(
411 || hasher.clone(),
412 |hasher, &pos| {
413 let (left, right) = F::children(pos, height);
414 let left_d = self.get_node(base, left).expect("left child missing");
415 let right_d = self.get_node(base, right).expect("right child missing");
416 let digest = hasher.node_digest(pos, &left_d, &right_d);
417 (pos, digest)
418 },
419 )
420 .collect()
421 });
422 for (pos, digest) in computed {
423 self.store_node(pos, digest);
424 }
425 }
426}
427
428#[allow(clippy::type_complexity)]
432fn collect_ancestor_batches<F: Family, D: Digest>(
433 parent: &Arc<MerkleizedBatch<F, D>>,
434) -> (Vec<Arc<Vec<D>>>, Vec<Arc<BTreeMap<Position<F>, D>>>) {
435 let mut appended = Vec::new();
436 let mut overwrites = Vec::new();
437
438 if !parent.appended.is_empty() || !parent.overwrites.is_empty() {
440 appended.push(Arc::clone(&parent.appended));
441 overwrites.push(Arc::clone(&parent.overwrites));
442 }
443
444 let mut current = parent.parent.as_ref().and_then(Weak::upgrade);
446 while let Some(batch) = current {
447 if !batch.appended.is_empty() || !batch.overwrites.is_empty() {
448 appended.push(Arc::clone(&batch.appended));
449 overwrites.push(Arc::clone(&batch.overwrites));
450 }
451 current = batch.parent.as_ref().and_then(Weak::upgrade);
452 }
453
454 appended.reverse();
455 overwrites.reverse();
456 (appended, overwrites)
457}
458
459#[derive(Debug)]
466pub struct MerkleizedBatch<F: Family, D: Digest> {
467 parent: Option<Weak<Self>>,
469
470 pub(crate) appended: Arc<Vec<D>>,
472
473 pub(crate) overwrites: Arc<BTreeMap<Position<F>, D>>,
475
476 root: D,
478
479 pub(crate) parent_size: Position<F>,
481
482 pub(crate) base_size: Position<F>,
485
486 pruning_boundary: Location<F>,
489
490 pub(crate) ancestor_appended: Vec<Arc<Vec<D>>>,
493
494 pub(crate) ancestor_overwrites: Vec<Arc<BTreeMap<Position<F>, D>>>,
497
498 #[cfg(feature = "std")]
499 pub(crate) pool: Option<ThreadPool>,
500}
501
502impl<F: Family, D: Digest> MerkleizedBatch<F, D> {
503 pub fn from_mem(mem: &Mem<F, D>) -> Arc<Self> {
505 Arc::new(Self {
506 parent: None,
507 appended: Arc::new(Vec::new()),
508 overwrites: Arc::new(BTreeMap::new()),
509 root: *mem.root(),
510 parent_size: mem.size(),
511 base_size: mem.size(),
512 pruning_boundary: Readable::pruning_boundary(mem),
513 ancestor_appended: Vec::new(),
514 ancestor_overwrites: Vec::new(),
515 #[cfg(feature = "std")]
516 pool: None,
517 })
518 }
519
520 pub fn size(&self) -> Position<F> {
522 Position::new(*self.parent_size + self.appended.len() as u64)
523 }
524
525 pub fn get_node(&self, pos: Position<F>) -> Option<D> {
531 if pos >= self.size() {
532 return None;
533 }
534 if let Some(d) = self.overwrites.get(&pos) {
535 return Some(*d);
536 }
537 if pos >= self.parent_size {
538 let i = (*pos - *self.parent_size) as usize;
539 return self.appended.get(i).copied();
540 }
541 let mut current = self.parent.as_ref().and_then(Weak::upgrade);
543 while let Some(batch) = current {
544 if let Some(d) = batch.overwrites.get(&pos) {
545 return Some(*d);
546 }
547 if pos >= batch.parent_size {
548 let i = (*pos - *batch.parent_size) as usize;
549 return batch.appended.get(i).copied();
550 }
551 current = batch.parent.as_ref().and_then(Weak::upgrade);
552 }
553 None
554 }
555
556 pub const fn root(&self) -> D {
558 self.root
559 }
560
561 pub const fn pruning_boundary(&self) -> Location<F> {
563 self.pruning_boundary
564 }
565
566 pub fn leaves(&self) -> Location<F> {
568 Location::try_from(self.size()).expect("invalid size")
569 }
570
571 pub fn new_batch(self: &Arc<Self>) -> UnmerkleizedBatch<F, D> {
577 let batch = UnmerkleizedBatch::new(Arc::clone(self));
578 #[cfg(feature = "std")]
579 let batch = batch.with_pool(self.pool.clone());
580 batch
581 }
582
583 pub const fn base_size(&self) -> Position<F> {
585 self.base_size
586 }
587}
588
589impl<F: Family, D: Digest> Readable for MerkleizedBatch<F, D> {
590 type Family = F;
591 type Digest = D;
592 type Error = Error<F>;
593
594 fn size(&self) -> Position<F> {
595 Self::size(self)
596 }
597
598 fn get_node(&self, pos: Position<F>) -> Option<D> {
599 Self::get_node(self, pos)
600 }
601
602 fn root(&self) -> D {
603 Self::root(self)
604 }
605
606 fn pruning_boundary(&self) -> Location<F> {
607 Self::pruning_boundary(self)
608 }
609
610 fn proof(
611 &self,
612 hasher: &impl Hasher<F, Digest = D>,
613 loc: Location<F>,
614 ) -> Result<Proof<F, D>, Error<F>> {
615 if !loc.is_valid_index() {
616 return Err(Error::LocationOverflow(loc));
617 }
618 self.range_proof(hasher, loc..loc + 1).map_err(|e| match e {
619 Error::RangeOutOfBounds(_) => Error::LeafOutOfBounds(loc),
620 _ => e,
621 })
622 }
623
624 fn range_proof(
625 &self,
626 hasher: &impl Hasher<F, Digest = D>,
627 range: Range<Location<F>>,
628 ) -> Result<Proof<F, D>, Error<F>> {
629 crate::merkle::proof::build_range_proof(
630 hasher,
631 self.leaves(),
632 range,
633 |pos| Self::get_node(self, pos),
634 Error::ElementPruned,
635 )
636 }
637}
638
639#[cfg(test)]
644mod tests {
645 use super::*;
646 use crate::merkle::{hasher::Standard, mem::Mem};
647 use commonware_cryptography::{sha256, Sha256};
648 use commonware_runtime::{deterministic, Runner as _};
649
650 type D = sha256::Digest;
651 type H = Standard<Sha256>;
652
653 fn build_reference<F: Family>(hasher: &H, n: u64) -> Mem<F, D> {
654 let mut mem = Mem::new(hasher);
655 let batch = {
656 let mut batch = mem.new_batch();
657 for i in 0u64..n {
658 let element = hasher.digest(&i.to_be_bytes());
659 batch = batch.add(hasher, &element);
660 }
661 batch.merkleize(&mem, hasher)
662 };
663 mem.apply_batch(&batch).unwrap();
664 mem
665 }
666
667 fn consistency_with_reference<F: Family>() {
668 let executor = deterministic::Runner::default();
669 executor.start(|_| async move {
670 let hasher: H = Standard::new();
671 for &n in &[1u64, 2, 10, 100, 199] {
672 let reference = build_reference::<F>(&hasher, n);
673 let base = Mem::<F, D>::new(&hasher);
674 let mut batch = base.new_batch();
675 for i in 0..n {
676 let element = hasher.digest(&i.to_be_bytes());
677 batch = batch.add(&hasher, &element);
678 }
679 let merkleized = batch.merkleize(&base, &hasher);
680 let mut result = Mem::<F, D>::new(&hasher);
681 result.apply_batch(&merkleized).unwrap();
682 assert_eq!(result.root(), reference.root(), "root mismatch for n={n}");
683 }
684 });
685 }
686
687 fn lifecycle<F: Family>() {
688 let executor = deterministic::Runner::default();
689 executor.start(|_| async move {
690 let hasher: H = Standard::new();
691 let base = build_reference::<F>(&hasher, 50);
692 let base_root = *base.root();
693 let mut batch = base.new_batch();
694 for i in 50u64..60 {
695 let element = hasher.digest(&i.to_be_bytes());
696 batch = batch.add(&hasher, &element);
697 }
698 let merkleized = batch.merkleize(&base, &hasher);
699 assert_ne!(merkleized.root(), base_root);
700 assert_eq!(*base.root(), base_root);
701 let mut applied = base;
703 applied.apply_batch(&merkleized).unwrap();
704 let loc = Location::<F>::new(55);
705 let element = hasher.digest(&55u64.to_be_bytes());
706 let proof = applied.proof(&hasher, loc).unwrap();
707 assert!(proof.verify_element_inclusion(&hasher, &element, loc, &merkleized.root()));
708 });
709 }
710
711 fn apply_batch<F: Family>() {
712 let executor = deterministic::Runner::default();
713 executor.start(|_| async move {
714 let hasher: H = Standard::new();
715 let mut base = build_reference::<F>(&hasher, 50);
716 let mut batch = base.new_batch();
717 for i in 50u64..75 {
718 let element = hasher.digest(&i.to_be_bytes());
719 batch = batch.add(&hasher, &element);
720 }
721 let merkleized = batch.merkleize(&base, &hasher);
722 let batch_root = merkleized.root();
723 base.apply_batch(&merkleized).unwrap();
724 assert_eq!(*base.root(), batch_root);
725 let reference = build_reference::<F>(&hasher, 75);
726 assert_eq!(base.root(), reference.root());
727 });
728 }
729
730 fn multiple_forks<F: Family>() {
731 let executor = deterministic::Runner::default();
732 executor.start(|_| async move {
733 let hasher: H = Standard::new();
734 let base = build_reference::<F>(&hasher, 50);
735 let base_root = *base.root();
736 let mut ba = base.new_batch();
737 for i in 50u64..60 {
738 let element = hasher.digest(&i.to_be_bytes());
739 ba = ba.add(&hasher, &element);
740 }
741 let ma = ba.merkleize(&base, &hasher);
742 let mut bb = base.new_batch();
743 for i in 100u64..105 {
744 let element = hasher.digest(&i.to_be_bytes());
745 bb = bb.add(&hasher, &element);
746 }
747 let mb = bb.merkleize(&base, &hasher);
748 assert_ne!(ma.root(), mb.root());
749 assert_ne!(ma.root(), base_root);
750 assert_eq!(*base.root(), base_root);
751 });
752 }
753
754 fn fork_of_fork_reads<F: Family>() {
755 let executor = deterministic::Runner::default();
756 executor.start(|_| async move {
757 let hasher: H = Standard::new();
758 let base = build_reference::<F>(&hasher, 50);
759 let mut ba = base.new_batch();
760 for i in 50u64..60 {
761 let element = hasher.digest(&i.to_be_bytes());
762 ba = ba.add(&hasher, &element);
763 }
764 let ma = ba.merkleize(&base, &hasher);
765 let mut bb = ma.new_batch();
766 for i in 60u64..70 {
767 let element = hasher.digest(&i.to_be_bytes());
768 bb = bb.add(&hasher, &element);
769 }
770 let mb = bb.merkleize(&base, &hasher);
771 let reference = build_reference::<F>(&hasher, 70);
772 assert_eq!(mb.root(), *reference.root());
773 let mut applied = base;
775 applied.apply_batch(&ma).unwrap();
776 applied.apply_batch(&mb).unwrap();
777 for i in [0u64, 25, 55, 65, 69] {
778 let loc = Location::<F>::new(i);
779 let element = hasher.digest(&i.to_be_bytes());
780 let proof = applied.proof(&hasher, loc).unwrap();
781 assert!(proof.verify_element_inclusion(&hasher, &element, loc, &mb.root()));
782 }
783 });
784 }
785
786 fn update_leaf_digest_roundtrip<F: Family>() {
787 let executor = deterministic::Runner::default();
788 executor.start(|_| async move {
789 let hasher: H = Standard::new();
790 let base = build_reference::<F>(&hasher, 100);
791 let base_root = *base.root();
792 let updated = Sha256::fill(0xFF);
793 let m = base
794 .new_batch()
795 .update_leaf_digest(Location::new(5), updated)
796 .unwrap()
797 .merkleize(&base, &hasher);
798 assert_ne!(m.root(), base_root);
799 let pos5 = Position::<F>::try_from(Location::new(5)).unwrap();
800 let original = base.get_node(pos5).unwrap();
801 let m2 = base
802 .new_batch()
803 .update_leaf_digest(Location::new(5), original)
804 .unwrap()
805 .merkleize(&base, &hasher);
806 assert_eq!(m2.root(), base_root);
807 });
808 }
809
810 fn update_and_add<F: Family>() {
811 let executor = deterministic::Runner::default();
812 executor.start(|_| async move {
813 let hasher: H = Standard::new();
814 let base = build_reference::<F>(&hasher, 50);
815 let base_root = *base.root();
816 let updated = Sha256::fill(0xAA);
817 let mut batch = base
818 .new_batch()
819 .update_leaf_digest(Location::new(10), updated)
820 .unwrap();
821 for i in 50u64..55 {
822 let element = hasher.digest(&i.to_be_bytes());
823 batch = batch.add(&hasher, &element);
824 }
825 let m = batch.merkleize(&base, &hasher);
826 assert_ne!(m.root(), base_root);
827 let pos10 = Position::<F>::try_from(Location::new(10)).unwrap();
828 assert_eq!(m.get_node(pos10), Some(updated));
829 });
830 }
831
832 fn update_leaf_batched_roundtrip<F: Family>() {
833 let executor = deterministic::Runner::default();
834 executor.start(|_| async move {
835 let hasher: H = Standard::new();
836 let base = build_reference::<F>(&hasher, 100);
837 let base_root = *base.root();
838 let updated = Sha256::fill(0xBB);
839 let locs = [0u64, 10, 50, 99];
840 let updates: Vec<(Location<F>, D)> =
841 locs.iter().map(|&i| (Location::new(i), updated)).collect();
842 let m = base
843 .new_batch()
844 .update_leaf_batched(&updates)
845 .unwrap()
846 .merkleize(&base, &hasher);
847 assert_ne!(m.root(), base_root);
848 let restore: Vec<(Location<F>, D)> = locs
849 .iter()
850 .map(|&l| {
851 let pos = Position::<F>::try_from(Location::new(l)).unwrap();
852 (Location::new(l), base.get_node(pos).unwrap())
853 })
854 .collect();
855 let m2 = base
856 .new_batch()
857 .update_leaf_batched(&restore)
858 .unwrap()
859 .merkleize(&base, &hasher);
860 assert_eq!(m2.root(), base_root);
861 });
862 }
863
864 fn proof_verification<F: Family>() {
865 let executor = deterministic::Runner::default();
866 executor.start(|_| async move {
867 let hasher: H = Standard::new();
868 let base = build_reference::<F>(&hasher, 50);
869 let mut batch = base.new_batch();
870 for i in 50u64..60 {
871 let element = hasher.digest(&i.to_be_bytes());
872 batch = batch.add(&hasher, &element);
873 }
874 let m = batch.merkleize(&base, &hasher);
875 let mut applied = base;
877 applied.apply_batch(&m).unwrap();
878 let loc = Location::<F>::new(55);
879 let element = hasher.digest(&55u64.to_be_bytes());
880 let proof = applied.proof(&hasher, loc).unwrap();
881 assert!(proof.verify_element_inclusion(&hasher, &element, loc, &m.root()));
882 let range = Location::<F>::new(50)..Location::new(55);
883 let rp = applied.range_proof(&hasher, range.clone()).unwrap();
884 let elements: Vec<D> = (50u64..55)
885 .map(|i| hasher.digest(&i.to_be_bytes()))
886 .collect();
887 assert!(rp.verify_range_inclusion(&hasher, &elements, range.start, &m.root()));
888 });
889 }
890
891 fn empty_batch<F: Family>() {
892 let executor = deterministic::Runner::default();
893 executor.start(|_| async move {
894 let hasher: H = Standard::new();
895 let base = build_reference::<F>(&hasher, 50);
896 let base_root = *base.root();
897 let m = base.new_batch().merkleize(&base, &hasher);
898 assert_eq!(m.root(), base_root);
899 });
900 }
901
902 fn batch_roundtrip<F: Family>() {
903 let executor = deterministic::Runner::default();
904 executor.start(|_| async move {
905 let hasher: H = Standard::new();
906 let base = build_reference::<F>(&hasher, 50);
907 let mut batch = base.new_batch();
908 for i in 50u64..55 {
909 let element = hasher.digest(&i.to_be_bytes());
910 batch = batch.add(&hasher, &element);
911 }
912 let merkleized = batch.merkleize(&base, &hasher);
913 let mut batch_again = merkleized.new_batch();
914 for i in 55u64..60 {
915 let element = hasher.digest(&i.to_be_bytes());
916 batch_again = batch_again.add(&hasher, &element);
917 }
918 let reference = build_reference::<F>(&hasher, 60);
919 assert_eq!(
920 batch_again.merkleize(&base, &hasher).root(),
921 *reference.root()
922 );
923 });
924 }
925
926 fn sequential_apply_batch<F: Family>() {
927 let executor = deterministic::Runner::default();
928 executor.start(|_| async move {
929 let hasher: H = Standard::new();
930 let mut base = build_reference::<F>(&hasher, 50);
931 let mut b1 = base.new_batch();
932 for i in 50u64..60 {
933 let element = hasher.digest(&i.to_be_bytes());
934 b1 = b1.add(&hasher, &element);
935 }
936 let m1 = b1.merkleize(&base, &hasher);
937 base.apply_batch(&m1).unwrap();
938 let mut b2 = base.new_batch();
939 for i in 60u64..70 {
940 let element = hasher.digest(&i.to_be_bytes());
941 b2 = b2.add(&hasher, &element);
942 }
943 let m2 = b2.merkleize(&base, &hasher);
944 base.apply_batch(&m2).unwrap();
945 let reference = build_reference::<F>(&hasher, 70);
946 assert_eq!(base.root(), reference.root());
947 });
948 }
949
950 fn batch_on_pruned_base<F: Family>() {
951 let executor = deterministic::Runner::default();
952 executor.start(|_| async move {
953 let hasher: H = Standard::new();
954 let mut base = build_reference::<F>(&hasher, 100);
955 base.prune(Location::new(27)).unwrap();
956 let mut batch = base.new_batch();
957 for i in 100u64..110 {
958 let element = hasher.digest(&i.to_be_bytes());
959 batch = batch.add(&hasher, &element);
960 }
961 let m = batch.merkleize(&base, &hasher);
962 let mut applied = base;
964 applied.apply_batch(&m).unwrap();
965 let loc = Location::<F>::new(80);
966 let element = hasher.digest(&80u64.to_be_bytes());
967 let proof = applied.proof(&hasher, loc).unwrap();
968 assert!(proof.verify_element_inclusion(&hasher, &element, loc, &m.root()));
969 assert!(matches!(
970 applied.proof(&hasher, Location::new(0)),
971 Err(Error::ElementPruned(_))
972 ));
973 });
974 }
975
976 fn three_deep_stacking<F: Family>() {
977 let executor = deterministic::Runner::default();
978 executor.start(|_| async move {
979 let hasher: H = Standard::new();
980 let mut base = build_reference::<F>(&hasher, 100);
981 let da = Sha256::fill(0xDD);
982 let db = Sha256::fill(0xEE);
983 let ma = base
984 .new_batch()
985 .update_leaf_digest(Location::new(5), da)
986 .unwrap()
987 .merkleize(&base, &hasher);
988 let mb = ma
989 .new_batch()
990 .update_leaf_digest(Location::new(10), db)
991 .unwrap()
992 .merkleize(&base, &hasher);
993 let mut bc = mb.new_batch();
994 for i in 300u64..310 {
995 let element = hasher.digest(&i.to_be_bytes());
996 bc = bc.add(&hasher, &element);
997 }
998 let mc = bc.merkleize(&base, &hasher);
999 let c_root = mc.root();
1000 base.apply_batch(&mc).unwrap();
1001 assert_eq!(*base.root(), c_root);
1002 });
1003 }
1004
1005 fn overwrite_collision<F: Family>() {
1006 let executor = deterministic::Runner::default();
1007 executor.start(|_| async move {
1008 let hasher: H = Standard::new();
1009 let mut base = build_reference::<F>(&hasher, 100);
1010 let dx = Sha256::fill(0xAA);
1011 let dy = Sha256::fill(0xBB);
1012 let ma = base
1013 .new_batch()
1014 .update_leaf_digest(Location::new(5), dx)
1015 .unwrap()
1016 .merkleize(&base, &hasher);
1017 let mb = ma
1018 .new_batch()
1019 .update_leaf_digest(Location::new(5), dy)
1020 .unwrap()
1021 .merkleize(&base, &hasher);
1022 let b_root = mb.root();
1023 base.apply_batch(&mb).unwrap();
1024 assert_eq!(*base.root(), b_root);
1025 let pos5 = Position::<F>::try_from(Location::new(5)).unwrap();
1026 assert_eq!(base.get_node(pos5), Some(dy));
1027 });
1028 }
1029
1030 fn update_appended_leaf<F: Family>() {
1031 let executor = deterministic::Runner::default();
1032 executor.start(|_| async move {
1033 let hasher: H = Standard::new();
1034 let base = build_reference::<F>(&hasher, 50);
1035 let mut batch = base.new_batch();
1036 for i in 50u64..60 {
1037 let element = hasher.digest(&i.to_be_bytes());
1038 batch = batch.add(&hasher, &element);
1039 }
1040 let updated = Sha256::fill(0xEE);
1041 let m = batch
1042 .update_leaf_digest(Location::new(52), updated)
1043 .unwrap()
1044 .merkleize(&base, &hasher);
1045 let pos52 = Position::<F>::try_from(Location::new(52)).unwrap();
1046 assert_eq!(m.get_node(pos52), Some(updated));
1047 let mut reference = build_reference::<F>(&hasher, 60);
1048 let batch = reference
1049 .new_batch()
1050 .update_leaf_digest(Location::new(52), updated)
1051 .unwrap()
1052 .merkleize(&reference, &hasher);
1053 reference.apply_batch(&batch).unwrap();
1054 assert_eq!(m.root(), *reference.root());
1055 });
1056 }
1057
1058 fn update_leaf_element<F: Family>() {
1059 let executor = deterministic::Runner::default();
1060 executor.start(|_| async move {
1061 let hasher: H = Standard::new();
1062 let base = build_reference::<F>(&hasher, 50);
1063 let base_root = *base.root();
1064 let element = b"updated-element";
1065 let m = base
1066 .new_batch()
1067 .update_leaf(&hasher, Location::new(5), element)
1068 .unwrap()
1069 .merkleize(&base, &hasher);
1070 assert_ne!(m.root(), base_root);
1071 let mut base = base;
1072 let batch = base
1073 .new_batch()
1074 .update_leaf(&hasher, Location::new(5), element)
1075 .unwrap()
1076 .merkleize(&base, &hasher);
1077 base.apply_batch(&batch).unwrap();
1078 assert_eq!(m.root(), *base.root());
1079 });
1080 }
1081
1082 fn update_out_of_bounds<F: Family>() {
1083 let executor = deterministic::Runner::default();
1084 executor.start(|_| async move {
1085 let hasher: H = Standard::new();
1086 let base = build_reference::<F>(&hasher, 50);
1087 let r1 = base
1088 .new_batch()
1089 .update_leaf_digest(Location::new(50), Sha256::fill(0xFF));
1090 assert!(matches!(r1, Err(Error::LeafOutOfBounds(_))));
1091 let updates = [(Location::<F>::new(50), Sha256::fill(0xFF))];
1092 let r2 = base.new_batch().update_leaf_batched(&updates);
1093 assert!(matches!(r2, Err(Error::LeafOutOfBounds(_))));
1094 });
1095 }
1096
1097 #[test]
1100 fn mmr_consistency() {
1101 consistency_with_reference::<crate::mmr::Family>();
1102 }
1103 #[test]
1104 fn mmr_lifecycle() {
1105 lifecycle::<crate::mmr::Family>();
1106 }
1107 #[test]
1108 fn mmr_apply_batch() {
1109 apply_batch::<crate::mmr::Family>();
1110 }
1111 #[test]
1112 fn mmr_multiple_forks() {
1113 multiple_forks::<crate::mmr::Family>();
1114 }
1115 #[test]
1116 fn mmr_fork_of_fork_reads() {
1117 fork_of_fork_reads::<crate::mmr::Family>();
1118 }
1119 #[test]
1120 fn mmr_update_leaf_digest() {
1121 update_leaf_digest_roundtrip::<crate::mmr::Family>();
1122 }
1123 #[test]
1124 fn mmr_update_and_add() {
1125 update_and_add::<crate::mmr::Family>();
1126 }
1127 #[test]
1128 fn mmr_update_leaf_batched() {
1129 update_leaf_batched_roundtrip::<crate::mmr::Family>();
1130 }
1131 #[test]
1132 fn mmr_proof_verification() {
1133 proof_verification::<crate::mmr::Family>();
1134 }
1135 #[test]
1136 fn mmr_empty_batch() {
1137 empty_batch::<crate::mmr::Family>();
1138 }
1139 #[test]
1140 fn mmr_batch_roundtrip() {
1141 batch_roundtrip::<crate::mmr::Family>();
1142 }
1143 #[test]
1144 fn mmr_sequential_apply_batch() {
1145 sequential_apply_batch::<crate::mmr::Family>();
1146 }
1147 #[test]
1148 fn mmr_batch_on_pruned_base() {
1149 batch_on_pruned_base::<crate::mmr::Family>();
1150 }
1151 #[test]
1152 fn mmr_three_deep_stacking() {
1153 three_deep_stacking::<crate::mmr::Family>();
1154 }
1155 #[test]
1156 fn mmr_overwrite_collision() {
1157 overwrite_collision::<crate::mmr::Family>();
1158 }
1159 #[test]
1160 fn mmr_update_appended_leaf() {
1161 update_appended_leaf::<crate::mmr::Family>();
1162 }
1163 #[test]
1164 fn mmr_update_leaf_element() {
1165 update_leaf_element::<crate::mmr::Family>();
1166 }
1167 #[test]
1168 fn mmr_update_out_of_bounds() {
1169 update_out_of_bounds::<crate::mmr::Family>();
1170 }
1171
1172 #[test]
1175 fn mmb_consistency() {
1176 consistency_with_reference::<crate::mmb::Family>();
1177 }
1178 #[test]
1179 fn mmb_lifecycle() {
1180 lifecycle::<crate::mmb::Family>();
1181 }
1182 #[test]
1183 fn mmb_apply_batch() {
1184 apply_batch::<crate::mmb::Family>();
1185 }
1186 #[test]
1187 fn mmb_multiple_forks() {
1188 multiple_forks::<crate::mmb::Family>();
1189 }
1190 #[test]
1191 fn mmb_fork_of_fork_reads() {
1192 fork_of_fork_reads::<crate::mmb::Family>();
1193 }
1194 #[test]
1195 fn mmb_update_leaf_digest() {
1196 update_leaf_digest_roundtrip::<crate::mmb::Family>();
1197 }
1198 #[test]
1199 fn mmb_update_and_add() {
1200 update_and_add::<crate::mmb::Family>();
1201 }
1202 #[test]
1203 fn mmb_update_leaf_batched() {
1204 update_leaf_batched_roundtrip::<crate::mmb::Family>();
1205 }
1206 #[test]
1207 fn mmb_proof_verification() {
1208 proof_verification::<crate::mmb::Family>();
1209 }
1210 #[test]
1211 fn mmb_empty_batch() {
1212 empty_batch::<crate::mmb::Family>();
1213 }
1214 #[test]
1215 fn mmb_batch_roundtrip() {
1216 batch_roundtrip::<crate::mmb::Family>();
1217 }
1218 #[test]
1219 fn mmb_sequential_apply_batch() {
1220 sequential_apply_batch::<crate::mmb::Family>();
1221 }
1222 #[test]
1223 fn mmb_batch_on_pruned_base() {
1224 batch_on_pruned_base::<crate::mmb::Family>();
1225 }
1226 #[test]
1227 fn mmb_three_deep_stacking() {
1228 three_deep_stacking::<crate::mmb::Family>();
1229 }
1230 #[test]
1231 fn mmb_overwrite_collision() {
1232 overwrite_collision::<crate::mmb::Family>();
1233 }
1234 #[test]
1235 fn mmb_update_appended_leaf() {
1236 update_appended_leaf::<crate::mmb::Family>();
1237 }
1238 #[test]
1239 fn mmb_update_leaf_element() {
1240 update_leaf_element::<crate::mmb::Family>();
1241 }
1242 #[test]
1243 fn mmb_update_out_of_bounds() {
1244 update_out_of_bounds::<crate::mmb::Family>();
1245 }
1246}