use super::{EmptySubtreeRoots, LeafIndex, SMT_DEPTH};
use crate::{
EMPTY_WORD, Word,
merkle::{
InnerNodeInfo, MerkleError, NodeIndex, SparseMerklePath,
smt::{InnerNode, InnerNodes, Leaves, SmtLeaf, SmtLeafError, SmtProof},
},
utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
};
#[cfg(test)]
mod tests;
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct PartialSmt {
root: Word,
num_entries: usize,
leaves: Leaves<SmtLeaf>,
inner_nodes: InnerNodes,
}
impl PartialSmt {
pub const EMPTY_VALUE: Word = EMPTY_WORD;
pub const EMPTY_ROOT: Word = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
pub fn new(root: Word) -> Self {
Self {
root,
num_entries: 0,
leaves: Leaves::<SmtLeaf>::default(),
inner_nodes: InnerNodes::default(),
}
}
pub fn from_proofs<I>(proofs: I) -> Result<Self, MerkleError>
where
I: IntoIterator<Item = SmtProof>,
{
let mut proofs = proofs.into_iter();
let Some(first_proof) = proofs.next() else {
return Ok(Self::default());
};
let mut partial_smt = Self::default();
let (path, leaf) = first_proof.into_parts();
let path_root = partial_smt.add_path_unchecked(leaf, path);
partial_smt.root = path_root;
for proof in proofs {
partial_smt.add_proof(proof)?;
}
Ok(partial_smt)
}
pub fn root(&self) -> Word {
self.root
}
pub fn open(&self, key: &Word) -> Result<SmtProof, MerkleError> {
let leaf = self.get_leaf(key)?;
let merkle_path = self.get_path(key);
Ok(SmtProof::new_unchecked(merkle_path, leaf))
}
pub fn get_leaf(&self, key: &Word) -> Result<SmtLeaf, MerkleError> {
self.get_tracked_leaf(key).ok_or(MerkleError::UntrackedKey(*key))
}
pub fn get_value(&self, key: &Word) -> Result<Word, MerkleError> {
self.get_tracked_leaf(key)
.map(|leaf| leaf.get_value(key).unwrap_or_default())
.ok_or(MerkleError::UntrackedKey(*key))
}
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
self.inner_nodes.values().map(|e| InnerNodeInfo {
value: e.hash(),
left: e.left,
right: e.right,
})
}
pub fn inner_node_indices(&self) -> impl Iterator<Item = (NodeIndex, InnerNode)> + '_ {
self.inner_nodes.iter().map(|(idx, inner)| (*idx, inner.clone()))
}
pub fn leaves(&self) -> impl Iterator<Item = (LeafIndex<SMT_DEPTH>, &SmtLeaf)> {
self.leaves
.iter()
.map(|(leaf_index, leaf)| (LeafIndex::new_max_depth(*leaf_index), leaf))
}
pub fn entries(&self) -> impl Iterator<Item = &(Word, Word)> {
self.leaves().flat_map(|(_, leaf)| leaf.entries())
}
pub fn num_leaves(&self) -> usize {
self.leaves.len()
}
pub fn num_entries(&self) -> usize {
self.num_entries
}
pub fn tracks_leaves(&self) -> bool {
!self.leaves.is_empty()
}
pub fn insert(&mut self, key: Word, value: Word) -> Result<Word, MerkleError> {
let current_leaf = self.get_tracked_leaf(&key).ok_or(MerkleError::UntrackedKey(key))?;
let leaf_index = current_leaf.index();
let previous_value = current_leaf.get_value(&key).unwrap_or(EMPTY_WORD);
let prev_entries = current_leaf.num_entries();
let leaf = self
.leaves
.entry(leaf_index.position())
.or_insert_with(|| SmtLeaf::new_empty(leaf_index));
if value != EMPTY_WORD {
leaf.insert(key, value).map_err(|e| match e {
SmtLeafError::TooManyLeafEntries { actual } => {
MerkleError::TooManyLeafEntries { actual }
},
other => panic!("unexpected SmtLeaf::insert error: {:?}", other),
})?;
} else {
leaf.remove(key);
}
let current_entries = leaf.num_entries();
let new_leaf_hash = leaf.hash();
self.num_entries = self.num_entries + current_entries - prev_entries;
if current_entries == 0 {
self.leaves.remove(&leaf_index.position());
}
self.recompute_nodes_from_leaf_to_root(leaf_index, new_leaf_hash);
Ok(previous_value)
}
pub fn add_proof(&mut self, proof: SmtProof) -> Result<(), MerkleError> {
let (path, leaf) = proof.into_parts();
self.add_path(leaf, path)
}
pub fn add_path(&mut self, leaf: SmtLeaf, path: SparseMerklePath) -> Result<(), MerkleError> {
let path_root = self.add_path_unchecked(leaf, path);
if self.root() != path_root {
return Err(MerkleError::ConflictingRoots {
expected_root: self.root(),
actual_root: path_root,
});
}
Ok(())
}
fn add_path_unchecked(&mut self, leaf: SmtLeaf, path: SparseMerklePath) -> Word {
let mut current_index = leaf.index().index;
let mut node_hash_at_current_index = leaf.hash();
let prev_entries = self
.leaves
.get(¤t_index.position())
.map(SmtLeaf::num_entries)
.unwrap_or(0);
let current_entries = leaf.num_entries();
if current_entries > 0 {
self.leaves.insert(current_index.position(), leaf);
} else {
self.leaves.remove(¤t_index.position());
}
self.num_entries = self.num_entries + current_entries - prev_entries;
for sibling_hash in path {
let is_sibling_right = current_index.sibling().is_position_odd();
current_index.move_up();
let new_parent_node = if is_sibling_right {
InnerNode {
left: node_hash_at_current_index,
right: sibling_hash,
}
} else {
InnerNode {
left: sibling_hash,
right: node_hash_at_current_index,
}
};
node_hash_at_current_index = new_parent_node.hash();
self.insert_inner_node(current_index, new_parent_node);
}
node_hash_at_current_index
}
fn get_tracked_leaf(&self, key: &Word) -> Option<SmtLeaf> {
let leaf_index = Self::key_to_leaf_index(key);
if let Some(leaf) = self.leaves.get(&leaf_index.position()) {
return Some(leaf.clone());
}
if self.root == Self::EMPTY_ROOT {
return Some(SmtLeaf::new_empty(leaf_index));
}
let target: NodeIndex = leaf_index.into();
let mut index = NodeIndex::root();
for i in (0..SMT_DEPTH).rev() {
let inner_node = self.get_inner_node(index)?;
let is_right = target.is_nth_bit_odd(i);
let child_hash = if is_right { inner_node.right } else { inner_node.left };
if child_hash == *EmptySubtreeRoots::entry(SMT_DEPTH, SMT_DEPTH - i) {
return Some(SmtLeaf::new_empty(leaf_index));
}
index = if is_right {
index.right_child()
} else {
index.left_child()
};
}
None
}
fn key_to_leaf_index(key: &Word) -> LeafIndex<SMT_DEPTH> {
let most_significant_felt = key[3];
LeafIndex::new_max_depth(most_significant_felt.as_canonical_u64())
}
fn get_inner_node(&self, index: NodeIndex) -> Option<InnerNode> {
self.inner_nodes.get(&index).cloned()
}
fn get_inner_node_or_empty(&self, index: NodeIndex) -> InnerNode {
self.get_inner_node(index)
.unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(SMT_DEPTH, index.depth()))
}
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) {
if inner_node == EmptySubtreeRoots::get_inner_node(SMT_DEPTH, index.depth()) {
self.inner_nodes.remove(&index);
} else {
self.inner_nodes.insert(index, inner_node);
}
}
fn get_path(&self, key: &Word) -> SparseMerklePath {
let index = NodeIndex::from(Self::key_to_leaf_index(key));
SparseMerklePath::from_sized_iter(index.proof_indices().map(|idx| self.get_node_hash(idx)))
.expect("path should be valid since it's from a valid SMT")
}
fn get_node_hash(&self, index: NodeIndex) -> Word {
if index.is_root() {
return self.root;
}
let InnerNode { left, right } = self.get_inner_node_or_empty(index.parent());
if index.is_position_odd() { right } else { left }
}
fn recompute_nodes_from_leaf_to_root(
&mut self,
leaf_index: LeafIndex<SMT_DEPTH>,
leaf_hash: Word,
) {
use crate::hash::poseidon2::Poseidon2;
let mut index: NodeIndex = leaf_index.into();
let mut node_hash = leaf_hash;
for _ in (0..index.depth()).rev() {
let is_right = index.is_position_odd();
index.move_up();
let InnerNode { left, right } = self.get_inner_node_or_empty(index);
let (left, right) = if is_right {
(left, node_hash)
} else {
(node_hash, right)
};
node_hash = Poseidon2::merge(&[left, right]);
self.insert_inner_node(index, InnerNode { left, right });
}
self.root = node_hash;
}
fn validate(&self) -> Result<(), DeserializationError> {
for (&idx, node) in &self.inner_nodes {
let node_hash = node.hash();
let expected_hash = self.get_node_hash(idx);
if node_hash != expected_hash {
return Err(DeserializationError::InvalidValue(
"inner node hash is inconsistent with parent".into(),
));
}
}
for (&leaf_pos, leaf) in &self.leaves {
let leaf_index = LeafIndex::<SMT_DEPTH>::new_max_depth(leaf_pos);
let node_index: NodeIndex = leaf_index.into();
let leaf_hash = leaf.hash();
let expected_hash = self.get_node_hash(node_index);
if leaf_hash != expected_hash {
return Err(DeserializationError::InvalidValue(
"leaf hash is inconsistent with parent inner node".into(),
));
}
}
Ok(())
}
}
impl Default for PartialSmt {
fn default() -> Self {
Self::new(Self::EMPTY_ROOT)
}
}
impl From<super::Smt> for PartialSmt {
fn from(smt: super::Smt) -> Self {
Self {
root: smt.root(),
num_entries: smt.num_entries(),
leaves: smt.leaves().map(|(idx, leaf)| (idx.position(), leaf.clone())).collect(),
inner_nodes: smt.inner_node_indices().collect(),
}
}
}
impl Serializable for PartialSmt {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write(self.root());
target.write_usize(self.leaves.len());
for (i, leaf) in &self.leaves {
target.write_u64(*i);
target.write(leaf);
}
target.write_usize(self.inner_nodes.len());
for (idx, node) in &self.inner_nodes {
target.write(idx);
target.write(node);
}
}
}
impl Deserializable for PartialSmt {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let root: Word = source.read()?;
let mut leaves = Leaves::<SmtLeaf>::default();
for _ in 0..source.read_usize()? {
let pos: u64 = source.read()?;
let leaf: SmtLeaf = source.read()?;
leaves.insert(pos, leaf);
}
let mut inner_nodes = InnerNodes::default();
for _ in 0..source.read_usize()? {
let idx: NodeIndex = source.read()?;
let node: InnerNode = source.read()?;
inner_nodes.insert(idx, node);
}
let num_entries = leaves.values().map(SmtLeaf::num_entries).sum();
let partial = Self { root, num_entries, leaves, inner_nodes };
partial.validate()?;
Ok(partial)
}
}