use alloy_primitives::B256;
use std::collections::HashMap;
use crate::{Hasher, Stem, SubIndex};
#[derive(Clone, Debug, PartialEq, Eq, Default)]
pub enum Node {
#[default]
Empty,
Internal(InternalNode),
Stem(StemNode),
Leaf(LeafNode),
}
impl Node {
pub fn hash<H: Hasher>(&self, hasher: &H) -> B256 {
match self {
Self::Empty => B256::ZERO,
Self::Internal(node) => node.hash(hasher),
Self::Stem(node) => node.hash(hasher),
Self::Leaf(node) => node.hash(hasher),
}
}
pub fn is_empty(&self) -> bool {
matches!(self, Self::Empty)
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct InternalNode {
pub left: Box<Node>,
pub right: Box<Node>,
}
impl InternalNode {
pub fn new(left: Node, right: Node) -> Self {
Self {
left: Box::new(left),
right: Box::new(right),
}
}
pub fn hash<H: Hasher>(&self, hasher: &H) -> B256 {
let left_hash = self.left.hash(hasher);
let right_hash = self.right.hash(hasher);
hasher.hash_64(&left_hash, &right_hash)
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct StemNode {
pub stem: Stem,
pub values: HashMap<SubIndex, B256>,
}
impl StemNode {
pub fn new(stem: Stem) -> Self {
Self {
stem,
values: HashMap::new(),
}
}
pub fn set_value(&mut self, subindex: SubIndex, value: B256) {
if value.is_zero() {
self.values.remove(&subindex);
} else {
self.values.insert(subindex, value);
}
}
pub fn get_value(&self, subindex: SubIndex) -> Option<B256> {
self.values.get(&subindex).copied()
}
pub fn hash<H: Hasher>(&self, hasher: &H) -> B256 {
let mut data = [B256::ZERO; 256];
for (&idx, &value) in &self.values {
data[idx as usize] = hasher.hash_32(&value);
}
for level in 1..=8 {
let pairs = 256 >> level;
for i in 0..pairs {
let left = data[i * 2];
let right = data[i * 2 + 1];
data[i] = hasher.hash_64(&left, &right);
}
}
let subtree_root = data[0];
hasher.hash_stem_node(self.stem.as_bytes(), &subtree_root)
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct LeafNode {
pub value: B256,
}
impl LeafNode {
pub const fn new(value: B256) -> Self {
Self { value }
}
pub fn hash<H: Hasher>(&self, hasher: &H) -> B256 {
hasher.hash_32(&self.value)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Blake3Hasher, Sha256Hasher};
#[test]
fn test_empty_node_hash() {
let hasher = Blake3Hasher;
let node: Node = Node::Empty;
assert_eq!(node.hash(&hasher), B256::ZERO);
}
#[test]
fn test_leaf_node_hash() {
let hasher = Blake3Hasher;
let value = B256::repeat_byte(0x42);
let node = LeafNode::new(value);
let hash = node.hash(&hasher);
assert_ne!(hash, B256::ZERO);
}
#[test]
fn test_stem_node_with_value() {
let hasher = Blake3Hasher;
let stem = Stem::new([0u8; 31]);
let mut node: StemNode = StemNode::new(stem);
node.set_value(0, B256::repeat_byte(0x42));
assert_eq!(node.get_value(0), Some(B256::repeat_byte(0x42)));
assert_eq!(node.get_value(1), None);
let hash = node.hash(&hasher);
assert_ne!(hash, B256::ZERO);
}
#[test]
fn test_stem_node_hash_sha256() {
let hasher = Sha256Hasher;
let stem = Stem::new([0u8; 31]);
let mut node = StemNode::new(stem);
node.set_value(0, B256::repeat_byte(0x01));
let hash = node.hash(&hasher);
assert_ne!(hash, B256::ZERO);
let hash2 = node.hash(&hasher);
assert_eq!(hash, hash2);
}
#[test]
fn test_stem_node_empty_hash() {
let hasher = Sha256Hasher;
let stem = Stem::new([0u8; 31]);
let node = StemNode::new(stem);
let hash = node.hash(&hasher);
println!("Empty stem node hash: {}", hash);
}
}