use std::collections::HashMap;
use chia_protocol::Bytes32;
use serde::{Deserialize, Serialize};
use super::{child_path, empty_hash, merkle_node_hash, MerkleError, SparseMerkleTree, SMT_HEIGHT};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct SparseMerkleProof {
pub key: Bytes32,
pub value: Option<Bytes32>,
pub siblings: Vec<Bytes32>,
}
impl SparseMerkleProof {
pub fn verify(&self, expected_root: &Bytes32) -> bool {
if self.siblings.len() != SMT_HEIGHT {
return false;
}
let leaf_hash = match &self.value {
Some(value) => *value, None => empty_hash(0), };
let mut current = leaf_hash;
for depth in (0..SMT_HEIGHT).rev() {
let bit = SparseMerkleTree::get_bit_public(&self.key, depth);
let sibling = &self.siblings[depth];
current = if bit {
merkle_node_hash(sibling, ¤t)
} else {
merkle_node_hash(¤t, sibling)
};
}
current == *expected_root
}
#[must_use]
pub fn leaf_value(&self) -> Bytes32 {
match self.value {
Some(h) => h,
None => empty_hash(0),
}
}
}
#[inline]
#[must_use]
pub fn verify_coin_proof(proof: &SparseMerkleProof, expected_root: &Bytes32) -> bool {
proof.verify(expected_root)
}
impl SparseMerkleTree {
#[inline]
pub fn get_bit_public(key: &Bytes32, n: usize) -> bool {
Self::get_bit(key, n)
}
fn build_sparse_proof_for_key(&self, key: &Bytes32) -> SparseMerkleProof {
let mut siblings = Vec::with_capacity(SMT_HEIGHT);
let leaf_refs: Vec<(&Bytes32, &Bytes32)> = self.leaves.iter().collect();
let mut current_leaves = leaf_refs;
for depth in 0..SMT_HEIGHT {
let bit = Self::get_bit(key, depth);
let (left_leaves, right_leaves): (Vec<_>, Vec<_>) = current_leaves
.into_iter()
.partition(|(k, _)| !Self::get_bit(k, depth));
let sibling_leaves = if bit { &left_leaves } else { &right_leaves };
let mut path = Bytes32::default();
for d in 0..depth {
path = child_path(&path, d, Self::get_bit(key, d));
}
let sibling_path = child_path(&path, depth, !bit);
let mut sink = HashMap::new();
let sibling_hash = Self::compute_subtree_hash_core(
sibling_leaves,
depth + 1,
&sibling_path,
&mut sink,
false,
);
siblings.push(sibling_hash);
current_leaves = if bit { right_leaves } else { left_leaves };
}
SparseMerkleProof {
key: *key,
value: self.leaves.get(key).copied(),
siblings,
}
}
pub fn get_coin_proof(&self, coin_id: &Bytes32) -> Result<SparseMerkleProof, MerkleError> {
if self.is_dirty() {
return Err(MerkleError::ProofRequiresCleanTree);
}
Ok(self.build_sparse_proof_for_key(coin_id))
}
pub fn get_proof(&self, key: &Bytes32) -> SparseMerkleProof {
self.build_sparse_proof_for_key(key)
}
}