use alloc::vec::Vec;
use core::{
fmt::{self, Display},
hash::Hash,
};
use super::{EmptySubtreeRoots, InnerNodeInfo, MerkleError, NodeIndex, SparseMerklePath};
use crate::{
EMPTY_WORD, Map, Set, Word,
hash::poseidon2::Poseidon2,
utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
};
mod full;
pub use full::{MAX_LEAF_ENTRIES, SMT_DEPTH, Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError};
#[cfg(feature = "concurrent")]
mod large;
#[cfg(feature = "internal")]
pub use full::concurrent::{SubtreeLeaf, build_subtree_for_bench};
#[cfg(feature = "concurrent")]
pub use large::{
LargeSmt, LargeSmtError, MemoryStorage, MemoryStorageSnapshot, SmtStorage, SmtStorageReader,
StorageError, StorageUpdateParts, StorageUpdates, Subtree, SubtreeError, SubtreeUpdate,
};
#[cfg(feature = "rocksdb")]
pub use large::{RocksDbConfig, RocksDbSnapshotStorage, RocksDbStorage};
mod large_forest;
pub use large_forest::{
Backend, BackendError, Config as ForestConfig,
DEFAULT_MAX_HISTORY_VERSIONS as FOREST_DEFAULT_MAX_HISTORY_VERSIONS, ForestOperation,
InMemoryBackend as ForestInMemoryBackend, LargeSmtForest, LargeSmtForestError, LineageId,
MIN_HISTORY_VERSIONS as FOREST_MIN_HISTORY_VERSIONS, RootInfo, SmtForestUpdateBatch,
SmtUpdateBatch, TreeEntry, TreeId, TreeWithRoot, VersionId,
};
#[cfg(feature = "persistent-forest")]
pub use large_forest::{PersistentBackend as ForestPersistentBackend, PersistentBackendConfig};
mod simple;
pub use simple::{SimpleSmt, SimpleSmtProof};
mod partial;
pub use partial::{NodeValue, PartialSmt, UniqueNodes};
mod forest;
pub use forest::SmtForest;
use miden_field::Felt;
pub const SMT_MIN_DEPTH: u8 = 1;
pub const SMT_MAX_DEPTH: u8 = 64;
pub const LEAF_DOMAIN: Felt = Felt::new_unchecked(0x13af);
type InnerNodes = Map<NodeIndex, InnerNode>;
type Leaves<T> = Map<u64, T>;
type NodeMutations = Map<NodeIndex, NodeMutation>;
pub(crate) trait SparseMerkleTreeReader<const DEPTH: u8> {
type Key: Clone + Ord + Eq + Hash;
type Value: Clone + PartialEq;
type Leaf: Clone;
type Opening;
const EMPTY_VALUE: Self::Value;
const EMPTY_ROOT: Word;
fn get_path(&self, key: &Self::Key) -> SparseMerklePath {
let index = NodeIndex::from(Self::key_to_leaf_index(key));
SparseMerklePath::from_sized_iter(
index.proof_indices().map(|index| self.get_node_hash(index)),
)
.expect("failed to convert to SparseMerklePath")
}
fn get_node_hash(&self, index: NodeIndex) -> Word {
if index.is_root() {
return self.root();
}
let InnerNode { left, right } = self.get_inner_node(index.parent());
let index_is_right = index.is_position_odd();
if index_is_right { right } else { left }
}
fn open(&self, key: &Self::Key) -> Self::Opening {
let leaf = self.get_leaf(key);
let merkle_path = self.get_path(key);
Self::path_and_leaf_to_opening(merkle_path, leaf)
}
fn compute_mutations(
&self,
kv_pairs: impl IntoIterator<Item = (Self::Key, Self::Value)>,
) -> Result<MutationSet<DEPTH, Self::Key, Self::Value>, MerkleError> {
self.compute_mutations_sequential(kv_pairs)
}
fn compute_mutations_sequential(
&self,
kv_pairs: impl IntoIterator<Item = (Self::Key, Self::Value)>,
) -> Result<MutationSet<DEPTH, Self::Key, Self::Value>, MerkleError> {
use NodeMutation::*;
let mut new_root = self.root();
let mut new_pairs: Map<Self::Key, Self::Value> = Default::default();
let mut node_mutations: NodeMutations = NodeMutations::new();
let mut seen_keys: Set<Self::Key> = Set::new();
for (key, value) in kv_pairs {
if !seen_keys.insert(key.clone()) {
return Err(MerkleError::DuplicateValuesForIndex(
Self::key_to_leaf_index(&key).position(),
));
}
let old_value = new_pairs.get(&key).cloned().unwrap_or_else(|| self.get_value(&key));
if value == old_value {
continue;
}
let leaf_index = Self::key_to_leaf_index(&key);
let mut node_index = NodeIndex::from(leaf_index);
let old_leaf = {
let pairs_at_index = new_pairs
.iter()
.filter(|&(new_key, _)| Self::key_to_leaf_index(new_key) == leaf_index);
pairs_at_index.fold(self.get_leaf(&key), |acc, (k, v)| {
let existing_leaf = acc;
self.construct_prospective_leaf(existing_leaf, k, v)
.expect("current leaf should be valid")
})
};
let new_leaf =
self.construct_prospective_leaf(old_leaf, &key, &value).map_err(|e| match e {
SmtLeafError::TooManyLeafEntries { actual } => {
MerkleError::TooManyLeafEntries { actual }
},
other => panic!("unexpected SmtLeaf::insert error: {:?}", other),
})?;
let mut new_child_hash = Self::hash_leaf(&new_leaf);
for node_depth in (0..node_index.depth()).rev() {
let is_right = node_index.is_position_odd();
node_index.move_up();
let old_node = node_mutations
.get(&node_index)
.map(|mutation| match mutation {
Addition(node) => node.clone(),
Removal => EmptySubtreeRoots::get_inner_node(DEPTH, node_depth),
})
.unwrap_or_else(|| self.get_inner_node(node_index));
let new_node = if is_right {
InnerNode {
left: old_node.left,
right: new_child_hash,
}
} else {
InnerNode {
left: new_child_hash,
right: old_node.right,
}
};
new_child_hash = new_node.hash();
let &equivalent_empty_hash = EmptySubtreeRoots::entry(DEPTH, node_depth);
let is_removal = new_child_hash == equivalent_empty_hash;
let new_entry = if is_removal { Removal } else { Addition(new_node) };
node_mutations.insert(node_index, new_entry);
}
new_root = new_child_hash;
new_pairs.insert(key, value);
}
Ok(MutationSet {
old_root: self.root(),
new_root,
node_mutations,
new_pairs,
})
}
fn root(&self) -> Word;
fn get_inner_node(&self, index: NodeIndex) -> InnerNode;
fn get_value(&self, key: &Self::Key) -> Self::Value;
fn get_leaf(&self, key: &Self::Key) -> Self::Leaf;
fn hash_leaf(leaf: &Self::Leaf) -> Word;
fn construct_prospective_leaf(
&self,
existing_leaf: Self::Leaf,
key: &Self::Key,
value: &Self::Value,
) -> Result<Self::Leaf, SmtLeafError>;
#[cfg(feature = "concurrent")] fn check_for_duplicate_keys(
sorted_kv_pairs: &[(Self::Key, Self::Value)],
) -> Result<(), MerkleError> {
if let Some(window) = sorted_kv_pairs.windows(2).find(|w| w[0].0 == w[1].0) {
return Err(MerkleError::DuplicateValuesForIndex(
Self::key_to_leaf_index(&window[0].0).position(),
));
}
Ok(())
}
fn key_to_leaf_index(key: &Self::Key) -> LeafIndex<DEPTH>;
fn path_and_leaf_to_opening(path: SparseMerklePath, leaf: Self::Leaf) -> Self::Opening;
}
pub(crate) trait SparseMerkleTree<const DEPTH: u8>: SparseMerkleTreeReader<DEPTH> {
fn insert(&mut self, key: Self::Key, value: Self::Value) -> Result<Self::Value, MerkleError> {
let old_value = self.insert_value(key.clone(), value.clone())?.unwrap_or(Self::EMPTY_VALUE);
if value == old_value {
return Ok(value);
}
let leaf = self.get_leaf(&key);
let node_index = {
let leaf_index: LeafIndex<DEPTH> = Self::key_to_leaf_index(&key);
leaf_index.into()
};
self.recompute_nodes_from_index_to_root(node_index, Self::hash_leaf(&leaf));
Ok(old_value)
}
fn recompute_nodes_from_index_to_root(
&mut self,
mut index: NodeIndex,
node_hash_at_index: Word,
) {
let mut node_hash = node_hash_at_index;
for node_depth in (0..index.depth()).rev() {
let is_right = index.is_position_odd();
index.move_up();
let InnerNode { left, right } = self.get_inner_node(index);
let (left, right) = if is_right {
(left, node_hash)
} else {
(node_hash, right)
};
node_hash = Poseidon2::merge(&[left, right]);
if node_hash == *EmptySubtreeRoots::entry(DEPTH, node_depth) {
self.remove_inner_node(index);
} else {
self.insert_inner_node(index, InnerNode { left, right });
}
}
self.set_root(node_hash);
}
fn apply_mutations(
&mut self,
mutations: MutationSet<DEPTH, Self::Key, Self::Value>,
) -> Result<(), MerkleError>
where
Self: Sized,
{
use NodeMutation::*;
let MutationSet {
old_root,
node_mutations,
new_pairs,
new_root,
} = mutations;
if old_root != self.root() {
return Err(MerkleError::ConflictingRoots {
expected_root: self.root(),
actual_root: old_root,
});
}
for (index, mutation) in node_mutations {
match mutation {
Removal => {
self.remove_inner_node(index);
},
Addition(node) => {
self.insert_inner_node(index, node);
},
}
}
for (key, value) in new_pairs {
self.insert_value(key, value)?;
}
self.set_root(new_root);
Ok(())
}
fn apply_mutations_with_reversion(
&mut self,
mutations: MutationSet<DEPTH, Self::Key, Self::Value>,
) -> Result<MutationSet<DEPTH, Self::Key, Self::Value>, MerkleError>
where
Self: Sized,
{
use NodeMutation::*;
let MutationSet {
old_root,
node_mutations,
new_pairs,
new_root,
} = mutations;
if old_root != self.root() {
return Err(MerkleError::ConflictingRoots {
expected_root: self.root(),
actual_root: old_root,
});
}
let mut reverse_mutations = NodeMutations::new();
for (index, mutation) in node_mutations {
match mutation {
Removal => {
if let Some(node) = self.remove_inner_node(index) {
reverse_mutations.insert(index, Addition(node));
}
},
Addition(node) => {
if let Some(old_node) = self.insert_inner_node(index, node) {
reverse_mutations.insert(index, Addition(old_node));
} else {
reverse_mutations.insert(index, Removal);
}
},
}
}
let mut reverse_pairs = Map::new();
for (key, value) in new_pairs {
match self.insert_value(key.clone(), value)? {
Some(old_value) => {
reverse_pairs.insert(key, old_value);
},
None => {
reverse_pairs.insert(key, Self::EMPTY_VALUE);
},
}
}
self.set_root(new_root);
Ok(MutationSet {
old_root: new_root,
node_mutations: reverse_mutations,
new_pairs: reverse_pairs,
new_root: old_root,
})
}
fn set_root(&mut self, root: Word);
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) -> Option<InnerNode>;
fn remove_inner_node(&mut self, index: NodeIndex) -> Option<InnerNode>;
fn insert_value(
&mut self,
key: Self::Key,
value: Self::Value,
) -> Result<Option<Self::Value>, MerkleError>;
}
#[doc(hidden)]
#[derive(Debug, Default, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct InnerNode {
pub left: Word,
pub right: Word,
}
impl InnerNode {
pub fn hash(&self) -> Word {
Poseidon2::merge(&[self.left, self.right])
}
}
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct LeafIndex<const DEPTH: u8> {
index: NodeIndex,
}
impl<const DEPTH: u8> LeafIndex<DEPTH> {
pub fn new(value: u64) -> Result<Self, MerkleError> {
if DEPTH < SMT_MIN_DEPTH {
return Err(MerkleError::DepthTooSmall(DEPTH));
}
Ok(LeafIndex { index: NodeIndex::new(DEPTH, value)? })
}
pub fn position(&self) -> u64 {
self.index.position()
}
}
impl LeafIndex<SMT_MAX_DEPTH> {
pub const fn new_max_depth(value: u64) -> Self {
LeafIndex {
index: NodeIndex::new_unchecked(SMT_MAX_DEPTH, value),
}
}
}
impl<const DEPTH: u8> From<LeafIndex<DEPTH>> for NodeIndex {
fn from(value: LeafIndex<DEPTH>) -> Self {
value.index
}
}
impl<const DEPTH: u8> TryFrom<NodeIndex> for LeafIndex<DEPTH> {
type Error = MerkleError;
fn try_from(node_index: NodeIndex) -> Result<Self, Self::Error> {
if node_index.depth() != DEPTH {
return Err(MerkleError::InvalidNodeIndexDepth {
expected: DEPTH,
provided: node_index.depth(),
});
}
Self::new(node_index.position())
}
}
impl<const DEPTH: u8> Serializable for LeafIndex<DEPTH> {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.index.write_into(target);
}
}
impl<const DEPTH: u8> Deserializable for LeafIndex<DEPTH> {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
Ok(Self { index: source.read()? })
}
}
impl<const DEPTH: u8> Display for LeafIndex<DEPTH> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "DEPTH={}, position={}", DEPTH, self.position())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum NodeMutation {
Removal,
Addition(InnerNode),
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct MutationSet<const DEPTH: u8, K: Eq + Hash, V> {
old_root: Word,
node_mutations: NodeMutations,
new_pairs: Map<K, V>,
new_root: Word,
}
impl<const DEPTH: u8, K: Eq + Hash, V> MutationSet<DEPTH, K, V> {
pub fn root(&self) -> Word {
self.new_root
}
pub fn old_root(&self) -> Word {
self.old_root
}
pub fn node_mutations(&self) -> &NodeMutations {
&self.node_mutations
}
pub fn new_pairs(&self) -> &Map<K, V> {
&self.new_pairs
}
pub fn is_empty(&self) -> bool {
self.node_mutations.is_empty()
&& self.new_pairs.is_empty()
&& self.old_root == self.new_root
}
}
impl Serializable for InnerNode {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write(self.left);
target.write(self.right);
}
}
impl Deserializable for InnerNode {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let left = source.read()?;
let right = source.read()?;
Ok(Self { left, right })
}
}
impl Serializable for NodeMutation {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
match self {
NodeMutation::Removal => target.write_bool(false),
NodeMutation::Addition(inner_node) => {
target.write_bool(true);
inner_node.write_into(target);
},
}
}
}
impl Deserializable for NodeMutation {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
if source.read_bool()? {
let inner_node = source.read()?;
return Ok(NodeMutation::Addition(inner_node));
}
Ok(NodeMutation::Removal)
}
}
impl<const DEPTH: u8, K: Serializable + Eq + Hash, V: Serializable> Serializable
for MutationSet<DEPTH, K, V>
{
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write(self.old_root);
target.write(self.new_root);
let inner_removals: Vec<_> = self
.node_mutations
.iter()
.filter(|(_, value)| matches!(value, NodeMutation::Removal))
.map(|(key, _)| key)
.collect();
let inner_additions: Vec<_> = self
.node_mutations
.iter()
.filter_map(|(key, value)| match value {
NodeMutation::Addition(node) => Some((key, node)),
_ => None,
})
.collect();
target.write(inner_removals);
target.write(inner_additions);
target.write_usize(self.new_pairs.len());
target.write_many(&self.new_pairs);
}
}
impl<const DEPTH: u8, K: Deserializable + Ord + Eq + Hash, V: Deserializable> Deserializable
for MutationSet<DEPTH, K, V>
{
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let old_root = source.read()?;
let new_root = source.read()?;
let inner_removals: Vec<NodeIndex> = source.read()?;
let inner_additions: Vec<(NodeIndex, InnerNode)> = source.read()?;
let node_mutations = NodeMutations::from_iter(
inner_removals.into_iter().map(|index| (index, NodeMutation::Removal)).chain(
inner_additions
.into_iter()
.map(|(index, node)| (index, NodeMutation::Addition(node))),
),
);
let num_new_pairs = source.read_usize()?;
let new_pairs: Map<_, _> =
source.read_many_iter(num_new_pairs)?.collect::<Result<_, _>>()?;
Ok(Self {
old_root,
node_mutations,
new_pairs,
new_root,
})
}
}