mod leaf_hash;
pub mod persistent;
pub mod proof;
pub use leaf_hash::coin_record_hash;
pub use proof::{verify_coin_proof, SparseMerkleProof};
use std::collections::HashMap;
use std::sync::OnceLock;
use chia_protocol::Bytes32;
use chia_sha2::Sha256;
use crate::storage::schema;
use crate::storage::{StorageBackend, WriteBatch};
pub use persistent::{MerkleNodePersistOp, MERKLE_STATE_ROOT_META_KEY};
pub const SMT_HEIGHT: usize = 256;
#[inline]
pub fn merkle_leaf_hash(data: &[u8]) -> Bytes32 {
let mut hasher = Sha256::new();
hasher.update([0x00]);
hasher.update(data);
let result = hasher.finalize();
let mut bytes = [0u8; 32];
bytes.copy_from_slice(&result);
Bytes32::from(bytes)
}
#[inline]
pub fn merkle_node_hash(left: &Bytes32, right: &Bytes32) -> Bytes32 {
let mut hasher = Sha256::new();
hasher.update([0x01]);
hasher.update(left.as_ref());
hasher.update(right.as_ref());
let result = hasher.finalize();
let mut bytes = [0u8; 32];
bytes.copy_from_slice(&result);
Bytes32::from(bytes)
}
const EMPTY_LEAF_SENTINEL: [u8; 32] = [0u8; 32];
static EMPTY_HASHES: OnceLock<[Bytes32; 257]> = OnceLock::new();
#[inline]
pub fn empty_hash(level: usize) -> Bytes32 {
assert!(
level <= SMT_HEIGHT,
"level {} exceeds SMT_HEIGHT {}",
level,
SMT_HEIGHT
);
get_empty_hashes()[level]
}
fn get_empty_hashes() -> &'static [Bytes32; 257] {
EMPTY_HASHES.get_or_init(|| {
let mut hashes = [Bytes32::default(); SMT_HEIGHT + 1];
hashes[0] = merkle_leaf_hash(&EMPTY_LEAF_SENTINEL);
for i in 1..=SMT_HEIGHT {
hashes[i] = merkle_node_hash(&hashes[i - 1], &hashes[i - 1]);
}
hashes
})
}
#[derive(Debug, Clone, PartialEq, thiserror::Error)]
pub enum MerkleError {
#[error("key already exists: {0}")]
KeyAlreadyExists(Bytes32),
#[error("key not found: {0}")]
KeyNotFound(Bytes32),
#[error("persisted merkle state root missing from metadata column family")]
PersistedRootMissing,
#[error("persisted merkle state root has invalid length: {0} bytes (expected 32)")]
InvalidPersistedRootLength(usize),
#[error("persisted merkle root mismatch: disk={disk:?} recomputed={recomputed:?}")]
PersistedRootMismatch { disk: Bytes32, recomputed: Bytes32 },
#[error("storage error during merkle load: {0}")]
Storage(String),
#[error("cannot generate coin proof while tree is dirty; call root() after mutations")]
ProofRequiresCleanTree,
}
#[derive(Debug, Clone)]
pub struct SparseMerkleTree {
leaves: HashMap<Bytes32, Bytes32>,
root_hash: Option<Bytes32>,
dirty_merkle_nodes: HashMap<[u8; 33], MerkleNodePersistOp>,
}
impl Default for SparseMerkleTree {
fn default() -> Self {
Self::new()
}
}
impl SparseMerkleTree {
pub fn new() -> Self {
Self {
leaves: HashMap::new(),
root_hash: Some(empty_hash(SMT_HEIGHT)),
dirty_merkle_nodes: HashMap::new(),
}
}
pub fn batch_insert(&mut self, entries: &[(Bytes32, Bytes32)]) -> Result<(), MerkleError> {
for (key, _) in entries {
if self.leaves.contains_key(key) {
return Err(MerkleError::KeyAlreadyExists(*key));
}
}
for (key, value) in entries {
self.leaves.insert(*key, *value);
}
if !entries.is_empty() {
self.root_hash = None; }
Ok(())
}
pub fn batch_update(&mut self, entries: &[(Bytes32, Bytes32)]) -> Result<(), MerkleError> {
for (key, _) in entries {
if !self.leaves.contains_key(key) {
return Err(MerkleError::KeyNotFound(*key));
}
}
for (key, value) in entries {
self.leaves.insert(*key, *value);
}
if !entries.is_empty() {
self.root_hash = None; }
Ok(())
}
pub fn batch_remove(&mut self, keys: &[Bytes32]) -> Result<(), MerkleError> {
for key in keys {
if !self.leaves.contains_key(key) {
return Err(MerkleError::KeyNotFound(*key));
}
}
for key in keys {
self.leaves.remove(key);
}
if !keys.is_empty() {
self.root_hash = None; }
Ok(())
}
pub fn root(&mut self) -> Bytes32 {
if let Some(cached) = self.root_hash {
return cached;
}
let leaf_refs: Vec<(&Bytes32, &Bytes32)> = self.leaves.iter().collect();
self.dirty_merkle_nodes.clear();
let path_root = Bytes32::default();
let root = Self::compute_subtree_hash_core(
&leaf_refs,
0,
&path_root,
&mut self.dirty_merkle_nodes,
true,
);
self.root_hash = Some(root);
root
}
pub fn root_observed(&self) -> Bytes32 {
if self.leaves.is_empty() {
return empty_hash(SMT_HEIGHT);
}
if let Some(cached) = self.root_hash {
return cached;
}
let leaf_refs: Vec<(&Bytes32, &Bytes32)> = self.leaves.iter().collect();
let mut sink = HashMap::new();
Self::compute_subtree_hash_core(&leaf_refs, 0, &Bytes32::default(), &mut sink, false)
}
pub fn is_dirty(&self) -> bool {
self.root_hash.is_none() && !self.leaves.is_empty()
}
pub fn len(&self) -> usize {
self.leaves.len()
}
pub fn is_empty(&self) -> bool {
self.leaves.is_empty()
}
pub fn contains_key(&self, key: &Bytes32) -> bool {
self.leaves.contains_key(key)
}
pub fn get(&self, key: &Bytes32) -> Option<&Bytes32> {
self.leaves.get(key)
}
pub fn dirty_nodes(&self) -> &HashMap<[u8; 33], MerkleNodePersistOp> {
&self.dirty_merkle_nodes
}
pub fn clear_dirty(&mut self) {
self.dirty_merkle_nodes.clear();
}
pub fn flush_to_batch(&mut self, batch: &mut WriteBatch) -> Result<(), MerkleError> {
let root = self.root();
for (key, op) in std::mem::take(&mut self.dirty_merkle_nodes) {
match op {
MerkleNodePersistOp::Put(h) => {
batch.put(schema::CF_MERKLE_NODES, &key, h.as_ref());
}
MerkleNodePersistOp::Delete => {
batch.delete(schema::CF_MERKLE_NODES, &key);
}
}
}
let meta_key = schema::metadata_key(MERKLE_STATE_ROOT_META_KEY);
batch.put(schema::CF_METADATA, &meta_key, root.as_ref());
Ok(())
}
pub fn load_from_store(
store: &dyn StorageBackend,
leaves: HashMap<Bytes32, Bytes32>,
) -> Result<Self, MerkleError> {
let meta_key = schema::metadata_key(MERKLE_STATE_ROOT_META_KEY);
let disk_bytes = store
.get(schema::CF_METADATA, &meta_key)
.map_err(|e| MerkleError::Storage(e.to_string()))?
.ok_or(MerkleError::PersistedRootMissing)?;
if disk_bytes.len() != 32 {
return Err(MerkleError::InvalidPersistedRootLength(disk_bytes.len()));
}
let mut arr = [0u8; 32];
arr.copy_from_slice(&disk_bytes);
let disk_root = Bytes32::from(arr);
let mut tree = Self {
leaves,
root_hash: None,
dirty_merkle_nodes: HashMap::new(),
};
let recomputed = tree.root_observed();
if recomputed != disk_root {
return Err(MerkleError::PersistedRootMismatch {
disk: disk_root,
recomputed,
});
}
tree.root_hash = Some(disk_root);
Ok(tree)
}
fn compute_subtree_hash_core(
leaves: &[(&Bytes32, &Bytes32)],
depth: usize,
path: &Bytes32,
dirty_out: &mut HashMap<[u8; 33], MerkleNodePersistOp>,
record_dirty: bool,
) -> Bytes32 {
if leaves.is_empty() {
let h = empty_hash(SMT_HEIGHT - depth);
if record_dirty {
record_merkle_persist_op(dirty_out, depth, path, h);
}
return h;
}
if depth == SMT_HEIGHT {
return *leaves[0].1;
}
let (left, right): (Vec<_>, Vec<_>) = leaves
.iter()
.partition(|(key, _)| !Self::get_bit(key, depth));
let path_left = child_path(path, depth, false);
let path_right = child_path(path, depth, true);
let left_hash =
Self::compute_subtree_hash_core(&left, depth + 1, &path_left, dirty_out, record_dirty);
let right_hash = Self::compute_subtree_hash_core(
&right,
depth + 1,
&path_right,
dirty_out,
record_dirty,
);
let node_hash = merkle_node_hash(&left_hash, &right_hash);
if record_dirty {
record_merkle_persist_op(dirty_out, depth, path, node_hash);
}
node_hash
}
#[inline]
fn get_bit(key: &Bytes32, n: usize) -> bool {
let byte_index = n / 8;
let bit_index = 7 - (n % 8); (key.as_ref()[byte_index] >> bit_index) & 1 == 1
}
}
fn child_path(base: &Bytes32, depth: usize, go_right: bool) -> Bytes32 {
let mut arr: [u8; 32] = base.as_ref().try_into().expect("Bytes32 is 32 bytes");
for bit in depth..256 {
let bi = bit / 8;
let bj = 7 - (bit % 8);
arr[bi] &= !(1 << bj);
}
let bi = depth / 8;
let bj = 7 - (depth % 8);
if go_right {
arr[bi] |= 1 << bj;
}
Bytes32::from(arr)
}
fn record_merkle_persist_op(
dirty: &mut HashMap<[u8; 33], MerkleNodePersistOp>,
depth: usize,
path: &Bytes32,
hash: Bytes32,
) {
if depth >= SMT_HEIGHT {
return;
}
let key = schema::merkle_node_key(depth as u8, path);
let empty = empty_hash(SMT_HEIGHT - depth);
if hash == empty {
dirty.insert(key, MerkleNodePersistOp::Delete);
} else {
dirty.insert(key, MerkleNodePersistOp::Put(hash));
}
}