use serde::{Deserialize, Serialize};
use crate::entry::AuditEntry;
use crate::hasher::ChainHasher;
#[derive(Debug, Clone)]
pub struct MerkleTree {
nodes: Vec<String>,
leaf_count: usize,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[non_exhaustive]
pub struct MerkleProof {
pub index: usize,
pub leaf_hash: String,
pub path: Vec<ProofNode>,
pub root: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[non_exhaustive]
pub struct ProofNode {
pub hash: String,
pub side: Side,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum Side {
Left,
Right,
}
impl MerkleTree {
pub fn build(entries: &[AuditEntry]) -> Option<Self> {
if entries.is_empty() {
return None;
}
let leaves: Vec<String> = entries.iter().map(|e| e.hash().to_owned()).collect();
let leaf_count = leaves.len();
let mut current_level = leaves;
let mut nodes = Vec::new();
loop {
if current_level.len() == 1 {
nodes.extend(current_level);
break;
}
let mut next_level = Vec::with_capacity(current_level.len().div_ceil(2));
let mut i = 0;
while i < current_level.len() {
let left = ¤t_level[i];
let right = if i + 1 < current_level.len() {
¤t_level[i + 1]
} else {
left
};
next_level.push(hash_pair(left, right));
i += 2;
}
nodes.extend(current_level);
current_level = next_level;
}
Some(Self { nodes, leaf_count })
}
#[inline]
#[must_use]
pub fn root(&self) -> &str {
self.nodes.last().map(|s| s.as_str()).unwrap_or("")
}
#[inline]
#[must_use]
pub fn leaf_count(&self) -> usize {
self.leaf_count
}
pub fn proof(&self, index: usize) -> Option<MerkleProof> {
if index >= self.leaf_count {
return None;
}
let mut path = Vec::new();
let mut level_start = 0;
let mut level_size = self.leaf_count;
let mut idx = index;
while level_size > 1 {
let sibling_idx = if idx.is_multiple_of(2) {
idx + 1
} else {
idx - 1
};
let sibling_hash = if sibling_idx < level_size {
self.nodes[level_start + sibling_idx].clone()
} else {
self.nodes[level_start + idx].clone()
};
let side = if idx.is_multiple_of(2) {
Side::Right
} else {
Side::Left
};
path.push(ProofNode {
hash: sibling_hash,
side,
});
level_start += level_size;
level_size = level_size.div_ceil(2);
idx /= 2;
}
Some(MerkleProof {
index,
leaf_hash: self.nodes[index].clone(),
path,
root: self.root().to_owned(),
})
}
pub fn consistency_proof(&self, old_size: usize) -> Option<ConsistencyProof> {
if old_size == 0 || old_size > self.leaf_count {
return None;
}
let old_root = self.canonical_root(old_size)?;
let new_root = self.canonical_root(self.leaf_count)?;
if old_size == self.leaf_count {
return Some(ConsistencyProof {
old_size,
new_size: self.leaf_count,
old_root,
new_root,
path: Vec::new(),
});
}
let mut path = Vec::new();
subproof(old_size, 0, self.leaf_count, true, &self.nodes, &mut path);
Some(ConsistencyProof {
old_size,
new_size: self.leaf_count,
old_root,
new_root,
path,
})
}
#[must_use]
pub fn canonical_root(&self, size: usize) -> Option<String> {
if size == 0 || size > self.leaf_count {
return None;
}
Some(canonical_subtree_hash(&self.nodes, 0, size))
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[non_exhaustive]
pub struct ConsistencyProof {
pub old_size: usize,
pub new_size: usize,
pub old_root: String,
pub new_root: String,
pub path: Vec<String>,
}
#[must_use]
pub fn verify_proof(proof: &MerkleProof) -> bool {
let mut current = proof.leaf_hash.clone();
for node in &proof.path {
current = match node.side {
Side::Left => hash_pair(&node.hash, ¤t),
Side::Right => hash_pair(¤t, &node.hash),
};
}
crate::entry::constant_time_eq(¤t, &proof.root)
}
#[must_use]
pub fn verify_consistency(proof: &ConsistencyProof) -> bool {
if proof.old_size == 0 || proof.old_size > proof.new_size {
return false;
}
if proof.old_size == proof.new_size {
return proof.path.is_empty()
&& crate::entry::constant_time_eq(&proof.old_root, &proof.new_root);
}
let mut path: Vec<&str> = proof.path.iter().map(|s| s.as_str()).collect();
if proof.old_size.is_power_of_two() {
path.insert(0, &proof.old_root);
}
if path.is_empty() {
return false;
}
let mut fn_idx = proof.old_size - 1;
let mut sn_idx = proof.new_size - 1;
while fn_idx & 1 == 1 {
fn_idx >>= 1;
sn_idx >>= 1;
}
let mut fr = path[0].to_owned();
let mut sr = path[0].to_owned();
for c in &path[1..] {
if sn_idx == 0 {
return false;
}
if fn_idx & 1 == 1 || fn_idx == sn_idx {
fr = hash_pair(c, &fr);
sr = hash_pair(c, &sr);
while fn_idx != 0 && fn_idx & 1 == 0 {
fn_idx >>= 1;
sn_idx >>= 1;
}
} else {
sr = hash_pair(&sr, c);
}
fn_idx >>= 1;
sn_idx >>= 1;
}
sn_idx == 0
&& crate::entry::constant_time_eq(&fr, &proof.old_root)
&& crate::entry::constant_time_eq(&sr, &proof.new_root)
}
fn subproof(
m: usize,
start: usize,
n: usize,
is_complete: bool,
nodes: &[String],
path: &mut Vec<String>,
) {
if m == n {
if !is_complete {
path.push(canonical_subtree_hash(nodes, start, n));
}
return;
}
if n == 1 {
if !is_complete {
path.push(nodes[start].clone());
}
return;
}
let k = largest_power_of_2_less_than(n);
if m <= k {
subproof(m, start, k, is_complete, nodes, path);
path.push(canonical_subtree_hash(nodes, start + k, n - k));
} else {
subproof(m - k, start + k, n - k, false, nodes, path);
path.push(canonical_subtree_hash(nodes, start, k));
}
}
fn canonical_subtree_hash(nodes: &[String], start: usize, count: usize) -> String {
if count == 0 {
return String::new();
}
if count == 1 {
return nodes[start].clone();
}
let k = largest_power_of_2_less_than(count);
let left = canonical_subtree_hash(nodes, start, k);
let right = canonical_subtree_hash(nodes, start + k, count - k);
hash_pair(&left, &right)
}
#[inline]
fn largest_power_of_2_less_than(n: usize) -> usize {
debug_assert!(n > 1);
1 << (usize::BITS - 1 - (n - 1).leading_zeros())
}
#[inline]
fn hash_pair(left: &str, right: &str) -> String {
let mut hasher = ChainHasher::new();
hasher.update(left.as_bytes());
hasher.update(right.as_bytes());
hasher.finalize_hex()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::entry::EventSeverity;
fn make_entries(n: usize) -> Vec<AuditEntry> {
let mut entries = Vec::new();
let first = AuditEntry::new(EventSeverity::Info, "s", "e0", serde_json::json!({}), "");
entries.push(first);
for i in 1..n {
let prev = entries[i - 1].hash();
entries.push(AuditEntry::new(
EventSeverity::Info,
"s",
format!("e{i}"),
serde_json::json!({}),
prev,
));
}
entries
}
#[test]
fn build_empty() {
assert!(MerkleTree::build(&[]).is_none());
}
#[test]
fn build_single_entry() {
let entries = make_entries(1);
let tree = MerkleTree::build(&entries).unwrap();
assert_eq!(tree.leaf_count(), 1);
assert_eq!(tree.root(), entries[0].hash());
}
#[test]
fn build_two_entries() {
let entries = make_entries(2);
let tree = MerkleTree::build(&entries).unwrap();
assert_eq!(tree.leaf_count(), 2);
let expected_root = hash_pair(entries[0].hash(), entries[1].hash());
assert_eq!(tree.root(), expected_root);
}
#[test]
fn build_power_of_two() {
let entries = make_entries(8);
let tree = MerkleTree::build(&entries).unwrap();
assert_eq!(tree.leaf_count(), 8);
assert!(!tree.root().is_empty());
}
#[test]
fn build_odd_count() {
let entries = make_entries(5);
let tree = MerkleTree::build(&entries).unwrap();
assert_eq!(tree.leaf_count(), 5);
assert!(!tree.root().is_empty());
}
#[test]
fn proof_and_verify_all_entries() {
let entries = make_entries(8);
let tree = MerkleTree::build(&entries).unwrap();
for (i, entry) in entries.iter().enumerate() {
let proof = tree.proof(i).unwrap();
assert_eq!(proof.index, i);
assert_eq!(proof.leaf_hash, entry.hash());
assert_eq!(proof.root, tree.root());
assert!(verify_proof(&proof), "proof failed for index {i}");
}
}
#[test]
fn proof_and_verify_odd_tree() {
let entries = make_entries(7);
let tree = MerkleTree::build(&entries).unwrap();
for i in 0..entries.len() {
let proof = tree.proof(i).unwrap();
assert!(verify_proof(&proof), "proof failed for index {i}");
}
}
#[test]
fn proof_out_of_bounds() {
let entries = make_entries(4);
let tree = MerkleTree::build(&entries).unwrap();
assert!(tree.proof(4).is_none());
assert!(tree.proof(100).is_none());
}
#[test]
fn tampered_proof_fails() {
let entries = make_entries(8);
let tree = MerkleTree::build(&entries).unwrap();
let mut proof = tree.proof(3).unwrap();
proof.leaf_hash = "tampered".to_owned();
assert!(!verify_proof(&proof));
}
#[test]
fn tampered_path_fails() {
let entries = make_entries(8);
let tree = MerkleTree::build(&entries).unwrap();
let mut proof = tree.proof(3).unwrap();
if let Some(node) = proof.path.first_mut() {
node.hash = "tampered".to_owned();
}
assert!(!verify_proof(&proof));
}
#[test]
fn different_entries_different_roots() {
let entries_a = make_entries(4);
let mut entries_b = make_entries(4);
entries_b[2].corrupt_action("different");
let tree_a = MerkleTree::build(&entries_a).unwrap();
let tree_b = MerkleTree::build(&entries_b).unwrap();
assert_ne!(tree_a.root(), tree_b.root());
}
#[test]
fn large_tree() {
let entries = make_entries(100);
let tree = MerkleTree::build(&entries).unwrap();
assert_eq!(tree.leaf_count(), 100);
for i in [0, 49, 99] {
let proof = tree.proof(i).unwrap();
assert!(verify_proof(&proof));
}
}
#[test]
fn single_entry_proof() {
let entries = make_entries(1);
let tree = MerkleTree::build(&entries).unwrap();
let proof = tree.proof(0).unwrap();
assert!(proof.path.is_empty()); assert!(verify_proof(&proof));
}
#[test]
fn consistency_same_size() {
let entries = make_entries(8);
let tree = MerkleTree::build(&entries).unwrap();
let proof = tree.consistency_proof(8).unwrap();
assert!(proof.path.is_empty());
assert_eq!(proof.old_root, proof.new_root);
assert!(verify_consistency(&proof));
}
#[test]
fn consistency_power_of_two() {
let entries = make_entries(8);
let tree = MerkleTree::build(&entries).unwrap();
for old_size in 1..=8 {
let proof = tree.consistency_proof(old_size).unwrap();
assert_eq!(proof.old_size, old_size);
assert_eq!(proof.new_size, 8);
assert!(
verify_consistency(&proof),
"consistency proof failed for old_size={old_size}"
);
}
}
#[test]
fn consistency_odd_sizes() {
for n in [3, 5, 7, 9, 11, 13, 15] {
let entries = make_entries(n);
let tree = MerkleTree::build(&entries).unwrap();
for m in 1..=n {
let proof = tree.consistency_proof(m).unwrap();
assert!(
verify_consistency(&proof),
"consistency proof failed for m={m}, n={n}"
);
}
}
}
#[test]
fn consistency_one_to_many() {
let entries = make_entries(16);
let tree = MerkleTree::build(&entries).unwrap();
let proof = tree.consistency_proof(1).unwrap();
assert!(verify_consistency(&proof));
assert_eq!(proof.old_root, entries[0].hash());
}
#[test]
fn consistency_invalid_old_size() {
let entries = make_entries(5);
let tree = MerkleTree::build(&entries).unwrap();
assert!(tree.consistency_proof(0).is_none());
assert!(tree.consistency_proof(6).is_none());
}
#[test]
fn consistency_tampered_path_fails() {
let entries = make_entries(8);
let tree = MerkleTree::build(&entries).unwrap();
let mut proof = tree.consistency_proof(3).unwrap();
if let Some(h) = proof.path.first_mut() {
*h = "tampered".to_owned();
}
assert!(!verify_consistency(&proof));
}
#[test]
fn consistency_wrong_old_size_fails() {
let entries = make_entries(8);
let tree = MerkleTree::build(&entries).unwrap();
let mut proof = tree.consistency_proof(4).unwrap();
proof.old_size = 3; assert!(!verify_consistency(&proof));
}
#[test]
fn canonical_root_power_of_two_matches_tree_root() {
for n in [1, 2, 4, 8, 16, 32] {
let entries = make_entries(n);
let tree = MerkleTree::build(&entries).unwrap();
let canonical = tree.canonical_root(n).unwrap();
assert_eq!(
canonical,
tree.root(),
"canonical root should match tree root for power-of-2 size {n}"
);
}
}
#[test]
fn canonical_root_bounds() {
let entries = make_entries(5);
let tree = MerkleTree::build(&entries).unwrap();
assert!(tree.canonical_root(0).is_none());
assert!(tree.canonical_root(6).is_none());
assert!(tree.canonical_root(5).is_some());
}
#[test]
fn canonical_root_prefix_stable() {
let entries_5 = make_entries(5);
let entries_8 = {
let mut v = entries_5.clone();
let prev = v.last().unwrap().hash().to_owned();
for i in 5..8 {
v.push(AuditEntry::new(
EventSeverity::Info,
"s",
format!("e{i}"),
serde_json::json!({}),
&prev,
));
}
v
};
let tree_5 = MerkleTree::build(&entries_5).unwrap();
let tree_8 = MerkleTree::build(&entries_8).unwrap();
assert_eq!(tree_5.canonical_root(5), tree_8.canonical_root(5));
}
#[test]
fn consistency_large_tree() {
let entries = make_entries(100);
let tree = MerkleTree::build(&entries).unwrap();
for m in [1, 10, 33, 50, 64, 99, 100] {
let proof = tree.consistency_proof(m).unwrap();
assert!(
verify_consistency(&proof),
"consistency proof failed for m={m}, n=100"
);
}
}
#[test]
fn consistency_serde_roundtrip() {
let entries = make_entries(8);
let tree = MerkleTree::build(&entries).unwrap();
let proof = tree.consistency_proof(3).unwrap();
let json = serde_json::to_string(&proof).unwrap();
let back: ConsistencyProof = serde_json::from_str(&json).unwrap();
assert_eq!(proof, back);
assert!(verify_consistency(&back));
}
}