use alloy_primitives::B256;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
use crate::{error::Result, Hasher, Node, Stem, StemNode, TreeKey, UbtError};
use super::{UnifiedBinaryTree, MAX_DEPTH};
#[cfg(feature = "parallel")]
const PARALLEL_STEM_THRESHOLD: usize = 100;
fn set_bit_at(mut value: B256, pos: usize) -> B256 {
debug_assert!(pos < 256);
let byte_idx = pos / 8;
let bit_idx = 7 - (pos % 8);
value.0[byte_idx] |= 1 << bit_idx;
value
}
fn b256_matches_prefix(value: &B256, prefix: &B256, depth: usize) -> bool {
debug_assert!(depth <= 256, "depth must be <= 256, got {depth}");
if depth > 256 {
return false;
}
let full_bytes = depth / 8;
if value.0[..full_bytes] != prefix.0[..full_bytes] {
return false;
}
let rem_bits = depth % 8;
if rem_bits == 0 {
return true;
}
let mask = 0xFFu8 << (8 - rem_bits);
(value.0[full_bytes] & mask) == (prefix.0[full_bytes] & mask)
}
impl<H: Hasher> UnifiedBinaryTree<H> {
#[must_use = "callers should handle errors and use the computed root hash"]
pub fn root_hash(&mut self) -> Result<B256> {
if self.root_dirty {
self.rebuild_root()?;
self.root_dirty = false;
}
Ok(self.root_hash_cached)
}
fn compute_stem_hash(&self, stem: &Stem) -> B256 {
if let Some(stem_node) = self.stems.get(stem) {
stem_node.hash(&self.hasher)
} else {
B256::ZERO
}
}
#[cfg(feature = "parallel")]
fn compute_stem_updates(&self, dirty_stems: &[Stem]) -> Vec<(Stem, B256)> {
if dirty_stems.len() >= PARALLEL_STEM_THRESHOLD {
dirty_stems
.par_iter()
.map(|stem| (*stem, self.compute_stem_hash(stem)))
.collect()
} else {
dirty_stems
.iter()
.map(|stem| (*stem, self.compute_stem_hash(stem)))
.collect()
}
}
fn rebuild_root(&mut self) -> Result<()> {
let dirty_stems: Vec<_> = self.dirty_stem_hashes.iter().copied().collect();
#[cfg(feature = "parallel")]
{
let stem_updates = self.compute_stem_updates(&dirty_stems);
for (stem, hash) in &stem_updates {
if hash.is_zero() {
self.stem_hash_cache.remove(stem);
} else {
self.stem_hash_cache.insert(*stem, *hash);
}
}
}
#[cfg(not(feature = "parallel"))]
for stem in &dirty_stems {
let hash = self.compute_stem_hash(stem);
if hash.is_zero() {
self.stem_hash_cache.remove(stem);
} else {
self.stem_hash_cache.insert(*stem, hash);
}
}
if self.stem_hash_cache.is_empty() {
self.root = Node::Empty;
self.root_hash_cached = B256::ZERO;
self.node_hash_cache.clear();
self.dirty_stem_hashes.clear();
return Ok(());
}
if self.incremental_enabled && !self.node_hash_cache.is_empty() {
self.rebuild_root_incremental(&dirty_stems)?;
} else {
let mut stem_hashes: Vec<_> =
self.stem_hash_cache.iter().map(|(s, h)| (*s, *h)).collect();
stem_hashes.sort_by_key(|(s, _)| *s);
let root_hash = if self.incremental_enabled {
self.node_hash_cache.clear();
self.build_root_hash_with_cache(&stem_hashes, 0, B256::ZERO)?
} else {
self.build_root_hash_from_stem_hashes(&stem_hashes, 0)?
};
let stems: Vec<_> = stem_hashes.iter().map(|(s, _)| *s).collect();
let root = self.build_tree_from_sorted_stems(&stems, 0)?;
self.root_hash_cached = root_hash;
self.root = root;
}
self.dirty_stem_hashes.clear();
Ok(())
}
fn build_root_hash_from_stem_hashes(
&self,
stem_hashes: &[(Stem, B256)],
depth: usize,
) -> Result<B256> {
if stem_hashes.is_empty() {
return Ok(B256::ZERO);
}
if stem_hashes.len() == 1 {
return Ok(stem_hashes[0].1);
}
if depth >= MAX_DEPTH {
return Err(UbtError::TreeDepthExceeded { depth });
}
let split_point = stem_hashes.partition_point(|(s, _)| !s.bit_at(depth));
let (left, right) = stem_hashes.split_at(split_point);
let left_hash = self.build_root_hash_from_stem_hashes(left, depth + 1)?;
let right_hash = self.build_root_hash_from_stem_hashes(right, depth + 1)?;
if left_hash.is_zero() && right_hash.is_zero() {
Ok(B256::ZERO)
} else {
Ok(self.hasher.hash_64(&left_hash, &right_hash))
}
}
fn build_root_hash_with_cache(
&mut self,
stem_hashes: &[(Stem, B256)],
depth: usize,
path_prefix: B256,
) -> Result<B256> {
if stem_hashes.is_empty() {
return Ok(B256::ZERO);
}
if stem_hashes.len() == 1 {
let hash = stem_hashes[0].1;
self.node_hash_cache.insert((depth, path_prefix), hash);
return Ok(hash);
}
if depth >= MAX_DEPTH {
return Err(UbtError::TreeDepthExceeded { depth });
}
let split_point = stem_hashes.partition_point(|(s, _)| !s.bit_at(depth));
let (left, right) = stem_hashes.split_at(split_point);
let left_hash = self.build_root_hash_with_cache(left, depth + 1, path_prefix)?;
let right_prefix = set_bit_at(path_prefix, depth);
let right_hash = self.build_root_hash_with_cache(right, depth + 1, right_prefix)?;
let node_hash = if left_hash.is_zero() && right_hash.is_zero() {
B256::ZERO
} else {
self.hasher.hash_64(&left_hash, &right_hash)
};
self.node_hash_cache.insert((depth, path_prefix), node_hash);
Ok(node_hash)
}
fn rebuild_root_incremental(&mut self, dirty_stems: &[Stem]) -> Result<()> {
let mut stem_hashes: Vec<_> = self.stem_hash_cache.iter().map(|(s, h)| (*s, *h)).collect();
stem_hashes.sort_by_key(|(s, _)| *s);
if stem_hashes.is_empty() {
self.root = Node::Empty;
self.root_hash_cached = B256::ZERO;
self.node_hash_cache.clear();
return Ok(());
}
let mut dirty_stems_sorted: Vec<_> = dirty_stems.to_vec();
dirty_stems_sorted.sort();
dirty_stems_sorted.dedup();
let root_hash =
self.incremental_hash_update(&stem_hashes, 0, B256::ZERO, &dirty_stems_sorted)?;
let stems: Vec<_> = stem_hashes.iter().map(|(s, _)| *s).collect();
let root = self.build_tree_from_sorted_stems(&stems, 0)?;
self.root_hash_cached = root_hash;
self.root = root;
Ok(())
}
fn incremental_hash_update(
&mut self,
stem_hashes: &[(Stem, B256)],
depth: usize,
path_prefix: B256,
dirty_stems: &[Stem],
) -> Result<B256> {
if stem_hashes.is_empty() {
if !dirty_stems.is_empty() || self.node_hash_cache.contains_key(&(depth, path_prefix)) {
self.prune_node_hash_cache_subtree(depth, path_prefix);
}
return Ok(B256::ZERO);
}
if dirty_stems.is_empty() {
if let Some(hash) = self.node_hash_cache.get(&(depth, path_prefix)).copied() {
return Ok(hash);
}
return self.build_root_hash_with_cache(stem_hashes, depth, path_prefix);
}
if stem_hashes.len() == 1 {
self.prune_node_hash_cache_descendants(depth, path_prefix);
let hash = stem_hashes[0].1;
self.node_hash_cache.insert((depth, path_prefix), hash);
return Ok(hash);
}
if depth >= MAX_DEPTH {
return Err(UbtError::TreeDepthExceeded { depth });
}
let split_point = stem_hashes.partition_point(|(s, _)| !s.bit_at(depth));
let (left, right) = stem_hashes.split_at(split_point);
#[cfg(debug_assertions)]
{
let mut seen_one = false;
for s in dirty_stems {
if s.bit_at(depth) {
seen_one = true;
} else {
debug_assert!(
!seen_one,
"dirty_stems must be partitioned at depth {depth}",
);
}
}
}
let dirty_split = dirty_stems.partition_point(|s| !s.bit_at(depth));
let (left_dirty, right_dirty) = dirty_stems.split_at(dirty_split);
let right_prefix = set_bit_at(path_prefix, depth);
let left_hash = self.incremental_hash_update(left, depth + 1, path_prefix, left_dirty)?;
let right_hash =
self.incremental_hash_update(right, depth + 1, right_prefix, right_dirty)?;
let node_hash = if left_hash.is_zero() && right_hash.is_zero() {
B256::ZERO
} else {
self.hasher.hash_64(&left_hash, &right_hash)
};
self.node_hash_cache.insert((depth, path_prefix), node_hash);
Ok(node_hash)
}
fn prune_node_hash_cache_descendants(&mut self, depth: usize, path_prefix: B256) {
if depth > MAX_DEPTH {
self.node_hash_cache.clear();
return;
}
self.node_hash_cache.retain(|(d, prefix), _| {
!(*d > depth && b256_matches_prefix(prefix, &path_prefix, depth))
});
}
fn prune_node_hash_cache_subtree(&mut self, depth: usize, path_prefix: B256) {
if depth > MAX_DEPTH {
self.node_hash_cache.clear();
return;
}
self.node_hash_cache.retain(|(d, prefix), _| {
!(*d >= depth && b256_matches_prefix(prefix, &path_prefix, depth))
});
}
pub fn enable_incremental_mode(&mut self) {
if !self.incremental_enabled {
self.incremental_enabled = true;
self.node_hash_cache.clear();
if !self.stem_hash_cache.is_empty() {
for stem in self.stem_hash_cache.keys() {
self.dirty_stem_hashes.insert(*stem);
}
self.root_dirty = true;
}
}
}
pub fn disable_incremental_mode(&mut self) {
self.incremental_enabled = false;
self.node_hash_cache.clear();
}
pub fn is_incremental_enabled(&self) -> bool {
self.incremental_enabled
}
pub fn node_cache_size(&self) -> usize {
self.node_hash_cache.len()
}
pub fn insert_batch(
&mut self,
entries: impl IntoIterator<Item = (TreeKey, B256)>,
) -> Result<()> {
let mut inserted_any = false;
for (key, value) in entries {
inserted_any = true;
let stem_node = self
.stems
.entry(key.stem)
.or_insert_with(|| StemNode::new(key.stem));
stem_node.set_value(key.subindex, value);
self.dirty_stem_hashes.insert(key.stem);
}
if inserted_any {
self.root_dirty = true;
self.rebuild_root()?;
self.root_dirty = false;
}
Ok(())
}
pub fn insert_batch_with_progress(
&mut self,
entries: impl IntoIterator<Item = (TreeKey, B256)>,
mut on_progress: impl FnMut(usize),
) -> Result<()> {
let mut count = 0usize;
let mut inserted_any = false;
for (key, value) in entries {
inserted_any = true;
let stem_node = self
.stems
.entry(key.stem)
.or_insert_with(|| StemNode::new(key.stem));
stem_node.set_value(key.subindex, value);
self.dirty_stem_hashes.insert(key.stem);
count += 1;
on_progress(count);
}
if inserted_any {
self.root_dirty = true;
self.rebuild_root()?;
self.root_dirty = false;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Blake3Hasher;
fn b256_from_zero(overrides: &[(usize, u8)]) -> B256 {
let mut bytes = [0u8; 32];
for &(idx, value) in overrides {
assert!(
idx < bytes.len(),
"byte index out of range: {byte_idx}",
byte_idx = idx
);
bytes[idx] = value;
}
B256::from(bytes)
}
fn b256_from_fill(fill: u8, overrides: &[(usize, u8)]) -> B256 {
let mut bytes = [fill; 32];
for &(idx, value) in overrides {
assert!(
idx < bytes.len(),
"byte index out of range: {byte_idx}",
byte_idx = idx
);
bytes[idx] = value;
}
B256::from(bytes)
}
fn assert_prefix_match(value: B256, prefix_ok: B256, prefix_bad: B256, depth: usize) {
assert!(
b256_matches_prefix(&value, &prefix_ok, depth),
"expected match at depth={depth_bits} (value={val:?}, prefix={prefix:?})",
depth_bits = depth,
val = value,
prefix = prefix_ok,
);
assert!(
!b256_matches_prefix(&value, &prefix_bad, depth),
"expected mismatch at depth={depth_bits} (value={val:?}, prefix={prefix:?})",
depth_bits = depth,
val = value,
prefix = prefix_bad,
);
}
#[test]
fn test_tree_depth_exceeded_returns_error() {
let tree: UnifiedBinaryTree<Blake3Hasher> = UnifiedBinaryTree::new();
let stem1 = Stem::new([0u8; 31]);
let mut stem2_bytes = [0u8; 31];
stem2_bytes[0] = 1;
let stem2 = Stem::new(stem2_bytes);
let stem_hashes = vec![(stem1, B256::repeat_byte(1)), (stem2, B256::repeat_byte(2))];
let err = tree
.build_root_hash_from_stem_hashes(&stem_hashes, MAX_DEPTH)
.unwrap_err();
assert!(matches!(err, UbtError::TreeDepthExceeded { depth } if depth == MAX_DEPTH));
}
#[test]
fn test_b256_matches_prefix_depth_0_matches_everything() {
let a = B256::repeat_byte(0xAA);
let b = B256::repeat_byte(0xBB);
assert!(b256_matches_prefix(&a, &b, 0));
}
#[test]
fn test_b256_matches_prefix_depth_256_requires_full_match() {
let a = B256::repeat_byte(0xAA);
let b = B256::repeat_byte(0xAA);
let c = B256::repeat_byte(0xBB);
assert!(b256_matches_prefix(&a, &b, 256));
assert!(!b256_matches_prefix(&a, &c, 256));
}
#[test]
fn test_b256_matches_prefix_partial_depths() {
assert_prefix_match(
b256_from_zero(&[(0, 0x80)]),
b256_from_zero(&[(0, 0x80)]),
b256_from_zero(&[]),
1,
);
assert_prefix_match(
b256_from_zero(&[(0, 0xAA)]),
b256_from_zero(&[(0, 0xAA)]),
b256_from_zero(&[(0, 0xAB)]),
8,
);
assert_prefix_match(
b256_from_zero(&[(0, 0xAA), (1, 0x80)]),
b256_from_zero(&[(0, 0xAA), (1, 0x80)]),
b256_from_zero(&[(0, 0xAA), (1, 0x00)]),
9,
);
assert_prefix_match(
b256_from_zero(&[(0, 0xAA), (1, 0xFE)]),
b256_from_zero(&[(0, 0xAA), (1, 0xFF)]),
b256_from_zero(&[(0, 0xAA), (1, 0x7E)]),
15,
);
assert_prefix_match(
b256_from_fill(0xAA, &[(31, 0xFE)]),
b256_from_fill(0xAA, &[(31, 0xFF)]),
b256_from_fill(0xAA, &[(31, 0x7E)]),
255,
);
}
#[test]
fn test_prune_node_hash_cache_invalid_depth_clears_cache() {
let mut tree: UnifiedBinaryTree<Blake3Hasher> = UnifiedBinaryTree::new();
tree.node_hash_cache
.insert((0, B256::ZERO), B256::repeat_byte(1));
tree.prune_node_hash_cache_subtree(MAX_DEPTH + 1, B256::ZERO);
assert!(tree.node_hash_cache.is_empty());
}
#[test]
fn test_incremental_delete_prunes_empty_subtree_cache() {
let mut key_right_bytes = [0u8; 32];
key_right_bytes[0] = 0x80;
let key_left = TreeKey::from_bytes(B256::ZERO);
let key_right = TreeKey::from_bytes(B256::from_slice(&key_right_bytes));
let left_value = B256::repeat_byte(0x11);
let right_value = B256::repeat_byte(0x22);
let mut tree_inc: UnifiedBinaryTree<Blake3Hasher> = UnifiedBinaryTree::new();
tree_inc.insert(key_left, left_value);
tree_inc.insert(key_right, right_value);
tree_inc.enable_incremental_mode();
tree_inc.root_hash().unwrap();
let right_prefix = set_bit_at(B256::ZERO, 0);
assert!(tree_inc.node_hash_cache.contains_key(&(1, right_prefix)));
tree_inc.delete(&key_right);
let root_inc = tree_inc.root_hash().unwrap();
let mut tree_full: UnifiedBinaryTree<Blake3Hasher> = UnifiedBinaryTree::new();
tree_full.insert(key_left, left_value);
tree_full.insert(key_right, right_value);
tree_full.delete(&key_right);
let root_full = tree_full.root_hash().unwrap();
assert_eq!(root_inc, root_full);
let has_right_cache_entries =
tree_inc.node_hash_cache.iter().any(|((depth, prefix), _)| {
*depth >= 1 && b256_matches_prefix(prefix, &right_prefix, 1)
});
assert!(!has_right_cache_entries);
}
#[test]
fn test_incremental_hash_update_prunes_cached_empty_subtree_without_dirty_info() {
let mut key_right_bytes = [0u8; 32];
key_right_bytes[0] = 0x80;
let key_left = TreeKey::from_bytes(B256::ZERO);
let key_right = TreeKey::from_bytes(B256::from_slice(&key_right_bytes));
let mut tree: UnifiedBinaryTree<Blake3Hasher> = UnifiedBinaryTree::new();
tree.insert(key_left, B256::repeat_byte(0x11));
tree.insert(key_right, B256::repeat_byte(0x22));
tree.enable_incremental_mode();
tree.root_hash().unwrap();
let right_prefix = set_bit_at(B256::ZERO, 0);
assert!(tree.node_hash_cache.contains_key(&(1, right_prefix)));
let () = tree.delete(&key_right);
assert!(tree.get(&key_right).is_none());
tree.dirty_stem_hashes.clear();
let out = tree
.incremental_hash_update(&[], 1, right_prefix, &[])
.unwrap();
assert_eq!(out, B256::ZERO);
let has_right_cache_entries = tree.node_hash_cache.iter().any(|((depth, prefix), _)| {
*depth >= 1 && b256_matches_prefix(prefix, &right_prefix, 1)
});
assert!(!has_right_cache_entries);
}
}