use alloc::{string::ToString, vec::Vec};
use super::{
EMPTY_WORD, EmptySubtreeRoots, InnerNode, InnerNodeInfo, InnerNodes, LeafIndex, MerkleError,
MutationSet, NodeIndex, SparseMerklePath, SparseMerkleTree, SparseMerkleTreeReader, Word,
};
mod error;
pub use error::{SmtLeafError, SmtProofError};
mod leaf;
pub use leaf::SmtLeaf;
mod proof;
pub use proof::SmtProof;
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
#[cfg(feature = "concurrent")]
pub(in crate::merkle::smt) mod concurrent;
#[cfg(test)]
mod tests;
pub const SMT_DEPTH: u8 = 64;
pub const MAX_LEAF_ENTRIES: usize = 1024;
type Leaves = super::Leaves<SmtLeaf>;
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct Smt {
root: Word,
num_entries: usize,
leaves: Leaves,
inner_nodes: InnerNodes,
}
impl Smt {
pub const EMPTY_VALUE: Word = <Self as SparseMerkleTreeReader<SMT_DEPTH>>::EMPTY_VALUE;
pub fn new() -> Self {
let root = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
Self {
root,
num_entries: 0,
inner_nodes: Default::default(),
leaves: Default::default(),
}
}
pub fn with_entries(
entries: impl IntoIterator<Item = (Word, Word)>,
) -> Result<Self, MerkleError> {
#[cfg(feature = "concurrent")]
{
Self::with_entries_concurrent(entries)
}
#[cfg(not(feature = "concurrent"))]
{
Self::with_entries_sequential(entries)
}
}
pub fn with_sorted_entries(
entries: impl IntoIterator<Item = (Word, Word)>,
) -> Result<Self, MerkleError> {
#[cfg(feature = "concurrent")]
{
Self::with_sorted_entries_concurrent(entries)
}
#[cfg(not(feature = "concurrent"))]
{
Self::with_entries_sequential(entries)
}
}
#[cfg(any(not(feature = "concurrent"), fuzzing, feature = "fuzzing", test))]
fn with_entries_sequential(
entries: impl IntoIterator<Item = (Word, Word)>,
) -> Result<Self, MerkleError> {
use alloc::collections::BTreeSet;
let mut tree = Self::new();
let mut key_set_to_zero = BTreeSet::new();
for (key, value) in entries {
let old_value = tree.insert(key, value)?;
if old_value != EMPTY_WORD || key_set_to_zero.contains(&key) {
return Err(MerkleError::DuplicateValuesForIndex(
LeafIndex::<SMT_DEPTH>::from(key).position(),
));
}
if value == EMPTY_WORD {
key_set_to_zero.insert(key);
};
}
Ok(tree)
}
pub fn from_raw_parts(inner_nodes: InnerNodes, leaves: Leaves, root: Word) -> Self {
if cfg!(debug_assertions) {
let root_node_hash = inner_nodes
.get(&NodeIndex::root())
.map(InnerNode::hash)
.unwrap_or(Self::EMPTY_ROOT);
assert_eq!(root_node_hash, root);
}
let num_entries = leaves.values().map(SmtLeaf::num_entries).sum();
Self { root, inner_nodes, leaves, num_entries }
}
pub const fn depth(&self) -> u8 {
SMT_DEPTH
}
pub fn root(&self) -> Word {
<Self as SparseMerkleTreeReader<SMT_DEPTH>>::root(self)
}
pub fn num_leaves(&self) -> usize {
self.leaves.len()
}
pub fn num_entries(&self) -> usize {
self.num_entries
}
pub fn get_leaf(&self, key: &Word) -> SmtLeaf {
<Self as SparseMerkleTreeReader<SMT_DEPTH>>::get_leaf(self, key)
}
pub fn get_leaf_by_index(&self, index: LeafIndex<SMT_DEPTH>) -> Option<SmtLeaf> {
self.leaves.get(&index.position()).cloned()
}
pub fn get_value(&self, key: &Word) -> Word {
<Self as SparseMerkleTreeReader<SMT_DEPTH>>::get_value(self, key)
}
pub fn open(&self, key: &Word) -> SmtProof {
<Self as SparseMerkleTreeReader<SMT_DEPTH>>::open(self, key)
}
pub fn is_empty(&self) -> bool {
debug_assert_eq!(self.leaves.is_empty(), self.root == Self::EMPTY_ROOT);
self.root == Self::EMPTY_ROOT
}
pub fn leaves(&self) -> impl Iterator<Item = (LeafIndex<SMT_DEPTH>, &SmtLeaf)> {
self.leaves
.iter()
.map(|(leaf_index, leaf)| (LeafIndex::new_max_depth(*leaf_index), leaf))
}
pub fn entries(&self) -> impl Iterator<Item = &(Word, Word)> {
self.leaves().flat_map(|(_, leaf)| leaf.entries())
}
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
self.inner_nodes.values().map(|e| InnerNodeInfo {
value: e.hash(),
left: e.left,
right: e.right,
})
}
pub fn inner_node_indices(&self) -> impl Iterator<Item = (NodeIndex, InnerNode)> + '_ {
self.inner_nodes.iter().map(|(idx, inner)| (*idx, inner.clone()))
}
pub fn insert(&mut self, key: Word, value: Word) -> Result<Word, MerkleError> {
<Self as SparseMerkleTree<SMT_DEPTH>>::insert(self, key, value)
}
pub fn compute_mutations(
&self,
kv_pairs: impl IntoIterator<Item = (Word, Word)>,
) -> Result<MutationSet<SMT_DEPTH, Word, Word>, MerkleError> {
#[cfg(feature = "concurrent")]
{
self.compute_mutations_concurrent(kv_pairs)
}
#[cfg(not(feature = "concurrent"))]
{
<Self as SparseMerkleTreeReader<SMT_DEPTH>>::compute_mutations(self, kv_pairs)
}
}
pub fn apply_mutations(
&mut self,
mutations: MutationSet<SMT_DEPTH, Word, Word>,
) -> Result<(), MerkleError> {
<Self as SparseMerkleTree<SMT_DEPTH>>::apply_mutations(self, mutations)
}
pub fn apply_mutations_with_reversion(
&mut self,
mutations: MutationSet<SMT_DEPTH, Word, Word>,
) -> Result<MutationSet<SMT_DEPTH, Word, Word>, MerkleError> {
<Self as SparseMerkleTree<SMT_DEPTH>>::apply_mutations_with_reversion(self, mutations)
}
fn perform_insert(&mut self, key: Word, value: Word) -> Result<Option<Word>, MerkleError> {
debug_assert_ne!(value, Self::EMPTY_VALUE);
let leaf_index: LeafIndex<SMT_DEPTH> = Self::key_to_leaf_index(&key);
match self.leaves.get_mut(&leaf_index.position()) {
Some(leaf) => {
let prev_entries = leaf.num_entries();
let result = leaf.insert(key, value).map_err(|e| match e {
SmtLeafError::TooManyLeafEntries { actual } => {
MerkleError::TooManyLeafEntries { actual }
},
other => panic!("unexpected SmtLeaf::insert error: {:?}", other),
})?;
let current_entries = leaf.num_entries();
self.num_entries += current_entries - prev_entries;
Ok(result)
},
None => {
self.leaves.insert(leaf_index.position(), SmtLeaf::Single((key, value)));
self.num_entries += 1;
Ok(None)
},
}
}
fn perform_remove(&mut self, key: Word) -> Option<Word> {
let leaf_index: LeafIndex<SMT_DEPTH> = Self::key_to_leaf_index(&key);
if let Some(leaf) = self.leaves.get_mut(&leaf_index.position()) {
let prev_entries = leaf.num_entries();
let (old_value, is_empty) = leaf.remove(key);
let current_entries = leaf.num_entries();
self.num_entries -= prev_entries - current_entries;
if is_empty {
self.leaves.remove(&leaf_index.position());
}
old_value
} else {
None
}
}
}
impl SparseMerkleTreeReader<SMT_DEPTH> for Smt {
type Key = Word;
type Value = Word;
type Leaf = SmtLeaf;
type Opening = SmtProof;
const EMPTY_VALUE: Self::Value = EMPTY_WORD;
const EMPTY_ROOT: Word = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
fn root(&self) -> Word {
self.root
}
fn get_inner_node(&self, index: NodeIndex) -> InnerNode {
self.inner_nodes
.get(&index)
.cloned()
.unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(SMT_DEPTH, index.depth()))
}
fn get_value(&self, key: &Self::Key) -> Self::Value {
let leaf_pos = LeafIndex::<SMT_DEPTH>::from(*key).position();
match self.leaves.get(&leaf_pos) {
Some(leaf) => leaf.get_value(key).unwrap_or_default(),
None => EMPTY_WORD,
}
}
fn get_leaf(&self, key: &Word) -> Self::Leaf {
let leaf_pos = LeafIndex::<SMT_DEPTH>::from(*key).position();
match self.leaves.get(&leaf_pos) {
Some(leaf) => leaf.clone(),
None => SmtLeaf::new_empty((*key).into()),
}
}
fn hash_leaf(leaf: &Self::Leaf) -> Word {
leaf.hash()
}
fn construct_prospective_leaf(
&self,
mut existing_leaf: SmtLeaf,
key: &Word,
value: &Word,
) -> Result<SmtLeaf, SmtLeafError> {
debug_assert_eq!(existing_leaf.index(), Self::key_to_leaf_index(key));
match existing_leaf {
SmtLeaf::Empty(_) => Ok(SmtLeaf::new_single(*key, *value)),
_ => {
if *value != EMPTY_WORD {
existing_leaf.insert(*key, *value)?;
} else {
existing_leaf.remove(*key);
}
Ok(existing_leaf)
},
}
}
fn key_to_leaf_index(key: &Word) -> LeafIndex<SMT_DEPTH> {
let most_significant_felt = key[3];
LeafIndex::new_max_depth(most_significant_felt.as_canonical_u64())
}
fn path_and_leaf_to_opening(path: SparseMerklePath, leaf: SmtLeaf) -> SmtProof {
SmtProof::new_unchecked(path, leaf)
}
}
impl SparseMerkleTree<SMT_DEPTH> for Smt {
fn set_root(&mut self, root: Word) {
self.root = root;
}
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) -> Option<InnerNode> {
if inner_node == EmptySubtreeRoots::get_inner_node(SMT_DEPTH, index.depth()) {
self.remove_inner_node(index)
} else {
self.inner_nodes.insert(index, inner_node)
}
}
fn remove_inner_node(&mut self, index: NodeIndex) -> Option<InnerNode> {
self.inner_nodes.remove(&index)
}
fn insert_value(
&mut self,
key: Self::Key,
value: Self::Value,
) -> Result<Option<Self::Value>, MerkleError> {
if value != Self::EMPTY_VALUE {
self.perform_insert(key, value)
} else {
Ok(self.perform_remove(key))
}
}
}
impl Default for Smt {
fn default() -> Self {
Self::new()
}
}
impl From<Word> for LeafIndex<SMT_DEPTH> {
fn from(value: Word) -> Self {
Self::new_max_depth(value[3].as_canonical_u64())
}
}
impl Serializable for Smt {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write_usize(self.entries().count());
for (key, value) in self.entries() {
target.write(key);
target.write(value);
}
}
fn get_size_hint(&self) -> usize {
let entries_count = self.entries().count();
entries_count.get_size_hint()
+ entries_count * (Word::SERIALIZED_SIZE + EMPTY_WORD.get_size_hint())
}
}
impl Deserializable for Smt {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let num_filled_leaves = source.read_usize()?;
let entries: Vec<(Word, Word)> =
source.read_many_iter(num_filled_leaves)?.collect::<Result<_, _>>()?;
Self::with_entries(entries)
.map_err(|err| DeserializationError::InvalidValue(err.to_string()))
}
fn min_serialized_size() -> usize {
1
}
}
#[cfg(any(fuzzing, feature = "fuzzing"))]
impl Smt {
pub fn fuzz_with_entries_sequential(
entries: impl IntoIterator<Item = (Word, Word)>,
) -> Result<Smt, MerkleError> {
Self::with_entries_sequential(entries)
}
pub fn fuzz_compute_mutations_sequential(
&self,
kv_pairs: impl IntoIterator<Item = (Word, Word)>,
) -> MutationSet<SMT_DEPTH, Word, Word> {
<Self as SparseMerkleTreeReader<SMT_DEPTH>>::compute_mutations(self, kv_pairs)
.expect("Failed to compute mutations in fuzzing")
}
}
#[cfg(test)]
use crate::Felt;
#[test]
fn test_smt_serialization_deserialization() {
let smt_default = Smt::default();
let bytes = smt_default.to_bytes();
assert_eq!(smt_default, Smt::read_from_bytes(&bytes).unwrap());
assert_eq!(bytes.len(), smt_default.get_size_hint());
let smt_leaves_2: [(Word, Word); 2] = [
(
Word::new([
Felt::new_unchecked(105),
Felt::new_unchecked(106),
Felt::new_unchecked(107),
Felt::new_unchecked(108),
]),
[
Felt::new_unchecked(5_u64),
Felt::new_unchecked(6_u64),
Felt::new_unchecked(7_u64),
Felt::new_unchecked(8_u64),
]
.into(),
),
(
Word::new([
Felt::new_unchecked(101),
Felt::new_unchecked(102),
Felt::new_unchecked(103),
Felt::new_unchecked(104),
]),
[
Felt::new_unchecked(1_u64),
Felt::new_unchecked(2_u64),
Felt::new_unchecked(3_u64),
Felt::new_unchecked(4_u64),
]
.into(),
),
];
let smt = Smt::with_entries(smt_leaves_2).unwrap();
let bytes = smt.to_bytes();
assert_eq!(smt, Smt::read_from_bytes(&bytes).unwrap());
assert_eq!(bytes.len(), smt.get_size_hint());
}
#[test]
fn smt_with_sorted_entries() {
let smt_leaves_2: [(Word, Word); 2] = [
(
Word::new([
Felt::new_unchecked(101),
Felt::new_unchecked(102),
Felt::new_unchecked(103),
Felt::new_unchecked(104),
]),
[
Felt::new_unchecked(1_u64),
Felt::new_unchecked(2_u64),
Felt::new_unchecked(3_u64),
Felt::new_unchecked(4_u64),
]
.into(),
),
(
Word::new([
Felt::new_unchecked(105),
Felt::new_unchecked(106),
Felt::new_unchecked(107),
Felt::new_unchecked(108),
]),
[
Felt::new_unchecked(5_u64),
Felt::new_unchecked(6_u64),
Felt::new_unchecked(7_u64),
Felt::new_unchecked(8_u64),
]
.into(),
),
];
let smt = Smt::with_sorted_entries(smt_leaves_2).unwrap();
let expected_smt = Smt::with_entries(smt_leaves_2).unwrap();
assert_eq!(smt, expected_smt);
}