use crate::trie::error::{Result, TrieError};
use super::bitpath::BitPath;
use super::codec::{hash_internal, hash_leaf, wrap_hash};
use super::{EMPTY_HASH, SmtHandle, SmtNode, TREE_DEPTH};
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct SmtProof {
pub key_hash: [u8; 32],
pub value: Option<Vec<u8>>,
pub siblings: Vec<[u8; 32]>,
}
pub fn prove(root: &SmtHandle, key_hash: &[u8; 32]) -> Result<SmtProof> {
let key_path = BitPath::from_hash(key_hash);
let mut siblings = vec![EMPTY_HASH; TREE_DEPTH];
let value = prove_walk(root, &key_path, key_hash, 0, &mut siblings)?;
Ok(SmtProof {
key_hash: *key_hash,
value,
siblings,
})
}
fn prove_walk(
handle: &SmtHandle,
key_path: &BitPath,
key_hash: &[u8; 32],
depth: usize,
siblings: &mut [[u8; 32]],
) -> Result<Option<Vec<u8>>> {
let node = handle.node();
match node {
SmtNode::Empty => {
Ok(None)
}
SmtNode::Leaf {
key_hash: leaf_kh,
value,
path: _,
} => {
if leaf_kh == key_hash {
Ok(Some(value.clone()))
} else {
let existing_full = BitPath::from_hash(leaf_kh);
let remaining_key = key_path.slice(depth, TREE_DEPTH);
let remaining_existing = existing_full.slice(depth, TREE_DEPTH);
let common = remaining_key.common_prefix(&remaining_existing);
let diverge_depth = depth + common;
if diverge_depth >= TREE_DEPTH {
return Err(TrieError::InvalidState(
"identical key hashes in non-membership proof".into(),
));
}
let existing_leaf_hash = hash_leaf(leaf_kh, value);
let below_diverge = existing_full.slice(diverge_depth + 1, TREE_DEPTH);
siblings[diverge_depth] = wrap_hash(existing_leaf_hash, &below_diverge);
Ok(None)
}
}
SmtNode::Internal { path, left, right } => {
let remaining = key_path.slice(depth, TREE_DEPTH);
if !remaining.starts_with(path) {
let common = remaining.common_prefix(path);
let diverge_depth = depth + common;
if diverge_depth >= TREE_DEPTH {
return Err(TrieError::InvalidState(
"depth overflow in divergent path".into(),
));
}
let left_h = to_hash32(left.expect_hash()?)?;
let right_h = to_hash32(right.expect_hash()?)?;
let internal_h = hash_internal(&left_h, &right_h);
let suffix = path.slice(common + 1, path.len());
siblings[diverge_depth] = wrap_hash(internal_h, &suffix);
Ok(None)
} else {
let split_depth = depth + path.len();
if split_depth >= TREE_DEPTH {
return Err(TrieError::InvalidState(
"SMT depth exceeded 256 bits".into(),
));
}
let bit = key_path.bit_at(split_depth);
if bit == 0 {
siblings[split_depth] = to_hash32(right.expect_hash()?)?;
prove_walk(left, key_path, key_hash, split_depth + 1, siblings)
} else {
siblings[split_depth] = to_hash32(left.expect_hash()?)?;
prove_walk(right, key_path, key_hash, split_depth + 1, siblings)
}
}
}
}
}
pub fn verify_proof(root_hash: &[u8; 32], proof: &SmtProof) -> Result<bool> {
if proof.siblings.len() != TREE_DEPTH {
return Err(TrieError::InvalidState(format!(
"proof must have {} siblings, got {}",
TREE_DEPTH,
proof.siblings.len()
)));
}
let key_path = BitPath::from_hash(&proof.key_hash);
let mut current = match proof.value {
Some(ref v) => hash_leaf(&proof.key_hash, v),
None => EMPTY_HASH,
};
for depth in (0..TREE_DEPTH).rev() {
let bit = key_path.bit_at(depth);
if bit == 0 {
current = hash_internal(¤t, &proof.siblings[depth]);
} else {
current = hash_internal(&proof.siblings[depth], ¤t);
}
}
Ok(current == *root_hash)
}
fn to_hash32(slice: &[u8]) -> Result<[u8; 32]> {
slice
.try_into()
.map_err(|_| TrieError::InvalidState("expected 32-byte hash".into()))
}