use crate::merkle::{
hasher::Hasher, mem::Mem, path, proof::Proof, Error, Family, Location, Position, Readable,
};
use alloc::{
collections::{BTreeMap, BTreeSet},
sync::{Arc, Weak},
vec::Vec,
};
use commonware_cryptography::Digest;
use core::ops::Range;
cfg_if::cfg_if! {
if #[cfg(feature = "std")] {
use commonware_parallel::ThreadPool;
use rayon::prelude::*;
}
}
#[cfg(feature = "std")]
pub(crate) const MIN_TO_PARALLELIZE: usize = 20;
pub struct UnmerkleizedBatch<F: Family, D: Digest> {
parent: Arc<MerkleizedBatch<F, D>>,
appended: Vec<D>,
overwrites: BTreeMap<Position<F>, D>,
dirty_nodes: BTreeSet<(u32, Position<F>)>,
#[cfg(feature = "std")]
pool: Option<ThreadPool>,
}
impl<F: Family, D: Digest> UnmerkleizedBatch<F, D> {
pub const fn new(parent: Arc<MerkleizedBatch<F, D>>) -> Self {
Self {
parent,
appended: Vec::new(),
overwrites: BTreeMap::new(),
dirty_nodes: BTreeSet::new(),
#[cfg(feature = "std")]
pool: None,
}
}
#[cfg(feature = "std")]
pub fn with_pool(mut self, pool: Option<ThreadPool>) -> Self {
self.pool = pool;
self
}
#[cfg(feature = "std")]
pub const fn pool(&self) -> Option<&ThreadPool> {
self.pool.as_ref()
}
pub(crate) fn size(&self) -> Position<F> {
Position::new(*self.parent.size() + self.appended.len() as u64)
}
pub fn leaves(&self) -> Location<F> {
Location::try_from(self.size()).expect("invalid size")
}
fn get_node(&self, base: &Mem<F, D>, pos: Position<F>) -> Option<D> {
if pos >= self.size() {
return None;
}
if let Some(d) = self.overwrites.get(&pos) {
return Some(*d);
}
let parent_size = self.parent.size();
if pos >= parent_size {
let index = (*pos - *parent_size) as usize;
return self.appended.get(index).copied();
}
if let Some(d) = self.parent.get_node(pos) {
return Some(d);
}
base.get_node(pos)
}
fn store_node(&mut self, pos: Position<F>, digest: D) {
let parent_size = self.parent.size();
if pos >= parent_size {
let index = (*pos - *parent_size) as usize;
self.appended[index] = digest;
} else {
self.overwrites.insert(pos, digest);
}
}
fn mark_dirty(&mut self, loc: Location<F>) {
let mut first_leaf = Location::new(0);
for (peak_pos, height) in F::peaks(self.size()) {
let leaves_in_peak = 1u64 << height;
if loc >= first_leaf + leaves_in_peak {
first_leaf += leaves_in_peak;
continue;
}
let mut buf = [(Position::new(0), Position::new(0), 0u32); path::MAX_PATH_LEN];
let mut len = 0;
for item in path::Iterator::new(peak_pos, height, first_leaf, loc) {
buf[len] = item;
len += 1;
}
for &(parent_pos, _, h) in buf[..len].iter().rev() {
if !self.dirty_nodes.insert((h, parent_pos)) {
break;
}
}
return;
}
panic!("leaf {loc} not found (size: {})", self.size());
}
pub fn add_leaf_digest(mut self, digest: D) -> Self {
let heights = F::parent_heights(self.leaves());
self.appended.push(digest);
for height in heights {
let pos = self.size();
self.appended.push(D::EMPTY);
self.dirty_nodes.insert((height, pos));
}
self
}
pub fn add(self, hasher: &impl Hasher<F, Digest = D>, element: &[u8]) -> Self {
let digest = hasher.leaf_digest(self.size(), element);
self.add_leaf_digest(digest)
}
pub fn update_leaf(
mut self,
hasher: &impl Hasher<F, Digest = D>,
loc: Location<F>,
element: &[u8],
) -> Result<Self, Error<F>> {
let leaves = self.leaves();
if loc >= leaves {
return Err(Error::LeafOutOfBounds(loc));
}
if loc < self.parent.pruning_boundary() {
return Err(Error::ElementPruned(Position::try_from(loc)?));
}
let pos = Position::try_from(loc)?;
let digest = hasher.leaf_digest(pos, element);
self.store_node(pos, digest);
self.mark_dirty(loc);
Ok(self)
}
#[cfg(any(feature = "std", test))]
pub fn update_leaf_digest(mut self, loc: Location<F>, digest: D) -> Result<Self, Error<F>> {
let leaves = self.leaves();
if loc >= leaves {
return Err(Error::LeafOutOfBounds(loc));
}
if loc < self.parent.pruning_boundary() {
return Err(Error::ElementPruned(Position::try_from(loc)?));
}
let pos = Position::try_from(loc)?;
if F::position_to_location(pos).is_none() {
return Err(Error::NonLeaf(pos));
}
self.store_node(pos, digest);
self.mark_dirty(loc);
Ok(self)
}
#[cfg(any(feature = "std", test))]
pub fn update_leaf_batched(mut self, updates: &[(Location<F>, D)]) -> Result<Self, Error<F>> {
let leaves = self.leaves();
let prune_boundary = self.parent.pruning_boundary();
for (loc, _) in updates {
if *loc >= leaves {
return Err(Error::LeafOutOfBounds(*loc));
}
if *loc < prune_boundary {
return Err(Error::ElementPruned(Position::try_from(*loc)?));
}
}
for (loc, digest) in updates {
let pos = Position::try_from(*loc).unwrap();
self.store_node(pos, *digest);
self.mark_dirty(*loc);
}
Ok(self)
}
pub fn merkleize(
mut self,
base: &Mem<F, D>,
hasher: &impl Hasher<F, Digest = D>,
) -> Arc<MerkleizedBatch<F, D>> {
let dirty: Vec<_> = core::mem::take(&mut self.dirty_nodes).into_iter().collect();
#[cfg(feature = "std")]
if let Some(pool) = self.pool.take() {
if dirty.len() >= MIN_TO_PARALLELIZE {
self.merkleize_parallel(base, hasher, &pool, &dirty);
} else {
self.merkleize_serial(base, hasher, &dirty);
}
self.pool = Some(pool);
} else {
self.merkleize_serial(base, hasher, &dirty);
}
#[cfg(not(feature = "std"))]
self.merkleize_serial(base, hasher, &dirty);
let leaves = self.leaves();
let peaks: Vec<D> = F::peaks(self.size())
.map(|(peak_pos, _)| self.get_node(base, peak_pos).expect("peak missing"))
.collect();
let root = hasher.root(leaves, peaks.iter());
let (ancestor_appended, ancestor_overwrites) = collect_ancestor_batches(&self.parent);
let parent_size = self.parent.size();
Arc::new(MerkleizedBatch {
parent: Some(Arc::downgrade(&self.parent)),
appended: Arc::new(self.appended),
overwrites: Arc::new(self.overwrites),
root,
parent_size,
base_size: self.parent.base_size,
pruning_boundary: self.parent.pruning_boundary(),
ancestor_appended,
ancestor_overwrites,
#[cfg(feature = "std")]
pool: self.pool,
})
}
fn merkleize_serial(
&mut self,
base: &Mem<F, D>,
hasher: &impl Hasher<F, Digest = D>,
dirty: &[(u32, Position<F>)],
) {
for &(height, pos) in dirty {
let (left, right) = F::children(pos, height);
let left_d = self.get_node(base, left).expect("left child missing");
let right_d = self.get_node(base, right).expect("right child missing");
let digest = hasher.node_digest(pos, &left_d, &right_d);
self.store_node(pos, digest);
}
}
#[cfg(feature = "std")]
fn merkleize_parallel(
&mut self,
base: &Mem<F, D>,
hasher: &impl Hasher<F, Digest = D>,
pool: &ThreadPool,
dirty: &[(u32, Position<F>)],
) {
let mut same_height = Vec::new();
let mut current_height = dirty.first().map_or(1, |&(h, _)| h);
for (i, &(height, pos)) in dirty.iter().enumerate() {
if height == current_height {
same_height.push(pos);
continue;
}
if same_height.len() < MIN_TO_PARALLELIZE {
self.merkleize_serial(base, hasher, &dirty[i - same_height.len()..]);
return;
}
self.compute_height_parallel(base, hasher, pool, &same_height, current_height);
same_height.clear();
current_height = height;
same_height.push(pos);
}
if same_height.len() < MIN_TO_PARALLELIZE {
self.merkleize_serial(base, hasher, &dirty[dirty.len() - same_height.len()..]);
return;
}
self.compute_height_parallel(base, hasher, pool, &same_height, current_height);
}
#[cfg(feature = "std")]
fn compute_height_parallel(
&mut self,
base: &Mem<F, D>,
hasher: &impl Hasher<F, Digest = D>,
pool: &ThreadPool,
same_height: &[Position<F>],
height: u32,
) {
let computed: Vec<(Position<F>, D)> = pool.install(|| {
same_height
.par_iter()
.map_init(
|| hasher.clone(),
|hasher, &pos| {
let (left, right) = F::children(pos, height);
let left_d = self.get_node(base, left).expect("left child missing");
let right_d = self.get_node(base, right).expect("right child missing");
let digest = hasher.node_digest(pos, &left_d, &right_d);
(pos, digest)
},
)
.collect()
});
for (pos, digest) in computed {
self.store_node(pos, digest);
}
}
}
#[allow(clippy::type_complexity)]
fn collect_ancestor_batches<F: Family, D: Digest>(
parent: &Arc<MerkleizedBatch<F, D>>,
) -> (Vec<Arc<Vec<D>>>, Vec<Arc<BTreeMap<Position<F>, D>>>) {
let mut appended = Vec::new();
let mut overwrites = Vec::new();
if !parent.appended.is_empty() || !parent.overwrites.is_empty() {
appended.push(Arc::clone(&parent.appended));
overwrites.push(Arc::clone(&parent.overwrites));
}
let mut current = parent.parent.as_ref().and_then(Weak::upgrade);
while let Some(batch) = current {
if !batch.appended.is_empty() || !batch.overwrites.is_empty() {
appended.push(Arc::clone(&batch.appended));
overwrites.push(Arc::clone(&batch.overwrites));
}
current = batch.parent.as_ref().and_then(Weak::upgrade);
}
appended.reverse();
overwrites.reverse();
(appended, overwrites)
}
#[derive(Debug)]
pub struct MerkleizedBatch<F: Family, D: Digest> {
parent: Option<Weak<Self>>,
pub(crate) appended: Arc<Vec<D>>,
pub(crate) overwrites: Arc<BTreeMap<Position<F>, D>>,
root: D,
pub(crate) parent_size: Position<F>,
pub(crate) base_size: Position<F>,
pruning_boundary: Location<F>,
pub(crate) ancestor_appended: Vec<Arc<Vec<D>>>,
pub(crate) ancestor_overwrites: Vec<Arc<BTreeMap<Position<F>, D>>>,
#[cfg(feature = "std")]
pub(crate) pool: Option<ThreadPool>,
}
impl<F: Family, D: Digest> MerkleizedBatch<F, D> {
pub fn from_mem(mem: &Mem<F, D>) -> Arc<Self> {
Arc::new(Self {
parent: None,
appended: Arc::new(Vec::new()),
overwrites: Arc::new(BTreeMap::new()),
root: *mem.root(),
parent_size: mem.size(),
base_size: mem.size(),
pruning_boundary: Readable::pruning_boundary(mem),
ancestor_appended: Vec::new(),
ancestor_overwrites: Vec::new(),
#[cfg(feature = "std")]
pool: None,
})
}
pub fn size(&self) -> Position<F> {
Position::new(*self.parent_size + self.appended.len() as u64)
}
pub fn get_node(&self, pos: Position<F>) -> Option<D> {
if pos >= self.size() {
return None;
}
if let Some(d) = self.overwrites.get(&pos) {
return Some(*d);
}
if pos >= self.parent_size {
let i = (*pos - *self.parent_size) as usize;
return self.appended.get(i).copied();
}
let mut current = self.parent.as_ref().and_then(Weak::upgrade);
while let Some(batch) = current {
if let Some(d) = batch.overwrites.get(&pos) {
return Some(*d);
}
if pos >= batch.parent_size {
let i = (*pos - *batch.parent_size) as usize;
return batch.appended.get(i).copied();
}
current = batch.parent.as_ref().and_then(Weak::upgrade);
}
None
}
pub const fn root(&self) -> D {
self.root
}
pub const fn pruning_boundary(&self) -> Location<F> {
self.pruning_boundary
}
pub fn leaves(&self) -> Location<F> {
Location::try_from(self.size()).expect("invalid size")
}
pub fn new_batch(self: &Arc<Self>) -> UnmerkleizedBatch<F, D> {
let batch = UnmerkleizedBatch::new(Arc::clone(self));
#[cfg(feature = "std")]
let batch = batch.with_pool(self.pool.clone());
batch
}
pub const fn base_size(&self) -> Position<F> {
self.base_size
}
}
impl<F: Family, D: Digest> Readable for MerkleizedBatch<F, D> {
type Family = F;
type Digest = D;
type Error = Error<F>;
fn size(&self) -> Position<F> {
Self::size(self)
}
fn get_node(&self, pos: Position<F>) -> Option<D> {
Self::get_node(self, pos)
}
fn root(&self) -> D {
Self::root(self)
}
fn pruning_boundary(&self) -> Location<F> {
Self::pruning_boundary(self)
}
fn proof(
&self,
hasher: &impl Hasher<F, Digest = D>,
loc: Location<F>,
) -> Result<Proof<F, D>, Error<F>> {
if !loc.is_valid_index() {
return Err(Error::LocationOverflow(loc));
}
self.range_proof(hasher, loc..loc + 1).map_err(|e| match e {
Error::RangeOutOfBounds(_) => Error::LeafOutOfBounds(loc),
_ => e,
})
}
fn range_proof(
&self,
hasher: &impl Hasher<F, Digest = D>,
range: Range<Location<F>>,
) -> Result<Proof<F, D>, Error<F>> {
crate::merkle::proof::build_range_proof(
hasher,
self.leaves(),
range,
|pos| Self::get_node(self, pos),
Error::ElementPruned,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::merkle::{hasher::Standard, mem::Mem};
use commonware_cryptography::{sha256, Sha256};
use commonware_runtime::{deterministic, Runner as _};
type D = sha256::Digest;
type H = Standard<Sha256>;
fn build_reference<F: Family>(hasher: &H, n: u64) -> Mem<F, D> {
let mut mem = Mem::new(hasher);
let batch = {
let mut batch = mem.new_batch();
for i in 0u64..n {
let element = hasher.digest(&i.to_be_bytes());
batch = batch.add(hasher, &element);
}
batch.merkleize(&mem, hasher)
};
mem.apply_batch(&batch).unwrap();
mem
}
fn consistency_with_reference<F: Family>() {
let executor = deterministic::Runner::default();
executor.start(|_| async move {
let hasher: H = Standard::new();
for &n in &[1u64, 2, 10, 100, 199] {
let reference = build_reference::<F>(&hasher, n);
let base = Mem::<F, D>::new(&hasher);
let mut batch = base.new_batch();
for i in 0..n {
let element = hasher.digest(&i.to_be_bytes());
batch = batch.add(&hasher, &element);
}
let merkleized = batch.merkleize(&base, &hasher);
let mut result = Mem::<F, D>::new(&hasher);
result.apply_batch(&merkleized).unwrap();
assert_eq!(result.root(), reference.root(), "root mismatch for n={n}");
}
});
}
fn lifecycle<F: Family>() {
let executor = deterministic::Runner::default();
executor.start(|_| async move {
let hasher: H = Standard::new();
let base = build_reference::<F>(&hasher, 50);
let base_root = *base.root();
let mut batch = base.new_batch();
for i in 50u64..60 {
let element = hasher.digest(&i.to_be_bytes());
batch = batch.add(&hasher, &element);
}
let merkleized = batch.merkleize(&base, &hasher);
assert_ne!(merkleized.root(), base_root);
assert_eq!(*base.root(), base_root);
let mut applied = base;
applied.apply_batch(&merkleized).unwrap();
let loc = Location::<F>::new(55);
let element = hasher.digest(&55u64.to_be_bytes());
let proof = applied.proof(&hasher, loc).unwrap();
assert!(proof.verify_element_inclusion(&hasher, &element, loc, &merkleized.root()));
});
}
fn apply_batch<F: Family>() {
let executor = deterministic::Runner::default();
executor.start(|_| async move {
let hasher: H = Standard::new();
let mut base = build_reference::<F>(&hasher, 50);
let mut batch = base.new_batch();
for i in 50u64..75 {
let element = hasher.digest(&i.to_be_bytes());
batch = batch.add(&hasher, &element);
}
let merkleized = batch.merkleize(&base, &hasher);
let batch_root = merkleized.root();
base.apply_batch(&merkleized).unwrap();
assert_eq!(*base.root(), batch_root);
let reference = build_reference::<F>(&hasher, 75);
assert_eq!(base.root(), reference.root());
});
}
fn multiple_forks<F: Family>() {
let executor = deterministic::Runner::default();
executor.start(|_| async move {
let hasher: H = Standard::new();
let base = build_reference::<F>(&hasher, 50);
let base_root = *base.root();
let mut ba = base.new_batch();
for i in 50u64..60 {
let element = hasher.digest(&i.to_be_bytes());
ba = ba.add(&hasher, &element);
}
let ma = ba.merkleize(&base, &hasher);
let mut bb = base.new_batch();
for i in 100u64..105 {
let element = hasher.digest(&i.to_be_bytes());
bb = bb.add(&hasher, &element);
}
let mb = bb.merkleize(&base, &hasher);
assert_ne!(ma.root(), mb.root());
assert_ne!(ma.root(), base_root);
assert_eq!(*base.root(), base_root);
});
}
fn fork_of_fork_reads<F: Family>() {
let executor = deterministic::Runner::default();
executor.start(|_| async move {
let hasher: H = Standard::new();
let base = build_reference::<F>(&hasher, 50);
let mut ba = base.new_batch();
for i in 50u64..60 {
let element = hasher.digest(&i.to_be_bytes());
ba = ba.add(&hasher, &element);
}
let ma = ba.merkleize(&base, &hasher);
let mut bb = ma.new_batch();
for i in 60u64..70 {
let element = hasher.digest(&i.to_be_bytes());
bb = bb.add(&hasher, &element);
}
let mb = bb.merkleize(&base, &hasher);
let reference = build_reference::<F>(&hasher, 70);
assert_eq!(mb.root(), *reference.root());
let mut applied = base;
applied.apply_batch(&ma).unwrap();
applied.apply_batch(&mb).unwrap();
for i in [0u64, 25, 55, 65, 69] {
let loc = Location::<F>::new(i);
let element = hasher.digest(&i.to_be_bytes());
let proof = applied.proof(&hasher, loc).unwrap();
assert!(proof.verify_element_inclusion(&hasher, &element, loc, &mb.root()));
}
});
}
fn update_leaf_digest_roundtrip<F: Family>() {
let executor = deterministic::Runner::default();
executor.start(|_| async move {
let hasher: H = Standard::new();
let base = build_reference::<F>(&hasher, 100);
let base_root = *base.root();
let updated = Sha256::fill(0xFF);
let m = base
.new_batch()
.update_leaf_digest(Location::new(5), updated)
.unwrap()
.merkleize(&base, &hasher);
assert_ne!(m.root(), base_root);
let pos5 = Position::<F>::try_from(Location::new(5)).unwrap();
let original = base.get_node(pos5).unwrap();
let m2 = base
.new_batch()
.update_leaf_digest(Location::new(5), original)
.unwrap()
.merkleize(&base, &hasher);
assert_eq!(m2.root(), base_root);
});
}
fn update_and_add<F: Family>() {
let executor = deterministic::Runner::default();
executor.start(|_| async move {
let hasher: H = Standard::new();
let base = build_reference::<F>(&hasher, 50);
let base_root = *base.root();
let updated = Sha256::fill(0xAA);
let mut batch = base
.new_batch()
.update_leaf_digest(Location::new(10), updated)
.unwrap();
for i in 50u64..55 {
let element = hasher.digest(&i.to_be_bytes());
batch = batch.add(&hasher, &element);
}
let m = batch.merkleize(&base, &hasher);
assert_ne!(m.root(), base_root);
let pos10 = Position::<F>::try_from(Location::new(10)).unwrap();
assert_eq!(m.get_node(pos10), Some(updated));
});
}
fn update_leaf_batched_roundtrip<F: Family>() {
let executor = deterministic::Runner::default();
executor.start(|_| async move {
let hasher: H = Standard::new();
let base = build_reference::<F>(&hasher, 100);
let base_root = *base.root();
let updated = Sha256::fill(0xBB);
let locs = [0u64, 10, 50, 99];
let updates: Vec<(Location<F>, D)> =
locs.iter().map(|&i| (Location::new(i), updated)).collect();
let m = base
.new_batch()
.update_leaf_batched(&updates)
.unwrap()
.merkleize(&base, &hasher);
assert_ne!(m.root(), base_root);
let restore: Vec<(Location<F>, D)> = locs
.iter()
.map(|&l| {
let pos = Position::<F>::try_from(Location::new(l)).unwrap();
(Location::new(l), base.get_node(pos).unwrap())
})
.collect();
let m2 = base
.new_batch()
.update_leaf_batched(&restore)
.unwrap()
.merkleize(&base, &hasher);
assert_eq!(m2.root(), base_root);
});
}
fn proof_verification<F: Family>() {
let executor = deterministic::Runner::default();
executor.start(|_| async move {
let hasher: H = Standard::new();
let base = build_reference::<F>(&hasher, 50);
let mut batch = base.new_batch();
for i in 50u64..60 {
let element = hasher.digest(&i.to_be_bytes());
batch = batch.add(&hasher, &element);
}
let m = batch.merkleize(&base, &hasher);
let mut applied = base;
applied.apply_batch(&m).unwrap();
let loc = Location::<F>::new(55);
let element = hasher.digest(&55u64.to_be_bytes());
let proof = applied.proof(&hasher, loc).unwrap();
assert!(proof.verify_element_inclusion(&hasher, &element, loc, &m.root()));
let range = Location::<F>::new(50)..Location::new(55);
let rp = applied.range_proof(&hasher, range.clone()).unwrap();
let elements: Vec<D> = (50u64..55)
.map(|i| hasher.digest(&i.to_be_bytes()))
.collect();
assert!(rp.verify_range_inclusion(&hasher, &elements, range.start, &m.root()));
});
}
fn empty_batch<F: Family>() {
let executor = deterministic::Runner::default();
executor.start(|_| async move {
let hasher: H = Standard::new();
let base = build_reference::<F>(&hasher, 50);
let base_root = *base.root();
let m = base.new_batch().merkleize(&base, &hasher);
assert_eq!(m.root(), base_root);
});
}
fn batch_roundtrip<F: Family>() {
let executor = deterministic::Runner::default();
executor.start(|_| async move {
let hasher: H = Standard::new();
let base = build_reference::<F>(&hasher, 50);
let mut batch = base.new_batch();
for i in 50u64..55 {
let element = hasher.digest(&i.to_be_bytes());
batch = batch.add(&hasher, &element);
}
let merkleized = batch.merkleize(&base, &hasher);
let mut batch_again = merkleized.new_batch();
for i in 55u64..60 {
let element = hasher.digest(&i.to_be_bytes());
batch_again = batch_again.add(&hasher, &element);
}
let reference = build_reference::<F>(&hasher, 60);
assert_eq!(
batch_again.merkleize(&base, &hasher).root(),
*reference.root()
);
});
}
fn sequential_apply_batch<F: Family>() {
let executor = deterministic::Runner::default();
executor.start(|_| async move {
let hasher: H = Standard::new();
let mut base = build_reference::<F>(&hasher, 50);
let mut b1 = base.new_batch();
for i in 50u64..60 {
let element = hasher.digest(&i.to_be_bytes());
b1 = b1.add(&hasher, &element);
}
let m1 = b1.merkleize(&base, &hasher);
base.apply_batch(&m1).unwrap();
let mut b2 = base.new_batch();
for i in 60u64..70 {
let element = hasher.digest(&i.to_be_bytes());
b2 = b2.add(&hasher, &element);
}
let m2 = b2.merkleize(&base, &hasher);
base.apply_batch(&m2).unwrap();
let reference = build_reference::<F>(&hasher, 70);
assert_eq!(base.root(), reference.root());
});
}
fn batch_on_pruned_base<F: Family>() {
let executor = deterministic::Runner::default();
executor.start(|_| async move {
let hasher: H = Standard::new();
let mut base = build_reference::<F>(&hasher, 100);
base.prune(Location::new(27)).unwrap();
let mut batch = base.new_batch();
for i in 100u64..110 {
let element = hasher.digest(&i.to_be_bytes());
batch = batch.add(&hasher, &element);
}
let m = batch.merkleize(&base, &hasher);
let mut applied = base;
applied.apply_batch(&m).unwrap();
let loc = Location::<F>::new(80);
let element = hasher.digest(&80u64.to_be_bytes());
let proof = applied.proof(&hasher, loc).unwrap();
assert!(proof.verify_element_inclusion(&hasher, &element, loc, &m.root()));
assert!(matches!(
applied.proof(&hasher, Location::new(0)),
Err(Error::ElementPruned(_))
));
});
}
fn three_deep_stacking<F: Family>() {
let executor = deterministic::Runner::default();
executor.start(|_| async move {
let hasher: H = Standard::new();
let mut base = build_reference::<F>(&hasher, 100);
let da = Sha256::fill(0xDD);
let db = Sha256::fill(0xEE);
let ma = base
.new_batch()
.update_leaf_digest(Location::new(5), da)
.unwrap()
.merkleize(&base, &hasher);
let mb = ma
.new_batch()
.update_leaf_digest(Location::new(10), db)
.unwrap()
.merkleize(&base, &hasher);
let mut bc = mb.new_batch();
for i in 300u64..310 {
let element = hasher.digest(&i.to_be_bytes());
bc = bc.add(&hasher, &element);
}
let mc = bc.merkleize(&base, &hasher);
let c_root = mc.root();
base.apply_batch(&mc).unwrap();
assert_eq!(*base.root(), c_root);
});
}
fn overwrite_collision<F: Family>() {
let executor = deterministic::Runner::default();
executor.start(|_| async move {
let hasher: H = Standard::new();
let mut base = build_reference::<F>(&hasher, 100);
let dx = Sha256::fill(0xAA);
let dy = Sha256::fill(0xBB);
let ma = base
.new_batch()
.update_leaf_digest(Location::new(5), dx)
.unwrap()
.merkleize(&base, &hasher);
let mb = ma
.new_batch()
.update_leaf_digest(Location::new(5), dy)
.unwrap()
.merkleize(&base, &hasher);
let b_root = mb.root();
base.apply_batch(&mb).unwrap();
assert_eq!(*base.root(), b_root);
let pos5 = Position::<F>::try_from(Location::new(5)).unwrap();
assert_eq!(base.get_node(pos5), Some(dy));
});
}
fn update_appended_leaf<F: Family>() {
let executor = deterministic::Runner::default();
executor.start(|_| async move {
let hasher: H = Standard::new();
let base = build_reference::<F>(&hasher, 50);
let mut batch = base.new_batch();
for i in 50u64..60 {
let element = hasher.digest(&i.to_be_bytes());
batch = batch.add(&hasher, &element);
}
let updated = Sha256::fill(0xEE);
let m = batch
.update_leaf_digest(Location::new(52), updated)
.unwrap()
.merkleize(&base, &hasher);
let pos52 = Position::<F>::try_from(Location::new(52)).unwrap();
assert_eq!(m.get_node(pos52), Some(updated));
let mut reference = build_reference::<F>(&hasher, 60);
let batch = reference
.new_batch()
.update_leaf_digest(Location::new(52), updated)
.unwrap()
.merkleize(&reference, &hasher);
reference.apply_batch(&batch).unwrap();
assert_eq!(m.root(), *reference.root());
});
}
fn update_leaf_element<F: Family>() {
let executor = deterministic::Runner::default();
executor.start(|_| async move {
let hasher: H = Standard::new();
let base = build_reference::<F>(&hasher, 50);
let base_root = *base.root();
let element = b"updated-element";
let m = base
.new_batch()
.update_leaf(&hasher, Location::new(5), element)
.unwrap()
.merkleize(&base, &hasher);
assert_ne!(m.root(), base_root);
let mut base = base;
let batch = base
.new_batch()
.update_leaf(&hasher, Location::new(5), element)
.unwrap()
.merkleize(&base, &hasher);
base.apply_batch(&batch).unwrap();
assert_eq!(m.root(), *base.root());
});
}
fn update_out_of_bounds<F: Family>() {
let executor = deterministic::Runner::default();
executor.start(|_| async move {
let hasher: H = Standard::new();
let base = build_reference::<F>(&hasher, 50);
let r1 = base
.new_batch()
.update_leaf_digest(Location::new(50), Sha256::fill(0xFF));
assert!(matches!(r1, Err(Error::LeafOutOfBounds(_))));
let updates = [(Location::<F>::new(50), Sha256::fill(0xFF))];
let r2 = base.new_batch().update_leaf_batched(&updates);
assert!(matches!(r2, Err(Error::LeafOutOfBounds(_))));
});
}
#[test]
fn mmr_consistency() {
consistency_with_reference::<crate::mmr::Family>();
}
#[test]
fn mmr_lifecycle() {
lifecycle::<crate::mmr::Family>();
}
#[test]
fn mmr_apply_batch() {
apply_batch::<crate::mmr::Family>();
}
#[test]
fn mmr_multiple_forks() {
multiple_forks::<crate::mmr::Family>();
}
#[test]
fn mmr_fork_of_fork_reads() {
fork_of_fork_reads::<crate::mmr::Family>();
}
#[test]
fn mmr_update_leaf_digest() {
update_leaf_digest_roundtrip::<crate::mmr::Family>();
}
#[test]
fn mmr_update_and_add() {
update_and_add::<crate::mmr::Family>();
}
#[test]
fn mmr_update_leaf_batched() {
update_leaf_batched_roundtrip::<crate::mmr::Family>();
}
#[test]
fn mmr_proof_verification() {
proof_verification::<crate::mmr::Family>();
}
#[test]
fn mmr_empty_batch() {
empty_batch::<crate::mmr::Family>();
}
#[test]
fn mmr_batch_roundtrip() {
batch_roundtrip::<crate::mmr::Family>();
}
#[test]
fn mmr_sequential_apply_batch() {
sequential_apply_batch::<crate::mmr::Family>();
}
#[test]
fn mmr_batch_on_pruned_base() {
batch_on_pruned_base::<crate::mmr::Family>();
}
#[test]
fn mmr_three_deep_stacking() {
three_deep_stacking::<crate::mmr::Family>();
}
#[test]
fn mmr_overwrite_collision() {
overwrite_collision::<crate::mmr::Family>();
}
#[test]
fn mmr_update_appended_leaf() {
update_appended_leaf::<crate::mmr::Family>();
}
#[test]
fn mmr_update_leaf_element() {
update_leaf_element::<crate::mmr::Family>();
}
#[test]
fn mmr_update_out_of_bounds() {
update_out_of_bounds::<crate::mmr::Family>();
}
#[test]
fn mmb_consistency() {
consistency_with_reference::<crate::mmb::Family>();
}
#[test]
fn mmb_lifecycle() {
lifecycle::<crate::mmb::Family>();
}
#[test]
fn mmb_apply_batch() {
apply_batch::<crate::mmb::Family>();
}
#[test]
fn mmb_multiple_forks() {
multiple_forks::<crate::mmb::Family>();
}
#[test]
fn mmb_fork_of_fork_reads() {
fork_of_fork_reads::<crate::mmb::Family>();
}
#[test]
fn mmb_update_leaf_digest() {
update_leaf_digest_roundtrip::<crate::mmb::Family>();
}
#[test]
fn mmb_update_and_add() {
update_and_add::<crate::mmb::Family>();
}
#[test]
fn mmb_update_leaf_batched() {
update_leaf_batched_roundtrip::<crate::mmb::Family>();
}
#[test]
fn mmb_proof_verification() {
proof_verification::<crate::mmb::Family>();
}
#[test]
fn mmb_empty_batch() {
empty_batch::<crate::mmb::Family>();
}
#[test]
fn mmb_batch_roundtrip() {
batch_roundtrip::<crate::mmb::Family>();
}
#[test]
fn mmb_sequential_apply_batch() {
sequential_apply_batch::<crate::mmb::Family>();
}
#[test]
fn mmb_batch_on_pruned_base() {
batch_on_pruned_base::<crate::mmb::Family>();
}
#[test]
fn mmb_three_deep_stacking() {
three_deep_stacking::<crate::mmb::Family>();
}
#[test]
fn mmb_overwrite_collision() {
overwrite_collision::<crate::mmb::Family>();
}
#[test]
fn mmb_update_appended_leaf() {
update_appended_leaf::<crate::mmb::Family>();
}
#[test]
fn mmb_update_leaf_element() {
update_leaf_element::<crate::mmb::Family>();
}
#[test]
fn mmb_update_out_of_bounds() {
update_out_of_bounds::<crate::mmb::Family>();
}
}