use alloy_primitives::B256;
use std::collections::HashMap;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
use crate::{error::Result, Blake3Hasher, Hasher, Stem, SubIndex, TreeKey, UbtError, STEM_LEN};
#[cfg(feature = "parallel")]
const PARALLEL_STEM_THRESHOLD: usize = 100;
pub struct StreamingTreeBuilder<H: Hasher = Blake3Hasher> {
hasher: H,
}
impl<H: Hasher> Default for StreamingTreeBuilder<H> {
fn default() -> Self {
Self::new()
}
}
impl<H: Hasher> StreamingTreeBuilder<H> {
pub fn new() -> Self {
Self {
hasher: H::default(),
}
}
pub fn with_hasher(hasher: H) -> Self {
Self { hasher }
}
#[must_use = "callers should handle errors and use the computed root hash"]
pub fn build_root_hash(
&self,
entries: impl IntoIterator<Item = (TreeKey, B256)>,
) -> Result<B256> {
let mut entries_iter = entries.into_iter().peekable();
if entries_iter.peek().is_none() {
return Ok(B256::ZERO);
}
let stem_hashes = self.collect_stem_hashes(&mut entries_iter);
if stem_hashes.is_empty() {
return Ok(B256::ZERO);
}
self.build_tree_hash(&stem_hashes, 0)
}
#[cfg(feature = "parallel")]
#[must_use = "callers should handle errors and use the computed root hash"]
pub fn build_root_hash_parallel(
&self,
entries: impl IntoIterator<Item = (TreeKey, B256)>,
) -> Result<B256> {
let mut entries_iter = entries.into_iter().peekable();
if entries_iter.peek().is_none() {
return Ok(B256::ZERO);
}
let stem_groups = Self::collect_stem_groups(&mut entries_iter);
if stem_groups.is_empty() {
return Ok(B256::ZERO);
}
let mut stem_hashes: Vec<(Stem, B256)> = if stem_groups.len() >= PARALLEL_STEM_THRESHOLD {
stem_groups
.into_par_iter()
.map(|(stem, values)| {
let hash = self.compute_stem_hash(&stem, &values);
(stem, hash)
})
.collect()
} else {
stem_groups
.into_iter()
.map(|(stem, values)| {
let hash = self.compute_stem_hash(&stem, &values);
(stem, hash)
})
.collect()
};
stem_hashes.sort_by(|a, b| a.0.cmp(&b.0));
self.build_tree_hash(&stem_hashes, 0)
}
#[cfg(feature = "parallel")]
fn collect_stem_groups<I: Iterator<Item = (TreeKey, B256)>>(
entries: &mut std::iter::Peekable<I>,
) -> Vec<(Stem, HashMap<SubIndex, B256>)> {
let mut stem_groups: Vec<(Stem, HashMap<SubIndex, B256>)> = Vec::new();
let mut current_stem: Option<Stem> = None;
let mut current_values: HashMap<SubIndex, B256> = HashMap::new();
#[cfg(debug_assertions)]
let mut prev_key: Option<TreeKey> = None;
for (key, value) in entries.by_ref() {
#[cfg(debug_assertions)]
{
if let Some(prev) = prev_key {
debug_assert!(
(prev.stem, prev.subindex) < (key.stem, key.subindex),
"Entries must be sorted: {prev:?} should come before {key:?}",
);
}
prev_key = Some(key);
}
match current_stem {
Some(stem) if stem == key.stem => {
if !value.is_zero() {
current_values.insert(key.subindex, value);
}
}
Some(stem) => {
if !current_values.is_empty() {
stem_groups.push((stem, std::mem::take(&mut current_values)));
}
current_stem = Some(key.stem);
if !value.is_zero() {
current_values.insert(key.subindex, value);
}
}
None => {
current_stem = Some(key.stem);
if !value.is_zero() {
current_values.insert(key.subindex, value);
}
}
}
}
if let Some(stem) = current_stem {
if !current_values.is_empty() {
stem_groups.push((stem, current_values));
}
}
stem_groups
}
fn collect_stem_hashes<I: Iterator<Item = (TreeKey, B256)>>(
&self,
entries: &mut std::iter::Peekable<I>,
) -> Vec<(Stem, B256)> {
let mut stem_hashes: Vec<(Stem, B256)> = Vec::new();
let mut current_stem: Option<Stem> = None;
let mut current_values: HashMap<SubIndex, B256> = HashMap::new();
#[cfg(debug_assertions)]
let mut prev_key: Option<TreeKey> = None;
for (key, value) in entries.by_ref() {
#[cfg(debug_assertions)]
{
if let Some(prev) = prev_key {
debug_assert!(
(prev.stem, prev.subindex) < (key.stem, key.subindex),
"Entries must be sorted: {prev:?} should come before {key:?}",
);
}
prev_key = Some(key);
}
match current_stem {
Some(stem) if stem == key.stem => {
if !value.is_zero() {
current_values.insert(key.subindex, value);
}
}
Some(stem) => {
if !current_values.is_empty() {
let hash = self.compute_stem_hash(&stem, ¤t_values);
stem_hashes.push((stem, hash));
}
current_values.clear();
current_stem = Some(key.stem);
if !value.is_zero() {
current_values.insert(key.subindex, value);
}
}
None => {
current_stem = Some(key.stem);
if !value.is_zero() {
current_values.insert(key.subindex, value);
}
}
}
}
if let Some(stem) = current_stem {
if !current_values.is_empty() {
let hash = self.compute_stem_hash(&stem, ¤t_values);
stem_hashes.push((stem, hash));
}
}
stem_hashes
}
fn compute_stem_hash(&self, stem: &Stem, values: &HashMap<SubIndex, B256>) -> B256 {
let mut data = [B256::ZERO; 256];
for (&idx, &value) in values {
data[idx as usize] = self.hasher.hash_32(&value);
}
for level in 1..=8 {
let pairs = 256 >> level;
for i in 0..pairs {
let left = data[i * 2];
let right = data[i * 2 + 1];
data[i] = self.hasher.hash_64(&left, &right);
}
}
let subtree_root = data[0];
self.hasher.hash_stem_node(stem.as_bytes(), &subtree_root)
}
fn build_tree_hash(&self, stem_hashes: &[(Stem, B256)], depth: usize) -> Result<B256> {
if stem_hashes.is_empty() {
return Ok(B256::ZERO);
}
if stem_hashes.len() == 1 {
return Ok(stem_hashes[0].1);
}
if depth >= STEM_LEN * 8 {
return Err(UbtError::TreeDepthExceeded { depth });
}
let split_point = stem_hashes.partition_point(|(s, _)| !s.bit_at(depth));
let (left, right) = stem_hashes.split_at(split_point);
let left_hash = self.build_tree_hash(left, depth + 1)?;
let right_hash = self.build_tree_hash(right, depth + 1)?;
if left_hash.is_zero() && right_hash.is_zero() {
Ok(B256::ZERO)
} else {
Ok(self.hasher.hash_64(&left_hash, &right_hash))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::UnifiedBinaryTree;
#[test]
fn test_streaming_empty() {
let builder: StreamingTreeBuilder<Blake3Hasher> = StreamingTreeBuilder::new();
let entries: Vec<(TreeKey, B256)> = vec![];
assert_eq!(builder.build_root_hash(entries).unwrap(), B256::ZERO);
}
#[test]
fn test_streaming_single_entry() {
let builder: StreamingTreeBuilder<Blake3Hasher> = StreamingTreeBuilder::new();
let key = TreeKey::from_bytes(B256::repeat_byte(0x01));
let value = B256::repeat_byte(0x42);
let entries = vec![(key, value)];
let streaming_root = builder.build_root_hash(entries).unwrap();
let mut tree: UnifiedBinaryTree<Blake3Hasher> = UnifiedBinaryTree::new();
tree.insert(key, value);
assert_eq!(streaming_root, tree.root_hash().unwrap());
}
#[test]
fn test_streaming_matches_tree() {
let builder: StreamingTreeBuilder<Blake3Hasher> = StreamingTreeBuilder::new();
let mut entries: Vec<(TreeKey, B256)> = Vec::new();
for i in 0u8..10 {
let mut stem_bytes = [0u8; 31];
stem_bytes[0] = i * 10;
let stem = Stem::new(stem_bytes);
for j in 0u8..5 {
let key = TreeKey::new(stem, j);
let value = B256::repeat_byte(i + j);
entries.push((key, value));
}
}
entries.sort_by(|a, b| (a.0.stem, a.0.subindex).cmp(&(b.0.stem, b.0.subindex)));
let streaming_root = builder.build_root_hash(entries.clone()).unwrap();
let mut tree: UnifiedBinaryTree<Blake3Hasher> = UnifiedBinaryTree::new();
tree.insert_batch(entries).unwrap();
assert_eq!(streaming_root, tree.root_hash().unwrap());
}
#[test]
fn test_streaming_many_stems() {
let builder: StreamingTreeBuilder<Blake3Hasher> = StreamingTreeBuilder::new();
let mut entries: Vec<(TreeKey, B256)> = Vec::new();
for i in 1u8..=100 {
let mut stem_bytes = [0u8; 31];
stem_bytes[0] = i;
stem_bytes[15] = i.wrapping_mul(7);
let stem = Stem::new(stem_bytes);
let key = TreeKey::new(stem, 0);
entries.push((key, B256::repeat_byte(i)));
}
entries.sort_by(|a, b| (a.0.stem, a.0.subindex).cmp(&(b.0.stem, b.0.subindex)));
let streaming_root = builder.build_root_hash(entries.clone()).unwrap();
let mut tree: UnifiedBinaryTree<Blake3Hasher> = UnifiedBinaryTree::new();
tree.insert_batch(entries).unwrap();
assert_eq!(streaming_root, tree.root_hash().unwrap());
}
#[test]
fn test_tree_depth_exceeded_returns_error() {
let builder: StreamingTreeBuilder<Blake3Hasher> = StreamingTreeBuilder::new();
let stem1 = Stem::new([0u8; 31]);
let mut stem2_bytes = [0u8; 31];
stem2_bytes[0] = 1;
let stem2 = Stem::new(stem2_bytes);
let stem_hashes = vec![(stem1, B256::repeat_byte(1)), (stem2, B256::repeat_byte(2))];
let err = builder
.build_tree_hash(&stem_hashes, STEM_LEN * 8)
.unwrap_err();
assert!(matches!(err, UbtError::TreeDepthExceeded { depth } if depth == STEM_LEN * 8));
}
#[cfg(feature = "parallel")]
#[test]
fn test_parallel_matches_serial() {
let builder: StreamingTreeBuilder<Blake3Hasher> = StreamingTreeBuilder::new();
let mut entries: Vec<(TreeKey, B256)> = Vec::new();
for i in 0u8..50 {
let mut stem_bytes = [0u8; 31];
stem_bytes[0] = i;
stem_bytes[10] = i.wrapping_mul(3);
stem_bytes[20] = i.wrapping_mul(7);
let stem = Stem::new(stem_bytes);
for j in 0u8..10 {
let key = TreeKey::new(stem, j);
let value = B256::repeat_byte(i.wrapping_add(j).wrapping_mul(2).max(1));
entries.push((key, value));
}
}
entries.sort_by(|a, b| (a.0.stem, a.0.subindex).cmp(&(b.0.stem, b.0.subindex)));
let serial_root = builder.build_root_hash(entries.clone()).unwrap();
let parallel_root = builder.build_root_hash_parallel(entries).unwrap();
assert_eq!(
parallel_root, serial_root,
"Parallel and serial should produce identical root hashes"
);
}
#[cfg(feature = "parallel")]
#[test]
fn test_parallel_empty() {
let builder: StreamingTreeBuilder<Blake3Hasher> = StreamingTreeBuilder::new();
let entries: Vec<(TreeKey, B256)> = vec![];
assert_eq!(
builder.build_root_hash_parallel(entries).unwrap(),
B256::ZERO
);
}
#[cfg(feature = "parallel")]
#[test]
fn test_parallel_single_entry() {
let builder: StreamingTreeBuilder<Blake3Hasher> = StreamingTreeBuilder::new();
let key = TreeKey::from_bytes(B256::repeat_byte(0x01));
let value = B256::repeat_byte(0x42);
let entries = vec![(key, value)];
let parallel_root = builder.build_root_hash_parallel(entries.clone()).unwrap();
let serial_root = builder.build_root_hash(entries).unwrap();
assert_eq!(parallel_root, serial_root);
}
}