use alloc::collections::BTreeMap;
use miden_crypto::Word;
use miden_crypto::merkle::smt::{LeafIndex, PartialSmt, SMT_DEPTH, SmtLeaf, SmtProof};
use miden_crypto::merkle::{InnerNodeInfo, MerkleError};
use crate::account::{StorageMap, StorageMapKey, StorageMapWitness};
use crate::utils::serde::{
ByteReader,
ByteWriter,
Deserializable,
DeserializationError,
Serializable,
};
#[derive(Clone, Debug, PartialEq, Eq, Default)]
pub struct PartialStorageMap {
partial_smt: PartialSmt,
entries: BTreeMap<StorageMapKey, Word>,
}
impl PartialStorageMap {
pub fn new(root: Word) -> Self {
PartialStorageMap {
partial_smt: PartialSmt::new(root),
entries: BTreeMap::new(),
}
}
pub fn with_witnesses(
witnesses: impl IntoIterator<Item = StorageMapWitness>,
) -> Result<Self, MerkleError> {
let mut map = BTreeMap::new();
let partial_smt = PartialSmt::from_proofs(witnesses.into_iter().map(|witness| {
map.extend(witness.entries());
SmtProof::from(witness)
}))?;
Ok(PartialStorageMap { partial_smt, entries: map })
}
pub fn new_full(storage_map: StorageMap) -> Self {
let partial_smt = PartialSmt::from(storage_map.smt);
let entries = storage_map.entries;
PartialStorageMap { partial_smt, entries }
}
pub fn new_minimal(storage_map: &StorageMap) -> Self {
Self::new(storage_map.root())
}
pub fn partial_smt(&self) -> &PartialSmt {
&self.partial_smt
}
pub fn root(&self) -> Word {
self.partial_smt.root()
}
pub fn get(&self, key: &StorageMapKey) -> Option<Word> {
let hash_word = key.hash().as_word();
self.partial_smt.get_value(&hash_word).ok()
}
pub fn open(&self, key: &StorageMapKey) -> Result<StorageMapWitness, MerkleError> {
let smt_proof = self.partial_smt.open(&key.hash().as_word())?;
let value = self.entries.get(key).copied().unwrap_or_default();
Ok(StorageMapWitness::new_unchecked(smt_proof, [(*key, value)]))
}
pub fn leaves(&self) -> impl Iterator<Item = (LeafIndex<SMT_DEPTH>, &SmtLeaf)> {
self.partial_smt.leaves()
}
pub fn entries(&self) -> impl Iterator<Item = (&StorageMapKey, &Word)> {
self.entries.iter()
}
pub fn inner_nodes(&self) -> impl Iterator<Item = InnerNodeInfo> + '_ {
self.partial_smt.inner_nodes()
}
pub fn add(&mut self, witness: StorageMapWitness) -> Result<(), MerkleError> {
self.entries.extend(witness.entries().map(|(key, value)| (*key, *value)));
self.partial_smt.add_proof(SmtProof::from(witness))
}
}
impl Serializable for PartialStorageMap {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write(&self.partial_smt);
target.write_usize(self.entries.len());
target.write_many(self.entries.keys());
}
}
impl Deserializable for PartialStorageMap {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let mut map = BTreeMap::new();
let partial_smt: PartialSmt = source.read()?;
let num_entries: usize = source.read()?;
for _ in 0..num_entries {
let key: StorageMapKey = source.read()?;
let hashed_map_key: Word = key.hash().into();
let value = partial_smt.get_value(&hashed_map_key).map_err(|err| {
DeserializationError::InvalidValue(format!(
"failed to find map key {key} in partial SMT: {err}"
))
})?;
map.insert(key, value);
}
Ok(PartialStorageMap { partial_smt, entries: map })
}
}