use alloy_primitives::B256;
use crate::{error::Result, Hasher, Stem, StemNode, TreeKey, UbtError, STEM_LEN};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum Direction {
Left,
Right,
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum ProofNode {
Internal { sibling: B256, direction: Direction },
Stem {
stem: Stem,
subtree_siblings: Vec<B256>,
},
Extension { stem: Stem, stem_hash: B256 },
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Proof {
pub key: TreeKey,
pub value: Option<B256>,
pub path: Vec<ProofNode>,
}
impl Proof {
pub fn new(key: TreeKey, value: Option<B256>, path: Vec<ProofNode>) -> Self {
Self { key, value, path }
}
pub fn verify<H: Hasher>(&self, hasher: &H, expected_root: &B256) -> Result<bool> {
let computed_root = self.compute_root(hasher)?;
Ok(&computed_root == expected_root)
}
pub fn compute_root<H: Hasher>(&self, hasher: &H) -> Result<B256> {
let mut current_hash = match &self.value {
Some(v) => hasher.hash_32(v),
None => B256::ZERO,
};
for node in &self.path {
match node {
ProofNode::Internal { sibling, direction } => {
current_hash = match direction {
Direction::Left => hasher.hash_64(¤t_hash, sibling),
Direction::Right => hasher.hash_64(sibling, ¤t_hash),
};
}
ProofNode::Stem {
stem,
subtree_siblings,
} => {
if subtree_siblings.len() != 8 {
return Err(UbtError::InvalidProof(format!(
"stem subtree siblings length must be 8, got {}",
subtree_siblings.len()
)));
}
for (level, sibling) in subtree_siblings.iter().enumerate() {
let bit = (self.key.subindex >> (7 - level)) & 1;
current_hash = if bit == 0 {
hasher.hash_64(¤t_hash, sibling)
} else {
hasher.hash_64(sibling, ¤t_hash)
};
}
current_hash = hasher.hash_stem_node(stem.as_bytes(), ¤t_hash);
}
ProofNode::Extension { stem, stem_hash } => {
if stem == &self.key.stem {
return Err(UbtError::InvalidProof(
"Extension proof with matching stem".to_string(),
));
}
current_hash = *stem_hash;
}
}
}
Ok(current_hash)
}
pub fn size(&self) -> usize {
let mut size = 32 + 1; size += 33;
for node in &self.path {
size += match node {
ProofNode::Internal { .. } => 32 + 1, ProofNode::Stem {
subtree_siblings, ..
} => STEM_LEN + subtree_siblings.len() * 32,
ProofNode::Extension { .. } => STEM_LEN + 32,
};
}
size
}
}
#[must_use]
pub fn generate_stem_proof<H: Hasher>(
stem_node: &StemNode,
subindex: u8,
hasher: &H,
) -> (Option<B256>, Vec<B256>) {
let value = stem_node.get_value(subindex);
let mut data = [B256::ZERO; 256];
for (&idx, &v) in &stem_node.values {
data[idx as usize] = hasher.hash_32(&v);
}
let mut siblings = Vec::with_capacity(8);
let mut idx = subindex as usize;
for level in 0..8 {
let sibling_idx = idx ^ 1;
siblings.push(data[sibling_idx]);
let pairs = 256 >> (level + 1);
for i in 0..pairs {
let left = data[i * 2];
let right = data[i * 2 + 1];
data[i] = hasher.hash_64(&left, &right);
}
idx /= 2;
}
(value, siblings)
}
#[derive(Clone, Debug, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct MultiProof {
pub keys: Vec<TreeKey>,
pub values: Vec<Option<B256>>,
pub nodes: Vec<B256>,
pub stems: Vec<Stem>,
}
impl MultiProof {
pub fn new() -> Self {
Self::default()
}
pub fn len(&self) -> usize {
self.keys.len()
}
pub fn is_empty(&self) -> bool {
self.keys.is_empty()
}
pub fn size(&self) -> usize {
let mut size = 0;
size += self.keys.len() * 32; size += self.values.len() * 33; size += self.nodes.len() * 32; size += self.stems.len() * STEM_LEN; size
}
}
#[derive(Clone, Debug, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Witness {
pub pre_values: Vec<(TreeKey, B256)>,
pub proof: MultiProof,
}
impl Witness {
pub fn new() -> Self {
Self::default()
}
pub fn size(&self) -> usize {
let mut size = self.proof.size();
size += self.pre_values.len() * (32 + 32); size
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Sha256Hasher;
#[test]
fn test_stem_proof_generation() {
let hasher = Sha256Hasher;
let stem = Stem::new([0u8; 31]);
let mut node = StemNode::new(stem);
node.set_value(0, B256::repeat_byte(0x42));
node.set_value(5, B256::repeat_byte(0x43));
let (value, siblings) = generate_stem_proof(&node, 0, &hasher);
assert_eq!(value, Some(B256::repeat_byte(0x42)));
assert_eq!(siblings.len(), 8);
}
#[test]
fn test_proof_size() {
let key = TreeKey::from_bytes(B256::repeat_byte(0x01));
let path = vec![
ProofNode::Internal {
sibling: B256::ZERO,
direction: Direction::Left,
},
ProofNode::Stem {
stem: key.stem,
subtree_siblings: vec![B256::ZERO; 8],
},
];
let proof = Proof::new(key, Some(B256::repeat_byte(0x42)), path);
let size = proof.size();
println!("Proof size: {} bytes", size);
assert!(size > 0);
}
fn build_stem_proof<H: Hasher>(
hasher: &H,
stem: Stem,
subindex: u8,
value: B256,
) -> (Proof, B256) {
let mut node = StemNode::new(stem);
node.set_value(subindex, value);
let (generated_value, siblings) = generate_stem_proof(&node, subindex, hasher);
assert_eq!(generated_value, Some(value));
let key = TreeKey::new(stem, subindex);
let proof = Proof::new(
key,
Some(value),
vec![ProofNode::Stem {
stem,
subtree_siblings: siblings,
}],
);
(proof, node.hash(hasher))
}
#[test]
fn test_proof_verify_simple() {
let hasher = Sha256Hasher;
let stem = Stem::new([0u8; 31]);
let value = B256::repeat_byte(0x42);
let (proof, expected_root) = build_stem_proof(&hasher, stem, 0, value);
let result = proof.verify(&hasher, &expected_root);
assert!(result.is_ok());
assert!(result.unwrap());
}
#[test]
fn test_compute_root_rejects_invalid_stem_sibling_lengths() {
let hasher = Sha256Hasher;
let stem = Stem::new([0u8; 31]);
let key = TreeKey::new(stem, 0);
let value = B256::repeat_byte(0x42);
for len in [0usize, 7, 9] {
let proof = Proof::new(
key,
Some(value),
vec![ProofNode::Stem {
stem,
subtree_siblings: vec![B256::ZERO; len],
}],
);
let err = proof.compute_root(&hasher).unwrap_err();
assert!(matches!(err, UbtError::InvalidProof(_)));
}
let (proof_ok, expected_root) = build_stem_proof(&hasher, stem, 0, value);
let ProofNode::Stem {
subtree_siblings, ..
} = proof_ok.path.first().unwrap()
else {
panic!("expected a stem proof node");
};
assert_eq!(subtree_siblings.len(), 8);
assert_eq!(proof_ok.compute_root(&hasher).unwrap(), expected_root);
}
}