use alloc::vec::Vec;
use crate::{
Map, Word,
hash::poseidon2::Poseidon2,
merkle::{EmptySubtreeRoots, MerkleError, MerklePath, MerkleProof, NodeIndex, smt::SMT_DEPTH},
};
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq)]
struct ForestInnerNode {
left: Word,
right: Word,
rc: usize,
}
impl ForestInnerNode {
pub fn hash(&self) -> Word {
Poseidon2::merge(&[self.left, self.right])
}
}
#[derive(Debug, Default, Clone, Eq, PartialEq)]
pub(super) struct SmtStore {
nodes: Map<Word, ForestInnerNode>,
}
impl SmtStore {
pub fn new() -> Self {
let nodes = empty_hashes().collect();
Self { nodes }
}
pub fn get_node(&self, root: Word, index: NodeIndex) -> Result<Word, MerkleError> {
let mut hash = root;
self.nodes.get(&hash).ok_or(MerkleError::RootNotInStore(hash))?;
for i in (0..index.depth()).rev() {
let node = self
.nodes
.get(&hash)
.ok_or(MerkleError::NodeIndexNotFoundInStore(hash, index))?;
hash = if index.is_nth_bit_odd(i) { node.right } else { node.left }
}
Ok(hash)
}
pub fn get_path(&self, root: Word, index: NodeIndex) -> Result<MerkleProof, MerkleError> {
let IndexedPath { value, path } = self.get_indexed_path(root, index)?;
let path_iter = path.into_iter().rev().map(|(_, value)| value);
Ok(MerkleProof::new(value, MerklePath::from_iter(path_iter)))
}
fn get_indexed_path(&self, root: Word, index: NodeIndex) -> Result<IndexedPath, MerkleError> {
let mut hash = root;
let mut path = Vec::with_capacity(index.depth().into());
self.nodes.get(&hash).ok_or(MerkleError::RootNotInStore(hash))?;
let mut current_index = NodeIndex::root();
for i in (0..index.depth()).rev() {
let node = self
.nodes
.get(&hash)
.ok_or(MerkleError::NodeIndexNotFoundInStore(hash, index))?;
hash = if index.is_nth_bit_odd(i) {
path.push((current_index.left_child(), node.left));
current_index = current_index.right_child();
node.right
} else {
path.push((current_index.right_child(), node.right));
current_index = current_index.left_child();
node.left
}
}
Ok(IndexedPath { value: hash, path })
}
pub fn set_leaves(
&mut self,
root: Word,
leaves: impl IntoIterator<Item = (NodeIndex, Word)>,
) -> Result<Word, MerkleError> {
self.nodes.get(&root).ok_or(MerkleError::RootNotInStore(root))?;
let mut nodes_by_index = Map::<NodeIndex, Word>::new();
let mut leaves_by_index = Map::<NodeIndex, Word>::new();
for (index, leaf_hash) in leaves {
let indexed_path = self.get_indexed_path(root, index)?;
if indexed_path.value == leaf_hash {
continue;
}
nodes_by_index.extend(indexed_path.path);
leaves_by_index.insert(index, leaf_hash);
}
if leaves_by_index.is_empty() {
return Ok(root);
}
#[allow(unused_mut)]
let mut sorted_leaf_indices = leaves_by_index.keys().cloned().collect::<Vec<_>>();
#[cfg(feature = "std")]
sorted_leaf_indices.sort();
nodes_by_index.extend(leaves_by_index);
let mut ancestors: Vec<NodeIndex> = Vec::new();
let mut last_ancestor = NodeIndex::new_unchecked(SMT_DEPTH, 0);
for leaf_index in sorted_leaf_indices {
let parent = leaf_index.parent();
if parent != last_ancestor {
last_ancestor = parent;
ancestors.push(last_ancestor);
}
}
let mut index = 0;
while index < ancestors.len() {
let node = ancestors[index];
if node.is_root() {
break;
}
let parent = node.parent();
if parent != last_ancestor {
last_ancestor = parent;
ancestors.push(last_ancestor);
}
index += 1;
}
let mut new_nodes: Map<Word, ForestInnerNode> = Map::new();
for index in ancestors {
let left_index = index.left_child();
let right_index = index.right_child();
let left_value = *nodes_by_index
.get(&left_index)
.ok_or(MerkleError::NodeIndexNotFoundInTree(left_index))?;
let right_value = *nodes_by_index
.get(&right_index)
.ok_or(MerkleError::NodeIndexNotFoundInTree(right_index))?;
let node = ForestInnerNode {
left: left_value,
right: right_value,
rc: 0,
};
let new_key = node.hash();
new_nodes.insert(new_key, node);
nodes_by_index.insert(index, new_key);
}
let new_root = nodes_by_index
.get(&NodeIndex::root())
.cloned()
.ok_or(MerkleError::NodeIndexNotFoundInStore(root, NodeIndex::root()))?;
fn dfs(
node: Word,
store: &mut Map<Word, ForestInnerNode>,
new_nodes: &mut Map<Word, ForestInnerNode>,
) {
if node == Word::empty() {
return;
}
if let Some(node) = store.get_mut(&node) {
node.rc += 1;
} else if let Some(mut smt_node) = new_nodes.remove(&node) {
smt_node.rc = 1;
store.insert(node, smt_node);
dfs(smt_node.left, store, new_nodes);
dfs(smt_node.right, store, new_nodes);
}
}
dfs(new_root, &mut self.nodes, &mut new_nodes);
Ok(new_root)
}
fn remove_node(&mut self, node: Word) -> Vec<Word> {
if node == Word::empty() {
return vec![];
}
let Some(smt_node) = self.nodes.get_mut(&node) else {
return vec![node];
};
smt_node.rc -= 1;
if smt_node.rc > 0 {
return vec![];
}
let left = smt_node.left;
let right = smt_node.right;
self.nodes.remove(&node);
let mut result = Vec::new();
result.extend(self.remove_node(left));
result.extend(self.remove_node(right));
result
}
pub fn remove_roots(&mut self, roots: impl IntoIterator<Item = Word>) -> Vec<Word> {
let mut removed_leaves = Vec::new();
for root in roots {
removed_leaves.extend(self.remove_node(root));
}
removed_leaves
}
}
fn empty_hashes() -> impl Iterator<Item = (Word, ForestInnerNode)> {
let subtrees = EmptySubtreeRoots::empty_hashes(SMT_DEPTH);
subtrees
.iter()
.rev()
.copied()
.zip(subtrees.iter().rev().skip(1).copied())
.map(|(child, parent)| (parent, ForestInnerNode { left: child, right: child, rc: 1 }))
}
struct IndexedPath {
value: Word,
path: Vec<(NodeIndex, Word)>,
}