use sha2::{Sha256, Digest};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum Direction {
Left,
Right,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProofStep {
pub direction: Direction,
pub hash: String,
}
pub const MERKLE_ALGORITHM_V1: &str = "sha256-duplicate-last";
pub const MERKLE_ALGORITHM_V2: &str = "sha256-rfc9162";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InclusionProof {
pub leaf_index: usize,
pub leaf_hash: String,
pub path: Vec<ProofStep>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub algorithm: Option<String>,
}
pub struct MerkleTree {
leaves: Vec<[u8; 32]>,
}
impl MerkleTree {
pub fn new() -> Self {
Self { leaves: Vec::new() }
}
pub fn append(&mut self, artifact_id: &str) -> usize {
let hash = Sha256::digest(artifact_id.as_bytes());
self.leaves.push(hash.into());
self.leaves.len() - 1
}
pub fn len(&self) -> usize {
self.leaves.len()
}
pub fn is_empty(&self) -> bool {
self.leaves.is_empty()
}
pub fn root(&self) -> Option<[u8; 32]> {
if self.leaves.is_empty() {
return None;
}
Some(self.compute_root(&self.leaves))
}
pub fn height(&self) -> usize {
if self.leaves.len() <= 1 {
return 0;
}
(self.leaves.len() as f64).log2().ceil() as usize
}
pub fn inclusion_proof(&self, leaf_index: usize) -> Option<InclusionProof> {
if leaf_index >= self.leaves.len() {
return None;
}
let mut path = Vec::new();
let mut idx = leaf_index;
let mut level: Vec<[u8; 32]> = self.leaves.clone();
while level.len() > 1 {
if idx + 1 < level.len() && idx % 2 == 0 {
path.push(ProofStep {
direction: Direction::Right,
hash: hex::encode(level[idx + 1]),
});
} else if idx % 2 == 1 {
path.push(ProofStep {
direction: Direction::Left,
hash: hex::encode(level[idx - 1]),
});
}
let mut next_level = Vec::with_capacity((level.len() + 1) / 2);
let mut i = 0;
while i + 1 < level.len() {
let mut h = Sha256::new();
h.update(level[i]);
h.update(level[i + 1]);
next_level.push(h.finalize().into());
i += 2;
}
if i < level.len() {
next_level.push(level[i]);
}
level = next_level;
idx /= 2;
}
Some(InclusionProof {
leaf_index,
leaf_hash: hex::encode(self.leaves[leaf_index]),
path,
algorithm: Some(MERKLE_ALGORITHM_V2.to_string()),
})
}
pub fn verify_proof(
root_hex: &str,
artifact_id: &str,
proof: &InclusionProof,
) -> bool {
let algo = proof.algorithm.as_deref().unwrap_or(MERKLE_ALGORITHM_V1);
if algo != MERKLE_ALGORITHM_V1 && algo != MERKLE_ALGORITHM_V2 {
return false; }
let current: [u8; 32] = Sha256::digest(artifact_id.as_bytes()).into();
if hex::encode(current) != proof.leaf_hash {
return false;
}
let mut current = current;
for step in &proof.path {
let sibling = match hex::decode(&step.hash) {
Ok(b) if b.len() == 32 => {
let mut arr = [0u8; 32];
arr.copy_from_slice(&b);
arr
}
_ => return false,
};
let mut h = Sha256::new();
match step.direction {
Direction::Right => {
h.update(current);
h.update(sibling);
}
Direction::Left => {
h.update(sibling);
h.update(current);
}
}
current = h.finalize().into();
}
hex::encode(current) == root_hex
}
fn compute_root(&self, leaves: &[[u8; 32]]) -> [u8; 32] {
if leaves.len() == 1 {
return leaves[0];
}
let mut level = leaves.to_vec();
while level.len() > 1 {
let mut next_level = Vec::with_capacity((level.len() + 1) / 2);
let mut i = 0;
while i + 1 < level.len() {
let mut h = Sha256::new();
h.update(level[i]);
h.update(level[i + 1]);
next_level.push(h.finalize().into());
i += 2;
}
if i < level.len() {
next_level.push(level[i]);
}
level = next_level;
}
level[0]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn single_leaf_root_is_leaf_hash() {
let mut tree = MerkleTree::new();
tree.append("art_abc123");
let root = tree.root().unwrap();
let expected = Sha256::digest(b"art_abc123");
assert_eq!(root, expected.as_slice());
}
#[test]
fn inclusion_proof_verifies() {
let mut tree = MerkleTree::new();
let ids = ["art_a", "art_b", "art_c", "art_d"];
for id in &ids {
tree.append(id);
}
let root = hex::encode(tree.root().unwrap());
let proof = tree.inclusion_proof(1).unwrap();
assert!(MerkleTree::verify_proof(&root, "art_b", &proof));
}
#[test]
fn wrong_artifact_fails_verification() {
let mut tree = MerkleTree::new();
tree.append("art_a");
tree.append("art_b");
let root = hex::encode(tree.root().unwrap());
let proof = tree.inclusion_proof(0).unwrap();
assert!(!MerkleTree::verify_proof(&root, "art_WRONG", &proof));
}
#[test]
fn tampered_sibling_fails() {
let mut tree = MerkleTree::new();
tree.append("art_a");
tree.append("art_b");
let root = hex::encode(tree.root().unwrap());
let mut proof = tree.inclusion_proof(0).unwrap();
proof.path[0].hash = "0".repeat(64);
assert!(!MerkleTree::verify_proof(&root, "art_a", &proof));
}
#[test]
fn odd_number_of_leaves() {
let mut tree = MerkleTree::new();
for i in 0..5 {
tree.append(&format!("art_{}", i));
}
let root = hex::encode(tree.root().unwrap());
let proof = tree.inclusion_proof(4).unwrap();
assert!(MerkleTree::verify_proof(&root, "art_4", &proof));
}
}