use alloc::collections::BTreeMap;
use alloc::vec::Vec;
use miden_protocol::account::{
AccountId,
AccountStorage,
StorageMap,
StorageMapKey,
StorageMapWitness,
StorageSlotContent,
};
use miden_protocol::asset::{Asset, AssetVault, AssetVaultKey, AssetWitness};
use miden_protocol::crypto::merkle::smt::{SMT_DEPTH, Smt, SmtForest};
use miden_protocol::crypto::merkle::{EmptySubtreeRoots, MerkleError};
use miden_protocol::{EMPTY_WORD, Word};
use super::StoreError;
#[derive(Debug, Default, Clone, Eq, PartialEq)]
pub struct AccountSmtForest {
forest: SmtForest,
account_roots: BTreeMap<AccountId, Vec<Word>>,
pending_old_roots: BTreeMap<AccountId, Vec<Vec<Word>>>,
root_refcounts: BTreeMap<Word, usize>,
}
impl AccountSmtForest {
pub fn new() -> Self {
Self::default()
}
pub fn get_roots(&self, account_id: &AccountId) -> Option<&Vec<Word>> {
self.account_roots.get(account_id)
}
pub fn get_asset_and_witness(
&self,
vault_root: Word,
vault_key: AssetVaultKey,
) -> Result<(Asset, AssetWitness), StoreError> {
let vault_key_word = vault_key.into();
let proof = self.forest.open(vault_root, vault_key_word)?;
let asset_word =
proof.get(&vault_key_word).ok_or(MerkleError::UntrackedKey(vault_key_word))?;
if asset_word == EMPTY_WORD {
return Err(MerkleError::UntrackedKey(vault_key_word).into());
}
let asset = Asset::from_key_value_words(vault_key_word, asset_word)?;
let witness = AssetWitness::new(proof)?;
Ok((asset, witness))
}
pub fn get_storage_map_item_witness(
&self,
map_root: Word,
key: StorageMapKey,
) -> Result<StorageMapWitness, StoreError> {
let hashed_key = key.hash().as_word();
let proof = self.forest.open(map_root, hashed_key).map_err(StoreError::from)?;
Ok(StorageMapWitness::new(proof, [key])?)
}
pub fn stage_roots(&mut self, account_id: AccountId, new_roots: Vec<Word>) {
increment_refcounts(&mut self.root_refcounts, &new_roots);
if let Some(old_roots) = self.account_roots.insert(account_id, new_roots) {
self.pending_old_roots.entry(account_id).or_default().push(old_roots);
}
}
pub fn commit_roots(&mut self, account_id: AccountId) {
if let Some(old_roots_stack) = self.pending_old_roots.remove(&account_id) {
for old_roots in old_roots_stack {
let to_pop = decrement_refcounts(&mut self.root_refcounts, &old_roots);
self.safe_pop_smts(to_pop);
}
}
}
pub fn discard_roots(&mut self, account_id: AccountId) {
let old_roots = self.pending_old_roots.get_mut(&account_id).and_then(Vec::pop);
let new_roots = match old_roots {
Some(old_roots) => self.account_roots.insert(account_id, old_roots),
None => self.account_roots.remove(&account_id),
};
if let Some(new_roots) = new_roots {
let to_pop = decrement_refcounts(&mut self.root_refcounts, &new_roots);
self.safe_pop_smts(to_pop);
}
if self.pending_old_roots.get(&account_id).is_some_and(Vec::is_empty) {
self.pending_old_roots.remove(&account_id);
}
}
pub fn replace_roots(&mut self, account_id: AccountId, new_roots: Vec<Word>) {
assert!(
!self.pending_old_roots.contains_key(&account_id),
"cannot replace roots while staged changes are pending for account {account_id}"
);
increment_refcounts(&mut self.root_refcounts, &new_roots);
if let Some(old_roots) = self.account_roots.insert(account_id, new_roots) {
let to_pop = decrement_refcounts(&mut self.root_refcounts, &old_roots);
self.safe_pop_smts(to_pop);
}
}
pub fn update_asset_nodes(
&mut self,
root: Word,
new_assets: impl Iterator<Item = Asset>,
removed_vault_keys: impl Iterator<Item = AssetVaultKey>,
) -> Result<Word, StoreError> {
let entries: Vec<(Word, Word)> = new_assets
.map(|asset| {
let key: Word = asset.vault_key().into();
let value = asset.to_value_word();
(key, value)
})
.chain(removed_vault_keys.map(|vault_key| (vault_key.into(), EMPTY_WORD)))
.collect();
if entries.is_empty() {
return Ok(root);
}
let new_root = self.forest.batch_insert(root, entries).map_err(StoreError::from)?;
Ok(new_root)
}
pub fn update_storage_map_nodes(
&mut self,
root: Word,
entries: impl Iterator<Item = (StorageMapKey, Word)>,
) -> Result<Word, StoreError> {
let entries: Vec<(Word, Word)> =
entries.map(|(key, value)| (key.hash().as_word(), value)).collect();
if entries.is_empty() {
return Ok(root);
}
let new_root = self.forest.batch_insert(root, entries).map_err(StoreError::from)?;
Ok(new_root)
}
pub fn insert_asset_nodes(&mut self, vault: &AssetVault) -> Result<(), StoreError> {
let smt = Smt::with_entries(vault.assets().map(|asset| {
let key: Word = asset.vault_key().into();
let value = asset.to_value_word();
(key, value)
}))
.map_err(StoreError::from)?;
let empty_root = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
let entries: Vec<(Word, Word)> = smt.entries().map(|(k, v)| (*k, *v)).collect();
if entries.is_empty() {
return Ok(());
}
let new_root = self.forest.batch_insert(empty_root, entries).map_err(StoreError::from)?;
debug_assert_eq!(new_root, smt.root());
Ok(())
}
pub fn insert_storage_map_nodes(&mut self, storage: &AccountStorage) -> Result<(), StoreError> {
let maps = storage.slots().iter().filter_map(|slot| match slot.content() {
StorageSlotContent::Map(map) => Some(map),
StorageSlotContent::Value(_) => None,
});
for map in maps {
self.insert_storage_map_nodes_for_map(map)?;
}
Ok(())
}
pub fn insert_account_state(
&mut self,
vault: &AssetVault,
storage: &AccountStorage,
) -> Result<(), StoreError> {
self.insert_storage_map_nodes(storage)?;
self.insert_asset_nodes(vault)?;
Ok(())
}
pub fn insert_and_stage_account_state(
&mut self,
account_id: AccountId,
vault: &AssetVault,
storage: &AccountStorage,
) -> Result<(), StoreError> {
self.insert_account_state(vault, storage)?;
let roots = Self::collect_account_roots(vault, storage);
self.stage_roots(account_id, roots);
Ok(())
}
pub fn insert_and_register_account_state(
&mut self,
account_id: AccountId,
vault: &AssetVault,
storage: &AccountStorage,
) -> Result<(), StoreError> {
self.insert_account_state(vault, storage)?;
let roots = Self::collect_account_roots(vault, storage);
self.replace_roots(account_id, roots);
Ok(())
}
pub fn insert_storage_map_nodes_for_map(&mut self, map: &StorageMap) -> Result<(), StoreError> {
let empty_root = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
let entries: Vec<(Word, Word)> =
map.entries().map(|(k, v)| (k.hash().as_word(), *v)).collect();
if entries.is_empty() {
return Ok(());
}
self.forest.batch_insert(empty_root, entries).map_err(StoreError::from)?;
Ok(())
}
fn collect_account_roots(vault: &AssetVault, storage: &AccountStorage) -> Vec<Word> {
let mut roots = vec![vault.root()];
for slot in storage.slots() {
if let StorageSlotContent::Map(map) = slot.content() {
roots.push(map.root());
}
}
roots
}
fn safe_pop_smts(&mut self, roots: impl IntoIterator<Item = Word>) {
self.forest.pop_smts(roots);
}
}
fn increment_refcounts(refcounts: &mut BTreeMap<Word, usize>, roots: &[Word]) {
for root in roots {
*refcounts.entry(*root).or_insert(0) += 1;
}
}
fn decrement_refcounts(refcounts: &mut BTreeMap<Word, usize>, roots: &[Word]) -> Vec<Word> {
let mut to_pop = Vec::new();
for root in roots {
if let Some(count) = refcounts.get_mut(root) {
*count -= 1;
if *count == 0 {
refcounts.remove(root);
to_pop.push(*root);
}
}
}
to_pop
}
#[cfg(test)]
mod tests {
use miden_protocol::testing::account_id::{
ACCOUNT_ID_PUBLIC_FUNGIBLE_FAUCET,
ACCOUNT_ID_PUBLIC_NON_FUNGIBLE_FAUCET,
};
use miden_protocol::{ONE, ZERO};
use super::*;
fn account_a() -> AccountId {
AccountId::try_from(ACCOUNT_ID_PUBLIC_FUNGIBLE_FAUCET).unwrap()
}
fn account_b() -> AccountId {
AccountId::try_from(ACCOUNT_ID_PUBLIC_NON_FUNGIBLE_FAUCET).unwrap()
}
fn insert_map(forest: &mut AccountSmtForest, key: Word, value: Word) -> Word {
let mut map = StorageMap::new();
map.insert(StorageMapKey::new(key), value).unwrap();
forest.insert_storage_map_nodes_for_map(&map).unwrap();
map.root()
}
fn root_is_live(forest: &AccountSmtForest, root: Word, key: Word) -> bool {
forest.get_storage_map_item_witness(root, StorageMapKey::new(key)).is_ok()
}
#[test]
fn stage_then_commit_releases_old_roots() {
let mut forest = AccountSmtForest::new();
let id = account_a();
let key1: Word = [ONE, ZERO, ZERO, ZERO].into();
let key2: Word = [ZERO, ONE, ZERO, ZERO].into();
let val: Word = [ONE, ONE, ONE, ONE].into();
let root1 = insert_map(&mut forest, key1, val);
let root2 = insert_map(&mut forest, key2, val);
forest.replace_roots(id, vec![root1]);
assert_eq!(forest.get_roots(&id), Some(&vec![root1]));
forest.stage_roots(id, vec![root2]);
assert_eq!(forest.get_roots(&id), Some(&vec![root2]));
assert!(root_is_live(&forest, root1, key1));
assert!(root_is_live(&forest, root2, key2));
forest.commit_roots(id);
assert_eq!(forest.get_roots(&id), Some(&vec![root2]));
assert!(!root_is_live(&forest, root1, key1));
assert!(root_is_live(&forest, root2, key2));
}
#[test]
fn stage_then_discard_restores_old_roots() {
let mut forest = AccountSmtForest::new();
let id = account_a();
let key1: Word = [ONE, ZERO, ZERO, ZERO].into();
let key2: Word = [ZERO, ONE, ZERO, ZERO].into();
let val: Word = [ONE, ONE, ONE, ONE].into();
let root1 = insert_map(&mut forest, key1, val);
let root2 = insert_map(&mut forest, key2, val);
forest.replace_roots(id, vec![root1]);
forest.stage_roots(id, vec![root2]);
forest.discard_roots(id);
assert_eq!(forest.get_roots(&id), Some(&vec![root1]));
assert!(root_is_live(&forest, root1, key1));
assert!(!root_is_live(&forest, root2, key2));
}
#[test]
fn shared_root_survives_single_account_replacement() {
let mut forest = AccountSmtForest::new();
let id1 = account_a();
let id2 = account_b();
let key: Word = [ONE, ZERO, ZERO, ZERO].into();
let val: Word = [ONE, ONE, ONE, ONE].into();
let shared_root = insert_map(&mut forest, key, val);
forest.replace_roots(id1, vec![shared_root]);
forest.replace_roots(id2, vec![shared_root]);
let key2: Word = [ZERO, ONE, ZERO, ZERO].into();
let other_root = insert_map(&mut forest, key2, val);
forest.replace_roots(id1, vec![other_root]);
assert!(root_is_live(&forest, shared_root, key));
forest.replace_roots(id2, vec![other_root]);
assert!(!root_is_live(&forest, shared_root, key));
}
#[test]
fn multiple_stages_discard_one_at_a_time() {
let mut forest = AccountSmtForest::new();
let id = account_a();
let key_a: Word = [ONE, ZERO, ZERO, ZERO].into();
let key_b: Word = [ZERO, ONE, ZERO, ZERO].into();
let key_c: Word = [ZERO, ZERO, ONE, ZERO].into();
let val: Word = [ONE, ONE, ONE, ONE].into();
let root_a = insert_map(&mut forest, key_a, val);
let root_b = insert_map(&mut forest, key_b, val);
let root_c = insert_map(&mut forest, key_c, val);
forest.replace_roots(id, vec![root_a]);
forest.stage_roots(id, vec![root_b]);
forest.stage_roots(id, vec![root_c]);
assert_eq!(forest.get_roots(&id), Some(&vec![root_c]));
forest.discard_roots(id);
assert_eq!(forest.get_roots(&id), Some(&vec![root_b]));
assert!(!root_is_live(&forest, root_c, key_c));
assert!(root_is_live(&forest, root_b, key_b));
assert!(root_is_live(&forest, root_a, key_a));
forest.discard_roots(id);
assert_eq!(forest.get_roots(&id), Some(&vec![root_a]));
assert!(!root_is_live(&forest, root_b, key_b));
assert!(root_is_live(&forest, root_a, key_a));
}
#[test]
fn multiple_stages_commit_releases_all_old() {
let mut forest = AccountSmtForest::new();
let id = account_a();
let key_a: Word = [ONE, ZERO, ZERO, ZERO].into();
let key_b: Word = [ZERO, ONE, ZERO, ZERO].into();
let key_c: Word = [ZERO, ZERO, ONE, ZERO].into();
let val: Word = [ONE, ONE, ONE, ONE].into();
let root_a = insert_map(&mut forest, key_a, val);
let root_b = insert_map(&mut forest, key_b, val);
let root_c = insert_map(&mut forest, key_c, val);
forest.replace_roots(id, vec![root_a]);
forest.stage_roots(id, vec![root_b]);
forest.stage_roots(id, vec![root_c]);
forest.commit_roots(id);
assert_eq!(forest.get_roots(&id), Some(&vec![root_c]));
assert!(!root_is_live(&forest, root_a, key_a));
assert!(!root_is_live(&forest, root_b, key_b));
assert!(root_is_live(&forest, root_c, key_c));
}
#[test]
fn unchanged_root_survives_stage_commit() {
let mut forest = AccountSmtForest::new();
let id = account_a();
let key1: Word = [ONE, ZERO, ZERO, ZERO].into();
let key2: Word = [ZERO, ONE, ZERO, ZERO].into();
let val: Word = [ONE, ONE, ONE, ONE].into();
let shared_root = insert_map(&mut forest, key1, val);
let changing_root = insert_map(&mut forest, key2, val);
forest.replace_roots(id, vec![shared_root, changing_root]);
let key3: Word = [ZERO, ZERO, ONE, ZERO].into();
let new_root = insert_map(&mut forest, key3, val);
forest.stage_roots(id, vec![shared_root, new_root]);
forest.commit_roots(id);
assert!(root_is_live(&forest, shared_root, key1));
assert!(!root_is_live(&forest, changing_root, key2));
assert!(root_is_live(&forest, new_root, key3));
}
}