use alloc::vec::Vec;
use core::mem;
use num::Integer;
use p3_maybe_rayon::prelude::*;
use super::{
EmptySubtreeRoots, InnerNode, InnerNodes, Leaves, MerkleError, MutationSet, NodeIndex,
SMT_DEPTH, Smt, SmtLeaf, SparseMerkleTreeReader, Word,
};
use crate::merkle::smt::{Map, NodeMutation, NodeMutations, SmtLeafError};
#[cfg(test)]
mod tests;
pub(in crate::merkle::smt) type MutatedSubtreeLeaves = Vec<Vec<SubtreeLeaf>>;
impl Smt {
pub(crate) fn with_entries_concurrent(
entries: impl IntoIterator<Item = (Word, Word)>,
) -> Result<Self, MerkleError> {
let entries: Vec<(Word, Word)> = entries.into_iter().collect();
if entries.is_empty() {
return Ok(Self::default());
}
let (inner_nodes, leaves) = Self::build_subtrees(entries)?;
if inner_nodes.is_empty() {
return Ok(Self::default());
}
let root = inner_nodes.get(&NodeIndex::root()).unwrap().hash();
Ok(Self::from_raw_parts(inner_nodes, leaves, root))
}
pub(crate) fn with_sorted_entries_concurrent(
entries: impl IntoIterator<Item = (Word, Word)>,
) -> Result<Self, MerkleError> {
let entries: Vec<(Word, Word)> = entries.into_iter().collect();
if entries.is_empty() {
return Ok(Self::default());
}
let (inner_nodes, leaves) = build_subtrees_from_sorted_entries(entries)?;
if inner_nodes.is_empty() {
return Ok(Self::default());
}
let root = inner_nodes.get(&NodeIndex::root()).unwrap().hash();
Ok(Self::from_raw_parts(inner_nodes, leaves, root))
}
pub(crate) fn compute_mutations_concurrent(
&self,
kv_pairs: impl IntoIterator<Item = (Word, Word)>,
) -> Result<MutationSet<SMT_DEPTH, Word, Word>, MerkleError>
where
Self: Sized + Sync,
{
let mut sorted_kv_pairs: Vec<_> = kv_pairs.into_iter().collect();
sorted_kv_pairs.par_sort_unstable_by_key(|(key, _)| *key);
Self::check_for_duplicate_keys(&sorted_kv_pairs)?;
let (mut subtree_leaves, new_pairs) =
self.sorted_pairs_to_mutated_subtree_leaves(sorted_kv_pairs)?;
if subtree_leaves.is_empty() {
return Ok(MutationSet {
old_root: self.root(),
new_root: self.root(),
node_mutations: NodeMutations::default(),
new_pairs,
});
}
let mut node_mutations = NodeMutations::default();
for depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
let (mutations_per_subtree, mut subtree_roots): (Vec<_>, Vec<_>) = subtree_leaves
.into_par_iter()
.map(|subtree| {
debug_assert!(subtree.is_sorted() && !subtree.is_empty());
self.build_subtree_mutations(subtree, SMT_DEPTH, depth)
})
.unzip();
subtree_leaves = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect();
node_mutations.extend(mutations_per_subtree.into_iter().flatten());
debug_assert!(!subtree_leaves.is_empty());
}
let new_root = subtree_leaves[0][0].hash;
let mutation_set = MutationSet {
old_root: self.root(),
new_root,
node_mutations,
new_pairs,
};
debug_assert!(
!mutation_set.node_mutations().is_empty() && !mutation_set.new_pairs().is_empty()
);
Ok(mutation_set)
}
fn build_subtree_mutations(
&self,
mut leaves: Vec<SubtreeLeaf>,
tree_depth: u8,
bottom_depth: u8,
) -> (NodeMutations, SubtreeLeaf)
where
Self: Sized,
{
debug_assert!(bottom_depth <= tree_depth);
debug_assert!(Integer::is_multiple_of(&bottom_depth, &SUBTREE_DEPTH));
debug_assert!(leaves.len() <= usize::pow(2, SUBTREE_DEPTH as u32));
let subtree_root_depth = bottom_depth - SUBTREE_DEPTH;
let mut node_mutations: NodeMutations = Default::default();
let mut next_leaves: Vec<SubtreeLeaf> = Vec::with_capacity(leaves.len() / 2);
for current_depth in (subtree_root_depth..bottom_depth).rev() {
debug_assert!(current_depth <= bottom_depth);
let next_depth = current_depth + 1;
let mut iter = leaves.drain(..).peekable();
while let Some(first_leaf) = iter.next() {
let parent_index = NodeIndex::new_unchecked(next_depth, first_leaf.col).parent();
let parent_node = self.get_inner_node(parent_index);
let combined_node = fetch_sibling_pair(&mut iter, first_leaf, parent_node);
let combined_hash = combined_node.hash();
let &empty_hash = EmptySubtreeRoots::entry(tree_depth, current_depth);
next_leaves.push(SubtreeLeaf {
col: parent_index.position(),
hash: combined_hash,
});
node_mutations.insert(
parent_index,
if combined_hash != empty_hash {
NodeMutation::Addition(combined_node)
} else {
NodeMutation::Removal
},
);
}
drop(iter);
leaves = mem::take(&mut next_leaves);
}
debug_assert_eq!(leaves.len(), 1);
let root_leaf = leaves.pop().unwrap();
(node_mutations, root_leaf)
}
fn build_subtrees(mut entries: Vec<(Word, Word)>) -> Result<(InnerNodes, Leaves), MerkleError> {
entries.par_sort_unstable_by_key(|item| {
let index = Self::key_to_leaf_index(&item.0);
index.position()
});
build_subtrees_from_sorted_entries(entries)
}
pub(in crate::merkle::smt) fn sorted_pairs_to_leaves(
pairs: Vec<(Word, Word)>,
) -> Result<PairComputations<u64, SmtLeaf>, MerkleError> {
process_sorted_pairs_to_leaves(pairs, Self::pairs_to_leaf)
}
fn pairs_to_leaf(mut pairs: Vec<(Word, Word)>) -> Result<Option<SmtLeaf>, MerkleError> {
assert!(!pairs.is_empty());
if pairs.len() > 1 {
pairs.sort_by(|(key_1, _), (key_2, _)| key_1.cmp(key_2));
Self::check_for_duplicate_keys(&pairs)?;
Ok(Some(SmtLeaf::new_multiple(pairs).unwrap()))
} else {
let (key, value) = pairs.pop().unwrap();
if value == Self::EMPTY_VALUE {
Ok(None)
} else {
Ok(Some(SmtLeaf::new_single(key, value)))
}
}
}
fn sorted_pairs_to_mutated_subtree_leaves(
&self,
pairs: Vec<(Word, Word)>,
) -> Result<(MutatedSubtreeLeaves, Map<Word, Word>), MerkleError> {
let mut new_pairs = Map::new();
let accumulator = process_sorted_pairs_to_leaves(pairs, |leaf_pairs| {
let mut leaf = self.get_leaf(&leaf_pairs[0].0);
let mut leaf_changed = false;
for (key, value) in leaf_pairs {
let old_value = new_pairs.get(&key).cloned().unwrap_or_else(|| {
leaf.get_value(&key).unwrap()
});
if value != old_value {
leaf = self.construct_prospective_leaf(leaf, &key, &value).map_err(
|e| match e {
SmtLeafError::TooManyLeafEntries { actual } => {
MerkleError::TooManyLeafEntries { actual }
},
other => panic!("unexpected SmtLeaf::insert error: {:?}", other),
},
)?;
new_pairs.insert(key, value);
leaf_changed = true;
}
}
if leaf_changed {
Ok(Some(leaf))
} else {
Ok(None)
}
});
Ok((accumulator?.leaves, new_pairs))
}
}
pub(in crate::merkle::smt) const SUBTREE_DEPTH: u8 = 8;
pub(in crate::merkle::smt) const COLS_PER_SUBTREE: u64 = u64::pow(2, SUBTREE_DEPTH as u32);
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)]
pub struct SubtreeLeaf {
pub col: u64,
pub hash: Word,
}
#[derive(Debug, Clone)]
pub(in crate::merkle::smt) struct PairComputations<K, L> {
pub nodes: Map<K, L>,
pub leaves: Vec<Vec<SubtreeLeaf>>,
}
impl<K, L> Default for PairComputations<K, L> {
fn default() -> Self {
Self {
nodes: Default::default(),
leaves: Default::default(),
}
}
}
#[derive(Debug)]
pub(in crate::merkle::smt) struct SubtreeLeavesIter<'s> {
leaves: core::iter::Peekable<alloc::vec::Drain<'s, SubtreeLeaf>>,
}
impl<'s> SubtreeLeavesIter<'s> {
pub(crate) fn from_leaves(leaves: &'s mut Vec<SubtreeLeaf>) -> Self {
Self { leaves: leaves.drain(..).peekable() }
}
}
impl Iterator for SubtreeLeavesIter<'_> {
type Item = Vec<SubtreeLeaf>;
fn next(&mut self) -> Option<Vec<SubtreeLeaf>> {
let mut subtree: Vec<SubtreeLeaf> = Default::default();
let mut last_subtree_col = 0;
while let Some(leaf) = self.leaves.peek() {
last_subtree_col = u64::max(1, last_subtree_col);
let is_exact_multiple = Integer::is_multiple_of(&last_subtree_col, &COLS_PER_SUBTREE);
let next_subtree_col = if is_exact_multiple {
u64::next_multiple_of(last_subtree_col + 1, COLS_PER_SUBTREE)
} else {
last_subtree_col.next_multiple_of(COLS_PER_SUBTREE)
};
last_subtree_col = leaf.col;
if leaf.col < next_subtree_col {
subtree.push(self.leaves.next().unwrap());
} else if subtree.is_empty() {
continue;
} else {
break;
}
}
if subtree.is_empty() {
debug_assert!(self.leaves.peek().is_none());
return None;
}
Some(subtree)
}
}
pub(crate) fn process_sorted_pairs_to_leaves<F>(
pairs: Vec<(Word, Word)>,
mut process_leaf: F,
) -> Result<PairComputations<u64, SmtLeaf>, MerkleError>
where
F: FnMut(Vec<(Word, Word)>) -> Result<Option<SmtLeaf>, MerkleError>,
{
debug_assert!(pairs.is_sorted_by_key(|(key, _)| Smt::key_to_leaf_index(key).position()));
let mut accumulator: PairComputations<u64, SmtLeaf> = Default::default();
let mut current_leaf_buffer: Vec<(Word, Word)> = Default::default();
let mut iter = pairs.into_iter().peekable();
while let Some((key, value)) = iter.next() {
let col = Smt::key_to_leaf_index(&key).index.position();
let peeked_col = iter.peek().map(|(key, _v)| {
let index = Smt::key_to_leaf_index(key);
let next_col = index.index.position();
debug_assert!(next_col >= col);
next_col
});
current_leaf_buffer.push((key, value));
if peeked_col == Some(col) {
continue;
}
let leaf_pairs = mem::take(&mut current_leaf_buffer);
match process_leaf(leaf_pairs) {
Ok(Some(leaf)) => {
accumulator.nodes.insert(col, leaf);
},
Ok(None) => {
},
Err(e) => return Err(e),
}
debug_assert!(current_leaf_buffer.is_empty());
}
let mut accumulated_leaves: Vec<SubtreeLeaf> = accumulator
.nodes
.clone()
.into_par_iter()
.map(|(col, leaf)| SubtreeLeaf { col, hash: Smt::hash_leaf(&leaf) })
.collect();
accumulated_leaves.par_sort_by_key(|leaf| leaf.col);
accumulator.leaves = SubtreeLeavesIter::from_leaves(&mut accumulated_leaves).collect();
Ok(accumulator)
}
fn build_subtrees_from_sorted_entries(
entries: Vec<(Word, Word)>,
) -> Result<(InnerNodes, Leaves), MerkleError> {
let mut accumulated_nodes: InnerNodes = Default::default();
let PairComputations {
leaves: mut leaf_subtrees,
nodes: initial_leaves,
} = Smt::sorted_pairs_to_leaves(entries)?;
if initial_leaves.is_empty() {
return Ok((accumulated_nodes, initial_leaves));
}
for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
let (nodes, mut subtree_roots): (Vec<Map<_, _>>, Vec<SubtreeLeaf>) = leaf_subtrees
.into_par_iter()
.map(|subtree| {
debug_assert!(subtree.is_sorted());
debug_assert!(!subtree.is_empty());
let (nodes, subtree_root) = build_subtree(subtree, SMT_DEPTH, current_depth);
(nodes, subtree_root)
})
.unzip();
leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect();
accumulated_nodes.extend(nodes.into_iter().flatten());
debug_assert!(!leaf_subtrees.is_empty());
}
Ok((accumulated_nodes, initial_leaves))
}
pub(crate) fn build_subtree(
mut leaves: Vec<SubtreeLeaf>,
tree_depth: u8,
bottom_depth: u8,
) -> (Map<NodeIndex, InnerNode>, SubtreeLeaf) {
#[cfg(debug_assertions)]
{
use alloc::collections::BTreeSet;
let mut seen_cols = BTreeSet::new();
for leaf in &leaves {
assert!(seen_cols.insert(leaf.col), "Duplicate column found in subtree: {}", leaf.col);
}
}
debug_assert!(bottom_depth <= tree_depth);
debug_assert!(Integer::is_multiple_of(&bottom_depth, &SUBTREE_DEPTH));
debug_assert!(leaves.len() <= usize::pow(2, SUBTREE_DEPTH as u32));
let subtree_root = bottom_depth - SUBTREE_DEPTH;
let mut inner_nodes: Map<NodeIndex, InnerNode> = Default::default();
let mut next_leaves: Vec<SubtreeLeaf> = Vec::with_capacity(leaves.len() / 2);
for next_depth in (subtree_root..bottom_depth).rev() {
debug_assert!(next_depth <= bottom_depth);
let current_depth = next_depth + 1;
let mut iter = leaves.drain(..).peekable();
while let Some(first) = iter.next() {
let is_right = first.col.is_odd();
let (left, right) = if is_right {
let left = SubtreeLeaf {
col: first.col - 1,
hash: *EmptySubtreeRoots::entry(tree_depth, current_depth),
};
let right = first;
(left, right)
} else {
let left = first;
let right_col = first.col + 1;
let right = match iter.peek().copied() {
Some(SubtreeLeaf { col, .. }) if col == right_col => {
debug_assert!(left.col <= col);
iter.next().unwrap()
},
_ => SubtreeLeaf {
col: right_col,
hash: *EmptySubtreeRoots::entry(tree_depth, current_depth),
},
};
(left, right)
};
let index = NodeIndex::new_unchecked(current_depth, left.col).parent();
let node = InnerNode { left: left.hash, right: right.hash };
let hash = node.hash();
let &equivalent_empty_hash = EmptySubtreeRoots::entry(tree_depth, next_depth);
if hash != equivalent_empty_hash {
inner_nodes.insert(index, node);
next_leaves.push(SubtreeLeaf { col: index.position(), hash });
}
}
drop(iter);
mem::swap(&mut leaves, &mut next_leaves);
}
debug_assert_eq!(leaves.len(), 1);
let root = leaves.pop().unwrap();
(inner_nodes, root)
}
pub(crate) fn fetch_sibling_pair(
iter: &mut core::iter::Peekable<alloc::vec::Drain<SubtreeLeaf>>,
first_leaf: SubtreeLeaf,
parent_node: InnerNode,
) -> InnerNode {
let is_right_node = first_leaf.col.is_odd();
if is_right_node {
let left_leaf = SubtreeLeaf {
col: first_leaf.col - 1,
hash: parent_node.left,
};
InnerNode {
left: left_leaf.hash,
right: first_leaf.hash,
}
} else {
let right_col = first_leaf.col + 1;
let right_leaf = match iter.peek().copied() {
Some(SubtreeLeaf { col, .. }) if col == right_col => iter.next().unwrap(),
_ => SubtreeLeaf { col: right_col, hash: parent_node.right },
};
InnerNode {
left: first_leaf.hash,
right: right_leaf.hash,
}
}
}
#[cfg(feature = "internal")]
pub fn build_subtree_for_bench(
leaves: Vec<SubtreeLeaf>,
tree_depth: u8,
bottom_depth: u8,
) -> (Map<NodeIndex, InnerNode>, SubtreeLeaf) {
build_subtree(leaves, tree_depth, bottom_depth)
}