use crate::error::{RaftError, RaftResult};
const LEAF_PREFIX: u8 = 0x00;
const INTERNAL_PREFIX: u8 = 0x01;
fn empty_root() -> [u8; 32] {
*blake3::hash(b"amaters-merkle-empty-v1").as_bytes()
}
fn hash_leaf(leaf: &[u8; 32]) -> [u8; 32] {
let mut hasher = blake3::Hasher::new();
hasher.update(&[LEAF_PREFIX]);
hasher.update(leaf);
*hasher.finalize().as_bytes()
}
fn hash_internal(left: &[u8; 32], right: &[u8; 32]) -> [u8; 32] {
let mut hasher = blake3::Hasher::new();
hasher.update(&[INTERNAL_PREFIX]);
hasher.update(left);
hasher.update(right);
*hasher.finalize().as_bytes()
}
fn build_layers(leaves: &[[u8; 32]]) -> Vec<Vec<[u8; 32]>> {
let mut layers: Vec<Vec<[u8; 32]>> = Vec::new();
let leaf_layer: Vec<[u8; 32]> = leaves.iter().map(hash_leaf).collect();
layers.push(leaf_layer);
loop {
let next: Vec<[u8; 32]> = match layers.last() {
Some(layer) if layer.len() > 1 => layer
.chunks(2)
.map(|pair| {
let left = &pair[0];
let right = pair.get(1).unwrap_or(&pair[0]);
hash_internal(left, right)
})
.collect(),
_ => break,
};
layers.push(next);
}
layers
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MerkleProof {
pub siblings: Vec<[u8; 32]>,
pub index: usize,
}
#[derive(Debug, Clone)]
pub struct MerkleTree {
leaves: Vec<[u8; 32]>,
root: [u8; 32],
layers: Vec<Vec<[u8; 32]>>,
}
impl MerkleTree {
pub fn new(leaves: Vec<[u8; 32]>) -> Self {
if leaves.is_empty() {
return Self {
leaves,
root: empty_root(),
layers: Vec::new(),
};
}
let layers = build_layers(&leaves);
let root = match layers.last().and_then(|l| l.first()) {
Some(r) => *r,
None => empty_root(),
};
Self {
leaves,
root,
layers,
}
}
pub fn len(&self) -> usize {
self.leaves.len()
}
pub fn is_empty(&self) -> bool {
self.leaves.is_empty()
}
pub fn root(&self) -> [u8; 32] {
self.root
}
pub fn proof(&self, index: usize) -> RaftResult<MerkleProof> {
if index >= self.leaves.len() {
return Err(RaftError::Other {
message: format!(
"MerkleTree::proof: index {} out of range (len = {})",
index,
self.leaves.len()
),
});
}
if self.leaves.len() == 1 {
return Ok(MerkleProof {
siblings: Vec::new(),
index,
});
}
let mut siblings = Vec::new();
let mut current = index;
for layer in self.layers.iter().take(self.layers.len().saturating_sub(1)) {
let sibling_idx = if current % 2 == 0 {
if current + 1 < layer.len() {
current + 1
} else {
current
}
} else {
current - 1
};
siblings.push(layer[sibling_idx]);
current /= 2;
}
Ok(MerkleProof { siblings, index })
}
pub fn verify(leaf: [u8; 32], proof: &MerkleProof, root: [u8; 32]) -> bool {
let mut current = hash_leaf(&leaf);
let mut idx = proof.index;
for sibling in &proof.siblings {
current = if idx % 2 == 0 {
hash_internal(¤t, sibling)
} else {
hash_internal(sibling, ¤t)
};
idx /= 2;
}
current == root
}
}
#[cfg(test)]
mod tests {
use super::*;
fn leaves_from_seed(n: usize) -> Vec<[u8; 32]> {
(0..n)
.map(|i| {
let mut leaf = [0u8; 32];
leaf[0] = i as u8;
leaf[1] = (i >> 8) as u8;
let h = blake3::hash(&i.to_le_bytes());
leaf[2..].copy_from_slice(&h.as_bytes()[..30]);
leaf
})
.collect()
}
#[test]
fn test_merkle_tree_root_deterministic() {
let leaves = leaves_from_seed(7);
let tree_a = MerkleTree::new(leaves.clone());
let tree_b = MerkleTree::new(leaves);
assert_eq!(
tree_a.root(),
tree_b.root(),
"two trees built from identical leaves must have identical roots"
);
}
#[test]
fn test_merkle_tree_proof_verifies() {
let leaves = leaves_from_seed(8);
let tree = MerkleTree::new(leaves.clone());
let root = tree.root();
for (i, leaf) in leaves.iter().enumerate() {
let proof = tree.proof(i).expect("proof must be available");
assert!(
MerkleTree::verify(*leaf, &proof, root),
"proof for leaf index {} must verify against the root",
i
);
}
}
#[test]
fn test_merkle_tree_proof_fails_on_tampered_leaf() {
let leaves = leaves_from_seed(6);
let tree = MerkleTree::new(leaves.clone());
let root = tree.root();
let proof = tree.proof(3).expect("proof at index 3");
let mut tampered = leaves[3];
tampered[0] ^= 0xff;
assert!(
!MerkleTree::verify(tampered, &proof, root),
"tampered leaf must not verify against the original root"
);
}
#[test]
fn test_merkle_tree_empty_leaves_root() {
let tree = MerkleTree::new(Vec::new());
assert!(tree.is_empty(), "empty leaves yield is_empty = true");
assert_eq!(tree.len(), 0);
assert_eq!(
tree.root(),
empty_root(),
"empty tree root must equal the well-known empty constant"
);
assert!(tree.proof(0).is_err(), "proof of empty tree must error");
}
#[test]
fn test_merkle_tree_single_leaf_root() {
let leaf = [0xa5u8; 32];
let tree = MerkleTree::new(vec![leaf]);
assert_eq!(tree.len(), 1);
assert_eq!(
tree.root(),
hash_leaf(&leaf),
"single-leaf tree root must equal the leaf hash"
);
let proof = tree.proof(0).expect("proof at index 0");
assert!(proof.siblings.is_empty());
assert!(MerkleTree::verify(leaf, &proof, tree.root()));
}
#[test]
fn test_merkle_tree_proof_odd_arity() {
let leaves = leaves_from_seed(5);
let tree = MerkleTree::new(leaves.clone());
let root = tree.root();
for (i, leaf) in leaves.iter().enumerate() {
let proof = tree.proof(i).expect("proof must be available");
assert!(
MerkleTree::verify(*leaf, &proof, root),
"odd-arity tree: proof for leaf {} must verify",
i
);
}
}
#[test]
fn test_merkle_tree_proof_out_of_range() {
let leaves = leaves_from_seed(3);
let tree = MerkleTree::new(leaves);
assert!(tree.proof(99).is_err());
}
#[test]
fn test_merkle_tree_verify_wrong_root_fails() {
let leaves = leaves_from_seed(4);
let tree = MerkleTree::new(leaves.clone());
let proof = tree.proof(1).expect("valid proof");
let bogus_root = [0xffu8; 32];
assert!(
!MerkleTree::verify(leaves[1], &proof, bogus_root),
"verification against a wrong root must fail"
);
}
#[test]
fn test_merkle_tree_domain_separation_distinguishes_layers() {
let a = [0x11u8; 32];
let b = [0x22u8; 32];
let tree = MerkleTree::new(vec![a, b]);
let root = tree.root();
let attacker_internal = hash_internal(&a, &b);
assert_ne!(
root, attacker_internal,
"domain separation must distinguish leaf-input from internal-input"
);
}
}