use alloc::collections::btree_set::BTreeSet;
use commonware_codec::{EncodeSize, Read, ReadExt, ReadRangeExt, Write};
use commonware_cryptography::{Digest, Hasher};
use commonware_runtime::{Buf, BufMut};
use commonware_utils::{non_empty_vec, vec::NonEmptyVec};
use thiserror::Error;
pub const MAX_LEVELS: usize = u8::MAX as usize;
#[derive(Error, Debug)]
pub enum Error {
#[error("invalid position: {0}")]
InvalidPosition(u32),
#[error("invalid proof: {0} != {1}")]
InvalidProof(String, String),
#[error("no leaves")]
NoLeaves,
#[error("unaligned proof")]
UnalignedProof,
#[error("duplicate position: {0}")]
DuplicatePosition(u32),
}
pub struct Builder<H: Hasher> {
hasher: H,
leaves: Vec<H::Digest>,
}
impl<H: Hasher> Builder<H> {
pub fn new(leaves: usize) -> Self {
Self {
hasher: H::new(),
leaves: Vec::with_capacity(leaves),
}
}
pub fn add(&mut self, leaf: &H::Digest) -> u32 {
let position: u32 = self.leaves.len().try_into().expect("too many leaves");
self.hasher.update(&position.to_be_bytes());
self.hasher.update(leaf);
self.leaves.push(self.hasher.finalize());
position
}
pub fn build(self) -> Tree<H::Digest> {
Tree::new(self.hasher, self.leaves)
}
}
#[derive(Clone, Debug)]
pub struct Tree<D: Digest> {
empty: bool,
levels: NonEmptyVec<NonEmptyVec<D>>,
root: D,
}
impl<D: Digest> Tree<D> {
fn new<H: Hasher<Digest = D>>(mut hasher: H, mut leaves: Vec<D>) -> Self {
let mut empty = false;
let leaf_count = leaves.len() as u32;
if leaves.is_empty() {
leaves.push(hasher.finalize());
empty = true;
}
let mut levels = non_empty_vec![non_empty_vec![@leaves]];
let mut current_level = levels.last();
while !current_level.is_singleton() {
let mut next_level = Vec::with_capacity(current_level.len().get().div_ceil(2));
for chunk in current_level.chunks(2) {
hasher.update(&chunk[0]);
if chunk.len() == 2 {
hasher.update(&chunk[1]);
} else {
hasher.update(&chunk[0]);
};
next_level.push(hasher.finalize());
}
levels.push(non_empty_vec![@next_level]);
current_level = levels.last();
}
let tree_root = levels.last().first();
hasher.update(&leaf_count.to_be_bytes());
hasher.update(tree_root);
let root = hasher.finalize();
Self {
empty,
levels,
root,
}
}
pub const fn root(&self) -> D {
self.root
}
pub fn proof(&self, position: u32) -> Result<Proof<D>, Error> {
self.multi_proof(core::iter::once(position))
}
pub fn range_proof(&self, start: u32, end: u32) -> Result<Proof<D>, Error> {
if self.empty {
if start == 0 && end == 0 {
return Ok(Proof::default());
}
return Err(Error::InvalidPosition(start));
}
if start > end {
return Err(Error::InvalidPosition(start));
}
let leaf_count = self.levels.first().len().get() as u32;
if start >= leaf_count {
return Err(Error::InvalidPosition(start));
}
if end >= leaf_count {
return Err(Error::InvalidPosition(end));
}
let sibling_positions = siblings_required_for_range_proof(leaf_count, start, end)?;
let siblings: Vec<D> = sibling_positions
.iter()
.map(|&(level, index)| self.levels[level][index])
.collect();
Ok(Proof {
leaf_count,
siblings,
})
}
pub fn multi_proof<I, P>(&self, positions: I) -> Result<Proof<D>, Error>
where
I: IntoIterator<Item = P>,
P: core::borrow::Borrow<u32>,
{
let mut positions = positions.into_iter().peekable();
let first = *positions.peek().ok_or(Error::NoLeaves)?.borrow();
if self.empty {
return Err(Error::InvalidPosition(first));
}
let leaf_count = self.levels.first().len().get() as u32;
let sibling_positions =
siblings_required_for_multi_proof(leaf_count, positions.map(|p| *p.borrow()))?;
let siblings: Vec<D> = sibling_positions
.iter()
.map(|&(level, index)| self.levels[level][index])
.collect();
Ok(Proof {
leaf_count,
siblings,
})
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Proof<D: Digest> {
pub leaf_count: u32,
pub siblings: Vec<D>,
}
impl<D: Digest> Default for Proof<D> {
fn default() -> Self {
Self {
leaf_count: 0,
siblings: Vec::new(),
}
}
}
impl<D: Digest> Write for Proof<D> {
fn write(&self, writer: &mut impl BufMut) {
self.leaf_count.write(writer);
self.siblings.write(writer);
}
}
impl<D: Digest> Read for Proof<D> {
type Cfg = usize;
fn read_cfg(
reader: &mut impl Buf,
max_items: &Self::Cfg,
) -> Result<Self, commonware_codec::Error> {
let leaf_count = u32::read(reader)?;
let max_siblings = max_items.saturating_mul(MAX_LEVELS);
let siblings = Vec::<D>::read_range(reader, ..=max_siblings)?;
Ok(Self {
leaf_count,
siblings,
})
}
}
impl<D: Digest> EncodeSize for Proof<D> {
fn encode_size(&self) -> usize {
self.leaf_count.encode_size() + self.siblings.encode_size()
}
}
#[cfg(feature = "arbitrary")]
impl<D: Digest> arbitrary::Arbitrary<'_> for Proof<D>
where
D: for<'a> arbitrary::Arbitrary<'a>,
{
fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
Ok(Self {
leaf_count: u.arbitrary()?,
siblings: u.arbitrary()?,
})
}
}
const fn levels_in_tree(leaf_count: u32) -> usize {
(u32::BITS - (leaf_count.saturating_sub(1)).leading_zeros() + 1) as usize
}
fn siblings_required_for_multi_proof(
leaf_count: u32,
positions: impl IntoIterator<Item = u32>,
) -> Result<BTreeSet<(usize, usize)>, Error> {
let mut current = BTreeSet::new();
for pos in positions {
if pos >= leaf_count {
return Err(Error::InvalidPosition(pos));
}
if !current.insert(pos as usize) {
return Err(Error::DuplicatePosition(pos));
}
}
if current.is_empty() {
return Err(Error::NoLeaves);
}
let mut sibling_positions = BTreeSet::new();
let levels_count = levels_in_tree(leaf_count);
let mut level_size = leaf_count as usize;
for level in 0..levels_count - 1 {
for &index in ¤t {
let sibling_index = if index.is_multiple_of(2) {
if index + 1 < level_size {
index + 1
} else {
index
}
} else {
index - 1
};
if sibling_index != index && !current.contains(&sibling_index) {
sibling_positions.insert((level, sibling_index));
}
}
current = current.iter().map(|idx| idx / 2).collect();
level_size = level_size.div_ceil(2);
}
Ok(sibling_positions)
}
fn siblings_required_for_range_proof(
leaf_count: u32,
start: u32,
end: u32,
) -> Result<BTreeSet<(usize, usize)>, Error> {
if leaf_count == 0 {
return Err(Error::NoLeaves);
}
if start > end {
return Err(Error::InvalidPosition(start));
}
if start >= leaf_count {
return Err(Error::InvalidPosition(start));
}
if end >= leaf_count {
return Err(Error::InvalidPosition(end));
}
let mut sibling_positions = BTreeSet::new();
let levels_count = levels_in_tree(leaf_count);
let mut level_start = start as usize;
let mut level_end = end as usize;
let mut level_size = leaf_count as usize;
for level in 0..levels_count - 1 {
if !level_start.is_multiple_of(2) {
sibling_positions.insert((level, level_start - 1));
}
if level_end.is_multiple_of(2) {
let right = level_end + 1;
if right < level_size {
sibling_positions.insert((level, right));
}
}
level_start /= 2;
level_end /= 2;
level_size = level_size.div_ceil(2);
}
Ok(sibling_positions)
}
impl<D: Digest> Proof<D> {
pub fn verify_element_inclusion<H: Hasher<Digest = D>>(
&self,
hasher: &mut H,
leaf: &D,
mut position: u32,
root: &D,
) -> Result<(), Error> {
if position >= self.leaf_count {
return Err(Error::InvalidPosition(position));
}
hasher.update(&position.to_be_bytes());
hasher.update(leaf);
let mut computed = hasher.finalize();
let mut level_size = self.leaf_count as usize;
let mut sibling_iter = self.siblings.iter();
while level_size > 1 {
let is_last_odd = position.is_multiple_of(2) && position as usize + 1 >= level_size;
let (left_node, right_node) = if is_last_odd {
(&computed, &computed)
} else if position.is_multiple_of(2) {
let sibling = sibling_iter.next().ok_or(Error::UnalignedProof)?;
(&computed, sibling)
} else {
let sibling = sibling_iter.next().ok_or(Error::UnalignedProof)?;
(sibling, &computed)
};
hasher.update(left_node);
hasher.update(right_node);
computed = hasher.finalize();
position /= 2;
level_size = level_size.div_ceil(2);
}
if sibling_iter.next().is_some() {
return Err(Error::UnalignedProof);
}
hasher.update(&self.leaf_count.to_be_bytes());
hasher.update(&computed);
let finalized = hasher.finalize();
if finalized == *root {
Ok(())
} else {
Err(Error::InvalidProof(finalized.to_string(), root.to_string()))
}
}
pub fn verify_multi_inclusion<H: Hasher<Digest = D>>(
&self,
hasher: &mut H,
elements: &[(D, u32)],
root: &D,
) -> Result<(), Error> {
if elements.is_empty() {
if self.leaf_count == 0 && self.siblings.is_empty() {
let empty_tree_root = hasher.finalize();
hasher.update(&0u32.to_be_bytes());
hasher.update(&empty_tree_root);
let finalized = hasher.finalize();
if finalized == *root {
return Ok(());
} else {
return Err(Error::InvalidProof(finalized.to_string(), root.to_string()));
}
}
return Err(Error::NoLeaves);
}
let mut sorted: Vec<(u32, D)> = Vec::with_capacity(elements.len());
for (leaf, position) in elements {
if *position >= self.leaf_count {
return Err(Error::InvalidPosition(*position));
}
hasher.update(&position.to_be_bytes());
hasher.update(leaf);
sorted.push((*position, hasher.finalize()));
}
sorted.sort_unstable_by_key(|(pos, _)| *pos);
for i in 1..sorted.len() {
if sorted[i - 1].0 == sorted[i].0 {
return Err(Error::DuplicatePosition(sorted[i].0));
}
}
let levels = levels_in_tree(self.leaf_count);
let mut level_size = self.leaf_count;
let mut sibling_iter = self.siblings.iter();
let mut current = sorted;
let mut next_level: Vec<(u32, D)> = Vec::with_capacity(current.len());
for _ in 0..levels - 1 {
let mut idx = 0;
while idx < current.len() {
let (pos, digest) = current[idx];
let parent_pos = pos / 2;
let (left, right) = if pos.is_multiple_of(2) {
let left = digest;
let right = if idx + 1 < current.len() && current[idx + 1].0 == pos + 1 {
idx += 1;
current[idx].1
} else if pos + 1 >= level_size {
left
} else {
*sibling_iter.next().ok_or(Error::UnalignedProof)?
};
(left, right)
} else {
let right = digest;
let left = *sibling_iter.next().ok_or(Error::UnalignedProof)?;
(left, right)
};
hasher.update(&left);
hasher.update(&right);
next_level.push((parent_pos, hasher.finalize()));
idx += 1;
}
core::mem::swap(&mut current, &mut next_level);
next_level.clear();
level_size = level_size.div_ceil(2);
}
if sibling_iter.next().is_some() {
return Err(Error::UnalignedProof);
}
if current.len() != 1 {
return Err(Error::UnalignedProof);
}
let tree_root = current[0].1;
hasher.update(&self.leaf_count.to_be_bytes());
hasher.update(&tree_root);
let finalized = hasher.finalize();
if finalized == *root {
Ok(())
} else {
Err(Error::InvalidProof(finalized.to_string(), root.to_string()))
}
}
pub fn verify_range_inclusion<H: Hasher<Digest = D>>(
&self,
hasher: &mut H,
position: u32,
leaves: &[D],
root: &D,
) -> Result<(), Error> {
if leaves.is_empty() && position != 0 {
return Err(Error::InvalidPosition(position));
}
if !leaves.is_empty() {
let leaves_len =
u32::try_from(leaves.len()).map_err(|_| Error::InvalidPosition(position))?;
let end = position
.checked_add(leaves_len - 1)
.ok_or(Error::InvalidPosition(position))?;
if end >= self.leaf_count {
return Err(Error::InvalidPosition(end));
}
}
let elements: Vec<(D, u32)> = leaves
.iter()
.enumerate()
.map(|(i, leaf)| (*leaf, position + i as u32))
.collect();
self.verify_multi_inclusion(hasher, &elements, root)
}
}
#[cfg(test)]
mod tests {
use super::*;
use commonware_codec::{Decode, Encode};
use commonware_cryptography::sha256::{Digest, Sha256};
use rstest::rstest;
#[test]
fn issue_2837_regression() {
let digests: Vec<Digest> = (0..255u32)
.map(|i| Sha256::hash(&i.to_be_bytes()))
.collect();
let mut builder = Builder::<Sha256>::new(255);
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let original_proof = tree.proof(0).unwrap();
assert_eq!(original_proof.leaf_count, 255);
let mut hasher = Sha256::default();
assert!(
original_proof
.verify_element_inclusion(&mut hasher, &digests[0], 0, &root)
.is_ok(),
"Original proof should verify"
);
let malleated_proof = Proof {
leaf_count: 254,
siblings: original_proof.siblings.clone(),
};
let result = malleated_proof.verify_element_inclusion(&mut hasher, &digests[0], 0, &root);
assert!(
result.is_err(),
"Malleated proof with wrong leaf_count must fail verification"
);
}
#[test]
fn test_tampered_proof_no_siblings() {
let txs = [b"tx1", b"tx2", b"tx3", b"tx4"];
let digests: Vec<Digest> = txs.iter().map(|tx| Sha256::hash(*tx)).collect();
let element = &digests[0];
let mut builder = Builder::<Sha256>::new(txs.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let mut proof = tree.proof(0).unwrap();
proof.siblings = Vec::new();
let mut hasher = Sha256::default();
assert!(proof
.verify_element_inclusion(&mut hasher, element, 0, &root)
.is_err());
}
#[test]
fn test_tampered_proof_extra_sibling() {
let txs = [b"tx1", b"tx2", b"tx3", b"tx4"];
let digests: Vec<Digest> = txs.iter().map(|tx| Sha256::hash(*tx)).collect();
let element = &digests[0];
let mut builder = Builder::<Sha256>::new(txs.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let mut proof = tree.proof(0).unwrap();
proof.siblings.push(*element);
let mut hasher = Sha256::default();
assert!(proof
.verify_element_inclusion(&mut hasher, element, 0, &root)
.is_err());
}
#[test]
fn test_invalid_proof_wrong_element() {
let txs = [b"tx1", b"tx2", b"tx3", b"tx4"];
let digests: Vec<Digest> = txs.iter().map(|tx| Sha256::hash(*tx)).collect();
let mut builder = Builder::<Sha256>::new(txs.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let proof = tree.proof(2).unwrap();
let mut hasher = Sha256::default();
let wrong_leaf = Sha256::hash(b"wrong_tx");
assert!(proof
.verify_element_inclusion(&mut hasher, &wrong_leaf, 2, &root)
.is_err());
}
#[test]
fn test_invalid_proof_wrong_index() {
let txs = [b"tx1", b"tx2", b"tx3", b"tx4"];
let digests: Vec<Digest> = txs.iter().map(|tx| Sha256::hash(*tx)).collect();
let mut builder = Builder::<Sha256>::new(txs.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let proof = tree.proof(1).unwrap();
let mut hasher = Sha256::default();
assert!(proof
.verify_element_inclusion(&mut hasher, &digests[1], 2, &root)
.is_err());
}
#[test]
fn test_invalid_proof_wrong_root() {
let txs = [b"tx1", b"tx2", b"tx3", b"tx4"];
let digests: Vec<Digest> = txs.iter().map(|tx| Sha256::hash(*tx)).collect();
let mut builder = Builder::<Sha256>::new(txs.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let proof = tree.proof(0).unwrap();
let mut hasher = Sha256::default();
let wrong_root = Sha256::hash(b"wrong_root");
assert!(proof
.verify_element_inclusion(&mut hasher, &digests[0], 0, &wrong_root)
.is_err());
}
#[test]
fn test_invalid_proof_serialization_truncated() {
let txs = [b"tx1", b"tx2", b"tx3"];
let digests: Vec<Digest> = txs.iter().map(|tx| Sha256::hash(*tx)).collect();
let mut builder = Builder::<Sha256>::new(txs.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let proof = tree.proof(1).unwrap();
let mut serialized = proof.encode();
serialized.truncate(serialized.len() - 1);
assert!(Proof::<Digest>::decode_cfg(&mut serialized, &1).is_err());
}
#[test]
fn test_invalid_proof_serialization_extra() {
let txs = [b"tx1", b"tx2", b"tx3"];
let digests: Vec<Digest> = txs.iter().map(|tx| Sha256::hash(*tx)).collect();
let mut builder = Builder::<Sha256>::new(txs.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let proof = tree.proof(1).unwrap();
let mut serialized = proof.encode_mut();
serialized.extend_from_slice(&[0u8]);
assert!(Proof::<Digest>::decode_cfg(&mut serialized, &1).is_err());
}
#[test]
fn test_invalid_proof_modified_hash() {
let txs = [b"tx1", b"tx2", b"tx3", b"tx4"];
let digests: Vec<Digest> = txs.iter().map(|tx| Sha256::hash(*tx)).collect();
let mut builder = Builder::<Sha256>::new(txs.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let mut proof = tree.proof(2).unwrap();
let mut hasher = Sha256::default();
proof.siblings[0] = Sha256::hash(b"modified");
assert!(proof
.verify_element_inclusion(&mut hasher, &digests[2], 2, &root)
.is_err());
}
#[test]
fn test_odd_tree_duplicate_index_proof() {
let txs = [b"tx1", b"tx2", b"tx3"];
let digests: Vec<Digest> = txs.iter().map(|tx| Sha256::hash(*tx)).collect();
let mut builder = Builder::<Sha256>::new(txs.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let proof = tree.proof(2).unwrap();
let mut hasher = Sha256::default();
assert!(proof
.verify_element_inclusion(&mut hasher, &digests[2], 2, &root)
.is_ok());
assert!(tree.proof(3).is_err());
assert!(proof
.verify_element_inclusion(&mut hasher, &digests[2], 3, &root)
.is_err());
}
#[test]
fn test_range_proof_basic() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let range_proof = tree.range_proof(2, 5).unwrap();
let mut hasher = Sha256::default();
let range_leaves = &digests[2..6];
assert!(range_proof
.verify_range_inclusion(&mut hasher, 2, range_leaves, &root)
.is_ok());
let mut serialized = range_proof.encode();
let deserialized = Proof::<Digest>::decode_cfg(&mut serialized, &4).unwrap();
assert!(deserialized
.verify_range_inclusion(&mut hasher, 2, range_leaves, &root)
.is_ok());
}
#[test]
fn test_range_proof_single_element() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
for (i, digest) in digests.iter().enumerate() {
let range_proof = tree.range_proof(i as u32, i as u32).unwrap();
let mut hasher = Sha256::default();
let result =
range_proof.verify_range_inclusion(&mut hasher, i as u32, &[*digest], &root);
assert!(result.is_ok());
}
}
#[test]
fn test_range_proof_full_tree() {
let digests: Vec<Digest> = (0..7u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let range_proof = tree.range_proof(0, (digests.len() - 1) as u32).unwrap();
let mut hasher = Sha256::default();
assert!(range_proof
.verify_range_inclusion(&mut hasher, 0, &digests, &root)
.is_ok());
}
#[test]
fn test_range_proof_edge_cases() {
let digests: Vec<Digest> = (0..15u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let mut hasher = Sha256::default();
let range_proof = tree.range_proof(0, 7).unwrap();
assert!(range_proof
.verify_range_inclusion(&mut hasher, 0, &digests[0..8], &root)
.is_ok());
let range_proof = tree.range_proof(8, 14).unwrap();
assert!(range_proof
.verify_range_inclusion(&mut hasher, 8, &digests[8..15], &root)
.is_ok());
let range_proof = tree.range_proof(13, 14).unwrap();
assert!(range_proof
.verify_range_inclusion(&mut hasher, 13, &digests[13..15], &root)
.is_ok());
}
#[test]
fn test_range_proof_invalid_range() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
assert!(tree.range_proof(8, 8).is_err()); assert!(tree.range_proof(0, 8).is_err()); assert!(tree.range_proof(5, 8).is_err()); assert!(tree.range_proof(2, 1).is_err()); }
#[test]
fn test_range_proof_tampering() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let range_proof = tree.range_proof(2, 4).unwrap();
let mut hasher = Sha256::default();
let range_leaves = &digests[2..5];
let wrong_leaves = vec![
Sha256::hash(b"wrong1"),
Sha256::hash(b"wrong2"),
Sha256::hash(b"wrong3"),
];
assert!(range_proof
.verify_range_inclusion(&mut hasher, 2, &wrong_leaves, &root)
.is_err());
assert!(range_proof
.verify_range_inclusion(&mut hasher, 2, &digests[2..4], &root)
.is_err());
let mut tampered_proof = range_proof.clone();
assert!(!tampered_proof.siblings.is_empty());
tampered_proof.siblings[0] = Sha256::hash(b"tampered");
assert!(tampered_proof
.verify_range_inclusion(&mut hasher, 2, range_leaves, &root)
.is_err());
let wrong_root = Sha256::hash(b"wrong_root");
assert!(range_proof
.verify_range_inclusion(&mut hasher, 2, range_leaves, &wrong_root)
.is_err());
}
#[test]
fn test_range_proof_various_sizes() {
for tree_size in [1, 2, 3, 4, 5, 7, 8, 15, 16, 31, 32, 63, 64] {
let digests: Vec<Digest> = (0..tree_size as u32)
.map(|i| Sha256::hash(&i.to_be_bytes()))
.collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let mut hasher = Sha256::default();
for range_size in 1..=tree_size.min(8) {
for start in 0..=(tree_size - range_size) {
let range_proof = tree
.range_proof(start as u32, (start + range_size - 1) as u32)
.unwrap();
let end = start + range_size;
assert!(
range_proof
.verify_range_inclusion(
&mut hasher,
start as u32,
&digests[start..end],
&root
)
.is_ok(),
"Failed for tree_size={tree_size}, start={start}, range_size={range_size}"
);
}
}
}
}
#[test]
fn test_range_proof_malicious_wrong_position() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let range_proof = tree.range_proof(2, 4).unwrap();
let mut hasher = Sha256::default();
let range_leaves = &digests[2..5];
assert!(range_proof
.verify_range_inclusion(&mut hasher, 3, range_leaves, &root)
.is_err());
assert!(range_proof
.verify_range_inclusion(&mut hasher, 1, range_leaves, &root)
.is_err());
}
#[test]
fn test_range_proof_malicious_reordered_leaves() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let range_proof = tree.range_proof(2, 4).unwrap();
let mut hasher = Sha256::default();
let reordered_leaves = vec![digests[3], digests[2], digests[4]];
assert!(range_proof
.verify_range_inclusion(&mut hasher, 2, &reordered_leaves, &root)
.is_err());
}
#[test]
fn test_range_proof_malicious_extra_siblings() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let mut range_proof = tree.range_proof(2, 3).unwrap();
let mut hasher = Sha256::default();
let range_leaves = &digests[2..4];
range_proof.siblings.push(Sha256::hash(b"extra"));
assert!(range_proof
.verify_range_inclusion(&mut hasher, 2, range_leaves, &root)
.is_err());
}
#[test]
fn test_range_proof_malicious_missing_siblings() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let mut range_proof = tree.range_proof(2, 2).unwrap();
let mut hasher = Sha256::default();
let range_leaves = &digests[2..3];
assert!(!range_proof.siblings.is_empty());
range_proof.siblings.pop();
assert!(range_proof
.verify_range_inclusion(&mut hasher, 2, range_leaves, &root)
.is_err());
}
#[test]
fn test_range_proof_integer_overflow_protection() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
assert!(tree.range_proof(u32::MAX, u32::MAX).is_err());
assert!(tree.range_proof(u32::MAX - 1, u32::MAX).is_err());
assert!(tree.range_proof(7, u32::MAX).is_err());
}
#[test]
fn test_range_proof_malicious_wrong_tree_structure() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let mut range_proof = tree.range_proof(2, 3).unwrap();
let mut hasher = Sha256::default();
let range_leaves = &digests[2..4];
range_proof.siblings.push(Sha256::hash(b"fake_level"));
assert!(range_proof
.verify_range_inclusion(&mut hasher, 2, range_leaves, &root)
.is_err());
let mut range_proof = tree.range_proof(2, 2).unwrap();
let range_leaves = &digests[2..3];
assert!(!range_proof.siblings.is_empty());
range_proof.siblings.pop();
assert!(range_proof
.verify_range_inclusion(&mut hasher, 2, range_leaves, &root)
.is_err());
}
#[test]
fn test_range_proof_boundary_conditions() {
for tree_size in [1, 2, 4, 8, 16, 32] {
let digests: Vec<Digest> = (0..tree_size as u32)
.map(|i| Sha256::hash(&i.to_be_bytes()))
.collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let mut hasher = Sha256::default();
let proof = tree.range_proof(0, 0).unwrap();
assert!(proof
.verify_range_inclusion(&mut hasher, 0, &digests[0..1], &root)
.is_ok());
let last_idx = tree_size - 1;
let proof = tree.range_proof(last_idx as u32, last_idx as u32).unwrap();
assert!(proof
.verify_range_inclusion(
&mut hasher,
last_idx as u32,
&digests[last_idx..tree_size],
&root
)
.is_ok());
let proof = tree.range_proof(0, (tree_size - 1) as u32).unwrap();
assert!(proof
.verify_range_inclusion(&mut hasher, 0, &digests, &root)
.is_ok());
}
}
#[test]
fn test_empty_tree_proof() {
let builder = Builder::<Sha256>::new(0);
let tree = builder.build();
assert!(tree.proof(0).is_err());
assert!(tree.proof(1).is_err());
assert!(tree.proof(100).is_err());
}
#[test]
fn test_empty_tree_range_proof() {
let builder = Builder::<Sha256>::new(0);
let tree = builder.build();
let root = tree.root();
let range_proof = tree.range_proof(0, 0).unwrap();
assert!(range_proof.siblings.is_empty());
assert_eq!(range_proof, Proof::default());
let invalid_ranges = vec![
(0, 1),
(0, 10),
(1, 1),
(1, 2),
(5, 5),
(10, 10),
(0, u32::MAX),
(u32::MAX, u32::MAX),
];
for (start, end) in invalid_ranges {
assert!(tree.range_proof(start, end).is_err());
}
let mut hasher = Sha256::default();
let empty_leaves: &[Digest] = &[];
assert!(range_proof
.verify_range_inclusion(&mut hasher, 0, empty_leaves, &root)
.is_ok());
let non_empty_leaves = vec![Sha256::hash(b"leaf")];
assert!(range_proof
.verify_range_inclusion(&mut hasher, 0, &non_empty_leaves, &root)
.is_err());
let wrong_root = Sha256::hash(b"wrong");
assert!(range_proof
.verify_range_inclusion(&mut hasher, 0, empty_leaves, &wrong_root)
.is_err());
assert!(range_proof
.verify_range_inclusion(&mut hasher, 1, empty_leaves, &root)
.is_err());
}
#[test]
fn test_empty_range_proof_serialization() {
let proof = Proof::<Digest>::default();
let mut serialized = proof.encode();
let deserialized = Proof::<Digest>::decode_cfg(&mut serialized, &0).unwrap();
assert_eq!(proof, deserialized);
}
#[test]
fn test_empty_tree_root_consistency() {
let mut roots = Vec::new();
for _ in 0..5 {
let builder = Builder::<Sha256>::new(0);
let tree = builder.build();
roots.push(tree.root());
}
for i in 1..roots.len() {
assert_eq!(roots[0], roots[i]);
}
let mut hasher = Sha256::default();
hasher.update(0u32.to_be_bytes().as_slice());
hasher.update(Sha256::hash(b"").as_ref());
let expected_root = hasher.finalize();
assert_eq!(roots[0], expected_root);
}
#[rstest]
#[case::need_left_sibling(1, 2)] #[case::need_right_sibling(4, 4)] #[case::full_tree(0, 16)] fn test_range_proof_siblings_usage(#[case] start: u32, #[case] count: u32) {
let digests: Vec<Digest> = (0..16u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let mut hasher = Sha256::default();
let range_proof = tree.range_proof(start, start + count - 1).unwrap();
let end = start as usize + count as usize;
assert!(range_proof
.verify_range_inclusion(&mut hasher, start, &digests[start as usize..end], &root)
.is_ok());
for sibling_idx in 0..range_proof.siblings.len() {
let mut tampered_proof = range_proof.clone();
tampered_proof.siblings[sibling_idx] = Sha256::hash(b"tampered");
assert!(tampered_proof
.verify_range_inclusion(&mut hasher, start, &digests[start as usize..end], &root)
.is_err());
}
}
#[rstest]
fn test_range_proof_duplicate_node_edge_cases(
#[values(3, 5, 7, 9, 11, 13, 15)] tree_size: usize,
) {
let digests: Vec<Digest> = (0..tree_size as u32)
.map(|i| Sha256::hash(&i.to_be_bytes()))
.collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let mut hasher = Sha256::default();
let start = tree_size - 2;
let proof = tree
.range_proof(start as u32, (tree_size - 1) as u32)
.unwrap();
assert!(proof
.verify_range_inclusion(&mut hasher, start as u32, &digests[start..tree_size], &root)
.is_ok());
}
#[test]
fn test_multi_proof_basic() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let positions = [0, 3, 5];
let multi_proof = tree.multi_proof(positions).unwrap();
let mut hasher = Sha256::default();
let elements: Vec<(Digest, u32)> = positions
.iter()
.map(|&p| (digests[p as usize], p))
.collect();
assert!(multi_proof
.verify_multi_inclusion(&mut hasher, &elements, &root)
.is_ok());
}
#[test]
fn test_multi_proof_single_element() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let mut hasher = Sha256::default();
for (i, digest) in digests.iter().enumerate() {
let multi_proof = tree.multi_proof([i as u32]).unwrap();
let elements = [(*digest, i as u32)];
assert!(
multi_proof
.verify_multi_inclusion(&mut hasher, &elements, &root)
.is_ok(),
"Failed for position {i}"
);
}
}
#[test]
fn test_multi_proof_all_elements() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let mut hasher = Sha256::default();
let positions: Vec<u32> = (0..digests.len() as u32).collect();
let multi_proof = tree.multi_proof(&positions).unwrap();
let elements: Vec<(Digest, u32)> = positions
.iter()
.map(|&p| (digests[p as usize], p))
.collect();
assert!(multi_proof
.verify_multi_inclusion(&mut hasher, &elements, &root)
.is_ok());
assert!(multi_proof.siblings.is_empty());
}
#[test]
fn test_multi_proof_adjacent_elements() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let mut hasher = Sha256::default();
let positions = [2, 3];
let multi_proof = tree.multi_proof(positions).unwrap();
let elements: Vec<(Digest, u32)> = positions
.iter()
.map(|&p| (digests[p as usize], p))
.collect();
assert!(multi_proof
.verify_multi_inclusion(&mut hasher, &elements, &root)
.is_ok());
}
#[test]
fn test_multi_proof_sparse_positions() {
let digests: Vec<Digest> = (0..16u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let mut hasher = Sha256::default();
let positions = [0, 7, 8, 15];
let multi_proof = tree.multi_proof(positions).unwrap();
let elements: Vec<(Digest, u32)> = positions
.iter()
.map(|&p| (digests[p as usize], p))
.collect();
assert!(multi_proof
.verify_multi_inclusion(&mut hasher, &elements, &root)
.is_ok());
}
#[test]
fn test_multi_proof_empty_tree() {
let builder = Builder::<Sha256>::new(0);
let tree = builder.build();
assert!(matches!(
tree.multi_proof(std::iter::empty::<u32>()),
Err(Error::NoLeaves)
));
assert!(matches!(
tree.multi_proof([0]),
Err(Error::InvalidPosition(0))
));
}
#[test]
fn test_multi_proof_empty_positions() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
assert!(matches!(
tree.multi_proof(std::iter::empty::<u32>()),
Err(Error::NoLeaves)
));
}
#[test]
fn test_multi_proof_duplicate_positions_error() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
assert!(matches!(
tree.multi_proof([1, 1]),
Err(Error::DuplicatePosition(1))
));
assert!(matches!(
tree.multi_proof([0, 2, 2, 5]),
Err(Error::DuplicatePosition(2))
));
}
#[test]
fn test_multi_proof_unsorted_input() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let mut hasher = Sha256::default();
let positions = [5, 0, 3];
let multi_proof = tree.multi_proof(positions).unwrap();
let unsorted_elements = [(digests[5], 5), (digests[0], 0), (digests[3], 3)];
assert!(multi_proof
.verify_multi_inclusion(&mut hasher, &unsorted_elements, &root)
.is_ok());
}
#[test]
fn test_multi_proof_various_sizes() {
for tree_size in [1, 2, 3, 4, 5, 7, 8, 15, 16, 31, 32] {
let digests: Vec<Digest> = (0..tree_size as u32)
.map(|i| Sha256::hash(&i.to_be_bytes()))
.collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let mut hasher = Sha256::default();
if tree_size >= 2 {
let positions = [0, (tree_size - 1) as u32];
let multi_proof = tree.multi_proof(positions).unwrap();
let elements: Vec<(Digest, u32)> = positions
.iter()
.map(|&p| (digests[p as usize], p))
.collect();
assert!(
multi_proof
.verify_multi_inclusion(&mut hasher, &elements, &root)
.is_ok(),
"Failed for tree_size={tree_size}, positions=[0, {}]",
tree_size - 1
);
}
if tree_size >= 4 {
let positions: Vec<u32> = (0..tree_size as u32).step_by(2).collect();
let multi_proof = tree.multi_proof(&positions).unwrap();
let elements: Vec<(Digest, u32)> = positions
.iter()
.map(|&p| (digests[p as usize], p))
.collect();
assert!(
multi_proof
.verify_multi_inclusion(&mut hasher, &elements, &root)
.is_ok(),
"Failed for tree_size={tree_size}, every other element"
);
}
}
}
#[test]
fn test_multi_proof_wrong_elements() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let mut hasher = Sha256::default();
let positions = [0, 3, 5];
let multi_proof = tree.multi_proof(positions).unwrap();
let wrong_elements = [
(Sha256::hash(b"wrong1"), 0),
(digests[3], 3),
(digests[5], 5),
];
assert!(multi_proof
.verify_multi_inclusion(&mut hasher, &wrong_elements, &root)
.is_err());
}
#[test]
fn test_multi_proof_wrong_positions() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let mut hasher = Sha256::default();
let positions = [0, 3, 5];
let multi_proof = tree.multi_proof(positions).unwrap();
let wrong_positions = [
(digests[0], 1), (digests[3], 3),
(digests[5], 5),
];
assert!(multi_proof
.verify_multi_inclusion(&mut hasher, &wrong_positions, &root)
.is_err());
}
#[test]
fn test_multi_proof_wrong_root() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let mut hasher = Sha256::default();
let positions = [0, 3, 5];
let multi_proof = tree.multi_proof(positions).unwrap();
let elements: Vec<(Digest, u32)> = positions
.iter()
.map(|&p| (digests[p as usize], p))
.collect();
let wrong_root = Sha256::hash(b"wrong_root");
assert!(multi_proof
.verify_multi_inclusion(&mut hasher, &elements, &wrong_root)
.is_err());
}
#[test]
fn test_multi_proof_tampering() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let mut hasher = Sha256::default();
let positions = [0, 5];
let multi_proof = tree.multi_proof(positions).unwrap();
let elements: Vec<(Digest, u32)> = positions
.iter()
.map(|&p| (digests[p as usize], p))
.collect();
assert!(!multi_proof.siblings.is_empty());
let mut modified = multi_proof.clone();
modified.siblings[0] = Sha256::hash(b"tampered");
assert!(modified
.verify_multi_inclusion(&mut hasher, &elements, &root)
.is_err());
let mut extra = multi_proof.clone();
extra.siblings.push(Sha256::hash(b"extra"));
assert!(extra
.verify_multi_inclusion(&mut hasher, &elements, &root)
.is_err());
let mut missing = multi_proof;
missing.siblings.pop();
assert!(missing
.verify_multi_inclusion(&mut hasher, &elements, &root)
.is_err());
}
#[test]
fn test_multi_proof_deduplication() {
let digests: Vec<Digest> = (0..16u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let individual_siblings: usize = [0u32, 1, 8, 9]
.iter()
.map(|&p| tree.proof(p).unwrap().siblings.len())
.sum();
let multi_proof = tree.multi_proof([0, 1, 8, 9]).unwrap();
assert!(
multi_proof.siblings.len() < individual_siblings,
"Multi-proof ({}) should have fewer siblings than sum of individual proofs ({})",
multi_proof.siblings.len(),
individual_siblings
);
}
#[test]
fn test_multi_proof_serialization() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let mut hasher = Sha256::default();
let positions = [0, 3, 5];
let multi_proof = tree.multi_proof(positions).unwrap();
let serialized = multi_proof.encode();
let deserialized = Proof::<Digest>::decode_cfg(serialized, &positions.len()).unwrap();
assert_eq!(multi_proof, deserialized);
let elements: Vec<(Digest, u32)> = positions
.iter()
.map(|&p| (digests[p as usize], p))
.collect();
assert!(deserialized
.verify_multi_inclusion(&mut hasher, &elements, &root)
.is_ok());
}
#[test]
fn test_multi_proof_serialization_truncated() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let positions = [0, 3, 5];
let multi_proof = tree.multi_proof(positions).unwrap();
let mut serialized = multi_proof.encode();
serialized.truncate(serialized.len() - 1);
assert!(Proof::<Digest>::decode_cfg(&mut serialized, &positions.len()).is_err());
}
#[test]
fn test_multi_proof_serialization_extra() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let positions = [0, 3, 5];
let multi_proof = tree.multi_proof(positions).unwrap();
let mut serialized = multi_proof.encode_mut();
serialized.extend_from_slice(&[0u8]);
assert!(Proof::<Digest>::decode_cfg(&mut serialized, &positions.len()).is_err());
}
#[test]
fn test_multi_proof_decode_insufficient_data() {
let mut serialized = Vec::new();
serialized.extend_from_slice(&8u32.encode()); serialized.extend_from_slice(&1usize.encode());
let err = Proof::<Digest>::decode_cfg(serialized.as_slice(), &1).unwrap_err();
assert!(matches!(err, commonware_codec::Error::EndOfBuffer));
}
#[test]
fn test_multi_proof_invalid_position() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
assert!(matches!(
tree.multi_proof([0, 8]),
Err(Error::InvalidPosition(8))
));
assert!(matches!(
tree.multi_proof([100]),
Err(Error::InvalidPosition(100))
));
}
#[test]
fn test_multi_proof_verify_invalid_position() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let mut hasher = Sha256::default();
let positions = [0, 3];
let multi_proof = tree.multi_proof(positions).unwrap();
let invalid_elements = [(digests[0], 0), (digests[3], 100)];
assert!(multi_proof
.verify_multi_inclusion(&mut hasher, &invalid_elements, &root)
.is_err());
}
#[test]
fn test_multi_proof_odd_tree_sizes() {
for tree_size in [3, 5, 7, 9, 11, 13, 15] {
let digests: Vec<Digest> = (0..tree_size as u32)
.map(|i| Sha256::hash(&i.to_be_bytes()))
.collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let mut hasher = Sha256::default();
let positions = [0, (tree_size - 1) as u32];
let multi_proof = tree.multi_proof(positions).unwrap();
let elements: Vec<(Digest, u32)> = positions
.iter()
.map(|&p| (digests[p as usize], p))
.collect();
assert!(
multi_proof
.verify_multi_inclusion(&mut hasher, &elements, &root)
.is_ok(),
"Failed for tree_size={tree_size}"
);
}
}
#[test]
fn test_multi_proof_verify_empty_elements() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let mut hasher = Sha256::default();
let positions = [0, 3];
let multi_proof = tree.multi_proof(positions).unwrap();
let empty_elements: &[(Digest, u32)] = &[];
assert!(multi_proof
.verify_multi_inclusion(&mut hasher, empty_elements, &root)
.is_err());
}
#[test]
fn test_multi_proof_default_verify() {
let mut hasher = Sha256::default();
let default_proof = Proof::<Digest>::default();
let empty_elements: &[(Digest, u32)] = &[];
let builder = Builder::<Sha256>::new(0);
let empty_tree = builder.build();
let empty_root = empty_tree.root();
assert!(default_proof
.verify_multi_inclusion(&mut hasher, empty_elements, &empty_root)
.is_ok());
let wrong_root = Sha256::hash(b"not_empty");
assert!(default_proof
.verify_multi_inclusion(&mut hasher, empty_elements, &wrong_root)
.is_err());
}
#[test]
fn test_multi_proof_single_leaf_tree() {
let digest = Sha256::hash(b"only_leaf");
let mut builder = Builder::<Sha256>::new(1);
builder.add(&digest);
let tree = builder.build();
let root = tree.root();
let mut hasher = Sha256::default();
let multi_proof = tree.multi_proof([0]).unwrap();
assert_eq!(multi_proof.leaf_count, 1);
assert!(
multi_proof.siblings.is_empty(),
"Single leaf tree should have no siblings"
);
let elements = [(digest, 0u32)];
assert!(
multi_proof
.verify_multi_inclusion(&mut hasher, &elements, &root)
.is_ok(),
"Single leaf multi-proof verification failed"
);
let wrong_digest = Sha256::hash(b"wrong");
let wrong_elements = [(wrong_digest, 0u32)];
assert!(
multi_proof
.verify_multi_inclusion(&mut hasher, &wrong_elements, &root)
.is_err(),
"Should fail with wrong digest"
);
let wrong_position_elements = [(digest, 1u32)];
assert!(
multi_proof
.verify_multi_inclusion(&mut hasher, &wrong_position_elements, &root)
.is_err(),
"Should fail with invalid position"
);
}
#[test]
fn test_multi_proof_malicious_leaf_count_zero() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let mut hasher = Sha256::default();
let positions = [0, 3];
let mut multi_proof = tree.multi_proof(positions).unwrap();
multi_proof.leaf_count = 0;
let elements: Vec<(Digest, u32)> = positions
.iter()
.map(|&p| (digests[p as usize], p))
.collect();
assert!(multi_proof
.verify_multi_inclusion(&mut hasher, &elements, &root)
.is_err());
}
#[test]
fn test_multi_proof_malicious_leaf_count_larger() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let mut hasher = Sha256::default();
let positions = [0, 3];
let mut multi_proof = tree.multi_proof(positions).unwrap();
let original_leaf_count = multi_proof.leaf_count;
multi_proof.leaf_count = 1000;
let elements: Vec<(Digest, u32)> = positions
.iter()
.map(|&p| (digests[p as usize], p))
.collect();
assert!(
multi_proof
.verify_multi_inclusion(&mut hasher, &elements, &root)
.is_err(),
"Should reject proof with inflated leaf_count ({} -> {})",
original_leaf_count,
multi_proof.leaf_count
);
}
#[test]
fn test_multi_proof_malicious_leaf_count_smaller() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let mut hasher = Sha256::default();
let positions = [0, 3];
let mut multi_proof = tree.multi_proof(positions).unwrap();
multi_proof.leaf_count = 4;
let elements: Vec<(Digest, u32)> = positions
.iter()
.map(|&p| (digests[p as usize], p))
.collect();
assert!(
multi_proof
.verify_multi_inclusion(&mut hasher, &elements, &root)
.is_err(),
"Should reject proof with deflated leaf_count"
);
}
#[test]
fn test_multi_proof_mismatched_element_count() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let mut hasher = Sha256::default();
let positions = [0, 3];
let multi_proof = tree.multi_proof(positions).unwrap();
let too_few = [(digests[0], 0u32)];
assert!(
multi_proof
.verify_multi_inclusion(&mut hasher, &too_few, &root)
.is_err(),
"Should reject when fewer elements provided than proof was generated for"
);
let too_many = [(digests[0], 0u32), (digests[3], 3), (digests[5], 5)];
assert!(
multi_proof
.verify_multi_inclusion(&mut hasher, &too_many, &root)
.is_err(),
"Should reject when more elements provided than proof was generated for"
);
}
#[test]
fn test_multi_proof_swapped_siblings() {
let digests: Vec<Digest> = (0..8u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let mut hasher = Sha256::default();
let positions = [0, 5];
let mut multi_proof = tree.multi_proof(positions).unwrap();
if multi_proof.siblings.len() >= 2 {
multi_proof.siblings.swap(0, 1);
let elements: Vec<(Digest, u32)> = positions
.iter()
.map(|&p| (digests[p as usize], p))
.collect();
assert!(
multi_proof
.verify_multi_inclusion(&mut hasher, &elements, &root)
.is_err(),
"Should reject proof with swapped siblings"
);
}
}
#[test]
fn test_multi_proof_dos_large_leaf_count() {
let digests: Vec<Digest> = (0..4u32).map(|i| Sha256::hash(&i.to_be_bytes())).collect();
let mut builder = Builder::<Sha256>::new(digests.len());
for digest in &digests {
builder.add(digest);
}
let tree = builder.build();
let root = tree.root();
let mut hasher = Sha256::default();
let positions = [0, 2];
let mut multi_proof = tree.multi_proof(positions).unwrap();
multi_proof.leaf_count = u32::MAX;
let elements: Vec<(Digest, u32)> = positions
.iter()
.map(|&p| (digests[p as usize], p))
.collect();
let result = multi_proof.verify_multi_inclusion(&mut hasher, &elements, &root);
assert!(result.is_err(), "Should reject malicious large leaf_count");
}
#[cfg(feature = "arbitrary")]
mod conformance {
use super::*;
use commonware_codec::conformance::CodecConformance;
use commonware_conformance::Conformance;
use commonware_cryptography::sha256::Digest as Sha256Digest;
fn test_merkle_tree(n: usize) -> Digest {
let mut digests = Vec::with_capacity(n);
let mut builder = Builder::<Sha256>::new(n);
for i in 0..n {
let digest = Sha256::hash(&i.to_be_bytes());
builder.add(&digest);
digests.push(digest);
}
let tree = builder.build();
let root = tree.root();
let mut hasher = Sha256::default();
for (i, leaf) in digests.iter().enumerate() {
let proof = tree.proof(i as u32).unwrap();
assert!(
proof
.verify_element_inclusion(&mut hasher, leaf, i as u32, &root)
.is_ok(),
"correct fail for size={n} leaf={i}"
);
let serialized = proof.encode();
let deserialized = Proof::<Digest>::decode_cfg(serialized, &1).unwrap();
assert!(
deserialized
.verify_element_inclusion(&mut hasher, leaf, i as u32, &root)
.is_ok(),
"deserialize fail for size={n} leaf={i}"
);
if !proof.siblings.is_empty() {
let mut update_tamper = proof.clone();
update_tamper.siblings[0] = Sha256::hash(b"tampered");
assert!(
update_tamper
.verify_element_inclusion(&mut hasher, leaf, i as u32, &root)
.is_err(),
"modify fail for size={n} leaf={i}"
);
}
let mut add_tamper = proof.clone();
add_tamper.siblings.push(Sha256::hash(b"tampered"));
assert!(
add_tamper
.verify_element_inclusion(&mut hasher, leaf, i as u32, &root)
.is_err(),
"add fail for size={n} leaf={i}"
);
if !proof.siblings.is_empty() {
let mut remove_tamper = proof.clone();
remove_tamper.siblings.pop();
assert!(
remove_tamper
.verify_element_inclusion(&mut hasher, leaf, i as u32, &root)
.is_err(),
"remove fail for size={n} leaf={i}"
);
}
}
assert!(tree.proof(n as u32).is_err());
root
}
struct RootConformance;
impl Conformance for RootConformance {
async fn commit(seed: u64) -> Vec<u8> {
let root = test_merkle_tree(seed as usize);
root.to_vec()
}
}
commonware_conformance::conformance_tests! {
CodecConformance<Proof<Sha256Digest>>,
RootConformance => 200
}
}
}