use super::heap_impl::Heap;
use super::resource::{Resource, ResourceId};
use sha2::{Digest, Sha256};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Direction {
Left,
Right,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ProofStep {
pub direction: Direction,
pub sibling_hash: [u8; 32],
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MerkleProof {
pub leaf_hash: [u8; 32],
pub path: Vec<ProofStep>,
pub root: [u8; 32],
}
impl MerkleProof {
pub fn verify(&self) -> bool {
let computed_root =
self.path
.iter()
.fold(self.leaf_hash, |current, step| match step.direction {
Direction::Left => hash_pair(&step.sibling_hash, ¤t),
Direction::Right => hash_pair(¤t, &step.sibling_hash),
});
computed_root == self.root
}
}
fn hash_pair(left: &[u8; 32], right: &[u8; 32]) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(left);
hasher.update(right);
hasher.finalize().into()
}
fn hash_leaf(rid: &ResourceId, resource: &Resource) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(rid.hash());
hasher.update(resource.to_bytes());
hasher.finalize().into()
}
fn empty_root() -> [u8; 32] {
Sha256::digest([]).into()
}
#[derive(Debug, Clone)]
pub struct MerkleTree {
pub root: [u8; 32],
leaves: Vec<[u8; 32]>,
levels: Vec<Vec<[u8; 32]>>,
}
impl MerkleTree {
pub fn from_leaves(leaves: Vec<[u8; 32]>) -> Self {
if leaves.is_empty() {
return Self {
root: empty_root(),
leaves: Vec::new(),
levels: Vec::new(),
};
}
let mut levels = vec![leaves.clone()];
let mut current_level = leaves.clone();
while current_level.len() > 1 {
if current_level.len() % 2 == 1 {
current_level.push(*current_level.last().expect("non-empty after len check"));
}
let next_level: Vec<[u8; 32]> = current_level
.chunks(2)
.map(|pair| hash_pair(&pair[0], &pair[1]))
.collect();
levels.push(next_level.clone());
current_level = next_level;
}
let root = current_level[0];
Self {
root,
leaves,
levels,
}
}
pub fn from_heap(heap: &Heap) -> Self {
let leaves: Vec<[u8; 32]> = heap
.active_resources()
.map(|(rid, resource)| hash_leaf(rid, resource))
.collect();
Self::from_leaves(leaves)
}
pub fn size(&self) -> usize {
self.leaves.len()
}
pub fn prove(&self, index: usize) -> Option<MerkleProof> {
if index >= self.leaves.len() {
return None;
}
let leaf_hash = self.leaves[index];
let mut path = Vec::new();
let mut current_index = index;
for level in &self.levels[..self.levels.len().saturating_sub(1)] {
let sibling_index = if current_index % 2 == 0 {
current_index + 1
} else {
current_index - 1
};
let sibling_hash = if sibling_index < level.len() {
level[sibling_index]
} else {
level[current_index]
};
let direction = if current_index % 2 == 0 {
Direction::Right
} else {
Direction::Left
};
path.push(ProofStep {
direction,
sibling_hash,
});
current_index /= 2;
}
Some(MerkleProof {
leaf_hash,
path,
root: self.root,
})
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct HeapCommitment {
pub resource_root: [u8; 32],
pub nullifier_root: [u8; 32],
pub counter: u64,
}
impl HeapCommitment {
pub fn from_heap(heap: &Heap) -> Self {
let resource_leaves: Vec<[u8; 32]> = heap
.active_resources()
.map(|(rid, resource)| hash_leaf(rid, resource))
.collect();
let resource_tree = MerkleTree::from_leaves(resource_leaves);
let nullifier_leaves: Vec<[u8; 32]> = heap.consumed_ids().map(|rid| rid.hash()).collect();
let nullifier_tree = MerkleTree::from_leaves(nullifier_leaves);
Self {
resource_root: resource_tree.root,
nullifier_root: nullifier_tree.root,
counter: heap.alloc_counter(),
}
}
pub fn hash(&self) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(self.resource_root);
hasher.update(self.nullifier_root);
hasher.update(self.counter.to_le_bytes());
hasher.finalize().into()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_tree() {
let tree = MerkleTree::from_leaves(vec![]);
assert_eq!(tree.root, empty_root());
assert_eq!(tree.size(), 0);
}
#[test]
fn test_single_leaf() {
let leaf = Sha256::digest(b"hello").into();
let tree = MerkleTree::from_leaves(vec![leaf]);
assert_eq!(tree.root, leaf);
assert_eq!(tree.size(), 1);
}
#[test]
fn test_two_leaves() {
let leaf1: [u8; 32] = Sha256::digest(b"hello").into();
let leaf2: [u8; 32] = Sha256::digest(b"world").into();
let tree = MerkleTree::from_leaves(vec![leaf1, leaf2]);
let expected_root = hash_pair(&leaf1, &leaf2);
assert_eq!(tree.root, expected_root);
assert_eq!(tree.size(), 2);
}
#[test]
fn test_four_leaves() {
let leaves: Vec<[u8; 32]> = (0_u8..4).map(|i| Sha256::digest([i]).into()).collect();
let tree = MerkleTree::from_leaves(leaves.clone());
let h01 = hash_pair(&leaves[0], &leaves[1]);
let h23 = hash_pair(&leaves[2], &leaves[3]);
let expected_root = hash_pair(&h01, &h23);
assert_eq!(tree.root, expected_root);
assert_eq!(tree.size(), 4);
}
#[test]
fn test_proof_generation_and_verification() {
let leaves: Vec<[u8; 32]> = (0_u8..4).map(|i| Sha256::digest([i]).into()).collect();
let tree = MerkleTree::from_leaves(leaves);
for i in 0..4 {
let proof = tree.prove(i).expect("should generate proof");
assert!(proof.verify(), "proof for leaf {} should verify", i);
}
}
#[test]
fn test_proof_for_out_of_bounds() {
let leaves: Vec<[u8; 32]> = (0_u8..4).map(|i| Sha256::digest([i]).into()).collect();
let tree = MerkleTree::from_leaves(leaves);
assert!(tree.prove(4).is_none());
assert!(tree.prove(100).is_none());
}
#[test]
fn test_odd_number_of_leaves() {
let leaves: Vec<[u8; 32]> = (0_u8..3).map(|i| Sha256::digest([i]).into()).collect();
let tree = MerkleTree::from_leaves(leaves);
for i in 0..3 {
let proof = tree.prove(i).expect("should generate proof");
assert!(proof.verify(), "proof for leaf {} should verify", i);
}
}
#[test]
fn test_heap_merkle_root() {
let heap = Heap::new();
let (_, heap) = heap.alloc_channel("Alice", "Bob");
let (_, heap) = heap.alloc_message("Alice", "Bob", "Hello", vec![], 0);
let root = MerkleTree::from_heap(&heap).root;
assert_ne!(root, empty_root());
}
#[test]
fn test_heap_commitment() {
let heap = Heap::new();
let (rid, heap) = heap.alloc_channel("Alice", "Bob");
let heap = heap.consume(&rid).unwrap();
let commitment = HeapCommitment::from_heap(&heap);
assert_eq!(commitment.counter, 1);
}
#[test]
fn test_commitment_determinism() {
let heap1 = Heap::new();
let (_, heap1) = heap1.alloc_channel("Alice", "Bob");
let heap2 = Heap::new();
let (_, heap2) = heap2.alloc_channel("Alice", "Bob");
let c1 = HeapCommitment::from_heap(&heap1);
let c2 = HeapCommitment::from_heap(&heap2);
assert_eq!(c1, c2);
assert_eq!(c1.hash(), c2.hash());
}
}