use sha2::{Digest, Sha256};
use super::types::{AttestationRecord, MerkleProof};
pub struct MerkleTree {
levels: Vec<Vec<String>>,
}
impl MerkleTree {
pub fn build(records: &[AttestationRecord]) -> Self {
if records.is_empty() {
return Self { levels: Vec::new() };
}
let leaves: Vec<String> = records
.iter()
.map(|r| r.record_hash.clone())
.collect();
let mut levels: Vec<Vec<String>> = Vec::new();
levels.push(leaves);
while levels.last().map(|l| l.len()).unwrap_or(0) > 1 {
let current = levels.last().unwrap();
let mut next: Vec<String> = Vec::new();
let mut i = 0;
while i < current.len() {
let left = ¤t[i];
let right = if i + 1 < current.len() {
¤t[i + 1]
} else {
¤t[i]
};
next.push(hash_pair(left, right));
i += 2;
}
levels.push(next);
}
Self { levels }
}
pub fn root(&self) -> Option<&str> {
self.levels.last()?.first().map(String::as_str)
}
pub fn generate_proof(&self, leaf_index: usize) -> Option<MerkleProof> {
if self.levels.is_empty() {
return None;
}
let leaves = &self.levels[0];
if leaf_index >= leaves.len() {
return None;
}
let leaf_hash = leaves[leaf_index].clone();
let total_leaves = leaves.len();
let root_hash = self.root()?.to_string();
let mut proof_hashes: Vec<(String, bool)> = Vec::new();
let mut current_index = leaf_index;
for level in &self.levels[..self.levels.len().saturating_sub(1)] {
let (sibling_index, is_right_sibling) = if current_index.is_multiple_of(2) {
let sibling = if current_index + 1 < level.len() {
current_index + 1
} else {
current_index };
(sibling, true)
} else {
(current_index - 1, false)
};
proof_hashes.push((level[sibling_index].clone(), is_right_sibling));
current_index /= 2;
}
Some(MerkleProof {
leaf_hash,
leaf_index,
proof_hashes,
root_hash,
total_leaves,
})
}
pub fn verify_proof(proof: &MerkleProof) -> bool {
let mut current = proof.leaf_hash.clone();
for (sibling_hash, is_right_sibling) in &proof.proof_hashes {
current = if *is_right_sibling {
hash_pair(¤t, sibling_hash)
} else {
hash_pair(sibling_hash, ¤t)
};
}
current == proof.root_hash
}
}
fn hash_pair(left: &str, right: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(left.as_bytes());
hasher.update(right.as_bytes());
format!("sha256:{}", hex::encode(hasher.finalize()))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::attestation::chain::AttestationChain;
use crate::storage::Storage;
use chrono::Utc;
fn make_record(hash: &str, name: &str) -> AttestationRecord {
AttestationRecord {
id: None,
document_hash: format!("sha256:{hash}"),
document_name: name.to_string(),
document_size: 0,
ingested_at: Utc::now(),
agent_id: None,
memory_ids: vec![],
previous_hash: "genesis".to_string(),
record_hash: format!("sha256:{hash}-record"),
signature: None,
metadata: serde_json::Value::Object(serde_json::Map::new()),
created_at: None,
}
}
#[test]
fn test_empty_tree() {
let tree = MerkleTree::build(&[]);
assert!(tree.root().is_none());
assert!(tree.generate_proof(0).is_none());
}
#[test]
fn test_single_leaf() {
let rec = make_record("aabbcc", "a.txt");
let tree = MerkleTree::build(&[rec.clone()]);
assert_eq!(tree.root(), Some(rec.record_hash.as_str()));
let proof = tree.generate_proof(0).unwrap();
assert_eq!(proof.leaf_hash, rec.record_hash);
assert_eq!(proof.total_leaves, 1);
assert!(MerkleTree::verify_proof(&proof));
}
#[test]
fn test_two_leaves() {
let r1 = make_record("aa", "a.txt");
let r2 = make_record("bb", "b.txt");
let tree = MerkleTree::build(&[r1.clone(), r2.clone()]);
let expected_root = hash_pair(&r1.record_hash, &r2.record_hash);
assert_eq!(tree.root(), Some(expected_root.as_str()));
let proof0 = tree.generate_proof(0).unwrap();
assert!(MerkleTree::verify_proof(&proof0));
let proof1 = tree.generate_proof(1).unwrap();
assert!(MerkleTree::verify_proof(&proof1));
}
#[test]
fn test_three_leaves_odd() {
let records: Vec<_> = ["aa", "bb", "cc"]
.iter()
.enumerate()
.map(|(i, h)| make_record(h, &format!("{i}.txt")))
.collect();
let tree = MerkleTree::build(&records);
assert!(tree.root().is_some());
for i in 0..3 {
let proof = tree.generate_proof(i).unwrap();
assert!(MerkleTree::verify_proof(&proof), "proof {i} failed");
}
}
#[test]
fn test_proof_from_real_chain() {
let storage = Storage::open_in_memory().unwrap();
let chain = AttestationChain::new(storage);
let r1 = chain.log_document(b"doc1", "d1.txt", None, &[], None).unwrap();
let r2 = chain.log_document(b"doc2", "d2.txt", None, &[], None).unwrap();
let r3 = chain.log_document(b"doc3", "d3.txt", None, &[], None).unwrap();
let tree = MerkleTree::build(&[r1, r2, r3]);
for i in 0..3 {
let proof = tree.generate_proof(i).unwrap();
assert!(MerkleTree::verify_proof(&proof), "real chain proof {i} failed");
}
}
#[test]
fn test_tampered_proof_fails() {
let r1 = make_record("aa", "a.txt");
let r2 = make_record("bb", "b.txt");
let tree = MerkleTree::build(&[r1, r2]);
let mut proof = tree.generate_proof(0).unwrap();
proof.root_hash = "sha256:0000000000000000000000000000000000000000000000000000000000000000".to_string();
assert!(!MerkleTree::verify_proof(&proof));
}
}