use serde::{Deserialize, Serialize};
use crate::{DocumentId, HashAlgorithm, Hasher, Result};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct MerkleNode {
pub hash: DocumentId,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub left: Option<Box<MerkleNode>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub right: Option<Box<MerkleNode>>,
}
impl MerkleNode {
#[must_use]
pub fn leaf(hash: DocumentId) -> Self {
Self {
hash,
left: None,
right: None,
}
}
#[must_use]
pub fn branch(left: MerkleNode, right: MerkleNode, algorithm: HashAlgorithm) -> Self {
let combined = format!("{}{}", left.hash.hex_digest(), right.hash.hex_digest());
let hash = Hasher::hash(algorithm, combined.as_bytes());
Self {
hash,
left: Some(Box::new(left)),
right: Some(Box::new(right)),
}
}
#[must_use]
pub fn is_leaf(&self) -> bool {
self.left.is_none() && self.right.is_none()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct MerkleTree {
root: MerkleNode,
algorithm: HashAlgorithm,
leaf_count: usize,
}
impl MerkleTree {
pub fn from_items<T: AsRef<[u8]>>(items: &[T], algorithm: HashAlgorithm) -> Result<Self> {
if items.is_empty() {
return Err(crate::Error::InvalidManifest {
reason: "Cannot build Merkle tree from empty items".to_string(),
});
}
let leaf_count = items.len();
let mut nodes: Vec<MerkleNode> = items
.iter()
.map(|item| MerkleNode::leaf(Hasher::hash(algorithm, item.as_ref())))
.collect();
while nodes.len() > 1 {
let mut next_level = Vec::with_capacity(nodes.len().div_ceil(2));
let mut iter = nodes.into_iter();
while let Some(left) = iter.next() {
let right = iter.next().unwrap_or_else(|| left.clone());
next_level.push(MerkleNode::branch(left, right, algorithm));
}
nodes = next_level;
}
Ok(Self {
root: nodes.into_iter().next().expect("nodes should not be empty"),
algorithm,
leaf_count,
})
}
pub fn from_hashes(hashes: &[DocumentId], algorithm: HashAlgorithm) -> Result<Self> {
if hashes.is_empty() {
return Err(crate::Error::InvalidManifest {
reason: "Cannot build Merkle tree from empty hashes".to_string(),
});
}
let leaf_count = hashes.len();
let mut nodes: Vec<MerkleNode> =
hashes.iter().map(|h| MerkleNode::leaf(h.clone())).collect();
while nodes.len() > 1 {
let mut next_level = Vec::with_capacity(nodes.len().div_ceil(2));
let mut iter = nodes.into_iter();
while let Some(left) = iter.next() {
let right = iter.next().unwrap_or_else(|| left.clone());
next_level.push(MerkleNode::branch(left, right, algorithm));
}
nodes = next_level;
}
Ok(Self {
root: nodes.into_iter().next().expect("nodes should not be empty"),
algorithm,
leaf_count,
})
}
#[must_use]
pub fn root_hash(&self) -> &DocumentId {
&self.root.hash
}
#[must_use]
pub fn root(&self) -> &MerkleNode {
&self.root
}
#[must_use]
pub fn algorithm(&self) -> HashAlgorithm {
self.algorithm
}
#[must_use]
pub fn leaf_count(&self) -> usize {
self.leaf_count
}
pub fn prove(&self, index: usize) -> Result<super::BlockProof> {
if index >= self.leaf_count {
return Err(crate::Error::InvalidManifest {
reason: format!(
"Index {} out of bounds for tree with {} leaves",
index, self.leaf_count
),
});
}
let mut path = Vec::new();
collect_proof_path(&self.root, index, 0, self.leaf_count, &mut path);
Ok(super::BlockProof {
index,
path,
root_hash: self.root.hash.clone(),
algorithm: self.algorithm,
})
}
}
fn collect_proof_path(
node: &MerkleNode,
target_index: usize,
current_start: usize,
level_size: usize,
path: &mut Vec<(DocumentId, bool)>,
) {
if node.is_leaf() {
return;
}
let mid = current_start + level_size / 2;
let left = node
.left
.as_ref()
.expect("branch node should have left child");
let right = node
.right
.as_ref()
.expect("branch node should have right child");
if target_index < mid {
collect_proof_path(left, target_index, current_start, level_size / 2, path);
path.push((right.hash.clone(), true));
} else {
collect_proof_path(right, target_index, mid, level_size / 2, path);
path.push((left.hash.clone(), false));
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_merkle_tree_from_items() {
let items = vec!["item1", "item2", "item3", "item4"];
let tree = MerkleTree::from_items(&items, HashAlgorithm::Sha256).unwrap();
assert_eq!(tree.leaf_count(), 4);
assert!(!tree.root_hash().is_pending());
}
#[test]
fn test_merkle_tree_odd_count() {
let items = vec!["item1", "item2", "item3"];
let tree = MerkleTree::from_items(&items, HashAlgorithm::Sha256).unwrap();
assert_eq!(tree.leaf_count(), 3);
}
#[test]
fn test_merkle_tree_single_item() {
let items = vec!["single"];
let tree = MerkleTree::from_items(&items, HashAlgorithm::Sha256).unwrap();
assert_eq!(tree.leaf_count(), 1);
}
#[test]
fn test_merkle_tree_empty_fails() {
let items: Vec<&str> = vec![];
let result = MerkleTree::from_items(&items, HashAlgorithm::Sha256);
assert!(result.is_err());
}
#[test]
fn test_merkle_tree_deterministic() {
let items = vec!["a", "b", "c", "d"];
let tree1 = MerkleTree::from_items(&items, HashAlgorithm::Sha256).unwrap();
let tree2 = MerkleTree::from_items(&items, HashAlgorithm::Sha256).unwrap();
assert_eq!(tree1.root_hash(), tree2.root_hash());
}
#[test]
fn test_merkle_tree_changes_with_content() {
let items1 = vec!["a", "b", "c", "d"];
let items2 = vec!["a", "b", "c", "e"];
let tree1 = MerkleTree::from_items(&items1, HashAlgorithm::Sha256).unwrap();
let tree2 = MerkleTree::from_items(&items2, HashAlgorithm::Sha256).unwrap();
assert_ne!(tree1.root_hash(), tree2.root_hash());
}
#[test]
fn test_generate_proof() {
let items = vec!["item0", "item1", "item2", "item3"];
let tree = MerkleTree::from_items(&items, HashAlgorithm::Sha256).unwrap();
let proof = tree.prove(2).unwrap();
assert_eq!(proof.index, 2);
assert!(!proof.path.is_empty());
}
#[test]
fn test_proof_out_of_bounds() {
let items = vec!["a", "b"];
let tree = MerkleTree::from_items(&items, HashAlgorithm::Sha256).unwrap();
let result = tree.prove(5);
assert!(result.is_err());
}
#[test]
fn test_merkle_node_leaf() {
let hash = Hasher::hash(HashAlgorithm::Sha256, b"test");
let node = MerkleNode::leaf(hash.clone());
assert!(node.is_leaf());
assert_eq!(node.hash, hash);
assert!(node.left.is_none());
assert!(node.right.is_none());
}
#[test]
fn test_merkle_node_branch() {
let left = MerkleNode::leaf(Hasher::hash(HashAlgorithm::Sha256, b"left"));
let right = MerkleNode::leaf(Hasher::hash(HashAlgorithm::Sha256, b"right"));
let branch = MerkleNode::branch(left.clone(), right.clone(), HashAlgorithm::Sha256);
assert!(!branch.is_leaf());
assert!(branch.left.is_some());
assert!(branch.right.is_some());
assert_ne!(branch.hash, left.hash);
assert_ne!(branch.hash, right.hash);
}
#[test]
fn test_merkle_tree_from_hashes() {
let hashes = vec![
Hasher::hash(HashAlgorithm::Sha256, b"a"),
Hasher::hash(HashAlgorithm::Sha256, b"b"),
Hasher::hash(HashAlgorithm::Sha256, b"c"),
Hasher::hash(HashAlgorithm::Sha256, b"d"),
];
let tree = MerkleTree::from_hashes(&hashes, HashAlgorithm::Sha256).unwrap();
assert_eq!(tree.leaf_count(), 4);
assert!(!tree.root_hash().is_pending());
}
#[test]
fn test_merkle_tree_from_hashes_empty_fails() {
let hashes: Vec<DocumentId> = vec![];
let result = MerkleTree::from_hashes(&hashes, HashAlgorithm::Sha256);
assert!(result.is_err());
}
#[test]
fn test_merkle_tree_serialization_roundtrip() {
let items = vec!["a", "b", "c", "d"];
let tree = MerkleTree::from_items(&items, HashAlgorithm::Sha256).unwrap();
let json = serde_json::to_string(&tree).unwrap();
let deserialized: MerkleTree = serde_json::from_str(&json).unwrap();
assert_eq!(tree.root_hash(), deserialized.root_hash());
assert_eq!(tree.leaf_count(), deserialized.leaf_count());
assert_eq!(tree.algorithm(), deserialized.algorithm());
}
#[test]
fn test_merkle_tree_different_algorithms() {
let items = vec!["test1", "test2"];
let tree_sha256 = MerkleTree::from_items(&items, HashAlgorithm::Sha256).unwrap();
let tree_sha384 = MerkleTree::from_items(&items, HashAlgorithm::Sha384).unwrap();
let tree_sha512 = MerkleTree::from_items(&items, HashAlgorithm::Sha512).unwrap();
assert_ne!(tree_sha256.root_hash(), tree_sha384.root_hash());
assert_ne!(tree_sha256.root_hash(), tree_sha512.root_hash());
assert_ne!(tree_sha384.root_hash(), tree_sha512.root_hash());
assert_eq!(tree_sha256.algorithm(), HashAlgorithm::Sha256);
assert_eq!(tree_sha384.algorithm(), HashAlgorithm::Sha384);
assert_eq!(tree_sha512.algorithm(), HashAlgorithm::Sha512);
}
#[test]
fn test_merkle_tree_large() {
let items: Vec<String> = (0..100).map(|i| format!("item{i}")).collect();
let tree = MerkleTree::from_items(&items, HashAlgorithm::Sha256).unwrap();
assert_eq!(tree.leaf_count(), 100);
for i in 0..100 {
let proof = tree.prove(i);
assert!(proof.is_ok(), "Failed to generate proof for index {i}");
}
}
#[test]
fn test_merkle_tree_two_items() {
let items = vec!["left", "right"];
let tree = MerkleTree::from_items(&items, HashAlgorithm::Sha256).unwrap();
assert_eq!(tree.leaf_count(), 2);
assert!(!tree.root().is_leaf());
let proof0 = tree.prove(0).unwrap();
let proof1 = tree.prove(1).unwrap();
assert_eq!(proof0.path.len(), 1);
assert_eq!(proof1.path.len(), 1);
}
#[test]
fn test_merkle_tree_root_accessor() {
let items = vec!["a", "b"];
let tree = MerkleTree::from_items(&items, HashAlgorithm::Sha256).unwrap();
let root = tree.root();
assert!(!root.is_leaf());
assert_eq!(&root.hash, tree.root_hash());
}
}