use sha2::{Sha256, Digest};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum Direction {
Left,
Right,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProofStep {
pub direction: Direction,
pub hash: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InclusionProof {
pub leaf_index: usize,
pub leaf_hash: String,
pub path: Vec<ProofStep>,
}
pub struct MerkleTree {
leaves: Vec<[u8; 32]>,
}
impl MerkleTree {
pub fn new() -> Self {
Self { leaves: Vec::new() }
}
pub fn append(&mut self, artifact_id: &str) -> usize {
let hash = Sha256::digest(artifact_id.as_bytes());
self.leaves.push(hash.into());
self.leaves.len() - 1
}
pub fn len(&self) -> usize {
self.leaves.len()
}
pub fn is_empty(&self) -> bool {
self.leaves.is_empty()
}
pub fn root(&self) -> Option<[u8; 32]> {
if self.leaves.is_empty() {
return None;
}
Some(self.compute_root(&self.leaves))
}
pub fn height(&self) -> usize {
if self.leaves.len() <= 1 {
return 0;
}
(self.leaves.len() as f64).log2().ceil() as usize
}
pub fn inclusion_proof(&self, leaf_index: usize) -> Option<InclusionProof> {
if leaf_index >= self.leaves.len() {
return None;
}
let mut path = Vec::new();
let mut idx = leaf_index;
let mut level: Vec<[u8; 32]> = self.leaves.clone();
while level.len() > 1 {
if level.len() % 2 != 0 {
level.push(*level.last().unwrap());
}
let sibling_idx = if idx % 2 == 0 { idx + 1 } else { idx - 1 };
let direction = if idx % 2 == 0 {
Direction::Right } else {
Direction::Left };
path.push(ProofStep {
direction,
hash: hex::encode(level[sibling_idx]),
});
level = level
.chunks(2)
.map(|pair| {
let mut h = Sha256::new();
h.update(pair[0]);
h.update(pair[1]);
h.finalize().into()
})
.collect();
idx /= 2;
}
Some(InclusionProof {
leaf_index,
leaf_hash: hex::encode(self.leaves[leaf_index]),
path,
})
}
pub fn verify_proof(
root_hex: &str,
artifact_id: &str,
proof: &InclusionProof,
) -> bool {
let current: [u8; 32] = Sha256::digest(artifact_id.as_bytes()).into();
if hex::encode(current) != proof.leaf_hash {
return false;
}
let mut current = current;
for step in &proof.path {
let sibling = match hex::decode(&step.hash) {
Ok(b) if b.len() == 32 => {
let mut arr = [0u8; 32];
arr.copy_from_slice(&b);
arr
}
_ => return false,
};
let mut h = Sha256::new();
match step.direction {
Direction::Right => {
h.update(current);
h.update(sibling);
}
Direction::Left => {
h.update(sibling);
h.update(current);
}
}
current = h.finalize().into();
}
hex::encode(current) == root_hex
}
fn compute_root(&self, leaves: &[[u8; 32]]) -> [u8; 32] {
if leaves.len() == 1 {
return leaves[0];
}
let mut level = leaves.to_vec();
while level.len() > 1 {
if level.len() % 2 != 0 {
level.push(*level.last().unwrap());
}
level = level
.chunks(2)
.map(|pair| {
let mut h = Sha256::new();
h.update(pair[0]);
h.update(pair[1]);
h.finalize().into()
})
.collect();
}
level[0]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn single_leaf_root_is_leaf_hash() {
let mut tree = MerkleTree::new();
tree.append("art_abc123");
let root = tree.root().unwrap();
let expected = Sha256::digest(b"art_abc123");
assert_eq!(root, expected.as_slice());
}
#[test]
fn inclusion_proof_verifies() {
let mut tree = MerkleTree::new();
let ids = ["art_a", "art_b", "art_c", "art_d"];
for id in &ids {
tree.append(id);
}
let root = hex::encode(tree.root().unwrap());
let proof = tree.inclusion_proof(1).unwrap();
assert!(MerkleTree::verify_proof(&root, "art_b", &proof));
}
#[test]
fn wrong_artifact_fails_verification() {
let mut tree = MerkleTree::new();
tree.append("art_a");
tree.append("art_b");
let root = hex::encode(tree.root().unwrap());
let proof = tree.inclusion_proof(0).unwrap();
assert!(!MerkleTree::verify_proof(&root, "art_WRONG", &proof));
}
#[test]
fn tampered_sibling_fails() {
let mut tree = MerkleTree::new();
tree.append("art_a");
tree.append("art_b");
let root = hex::encode(tree.root().unwrap());
let mut proof = tree.inclusion_proof(0).unwrap();
proof.path[0].hash = "0".repeat(64);
assert!(!MerkleTree::verify_proof(&root, "art_a", &proof));
}
#[test]
fn odd_number_of_leaves() {
let mut tree = MerkleTree::new();
for i in 0..5 {
tree.append(&format!("art_{}", i));
}
let root = hex::encode(tree.root().unwrap());
let proof = tree.inclusion_proof(4).unwrap();
assert!(MerkleTree::verify_proof(&root, "art_4", &proof));
}
}