use alloc::{collections::BTreeSet, vec::Vec};
use super::{EmptySubtreeRoots, MerkleError, NodeIndex, SmtLeaf, SmtProof, Word};
use crate::{
Map,
merkle::smt::{LeafIndex, SMT_DEPTH, SmtLeafError, SmtProofError, forest::store::SmtStore},
};
mod store;
#[cfg(test)]
mod tests;
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct SmtForest {
roots: BTreeSet<Word>,
store: SmtStore,
leaves: Map<Word, SmtLeaf>,
}
impl Default for SmtForest {
fn default() -> Self {
Self::new()
}
}
impl SmtForest {
pub fn new() -> SmtForest {
let roots = BTreeSet::new();
let store = SmtStore::new();
let leaves = Map::new();
SmtForest { roots, store, leaves }
}
pub fn open(&self, root: Word, key: Word) -> Result<SmtProof, MerkleError> {
if !self.contains_root(root) {
return Err(MerkleError::RootNotInStore(root));
}
let leaf_index = NodeIndex::from(LeafIndex::from(key));
let proof = self.store.get_path(root, leaf_index)?;
let path = proof.path.try_into()?;
let leaf_hash = proof.value;
let leaf = if leaf_hash == crate::EMPTY_WORD {
SmtLeaf::new_empty(LeafIndex::from(key))
} else {
let Some(leaf) = self.leaves.get(&leaf_hash).cloned() else {
return Err(MerkleError::UntrackedKey(key));
};
leaf
};
SmtProof::new(path, leaf).map_err(|error| match error {
SmtProofError::InvalidMerklePathLength(depth) => MerkleError::InvalidPathLength(depth),
SmtProofError::InvalidKeyForProof
| SmtProofError::ValueMismatch { .. }
| SmtProofError::ConflictingRoots { .. }
| SmtProofError::ValuePresent { .. } => unreachable!(),
})
}
pub fn insert(&mut self, root: Word, key: Word, value: Word) -> Result<Word, MerkleError> {
self.batch_insert(root, vec![(key, value)])
}
pub fn batch_insert(
&mut self,
root: Word,
entries: impl IntoIterator<Item = (Word, Word)> + Clone,
) -> Result<Word, MerkleError> {
if !self.contains_root(root) {
return Err(MerkleError::RootNotInStore(root));
}
let indices = entries
.clone()
.into_iter()
.map(|(key, _)| LeafIndex::from(key))
.collect::<BTreeSet<_>>();
let mut new_leaves = Map::new();
for index in indices {
let node_index = NodeIndex::from(index);
let current_hash = self.store.get_node(root, node_index)?;
let current_leaf = self
.leaves
.get(¤t_hash)
.cloned()
.unwrap_or_else(|| SmtLeaf::new_empty(index));
new_leaves.insert(index, (current_hash, current_leaf));
}
for (key, value) in entries {
let index = LeafIndex::from(key);
let (_old_hash, leaf) = new_leaves.get_mut(&index).unwrap();
if value == crate::EMPTY_WORD {
let _ = leaf.remove(key);
} else {
leaf.insert(key, value).map_err(to_merkle_error)?;
}
}
new_leaves = new_leaves
.into_iter()
.filter_map(|(key, (old_hash, leaf))| {
let new_hash = leaf.hash();
if new_hash == old_hash {
None
} else {
Some((key, (new_hash, leaf)))
}
})
.collect();
let new_leaf_entries =
new_leaves.iter().map(|(index, leaf)| (NodeIndex::from(*index), leaf.0));
let new_root = self.store.set_leaves(root, new_leaf_entries)?;
for (leaf_hash, leaf) in new_leaves.into_values() {
if leaf_hash != crate::EMPTY_WORD {
self.leaves.insert(leaf_hash, leaf);
}
}
if new_root != *EmptySubtreeRoots::entry(SMT_DEPTH, 0) {
self.roots.insert(new_root);
}
Ok(new_root)
}
pub fn pop_smts(&mut self, roots: impl IntoIterator<Item = Word>) {
let roots = roots
.into_iter()
.filter(|root| {
self.roots.contains(root)
})
.collect::<Vec<_>>();
for root in &roots {
self.roots.remove(root);
}
for leaf in self.store.remove_roots(roots) {
self.leaves.remove(&leaf);
}
}
fn contains_root(&self, root: Word) -> bool {
self.roots.contains(&root) || *EmptySubtreeRoots::entry(SMT_DEPTH, 0) == root
}
}
fn to_merkle_error(err: SmtLeafError) -> MerkleError {
match err {
SmtLeafError::TooManyLeafEntries { actual } => MerkleError::TooManyLeafEntries { actual },
_ => unreachable!("other SmtLeafError variants should not be possible here"),
}
}