use crate::hash::hash_vc_node;
#[derive(Clone, Debug)]
pub(crate) struct MerkleTree {
layers: Vec<Vec<[u8; 32]>>,
}
impl MerkleTree {
pub(crate) fn build_from_hashes(mut leaf_hashes: Vec<[u8; 32]>) -> Self {
assert!(
!leaf_hashes.is_empty(),
"MerkleTree::build_from_hashes: empty"
);
let n = leaf_hashes.len().next_power_of_two();
while leaf_hashes.len() < n {
leaf_hashes.push(*leaf_hashes.last().unwrap());
}
let mut layers = vec![leaf_hashes];
while layers.last().unwrap().len() > 1 {
let prev = layers.last().unwrap();
let next: Vec<[u8; 32]> = prev
.chunks(2)
.map(|pair| hash_vc_node(&pair[0], &pair[1]))
.collect();
layers.push(next);
}
Self { layers }
}
pub(crate) fn root(&self) -> [u8; 32] {
self.layers.last().unwrap()[0]
}
pub(crate) fn multiproof(&self, sorted_unique_indices: &[usize]) -> Vec<[u8; 32]> {
use std::collections::BTreeSet;
debug_assert!(
sorted_unique_indices.windows(2).all(|w| w[0] < w[1]),
"multiproof: indices must be sorted-ascending and deduplicated"
);
let mut frontier: BTreeSet<usize> = sorted_unique_indices.iter().copied().collect();
let mut out = Vec::new();
for layer in &self.layers[..self.layers.len() - 1] {
let mut next: BTreeSet<usize> = BTreeSet::new();
for &i in &frontier {
let sib = i ^ 1;
if !frontier.contains(&sib) {
out.push(layer[sib]);
}
next.insert(i >> 1);
}
frontier = next;
}
out
}
}
pub(crate) fn multiproof_size(sorted_unique_indices: &[usize], depth: usize) -> usize {
use std::collections::BTreeSet;
debug_assert!(
sorted_unique_indices.windows(2).all(|w| w[0] < w[1]),
"multiproof_size: indices must be sorted-ascending and deduplicated"
);
let mut frontier: BTreeSet<usize> = sorted_unique_indices.iter().copied().collect();
let mut count = 0usize;
for _ in 0..depth {
let mut next: BTreeSet<usize> = BTreeSet::new();
for &i in &frontier {
let sib = i ^ 1;
if !frontier.contains(&sib) {
count += 1;
}
next.insert(i >> 1);
}
frontier = next;
}
count
}
pub(crate) fn verify_multiproof(
root: [u8; 32],
num_leaves: usize,
sorted_unique_indices: &[usize],
leaf_hashes: &[[u8; 32]],
proof: &[[u8; 32]],
) -> bool {
use std::collections::BTreeMap;
if sorted_unique_indices.len() != leaf_hashes.len() {
return false;
}
if !sorted_unique_indices.windows(2).all(|w| w[0] < w[1]) {
return false;
}
let depth = num_leaves.next_power_of_two().trailing_zeros() as usize;
let mut nodes: BTreeMap<usize, [u8; 32]> = sorted_unique_indices
.iter()
.copied()
.zip(leaf_hashes.iter().copied())
.collect();
let mut proof_iter = proof.iter().copied();
for _ in 0..depth {
let mut next: BTreeMap<usize, [u8; 32]> = BTreeMap::new();
let entries: Vec<(usize, [u8; 32])> = nodes.iter().map(|(&k, &v)| (k, v)).collect();
let mut skip_next = false;
for (k, &(i, node_hash)) in entries.iter().enumerate() {
if skip_next {
skip_next = false;
continue;
}
let sib = i ^ 1;
let sib_hash = match entries.get(k + 1) {
Some(&(next_i, next_hash)) if next_i == sib => {
skip_next = true;
next_hash
}
_ => match proof_iter.next() {
Some(h) => h,
None => return false,
},
};
let (l, r) = if i & 1 == 0 {
(node_hash, sib_hash)
} else {
(sib_hash, node_hash)
};
let parent = crate::hash::hash_vc_node(&l, &r);
next.insert(i >> 1, parent);
}
nodes = next;
}
if proof_iter.next().is_some() {
return false; }
nodes.len() == 1 && nodes.values().next().copied() == Some(root)
}
#[cfg(test)]
mod tests {
use super::*;
fn fake_hashes(n: usize) -> Vec<[u8; 32]> {
(0..n)
.map(|i| {
let mut x = [0u8; 32];
x[0] = i as u8;
x
})
.collect()
}
#[test]
fn pads_to_power_of_two() {
let hashes = fake_hashes(13);
let tree = MerkleTree::build_from_hashes(hashes);
assert_eq!(tree.layers.len(), 5);
}
#[test]
fn multiproof_round_trip() {
let n = 64;
let hashes = fake_hashes(n);
let tree = MerkleTree::build_from_hashes(hashes.clone());
let root = tree.root();
let probes: &[&[usize]] = &[
&[0, 1, 2, 3, 4, 5, 6, 7],
&[0, 8, 16, 24, 32, 40, 48, 56],
&[3, 7, 11, 15, 19, 23, 27, 31],
&[0, 1],
&[0, 63],
];
for indices in probes {
let mut sorted: Vec<usize> = indices.to_vec();
sorted.sort_unstable();
sorted.dedup();
let proof = tree.multiproof(&sorted);
let leaf_hashes: Vec<[u8; 32]> = sorted.iter().map(|&i| hashes[i]).collect();
let expected_size = multiproof_size(&sorted, tree.layers.len() - 1);
assert_eq!(
proof.len(),
expected_size,
"multiproof size mismatch for {indices:?}"
);
assert!(
verify_multiproof(root, n, &sorted, &leaf_hashes, &proof),
"multiproof verify failed for {indices:?}"
);
}
}
#[test]
fn multiproof_dense_query_set_emits_no_siblings() {
let n = 64;
let hashes = fake_hashes(n);
let tree = MerkleTree::build_from_hashes(hashes.clone());
let root = tree.root();
let indices: Vec<usize> = (0..n).collect();
let proof = tree.multiproof(&indices);
assert!(
proof.is_empty(),
"full-query multiproof should be empty, got {} hashes",
proof.len()
);
assert!(verify_multiproof(root, n, &indices, &hashes, &proof));
}
#[test]
fn multiproof_tampering_rejected() {
let n = 64;
let hashes = fake_hashes(n);
let tree = MerkleTree::build_from_hashes(hashes.clone());
let root = tree.root();
let indices: Vec<usize> = vec![3, 17, 42, 50];
let proof = tree.multiproof(&indices);
let mut leaf_hashes: Vec<[u8; 32]> = indices.iter().map(|&i| hashes[i]).collect();
leaf_hashes[0][0] ^= 1; assert!(!verify_multiproof(root, n, &indices, &leaf_hashes, &proof));
}
#[test]
fn multiproof_savings_match_formula() {
let n = 1024; let hashes = fake_hashes(n);
let tree = MerkleTree::build_from_hashes(hashes.clone());
let depth = tree.layers.len() - 1;
assert_eq!(depth, 10);
let mut indices: Vec<usize> = vec![3, 17, 42, 50, 99, 200, 500, 900];
indices.sort_unstable();
let proof = tree.multiproof(&indices);
assert!(
proof.len() < indices.len() * depth,
"pruned proof {} must be smaller than naive {}",
proof.len(),
indices.len() * depth
);
}
}