#[cfg(test)]
mod tests {
use crate::{
merkle::{
hasher::Standard, mmb::mem::Mmb, proof::Blueprint, Bagging, Bagging::ForwardFold,
Family,
},
mmb::Location,
};
use commonware_cryptography::Sha256;
type D = <Sha256 as commonware_cryptography::Hasher>::Digest;
type H = Standard<Sha256>;
fn make_mmb(n: u64) -> (H, Mmb<D>) {
let hasher = H::new(ForwardFold);
let mut mmb = Mmb::new();
let batch = {
let mut batch = mmb.new_batch();
for i in 0..n {
batch = batch.add(&hasher, &i.to_be_bytes());
}
batch.merkleize(&mmb, &hasher)
};
mmb.apply_batch(&batch).unwrap();
(hasher, mmb)
}
#[test]
fn test_verify_proof_and_pinned_nodes_recursive_fold_prefix_regression() {
let (hasher, mmb) = make_mmb(5);
let root = mmb.root(&hasher, 0).unwrap();
let start = 4;
let pinned: Vec<D> = crate::merkle::mmb::Family::nodes_to_pin(Location::new(start))
.map(|pos| mmb.get_node(pos).unwrap())
.collect();
let proof = mmb
.range_proof(&hasher, Location::new(start)..Location::new(start + 1), 0)
.unwrap();
assert!(proof.verify_proof_and_pinned_nodes(
&hasher,
&[start.to_be_bytes()],
Location::new(start),
&pinned,
&root
));
}
#[test]
fn test_last_element_proof_size_is_two() {
let hasher = H::new(ForwardFold);
let (_, mut mmb) = make_mmb(1000);
let mut n = 1000u64;
while n <= 5000 {
let leaves = mmb.leaves();
let root = mmb.root(&hasher, 0).unwrap();
let loc = n - 1;
let inactive_peaks =
crate::merkle::mmb::Family::inactive_peaks(mmb.size(), Location::new(0));
let bp = Blueprint::new(
leaves,
inactive_peaks,
Bagging::ForwardFold,
Location::new(loc)..Location::new(n),
)
.unwrap();
let total_digests = usize::from(!bp.fold_prefix.is_empty()) + bp.fetch_nodes.len();
assert!(
total_digests <= 2,
"n={n}: expected <= 2 digests, got {total_digests} \
(fold_prefix={}, fetch_nodes={})",
bp.fold_prefix.len(),
bp.fetch_nodes.len(),
);
let proof = mmb.proof(&hasher, Location::new(loc), 0).unwrap();
assert!(
proof.verify_element_inclusion(
&hasher,
&loc.to_be_bytes(),
Location::new(loc),
&root
),
"n={n}: verification failed"
);
let batch = {
let mut batch = mmb.new_batch();
for i in n..n + 100 {
batch = batch.add(&hasher, &i.to_be_bytes());
}
batch.merkleize(&mmb, &hasher)
};
mmb.apply_batch(&batch).unwrap();
n += 100;
}
}
}