use crate::merkle::{hasher::Hasher, Error, Family, Location, Position};
use alloc::{
collections::{BTreeMap, BTreeSet},
vec,
vec::Vec,
};
use bytes::{Buf, BufMut};
use commonware_codec::{EncodeSize, ReadExt, ReadRangeExt, Write};
use commonware_cryptography::Digest;
use core::ops::Range;
#[derive(thiserror::Error, Debug)]
pub enum ReconstructionError {
#[error("missing digests in proof")]
MissingDigests,
#[error("extra digests in proof")]
ExtraDigests,
#[error("start location is out of bounds")]
InvalidStartLoc,
#[error("end location is out of bounds")]
InvalidEndLoc,
#[error("missing elements")]
MissingElements,
#[error("invalid size")]
InvalidSize,
}
#[derive(Clone, Debug, Eq)]
pub struct Proof<F: Family, D: Digest> {
pub leaves: Location<F>,
pub digests: Vec<D>,
}
impl<F: Family, D: Digest> PartialEq for Proof<F, D> {
fn eq(&self, other: &Self) -> bool {
self.leaves == other.leaves && self.digests == other.digests
}
}
impl<F: Family, D: Digest> EncodeSize for Proof<F, D> {
fn encode_size(&self) -> usize {
self.leaves.encode_size() + self.digests.encode_size()
}
}
impl<F: Family, D: Digest> Write for Proof<F, D> {
fn write(&self, buf: &mut impl BufMut) {
self.leaves.write(buf);
self.digests.write(buf);
}
}
impl<F: Family, D: Digest> commonware_codec::Read for Proof<F, D> {
type Cfg = usize;
fn read_cfg(
buf: &mut impl Buf,
max_digests: &Self::Cfg,
) -> Result<Self, commonware_codec::Error> {
let leaves = Location::<F>::read(buf)?;
let digests = Vec::<D>::read_range(buf, ..=*max_digests)?;
Ok(Self { leaves, digests })
}
}
impl<F: Family, D: Digest> Default for Proof<F, D> {
fn default() -> Self {
Self {
leaves: Location::new(0),
digests: vec![],
}
}
}
impl<F: Family, D: Digest> Proof<F, D> {
pub fn verify_element_inclusion<H>(
&self,
hasher: &H,
element: &[u8],
loc: Location<F>,
root: &D,
) -> bool
where
H: Hasher<F, Digest = D>,
{
self.verify_range_inclusion(hasher, &[element], loc, root)
}
pub fn verify_range_inclusion<H, E>(
&self,
hasher: &H,
elements: &[E],
start_loc: Location<F>,
root: &D,
) -> bool
where
H: Hasher<F, Digest = D>,
E: AsRef<[u8]>,
{
match self.reconstruct_root(hasher, elements, start_loc) {
Ok(reconstructed_root) => *root == reconstructed_root,
Err(_error) => {
#[cfg(feature = "std")]
tracing::debug!(error = ?_error, "invalid proof input");
false
}
}
}
pub fn verify_multi_inclusion<H, E>(
&self,
hasher: &H,
elements: &[(E, Location<F>)],
root: &D,
) -> bool
where
H: Hasher<F, Digest = D>,
E: AsRef<[u8]>,
{
if elements.is_empty() {
return self.digests.is_empty()
&& self.leaves == Location::new(0)
&& *root == hasher.root(Location::new(0), core::iter::empty());
}
let mut node_positions = BTreeSet::new();
let mut blueprints = BTreeMap::new();
for (_, loc) in elements {
if !loc.is_valid_index() {
return false;
}
let Ok(bp) = Blueprint::new(self.leaves, *loc..*loc + 1) else {
return false;
};
node_positions.extend(&bp.fold_prefix);
node_positions.extend(&bp.fetch_nodes);
blueprints.insert(*loc, bp);
}
if node_positions.len() != self.digests.len() {
return false;
}
let node_digests: BTreeMap<Position<F>, D> = node_positions
.iter()
.zip(self.digests.iter())
.map(|(&pos, digest)| (pos, *digest))
.collect();
for (element, loc) in elements {
let bp = &blueprints[loc];
let mut digests = Vec::with_capacity(
if bp.fold_prefix.is_empty() { 0 } else { 1 } + bp.fetch_nodes.len(),
);
if let Some((&first_pos, rest)) = bp.fold_prefix.split_first() {
let first = *node_digests
.get(&first_pos)
.expect("must exist by construction");
let acc = rest.iter().fold(first, |acc, &pos| {
let d = node_digests.get(&pos).expect("must exist by construction");
hasher.fold(&acc, d)
});
digests.push(acc);
}
for &pos in &bp.fetch_nodes {
let d = node_digests.get(&pos).expect("must exist by construction");
digests.push(*d);
}
let proof = Self {
leaves: self.leaves,
digests,
};
if !proof.verify_element_inclusion(hasher, element.as_ref(), *loc, root) {
return false;
}
}
true
}
pub fn reconstruct_root<H, E>(
&self,
hasher: &H,
elements: &[E],
start_loc: Location<F>,
) -> Result<D, ReconstructionError>
where
H: Hasher<F, Digest = D>,
E: AsRef<[u8]>,
{
self.reconstruct_root_collecting(hasher, elements, start_loc, None)
}
pub fn verify_range_inclusion_and_extract_digests<H, E>(
&self,
hasher: &H,
elements: &[E],
start_loc: Location<F>,
root: &D,
) -> Result<Vec<(Position<F>, D)>, Error<F>>
where
H: Hasher<F, Digest = D>,
E: AsRef<[u8]>,
{
let mut collected_digests = Vec::new();
let Ok(reconstructed_root) = self.reconstruct_root_collecting(
hasher,
elements,
start_loc,
Some(&mut collected_digests),
) else {
return Err(Error::InvalidProof);
};
if reconstructed_root != *root {
return Err(Error::RootMismatch);
}
Ok(collected_digests)
}
pub fn verify_proof_and_pinned_nodes<H, E>(
&self,
hasher: &H,
elements: &[E],
start_loc: Location<F>,
pinned_nodes: &[D],
root: &D,
) -> bool
where
H: Hasher<F, Digest = D>,
E: AsRef<[u8]>,
{
let collected = match self
.verify_range_inclusion_and_extract_digests(hasher, elements, start_loc, root)
{
Ok(c) => c,
Err(_) => return false,
};
if elements.is_empty() {
return pinned_nodes.is_empty();
}
if !start_loc.is_valid() || start_loc > self.leaves {
return false;
}
let pinned_positions: alloc::vec::Vec<_> = F::nodes_to_pin(start_loc).collect();
if pinned_positions.len() != pinned_nodes.len() {
return false;
}
let Ok(fold_prefix) = Blueprint::fold_prefix(self.leaves, start_loc) else {
return false;
};
let mut pinned_map: alloc::collections::BTreeMap<Position<F>, D> = pinned_positions
.into_iter()
.zip(pinned_nodes.iter().copied())
.collect();
if !fold_prefix.is_empty() {
if self.digests.is_empty() {
return false;
}
let Some(first) = pinned_map.remove(&fold_prefix[0]) else {
return false;
};
let mut acc = first;
for pos in &fold_prefix[1..] {
let Some(digest) = pinned_map.remove(pos) else {
return false;
};
acc = hasher.fold(&acc, &digest);
}
if acc != self.digests[0] {
return false;
}
}
let extracted: alloc::collections::BTreeMap<Position<F>, D> =
collected.into_iter().collect();
for (pos, digest) in pinned_map {
if extracted.get(&pos) != Some(&digest) {
return false;
}
}
true
}
pub(crate) fn reconstruct_root_collecting<H, E>(
&self,
hasher: &H,
elements: &[E],
start_loc: Location<F>,
mut collected: Option<&mut Vec<(Position<F>, D)>>,
) -> Result<D, ReconstructionError>
where
H: Hasher<F, Digest = D>,
E: AsRef<[u8]>,
{
if elements.is_empty() {
if start_loc == 0 {
return if self.digests.is_empty() {
Ok(hasher.digest(&self.leaves.to_be_bytes()))
} else {
Err(ReconstructionError::ExtraDigests)
};
}
return Err(ReconstructionError::MissingElements);
}
if !start_loc.is_valid_index() {
return Err(ReconstructionError::InvalidStartLoc);
}
let end_loc = start_loc
.checked_add(elements.len() as u64)
.ok_or(ReconstructionError::InvalidEndLoc)?;
if end_loc > self.leaves {
return Err(ReconstructionError::InvalidEndLoc);
}
let range = start_loc..end_loc;
let bp =
Blueprint::new(self.leaves, range).map_err(|_| ReconstructionError::InvalidSize)?;
let prefix_digests = usize::from(!bp.fold_prefix.is_empty());
let expected_min = prefix_digests + bp.fetch_nodes.len();
if self.digests.len() < expected_min {
return Err(ReconstructionError::MissingDigests);
}
let after_start = prefix_digests;
let after_peaks_count = bp.after_peaks.len();
let after_end = after_start + after_peaks_count;
let siblings = &self.digests[after_end..];
let mut peak_digests = Vec::new();
if !bp.fold_prefix.is_empty() {
peak_digests.push(self.digests[0]);
}
let mut sibling_cursor = 0usize;
let mut elements_iter = elements.iter();
for &peak in &bp.range_peaks {
let peak_digest = reconstruct_peak_from_range(
hasher,
peak,
&bp.range,
&mut elements_iter,
siblings,
&mut sibling_cursor,
collected.as_deref_mut(),
)?;
if let Some(ref mut cd) = collected {
cd.push((peak.pos, peak_digest));
}
peak_digests.push(peak_digest);
}
for (i, &after_peak_pos) in bp.after_peaks.iter().enumerate() {
let digest = self.digests[after_start + i];
if let Some(ref mut cd) = collected {
cd.push((after_peak_pos, digest));
}
peak_digests.push(digest);
}
if elements_iter.next().is_some() {
return Err(ReconstructionError::ExtraDigests);
}
if sibling_cursor != siblings.len() {
return Err(ReconstructionError::ExtraDigests);
}
Ok(hasher.root(self.leaves, peak_digests.iter()))
}
}
#[derive(Copy, Clone)]
pub(crate) struct Subtree<F: Family> {
pub pos: Position<F>,
pub height: u32,
pub leaf_start: Location<F>,
}
impl<F: Family> Subtree<F> {
fn leaf_end(&self) -> Location<F> {
self.leaf_start + (1u64 << self.height)
}
fn children(&self) -> (Self, Self) {
let (left_pos, right_pos) = F::children(self.pos, self.height);
let child_height = self.height - 1;
let mid = self.leaf_start + (1u64 << child_height);
(
Self {
pos: left_pos,
height: child_height,
leaf_start: self.leaf_start,
},
Self {
pos: right_pos,
height: child_height,
leaf_start: mid,
},
)
}
}
pub(crate) struct Blueprint<F: Family> {
leaves: Location<F>,
pub range: Range<Location<F>>,
pub fold_prefix: Vec<Position<F>>,
pub after_peaks: Vec<Position<F>>,
pub range_peaks: Vec<Subtree<F>>,
pub fetch_nodes: Vec<Position<F>>,
}
impl<F: Family> Blueprint<F> {
pub(crate) fn fold_prefix(
leaves: Location<F>,
start_loc: Location<F>,
) -> Result<Vec<Position<F>>, super::Error<F>> {
let size = Position::<F>::try_from(leaves)?;
let mut fold_prefix = Vec::new();
let mut leaf_cursor = Location::new(0);
for (peak_pos, height) in F::peaks(size) {
let leaf_end = leaf_cursor + (1u64 << height);
if leaf_end <= start_loc {
fold_prefix.push(peak_pos);
} else {
break;
}
leaf_cursor = leaf_end;
}
Ok(fold_prefix)
}
pub(crate) fn new(
leaves: Location<F>,
range: Range<Location<F>>,
) -> Result<Self, super::Error<F>> {
if range.is_empty() {
return Err(super::Error::Empty);
}
let end_minus_one = range
.end
.checked_sub(1)
.expect("can't underflow because range is non-empty");
if end_minus_one >= leaves {
return Err(super::Error::RangeOutOfBounds(range.end));
}
let size = Position::try_from(leaves)?;
let mut fold_prefix = Vec::new();
let mut after_peaks = Vec::new();
let mut range_peaks = Vec::new();
let mut leaf_cursor = Location::new(0);
for (peak_pos, height) in F::peaks(size) {
let leaf_start = leaf_cursor;
let leaf_end = leaf_start + (1u64 << height);
if leaf_end <= range.start {
fold_prefix.push(peak_pos);
} else if leaf_start >= range.end {
after_peaks.push(peak_pos);
} else {
range_peaks.push(Subtree {
pos: peak_pos,
height,
leaf_start,
});
}
leaf_cursor = leaf_end;
}
assert!(
!range_peaks.is_empty(),
"at least one peak must contain range elements"
);
let mut fetch_nodes = after_peaks.clone();
for &peak in &range_peaks {
collect_siblings_dfs(peak, &range, &mut fetch_nodes);
}
Ok(Self {
leaves,
range,
fold_prefix,
after_peaks,
range_peaks,
fetch_nodes,
})
}
pub(crate) fn build_proof<D, H, E>(
self,
hasher: &H,
get_node: impl Fn(Position<F>) -> Option<D>,
element_pruned: impl Fn(Position<F>) -> E,
) -> Result<Proof<F, D>, E>
where
D: Digest,
H: Hasher<F, Digest = D>,
{
let mut digests = Vec::with_capacity(
if self.fold_prefix.is_empty() { 0 } else { 1 } + self.fetch_nodes.len(),
);
if let Some((&first_pos, rest)) = self.fold_prefix.split_first() {
let first = get_node(first_pos).ok_or_else(|| element_pruned(first_pos))?;
let acc = rest.iter().try_fold(first, |acc, &pos| {
let d = get_node(pos).ok_or_else(|| element_pruned(pos))?;
Ok(hasher.fold(&acc, &d))
})?;
digests.push(acc);
}
for &pos in &self.fetch_nodes {
digests.push(get_node(pos).ok_or_else(|| element_pruned(pos))?);
}
Ok(Proof {
leaves: self.leaves,
digests,
})
}
}
pub const MAX_PROOF_DIGESTS_PER_ELEMENT: usize = 122;
pub(crate) fn build_range_proof<F, D, H, E>(
hasher: &H,
leaves: Location<F>,
range: Range<Location<F>>,
get_node: impl Fn(Position<F>) -> Option<D>,
element_pruned: impl Fn(Position<F>) -> E,
) -> Result<Proof<F, D>, E>
where
F: Family,
D: Digest,
H: Hasher<F, Digest = D>,
E: From<super::Error<F>>,
{
Blueprint::new(leaves, range)?.build_proof(hasher, get_node, element_pruned)
}
#[cfg(any(feature = "std", test))]
pub(crate) fn nodes_required_for_multi_proof<F: Family>(
leaves: Location<F>,
locations: &[Location<F>],
) -> Result<BTreeSet<Position<F>>, super::Error<F>> {
if locations.is_empty() {
return Err(super::Error::Empty);
}
locations.iter().try_fold(BTreeSet::new(), |mut acc, loc| {
if !loc.is_valid_index() {
return Err(super::Error::LocationOverflow(*loc));
}
let bp = Blueprint::new(leaves, *loc..*loc + 1)?;
acc.extend(bp.fold_prefix);
acc.extend(bp.fetch_nodes);
Ok(acc)
})
}
pub(crate) fn collect_siblings_dfs<F: Family>(
node: Subtree<F>,
range: &Range<Location<F>>,
out: &mut Vec<Position<F>>,
) {
if node.leaf_end() <= range.start || node.leaf_start >= range.end {
out.push(node.pos);
return;
}
if node.height > 0 {
let (left, right) = node.children();
collect_siblings_dfs::<F>(left, range, out);
collect_siblings_dfs::<F>(right, range, out);
}
}
pub(crate) fn reconstruct_peak_from_range<F, D, H, E>(
hasher: &H,
node: Subtree<F>,
range: &Range<Location<F>>,
elements: &mut E,
siblings: &[D],
cursor: &mut usize,
mut collected: Option<&mut Vec<(Position<F>, D)>>,
) -> Result<D, ReconstructionError>
where
F: Family,
D: Digest,
H: Hasher<F, Digest = D>,
E: Iterator<Item: AsRef<[u8]>>,
{
if node.leaf_end() <= range.start || node.leaf_start >= range.end {
let Some(digest) = siblings.get(*cursor).copied() else {
return Err(ReconstructionError::MissingDigests);
};
*cursor += 1;
return Ok(digest);
}
if node.height == 0 {
let elem = elements
.next()
.ok_or(ReconstructionError::MissingElements)?;
return Ok(hasher.leaf_digest(node.pos, elem.as_ref()));
}
let (left, right) = node.children();
let left_pos = left.pos;
let right_pos = right.pos;
let left_d = reconstruct_peak_from_range::<F, D, H, E>(
hasher,
left,
range,
elements,
siblings,
cursor,
collected.as_deref_mut(),
)?;
let right_d = reconstruct_peak_from_range::<F, D, H, E>(
hasher,
right,
range,
elements,
siblings,
cursor,
collected.as_deref_mut(),
)?;
if let Some(ref mut cd) = collected {
cd.push((left_pos, left_d));
cd.push((right_pos, right_d));
}
Ok(hasher.node_digest(node.pos, &left_d, &right_d))
}
#[cfg(feature = "arbitrary")]
impl<F: Family, D: Digest> arbitrary::Arbitrary<'_> for Proof<F, D>
where
D: for<'a> arbitrary::Arbitrary<'a>,
{
fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
Ok(Self {
leaves: u.arbitrary()?,
digests: u.arbitrary()?,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::merkle::{
hasher::Standard,
mem::Mem,
mmb, mmr,
proof::{nodes_required_for_multi_proof, Blueprint, Proof},
Family, Location, LocationRangeExt as _,
};
use alloc::vec;
use commonware_codec::{Decode, Encode, EncodeSize};
use commonware_cryptography::{sha256, Sha256};
use commonware_macros::test_traced;
type D = sha256::Digest;
type H = Standard<Sha256>;
fn test_digest(v: u8) -> D {
<Sha256 as commonware_cryptography::Hasher>::hash(&[v])
}
fn build_raw<F: Family>(hasher: &H, n: u64) -> Mem<F, D> {
let mut mem = Mem::new(hasher);
let batch = {
let mut batch = mem.new_batch();
for i in 0..n {
batch = batch.add(hasher, &i.to_be_bytes());
}
batch.merkleize(&mem, hasher)
};
mem.apply_batch(&batch).unwrap();
mem
}
fn empty_proof<F: Family>() {
let hasher = H::new();
let mem = Mem::<F, D>::new(&hasher);
let root = mem.root();
let proof: Proof<F, D> = Proof::default();
assert!(proof.verify_range_inclusion(&hasher, &[] as &[D], Location::new(0), root));
assert!(!proof.verify_range_inclusion(&hasher, &[] as &[D], Location::new(1), root));
let td = test_digest(0);
assert!(!proof.verify_range_inclusion(&hasher, &[] as &[D], Location::new(0), &td));
assert!(!proof.verify_range_inclusion(&hasher, &[td], Location::new(0), root));
}
fn verify_element<F: Family>() {
let element = D::from(*b"01234567012345670123456701234567");
let hasher = H::new();
let mut mem = Mem::<F, D>::new(&hasher);
let batch = {
let mut batch = mem.new_batch();
for _ in 0..11 {
batch = batch.add(&hasher, &element);
}
batch.merkleize(&mem, &hasher)
};
mem.apply_batch(&batch).unwrap();
let root = mem.root();
for leaf in 0u64..11 {
let leaf = Location::new(leaf);
let proof: Proof<F, D> = mem.proof(&hasher, leaf).unwrap();
assert!(
proof.verify_element_inclusion(&hasher, &element, leaf, root),
"valid proof should verify successfully"
);
}
let leaf = Location::<F>::new(10);
let proof = mem.proof(&hasher, leaf).unwrap();
assert!(
proof.verify_element_inclusion(&hasher, &element, leaf, root),
"proof verification should be successful"
);
assert!(
!proof.verify_element_inclusion(&hasher, &element, leaf + 1, root),
"proof verification should fail with incorrect element position"
);
assert!(
!proof.verify_element_inclusion(&hasher, &element, leaf - 1, root),
"proof verification should fail with incorrect element position 2"
);
assert!(
!proof.verify_element_inclusion(&hasher, &test_digest(0), leaf, root),
"proof verification should fail with mangled element"
);
let root2 = test_digest(0);
assert!(
!proof.verify_element_inclusion(&hasher, &element, leaf, &root2),
"proof verification should fail with mangled root"
);
let mut proof2 = proof.clone();
proof2.digests[0] = test_digest(0);
assert!(
!proof2.verify_element_inclusion(&hasher, &element, leaf, root),
"proof verification should fail with mangled proof hash"
);
proof2 = proof.clone();
proof2.leaves = Location::new(10);
assert!(
!proof2.verify_element_inclusion(&hasher, &element, leaf, root),
"proof verification should fail with incorrect leaves"
);
proof2 = proof.clone();
proof2.digests.push(test_digest(0));
assert!(
!proof2.verify_element_inclusion(&hasher, &element, leaf, root),
"proof verification should fail with extra hash"
);
proof2 = proof.clone();
while !proof2.digests.is_empty() {
proof2.digests.pop();
assert!(
!proof2.verify_element_inclusion(&hasher, &element, leaf, root),
"proof verification should fail with missing digests"
);
}
if proof.digests.len() >= 2 {
proof2 = proof.clone();
proof2.digests.clear();
proof2.digests.extend(proof.digests[0..1].iter().cloned());
proof2.digests.push(test_digest(0));
proof2.digests.extend(proof.digests[1..].iter().cloned());
assert!(
!proof2.verify_element_inclusion(&hasher, &element, leaf, root),
"proof verification should fail with extra hash even if it's unused by the computation"
);
}
}
fn verify_range<F: Family>() {
let hasher = H::new();
let mut mem = Mem::<F, D>::new(&hasher);
let elements: Vec<_> = (0..49).map(test_digest).collect();
let batch = {
let mut batch = mem.new_batch();
for element in &elements {
batch = batch.add(&hasher, element);
}
batch.merkleize(&mem, &hasher)
};
mem.apply_batch(&batch).unwrap();
let root = mem.root();
for i in 0..elements.len() {
for j in i + 1..elements.len() {
let range = Location::new(i as u64)..Location::new(j as u64);
let range_proof = mem.range_proof(&hasher, range.clone()).unwrap();
assert!(
range_proof.verify_range_inclusion(
&hasher,
&elements[range.to_usize_range()],
range.start,
root,
),
"valid range proof should verify successfully {i}:{j}",
);
}
}
let range = Location::new(33)..Location::new(40);
let range_proof = mem.range_proof(&hasher, range.clone()).unwrap();
let valid_elements = &elements[range.to_usize_range()];
assert!(
range_proof.verify_range_inclusion(&hasher, valid_elements, range.start, root),
"valid range proof should verify successfully"
);
let mut invalid_proof = range_proof.clone();
for _i in 0..range_proof.digests.len() {
invalid_proof.digests.remove(0);
assert!(
!invalid_proof.verify_range_inclusion(&hasher, valid_elements, range.start, root,),
"range proof with removed elements should fail"
);
}
for i in 0..elements.len() {
for j in i + 1..elements.len() {
if Location::<F>::from(i) == range.start && Location::<F>::from(j) == range.end {
continue;
}
assert!(
!range_proof.verify_range_inclusion(
&hasher,
&elements[i..j],
range.start,
root,
),
"range proof with invalid element range should fail {i}:{j}",
);
}
}
let invalid_root = test_digest(1);
assert!(
!range_proof.verify_range_inclusion(
&hasher,
valid_elements,
range.start,
&invalid_root,
),
"range proof with invalid root should fail"
);
for i in 0..range_proof.digests.len() {
let mut invalid_proof = range_proof.clone();
invalid_proof.digests[i] = test_digest(0);
assert!(
!invalid_proof.verify_range_inclusion(&hasher, valid_elements, range.start, root,),
"mangled range proof should fail verification"
);
}
for i in 0..range_proof.digests.len() {
let mut invalid_proof = range_proof.clone();
invalid_proof.digests.insert(i, test_digest(0));
assert!(
!invalid_proof.verify_range_inclusion(&hasher, valid_elements, range.start, root,),
"mangled range proof should fail verification. inserted element at: {i}",
);
}
for loc in 0..elements.len() {
let loc = Location::new(loc as u64);
if loc == range.start {
continue;
}
assert!(
!range_proof.verify_range_inclusion(&hasher, valid_elements, loc, root),
"bad start_loc should fail verification {loc}",
);
}
}
fn retained_nodes_provable_after_pruning<F: Family>() {
let hasher = H::new();
let mut mem = Mem::<F, D>::new(&hasher);
let elements: Vec<_> = (0..49).map(test_digest).collect();
let batch = {
let mut batch = mem.new_batch();
for element in &elements {
batch = batch.add(&hasher, element);
}
batch.merkleize(&mem, &hasher)
};
mem.apply_batch(&batch).unwrap();
let root = *mem.root();
for prune_leaf in 1..*mem.leaves() {
let prune_loc = Location::new(prune_leaf);
mem.prune(prune_loc).unwrap();
let pruned_root = mem.root();
assert_eq!(root, *pruned_root);
for loc in 0..elements.len() {
let loc = Location::new(loc as u64);
let proof = mem.proof(&hasher, loc);
if loc < prune_loc {
continue;
}
assert!(proof.is_ok());
assert!(proof.unwrap().verify_element_inclusion(
&hasher,
&elements[*loc as usize],
loc,
&root
));
}
}
}
fn ranges_provable_after_pruning<F: Family>() {
let hasher = H::new();
let mut mem = Mem::<F, D>::new(&hasher);
let mut elements: Vec<_> = (0..49).map(test_digest).collect();
let batch = {
let mut batch = mem.new_batch();
for element in &elements {
batch = batch.add(&hasher, element);
}
batch.merkleize(&mem, &hasher)
};
mem.apply_batch(&batch).unwrap();
let prune_loc = Location::<F>::new(32);
mem.prune(prune_loc).unwrap();
assert_eq!(mem.bounds().start, prune_loc);
let root = mem.root();
for i in 0..elements.len() - 1 {
if Location::<F>::new(i as u64) < prune_loc {
continue;
}
for j in (i + 2)..elements.len() {
let range = Location::new(i as u64)..Location::new(j as u64);
let range_proof = mem.range_proof(&hasher, range.clone()).unwrap();
assert!(
range_proof.verify_range_inclusion(
&hasher,
&elements[range.to_usize_range()],
range.start,
root,
),
"valid range proof over remaining elements should verify successfully",
);
}
}
let new_elements: Vec<_> = (0..37).map(test_digest).collect();
let batch = {
let mut batch = mem.new_batch();
for element in &new_elements {
batch = batch.add(&hasher, element);
}
batch.merkleize(&mem, &hasher)
};
mem.apply_batch(&batch).unwrap();
elements.extend(new_elements);
mem.prune(Location::new(66)).unwrap();
assert_eq!(mem.bounds().start, Location::new(66));
let updated_root = mem.root();
let range = Location::new(elements.len() as u64 - 10)..Location::new(elements.len() as u64);
let range_proof = mem.range_proof(&hasher, range.clone()).unwrap();
assert!(
range_proof.verify_range_inclusion(
&hasher,
&elements[range.to_usize_range()],
range.start,
updated_root,
),
"valid range proof over remaining elements after 2 pruning rounds should verify",
);
}
fn proof_serialization<F: Family>() {
let hasher = H::new();
let mut mem = Mem::<F, D>::new(&hasher);
let elements: Vec<_> = (0..25).map(test_digest).collect();
let batch = {
let mut batch = mem.new_batch();
for element in &elements {
batch = batch.add(&hasher, element);
}
batch.merkleize(&mem, &hasher)
};
mem.apply_batch(&batch).unwrap();
for i in 0..elements.len() {
for j in i + 1..elements.len() {
let range = Location::new(i as u64)..Location::new(j as u64);
let proof = mem.range_proof(&hasher, range).unwrap();
let expected_size = proof.encode_size();
let serialized_proof = proof.encode();
assert_eq!(
serialized_proof.len(),
expected_size,
"serialized proof should have expected size"
);
let max_digests = proof.digests.len();
let deserialized_proof =
Proof::<F, D>::decode_cfg(serialized_proof, &max_digests).unwrap();
assert_eq!(
proof, deserialized_proof,
"deserialized proof should match source proof"
);
let serialized_proof = proof.encode();
let serialized_proof = serialized_proof.slice(0..serialized_proof.len() - 1);
assert!(
Proof::<F, D>::decode_cfg(serialized_proof, &max_digests).is_err(),
"proof should not deserialize with truncated data"
);
let mut serialized_proof = proof.encode_mut();
serialized_proof.extend_from_slice(&[0; 10]);
let serialized_proof = serialized_proof;
assert!(
Proof::<F, D>::decode_cfg(serialized_proof, &max_digests).is_err(),
"proof should not deserialize with extra data"
);
let actual_digests = proof.digests.len();
if actual_digests > 0 {
let too_small = actual_digests - 1;
let serialized_proof = proof.encode();
assert!(
Proof::<F, D>::decode_cfg(serialized_proof, &too_small).is_err(),
"proof should not deserialize with max_digests too small"
);
}
}
}
}
fn multi_proof_generation_and_verify<F: Family>() {
let hasher = H::new();
let mut mem = Mem::<F, D>::new(&hasher);
let elements: Vec<_> = (0..20).map(test_digest).collect();
let batch = {
let mut batch = mem.new_batch();
for element in &elements {
batch = batch.add(&hasher, element);
}
batch.merkleize(&mem, &hasher)
};
mem.apply_batch(&batch).unwrap();
let root = mem.root();
let locations = &[Location::new(0), Location::new(5), Location::new(10)];
let nodes_for_multi_proof =
nodes_required_for_multi_proof(mem.leaves(), locations).expect("test locations valid");
let digests = nodes_for_multi_proof
.into_iter()
.map(|pos| mem.get_node(pos).unwrap())
.collect();
let multi_proof = Proof {
leaves: mem.leaves(),
digests,
};
assert_eq!(multi_proof.leaves, mem.leaves());
assert!(multi_proof.verify_multi_inclusion(
&hasher,
&[
(elements[0], Location::new(0)),
(elements[5], Location::new(5)),
(elements[10], Location::new(10)),
],
root
));
assert!(multi_proof.verify_multi_inclusion(
&hasher,
&[
(elements[10], Location::new(10)),
(elements[5], Location::new(5)),
(elements[0], Location::new(0)),
],
root
));
assert!(multi_proof.verify_multi_inclusion(
&hasher,
&[
(elements[0], Location::new(0)),
(elements[0], Location::new(0)),
(elements[10], Location::new(10)),
(elements[5], Location::new(5)),
],
root
));
let mut wrong_size_proof = multi_proof.clone();
wrong_size_proof.leaves = Location::new(*F::MAX_LEAVES + 2);
assert!(!wrong_size_proof.verify_multi_inclusion(
&hasher,
&[
(elements[0], Location::new(0)),
(elements[5], Location::new(5)),
(elements[10], Location::new(10)),
],
root,
));
assert!(!multi_proof.verify_multi_inclusion(
&hasher,
&[
(elements[0], Location::new(1)),
(elements[5], Location::new(6)),
(elements[10], Location::new(11)),
],
root,
));
let wrong_elements = [
vec![255u8, 254u8, 253u8],
vec![252u8, 251u8, 250u8],
vec![249u8, 248u8, 247u8],
];
let wrong_verification = multi_proof.verify_multi_inclusion(
&hasher,
&[
(wrong_elements[0].as_slice(), Location::new(0)),
(wrong_elements[1].as_slice(), Location::new(5)),
(wrong_elements[2].as_slice(), Location::new(10)),
],
root,
);
assert!(!wrong_verification, "Should fail with wrong elements");
let wrong_verification = multi_proof.verify_multi_inclusion(
&hasher,
&[
(elements[0], Location::new(0)),
(elements[5], Location::new(5)),
(elements[10], Location::new(1000)),
],
root,
);
assert!(
!wrong_verification,
"Should fail with out of range elements"
);
let wrong_root = test_digest(99);
assert!(!multi_proof.verify_multi_inclusion(
&hasher,
&[
(elements[0], Location::new(0)),
(elements[5], Location::new(5)),
(elements[10], Location::new(10)),
],
&wrong_root
));
let hasher = H::new();
let empty_mem = Mem::<F, D>::new(&hasher);
let empty_root = empty_mem.root();
let empty_proof: Proof<F, D> = Proof::default();
assert!(empty_proof.verify_multi_inclusion(
&hasher,
&[] as &[(D, Location<F>)],
empty_root
));
let malformed_proof: Proof<F, D> = Proof {
leaves: Location::new(0),
digests: vec![test_digest(0)],
};
assert!(!malformed_proof.verify_multi_inclusion(
&hasher,
&[] as &[(D, Location<F>)],
empty_root
));
}
fn multi_proof_deduplication<F: Family>() {
let hasher = H::new();
let mut mem = Mem::<F, D>::new(&hasher);
let elements: Vec<_> = (0..30).map(test_digest).collect();
let batch = {
let mut batch = mem.new_batch();
for element in &elements {
batch = batch.add(&hasher, element);
}
batch.merkleize(&mem, &hasher)
};
mem.apply_batch(&batch).unwrap();
let proof1 = mem.proof(&hasher, Location::new(0)).unwrap();
let proof2 = mem.proof(&hasher, Location::new(1)).unwrap();
let total_digests_separate = proof1.digests.len() + proof2.digests.len();
let locations = &[Location::new(0), Location::new(1)];
let multi_proof_nodes =
nodes_required_for_multi_proof(mem.leaves(), locations).expect("test locations valid");
let digests = multi_proof_nodes
.into_iter()
.map(|pos| mem.get_node(pos).unwrap())
.collect();
let multi_proof = Proof {
leaves: mem.leaves(),
digests,
};
assert!(multi_proof.digests.len() < total_digests_separate);
let root = mem.root();
assert!(multi_proof.verify_multi_inclusion(
&hasher,
&[
(elements[0], Location::new(0)),
(elements[1], Location::new(1))
],
root
));
}
fn proof_leaves_malleability<F: Family>() {
let hasher = H::new();
let mut mem = Mem::<F, D>::new(&hasher);
let elements: Vec<D> = (0..252u16)
.map(|i| <Sha256 as commonware_cryptography::Hasher>::hash(&i.to_be_bytes()))
.collect();
let batch = {
let mut batch = mem.new_batch();
for e in &elements {
batch = batch.add(&hasher, e);
}
batch.merkleize(&mem, &hasher)
};
mem.apply_batch(&batch).unwrap();
let root = mem.root();
let loc = Location::new(240);
let proof = mem.proof(&hasher, loc).unwrap();
assert!(proof.verify_element_inclusion(&hasher, &elements[240], loc, root));
let mut tampered = proof.clone();
tampered.leaves = Location::new(249);
assert_ne!(tampered, proof);
assert!(
!tampered.verify_element_inclusion(&hasher, &elements[240], loc, root),
"proof with tampered leaves field must not verify"
);
}
fn blueprint_errors<F: Family>() {
let leaves = Location::<F>::new(10);
assert!(matches!(
Blueprint::<F>::new(leaves, Location::new(3)..Location::new(3)),
Err(crate::merkle::Error::Empty)
));
assert!(matches!(
Blueprint::<F>::new(leaves, Location::new(0)..Location::new(11)),
Err(crate::merkle::Error::RangeOutOfBounds(_))
));
assert!(matches!(
nodes_required_for_multi_proof::<F>(leaves, &[]),
Err(crate::merkle::Error::Empty)
));
}
fn single_element_proof_reconstruction<F: Family>() {
for n in 1u64..=64 {
let hasher = H::new();
let mem = build_raw::<F>(&hasher, n);
let root = *mem.root();
for loc_idx in 0..n {
let proof = mem
.proof(&hasher, Location::new(loc_idx))
.unwrap_or_else(|e| panic!("n={n}, loc={loc_idx}: build failed: {e:?}"));
let elements = [loc_idx.to_be_bytes()];
let start_loc = Location::new(loc_idx);
let reconstructed = proof
.reconstruct_root(&hasher, &elements, start_loc)
.unwrap_or_else(|e| panic!("n={n}, loc={loc_idx}: reconstruct failed: {e:?}"));
assert_eq!(reconstructed, root, "n={n}, loc={loc_idx}: root mismatch");
}
}
}
fn range_proof_reconstruction<F: Family>() {
for n in 2u64..=32 {
let hasher = H::new();
let mem = build_raw::<F>(&hasher, n);
let root = *mem.root();
let ranges: Vec<(u64, u64)> = vec![
(0, n),
(0, 1),
(n - 1, n),
(0, n.min(3)),
(n.saturating_sub(3), n),
];
for (start, end) in ranges {
if start >= end || end > n {
continue;
}
let proof = mem
.range_proof(&hasher, Location::new(start)..Location::new(end))
.unwrap_or_else(|e| panic!("n={n}, range={start}..{end}: build failed: {e:?}"));
let elements: Vec<_> = (start..end).map(|i| i.to_be_bytes()).collect();
let start_loc = Location::new(start);
let reconstructed = proof
.reconstruct_root(&hasher, &elements, start_loc)
.unwrap_or_else(|e| {
panic!("n={n}, range={start}..{end}: reconstruct failed: {e}")
});
assert_eq!(
reconstructed, root,
"n={n}, range={start}..{end}: root mismatch"
);
}
}
}
fn verify_element_inclusion<F: Family>() {
for n in 1u64..=32 {
let hasher = H::new();
let mem = build_raw::<F>(&hasher, n);
let root = *mem.root();
for loc_idx in 0..n {
let proof = mem.proof(&hasher, Location::new(loc_idx)).unwrap();
let loc = Location::new(loc_idx);
assert!(
proof.verify_element_inclusion(&hasher, &loc_idx.to_be_bytes(), loc, &root),
"n={n}, loc={loc_idx}: verification failed"
);
assert!(
!proof.verify_element_inclusion(
&hasher,
&(loc_idx + 1000).to_be_bytes(),
loc,
&root,
),
"n={n}, loc={loc_idx}: wrong element should not verify"
);
}
}
}
fn full_range<F: Family>() {
for n in 1u64..=32 {
let hasher = H::new();
let mem = build_raw::<F>(&hasher, n);
let root = *mem.root();
let proof = mem
.range_proof(&hasher, Location::new(0)..Location::new(n))
.unwrap();
let elements: Vec<_> = (0..n).map(|i| i.to_be_bytes()).collect();
let reconstructed = proof
.reconstruct_root(&hasher, &elements, Location::new(0))
.unwrap();
assert_eq!(reconstructed, root, "n={n}: full range failed");
assert_eq!(
proof.digests.len(),
0,
"n={n}: full range proof should have 0 digests"
);
}
}
fn empty_proof_verifies_empty_tree<F: Family>() {
let hasher = H::new();
let mem = Mem::<F, D>::new(&hasher);
let root = *mem.root();
let proof = Proof::<F, D>::default();
assert!(proof.verify_range_inclusion(&hasher, &[] as &[&[u8]], Location::new(0), &root,));
assert!(!proof.verify_range_inclusion(&hasher, &[] as &[&[u8]], Location::new(1), &root,));
}
fn every_element_contributes_to_root<F: Family>() {
for n in [8u64, 13, 20, 32] {
let hasher = H::new();
let mem = build_raw::<F>(&hasher, n);
let root = *mem.root();
let start = 1;
let end = n - 1;
let proof = mem
.range_proof(&hasher, Location::new(start)..Location::new(end))
.unwrap();
let elements: Vec<_> = (start..end).map(|i| i.to_be_bytes()).collect();
assert!(
proof.verify_range_inclusion(&hasher, &elements, Location::new(start), &root),
"n={n}: valid range should verify"
);
for flip_idx in 0..elements.len() {
let mut tampered = elements.clone();
tampered[flip_idx][0] ^= 0xFF;
assert!(
!proof.verify_range_inclusion(&hasher, &tampered, Location::new(start), &root,),
"n={n}: tampered element at index {flip_idx} should not verify"
);
}
}
}
fn multi_proof_generation_and_verify_raw<F: Family>() {
let hasher = H::new();
let mem = build_raw::<F>(&hasher, 20);
let root = *mem.root();
let locations = &[Location::new(0), Location::new(5), Location::new(10)];
let nodes =
nodes_required_for_multi_proof(mem.leaves(), locations).expect("valid locations");
let digests = nodes
.into_iter()
.map(|pos| mem.get_node(pos).unwrap())
.collect();
let multi_proof = Proof {
leaves: mem.leaves(),
digests,
};
assert!(multi_proof.verify_multi_inclusion(
&hasher,
&[
(0u64.to_be_bytes(), Location::new(0)),
(5u64.to_be_bytes(), Location::new(5)),
(10u64.to_be_bytes(), Location::new(10)),
],
&root
));
assert!(multi_proof.verify_multi_inclusion(
&hasher,
&[
(10u64.to_be_bytes(), Location::new(10)),
(5u64.to_be_bytes(), Location::new(5)),
(0u64.to_be_bytes(), Location::new(0)),
],
&root
));
assert!(!multi_proof.verify_multi_inclusion(
&hasher,
&[
(99u64.to_be_bytes(), Location::new(0)),
(5u64.to_be_bytes(), Location::new(5)),
(10u64.to_be_bytes(), Location::new(10)),
],
&root
));
let wrong_root = hasher.digest(b"wrong");
assert!(!multi_proof.verify_multi_inclusion(
&hasher,
&[
(0u64.to_be_bytes(), Location::new(0)),
(5u64.to_be_bytes(), Location::new(5)),
(10u64.to_be_bytes(), Location::new(10)),
],
&wrong_root
));
let hasher2 = H::new();
let empty_mem = Mem::<F, D>::new(&hasher2);
let empty_proof: Proof<F, D> = Proof::default();
assert!(empty_proof.verify_multi_inclusion(
&hasher2,
&[] as &[([u8; 8], Location<F>)],
empty_mem.root()
));
let malformed_proof: Proof<F, D> = Proof {
leaves: Location::new(0),
digests: vec![test_digest(0)],
};
assert!(!malformed_proof.verify_multi_inclusion(
&hasher2,
&[] as &[([u8; 8], Location<F>)],
empty_mem.root()
));
}
fn tampered_proof_digests_rejected<F: Family>() {
for n in [8u64, 13, 20, 32] {
let hasher = H::new();
let mem = build_raw::<F>(&hasher, n);
let root = *mem.root();
for loc_idx in [0, n / 2, n - 1] {
let proof = mem.proof(&hasher, Location::new(loc_idx)).unwrap();
let element = loc_idx.to_be_bytes();
let loc = Location::new(loc_idx);
assert!(proof.verify_element_inclusion(&hasher, &element, loc, &root));
for digest_idx in 0..proof.digests.len() {
let mut tampered = proof.clone();
tampered.digests[digest_idx].0[0] ^= 1;
assert!(
!tampered.verify_element_inclusion(&hasher, &element, loc, &root),
"n={n}, loc={loc_idx}: tampered digest[{digest_idx}] should not verify"
);
}
}
}
}
fn no_duplicate_positions<F: Family>() {
use alloc::collections::BTreeSet;
for n in 1u64..=64 {
let hasher = H::new();
let mem = build_raw::<F>(&hasher, n);
let leaves = mem.leaves();
for loc in 0..n {
let loc = Location::new(loc);
let bp = Blueprint::<F>::new(leaves, loc..loc + 1).unwrap();
let mut positions: Vec<Position<F>> = Vec::new();
positions.extend(&bp.fold_prefix);
positions.extend(&bp.fetch_nodes);
let set: BTreeSet<_> = positions.iter().copied().collect();
assert_eq!(
positions.len(),
set.len(),
"n={n}, loc={loc}: duplicate positions"
);
}
}
}
#[test]
fn mmr_empty_proof() {
empty_proof::<mmr::Family>();
}
#[test]
fn mmr_verify_element() {
verify_element::<mmr::Family>();
}
#[test]
fn mmr_verify_range() {
verify_range::<mmr::Family>();
}
#[test_traced]
fn mmr_retained_nodes_provable_after_pruning() {
retained_nodes_provable_after_pruning::<mmr::Family>();
}
#[test]
fn mmr_ranges_provable_after_pruning() {
ranges_provable_after_pruning::<mmr::Family>();
}
#[test]
fn mmr_proof_serialization() {
proof_serialization::<mmr::Family>();
}
#[test]
fn mmr_multi_proof_generation_and_verify() {
multi_proof_generation_and_verify::<mmr::Family>();
}
#[test]
fn mmr_multi_proof_deduplication() {
multi_proof_deduplication::<mmr::Family>();
}
#[test]
fn mmr_proof_leaves_malleability() {
proof_leaves_malleability::<mmr::Family>();
}
#[test]
fn mmr_blueprint_errors() {
blueprint_errors::<mmr::Family>();
}
#[test]
fn mmr_single_element_proof_reconstruction() {
single_element_proof_reconstruction::<mmr::Family>();
}
#[test]
fn mmr_range_proof_reconstruction() {
range_proof_reconstruction::<mmr::Family>();
}
#[test]
fn mmr_verify_element_inclusion() {
verify_element_inclusion::<mmr::Family>();
}
#[test]
fn mmr_full_range() {
full_range::<mmr::Family>();
}
#[test]
fn mmr_empty_proof_verifies_empty_tree() {
empty_proof_verifies_empty_tree::<mmr::Family>();
}
#[test]
fn mmr_every_element_contributes_to_root() {
every_element_contributes_to_root::<mmr::Family>();
}
#[test]
fn mmr_multi_proof_generation_and_verify_raw() {
multi_proof_generation_and_verify_raw::<mmr::Family>();
}
#[test]
fn mmr_tampered_proof_digests_rejected() {
tampered_proof_digests_rejected::<mmr::Family>();
}
#[test]
fn mmr_no_duplicate_positions() {
no_duplicate_positions::<mmr::Family>();
}
#[test]
fn mmb_empty_proof() {
empty_proof::<mmb::Family>();
}
#[test]
fn mmb_verify_element() {
verify_element::<mmb::Family>();
}
#[test]
fn mmb_verify_range() {
verify_range::<mmb::Family>();
}
#[test_traced]
fn mmb_retained_nodes_provable_after_pruning() {
retained_nodes_provable_after_pruning::<mmb::Family>();
}
#[test]
fn mmb_ranges_provable_after_pruning() {
ranges_provable_after_pruning::<mmb::Family>();
}
#[test]
fn mmb_proof_serialization() {
proof_serialization::<mmb::Family>();
}
#[test]
fn mmb_multi_proof_generation_and_verify() {
multi_proof_generation_and_verify::<mmb::Family>();
}
#[test]
fn mmb_multi_proof_deduplication() {
multi_proof_deduplication::<mmb::Family>();
}
#[test]
fn mmb_proof_leaves_malleability() {
proof_leaves_malleability::<mmb::Family>();
}
#[test]
fn mmb_blueprint_errors() {
blueprint_errors::<mmb::Family>();
}
#[test]
fn mmb_single_element_proof_reconstruction() {
single_element_proof_reconstruction::<mmb::Family>();
}
#[test]
fn mmb_range_proof_reconstruction() {
range_proof_reconstruction::<mmb::Family>();
}
#[test]
fn mmb_verify_element_inclusion() {
verify_element_inclusion::<mmb::Family>();
}
#[test]
fn mmb_full_range() {
full_range::<mmb::Family>();
}
#[test]
fn mmb_empty_proof_verifies_empty_tree() {
empty_proof_verifies_empty_tree::<mmb::Family>();
}
#[test]
fn mmb_every_element_contributes_to_root() {
every_element_contributes_to_root::<mmb::Family>();
}
#[test]
fn mmb_multi_proof_generation_and_verify_raw() {
multi_proof_generation_and_verify_raw::<mmb::Family>();
}
#[test]
fn mmb_tampered_proof_digests_rejected() {
tampered_proof_digests_rejected::<mmb::Family>();
}
#[test]
fn mmb_no_duplicate_positions() {
no_duplicate_positions::<mmb::Family>();
}
}