Skip to main content

miden_client/store/
smt_forest.rs

1use alloc::collections::BTreeMap;
2use alloc::vec::Vec;
3
4use miden_protocol::account::{
5    AccountId,
6    AccountStorage,
7    StorageMap,
8    StorageMapKey,
9    StorageMapWitness,
10    StorageSlotContent,
11};
12use miden_protocol::asset::{Asset, AssetVault, AssetVaultKey, AssetWitness};
13use miden_protocol::crypto::merkle::smt::{SMT_DEPTH, Smt, SmtForest};
14use miden_protocol::crypto::merkle::{EmptySubtreeRoots, MerkleError};
15use miden_protocol::{EMPTY_WORD, Word};
16
17use super::StoreError;
18
19/// Thin wrapper around `SmtForest` for account vault/storage proofs and updates.
20///
21/// Tracks current SMT roots per account with reference counting to safely pop
22/// roots from the underlying forest when no account references them anymore.
23/// Supports staged updates for transaction rollback via a pending roots stack.
24#[derive(Debug, Default, Clone, Eq, PartialEq)]
25pub struct AccountSmtForest {
26    forest: SmtForest,
27    /// Current roots per account (vault root + storage map roots).
28    account_roots: BTreeMap<AccountId, Vec<Word>>,
29    /// Stack of old roots saved during staging, awaiting commit or undo.
30    pending_old_roots: BTreeMap<AccountId, Vec<Vec<Word>>>,
31    /// Reference count for each SMT root across all accounts.
32    root_refcounts: BTreeMap<Word, usize>,
33}
34
35impl AccountSmtForest {
36    pub fn new() -> Self {
37        Self::default()
38    }
39
40    // READERS
41    // --------------------------------------------------------------------------------------------
42
43    /// Returns the current roots for an account.
44    pub fn get_roots(&self, account_id: &AccountId) -> Option<&Vec<Word>> {
45        self.account_roots.get(account_id)
46    }
47
48    /// Retrieves the vault asset and its witness for a specific vault key.
49    pub fn get_asset_and_witness(
50        &self,
51        vault_root: Word,
52        vault_key: AssetVaultKey,
53    ) -> Result<(Asset, AssetWitness), StoreError> {
54        let vault_key_word = vault_key.into();
55        let proof = self.forest.open(vault_root, vault_key_word)?;
56        let asset_word =
57            proof.get(&vault_key_word).ok_or(MerkleError::UntrackedKey(vault_key_word))?;
58        if asset_word == EMPTY_WORD {
59            return Err(MerkleError::UntrackedKey(vault_key_word).into());
60        }
61
62        let asset = Asset::from_key_value_words(vault_key_word, asset_word)?;
63        let witness = AssetWitness::new(proof)?;
64        Ok((asset, witness))
65    }
66
67    /// Retrieves the storage map witness for a specific map item.
68    pub fn get_storage_map_item_witness(
69        &self,
70        map_root: Word,
71        key: StorageMapKey,
72    ) -> Result<StorageMapWitness, StoreError> {
73        let hashed_key = key.hash().as_word();
74        let proof = self.forest.open(map_root, hashed_key).map_err(StoreError::from)?;
75        Ok(StorageMapWitness::new(proof, [key])?)
76    }
77
78    // ROOT LIFECYCLE
79    // --------------------------------------------------------------------------------------------
80
81    /// Stages new roots for an account, saving old roots for potential rollback.
82    ///
83    /// The old roots are pushed onto a pending stack and their refcounts are preserved.
84    /// Call [`Self::commit_roots`] to release old roots or [`Self::discard_roots`] to
85    /// restore them.
86    pub fn stage_roots(&mut self, account_id: AccountId, new_roots: Vec<Word>) {
87        increment_refcounts(&mut self.root_refcounts, &new_roots);
88        if let Some(old_roots) = self.account_roots.insert(account_id, new_roots) {
89            self.pending_old_roots.entry(account_id).or_default().push(old_roots);
90        }
91    }
92
93    /// Commits staged changes: releases all pending old roots for the account.
94    pub fn commit_roots(&mut self, account_id: AccountId) {
95        if let Some(old_roots_stack) = self.pending_old_roots.remove(&account_id) {
96            for old_roots in old_roots_stack {
97                let to_pop = decrement_refcounts(&mut self.root_refcounts, &old_roots);
98                self.safe_pop_smts(to_pop);
99            }
100        }
101    }
102
103    /// Discards the most recent staged change: restores old roots and releases new roots.
104    ///
105    /// If there are old roots to restore, the current roots are replaced with them.
106    /// If there are no old roots (i.e., the account was first staged without prior state),
107    /// the current roots are simply removed.
108    pub fn discard_roots(&mut self, account_id: AccountId) {
109        let old_roots = self.pending_old_roots.get_mut(&account_id).and_then(Vec::pop);
110
111        // Release the current (staged) roots and restore old ones if available
112        let new_roots = match old_roots {
113            Some(old_roots) => self.account_roots.insert(account_id, old_roots),
114            None => self.account_roots.remove(&account_id),
115        };
116
117        if let Some(new_roots) = new_roots {
118            let to_pop = decrement_refcounts(&mut self.root_refcounts, &new_roots);
119            self.safe_pop_smts(to_pop);
120        }
121
122        // Clean up empty stack
123        if self.pending_old_roots.get(&account_id).is_some_and(Vec::is_empty) {
124            self.pending_old_roots.remove(&account_id);
125        }
126    }
127
128    /// Replaces roots atomically: sets new roots and immediately releases old roots.
129    ///
130    /// Use this when no rollback is needed (e.g., initial insert, network updates).
131    ///
132    /// # Panics
133    ///
134    /// Panics if there are pending staged changes for the account. Use
135    /// [`Self::commit_roots`] or [`Self::discard_roots`] first.
136    pub fn replace_roots(&mut self, account_id: AccountId, new_roots: Vec<Word>) {
137        assert!(
138            !self.pending_old_roots.contains_key(&account_id),
139            "cannot replace roots while staged changes are pending for account {account_id}"
140        );
141        increment_refcounts(&mut self.root_refcounts, &new_roots);
142        if let Some(old_roots) = self.account_roots.insert(account_id, new_roots) {
143            let to_pop = decrement_refcounts(&mut self.root_refcounts, &old_roots);
144            self.safe_pop_smts(to_pop);
145        }
146    }
147
148    // TREE MUTATORS
149    // --------------------------------------------------------------------------------------------
150
151    /// Updates the SMT forest with the new asset values.
152    pub fn update_asset_nodes(
153        &mut self,
154        root: Word,
155        new_assets: impl Iterator<Item = Asset>,
156        removed_vault_keys: impl Iterator<Item = AssetVaultKey>,
157    ) -> Result<Word, StoreError> {
158        let entries: Vec<(Word, Word)> = new_assets
159            .map(|asset| {
160                let key: Word = asset.vault_key().into();
161                let value = asset.to_value_word();
162                (key, value)
163            })
164            .chain(removed_vault_keys.map(|vault_key| (vault_key.into(), EMPTY_WORD)))
165            .collect();
166
167        if entries.is_empty() {
168            return Ok(root);
169        }
170
171        let new_root = self.forest.batch_insert(root, entries).map_err(StoreError::from)?;
172        Ok(new_root)
173    }
174
175    /// Updates the SMT forest with the new storage map values.
176    pub fn update_storage_map_nodes(
177        &mut self,
178        root: Word,
179        entries: impl Iterator<Item = (StorageMapKey, Word)>,
180    ) -> Result<Word, StoreError> {
181        let entries: Vec<(Word, Word)> =
182            entries.map(|(key, value)| (key.hash().as_word(), value)).collect();
183
184        if entries.is_empty() {
185            return Ok(root);
186        }
187
188        let new_root = self.forest.batch_insert(root, entries).map_err(StoreError::from)?;
189        Ok(new_root)
190    }
191
192    /// Inserts the asset vault SMT nodes to the SMT forest.
193    pub fn insert_asset_nodes(&mut self, vault: &AssetVault) -> Result<(), StoreError> {
194        let smt = Smt::with_entries(vault.assets().map(|asset| {
195            let key: Word = asset.vault_key().into();
196            let value = asset.to_value_word();
197            (key, value)
198        }))
199        .map_err(StoreError::from)?;
200
201        let empty_root = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
202        let entries: Vec<(Word, Word)> = smt.entries().map(|(k, v)| (*k, *v)).collect();
203        if entries.is_empty() {
204            return Ok(());
205        }
206        let new_root = self.forest.batch_insert(empty_root, entries).map_err(StoreError::from)?;
207        debug_assert_eq!(new_root, smt.root());
208        Ok(())
209    }
210
211    /// Inserts all storage map SMT nodes to the SMT forest.
212    pub fn insert_storage_map_nodes(&mut self, storage: &AccountStorage) -> Result<(), StoreError> {
213        let maps = storage.slots().iter().filter_map(|slot| match slot.content() {
214            StorageSlotContent::Map(map) => Some(map),
215            StorageSlotContent::Value(_) => None,
216        });
217
218        for map in maps {
219            self.insert_storage_map_nodes_for_map(map)?;
220        }
221        Ok(())
222    }
223
224    /// Inserts the SMT nodes for an account's vault and storage maps into the
225    /// forest, without tracking roots for the account.
226    pub fn insert_account_state(
227        &mut self,
228        vault: &AssetVault,
229        storage: &AccountStorage,
230    ) -> Result<(), StoreError> {
231        self.insert_storage_map_nodes(storage)?;
232        self.insert_asset_nodes(vault)?;
233        Ok(())
234    }
235
236    /// Inserts all SMT nodes for an account's vault and storage, then stages
237    /// the account's roots for later commit or discard.
238    pub fn insert_and_stage_account_state(
239        &mut self,
240        account_id: AccountId,
241        vault: &AssetVault,
242        storage: &AccountStorage,
243    ) -> Result<(), StoreError> {
244        self.insert_account_state(vault, storage)?;
245        let roots = Self::collect_account_roots(vault, storage);
246        self.stage_roots(account_id, roots);
247        Ok(())
248    }
249
250    /// Inserts all SMT nodes for an account's vault and storage, then replaces
251    /// the account's tracked roots atomically.
252    pub fn insert_and_register_account_state(
253        &mut self,
254        account_id: AccountId,
255        vault: &AssetVault,
256        storage: &AccountStorage,
257    ) -> Result<(), StoreError> {
258        self.insert_account_state(vault, storage)?;
259        let roots = Self::collect_account_roots(vault, storage);
260        self.replace_roots(account_id, roots);
261        Ok(())
262    }
263
264    /// Inserts storage map SMT nodes for a specific storage map.
265    pub fn insert_storage_map_nodes_for_map(&mut self, map: &StorageMap) -> Result<(), StoreError> {
266        let empty_root = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
267        let entries: Vec<(Word, Word)> =
268            map.entries().map(|(k, v)| (k.hash().as_word(), *v)).collect();
269        if entries.is_empty() {
270            return Ok(());
271        }
272        self.forest.batch_insert(empty_root, entries).map_err(StoreError::from)?;
273        Ok(())
274    }
275
276    // HELPERS
277    // --------------------------------------------------------------------------------------------
278
279    /// Collects all SMT roots (vault root + storage map roots) for an account's state.
280    fn collect_account_roots(vault: &AssetVault, storage: &AccountStorage) -> Vec<Word> {
281        let mut roots = vec![vault.root()];
282        for slot in storage.slots() {
283            if let StorageSlotContent::Map(map) = slot.content() {
284                roots.push(map.root());
285            }
286        }
287        roots
288    }
289
290    /// Pops SMT roots from the forest that are no longer referenced by any account.
291    fn safe_pop_smts(&mut self, roots: impl IntoIterator<Item = Word>) {
292        self.forest.pop_smts(roots);
293    }
294}
295
296fn increment_refcounts(refcounts: &mut BTreeMap<Word, usize>, roots: &[Word]) {
297    for root in roots {
298        *refcounts.entry(*root).or_insert(0) += 1;
299    }
300}
301
302/// Decrements refcounts for the given roots, returning those that reached zero.
303fn decrement_refcounts(refcounts: &mut BTreeMap<Word, usize>, roots: &[Word]) -> Vec<Word> {
304    let mut to_pop = Vec::new();
305    for root in roots {
306        if let Some(count) = refcounts.get_mut(root) {
307            *count -= 1;
308            if *count == 0 {
309                refcounts.remove(root);
310                to_pop.push(*root);
311            }
312        }
313    }
314    to_pop
315}
316
317#[cfg(test)]
318mod tests {
319    use miden_protocol::testing::account_id::{
320        ACCOUNT_ID_PUBLIC_FUNGIBLE_FAUCET,
321        ACCOUNT_ID_PUBLIC_NON_FUNGIBLE_FAUCET,
322    };
323    use miden_protocol::{ONE, ZERO};
324
325    use super::*;
326
327    fn account_a() -> AccountId {
328        AccountId::try_from(ACCOUNT_ID_PUBLIC_FUNGIBLE_FAUCET).unwrap()
329    }
330
331    fn account_b() -> AccountId {
332        AccountId::try_from(ACCOUNT_ID_PUBLIC_NON_FUNGIBLE_FAUCET).unwrap()
333    }
334
335    /// Creates a `StorageMap` with a single entry and inserts its nodes into the forest.
336    /// Returns the map's root.
337    fn insert_map(forest: &mut AccountSmtForest, key: Word, value: Word) -> Word {
338        let mut map = StorageMap::new();
339        map.insert(StorageMapKey::new(key), value).unwrap();
340        forest.insert_storage_map_nodes_for_map(&map).unwrap();
341        map.root()
342    }
343
344    /// Returns true if the forest can still serve a proof for the given root.
345    fn root_is_live(forest: &AccountSmtForest, root: Word, key: Word) -> bool {
346        forest.get_storage_map_item_witness(root, StorageMapKey::new(key)).is_ok()
347    }
348
349    #[test]
350    fn stage_then_commit_releases_old_roots() {
351        let mut forest = AccountSmtForest::new();
352        let id = account_a();
353
354        let key1: Word = [ONE, ZERO, ZERO, ZERO].into();
355        let key2: Word = [ZERO, ONE, ZERO, ZERO].into();
356        let val: Word = [ONE, ONE, ONE, ONE].into();
357
358        let root1 = insert_map(&mut forest, key1, val);
359        let root2 = insert_map(&mut forest, key2, val);
360
361        // Initial state
362        forest.replace_roots(id, vec![root1]);
363        assert_eq!(forest.get_roots(&id), Some(&vec![root1]));
364
365        // Stage new roots (apply_delta)
366        forest.stage_roots(id, vec![root2]);
367        assert_eq!(forest.get_roots(&id), Some(&vec![root2]));
368
369        // Both roots alive during staging (old preserved for rollback)
370        assert!(root_is_live(&forest, root1, key1));
371        assert!(root_is_live(&forest, root2, key2));
372
373        // Commit — old roots released
374        forest.commit_roots(id);
375        assert_eq!(forest.get_roots(&id), Some(&vec![root2]));
376        assert!(!root_is_live(&forest, root1, key1));
377        assert!(root_is_live(&forest, root2, key2));
378    }
379
380    #[test]
381    fn stage_then_discard_restores_old_roots() {
382        let mut forest = AccountSmtForest::new();
383        let id = account_a();
384
385        let key1: Word = [ONE, ZERO, ZERO, ZERO].into();
386        let key2: Word = [ZERO, ONE, ZERO, ZERO].into();
387        let val: Word = [ONE, ONE, ONE, ONE].into();
388
389        let root1 = insert_map(&mut forest, key1, val);
390        let root2 = insert_map(&mut forest, key2, val);
391
392        forest.replace_roots(id, vec![root1]);
393
394        // Stage and discard (rollback)
395        forest.stage_roots(id, vec![root2]);
396        forest.discard_roots(id);
397
398        assert_eq!(forest.get_roots(&id), Some(&vec![root1]));
399        assert!(root_is_live(&forest, root1, key1));
400        assert!(!root_is_live(&forest, root2, key2));
401    }
402
403    #[test]
404    fn shared_root_survives_single_account_replacement() {
405        let mut forest = AccountSmtForest::new();
406        let id1 = account_a();
407        let id2 = account_b();
408
409        let key: Word = [ONE, ZERO, ZERO, ZERO].into();
410        let val: Word = [ONE, ONE, ONE, ONE].into();
411        let shared_root = insert_map(&mut forest, key, val);
412
413        // Both accounts reference the same root
414        forest.replace_roots(id1, vec![shared_root]);
415        forest.replace_roots(id2, vec![shared_root]);
416
417        // Replace id1 with a different root
418        let key2: Word = [ZERO, ONE, ZERO, ZERO].into();
419        let other_root = insert_map(&mut forest, key2, val);
420        forest.replace_roots(id1, vec![other_root]);
421
422        // Shared root still alive (id2 still references it)
423        assert!(root_is_live(&forest, shared_root, key));
424
425        // Replace id2 too — now shared root should be popped
426        forest.replace_roots(id2, vec![other_root]);
427        assert!(!root_is_live(&forest, shared_root, key));
428    }
429
430    #[test]
431    fn multiple_stages_discard_one_at_a_time() {
432        let mut forest = AccountSmtForest::new();
433        let id = account_a();
434
435        let key_a: Word = [ONE, ZERO, ZERO, ZERO].into();
436        let key_b: Word = [ZERO, ONE, ZERO, ZERO].into();
437        let key_c: Word = [ZERO, ZERO, ONE, ZERO].into();
438        let val: Word = [ONE, ONE, ONE, ONE].into();
439
440        let root_a = insert_map(&mut forest, key_a, val);
441        let root_b = insert_map(&mut forest, key_b, val);
442        let root_c = insert_map(&mut forest, key_c, val);
443
444        // A -> B -> C
445        forest.replace_roots(id, vec![root_a]);
446        forest.stage_roots(id, vec![root_b]);
447        forest.stage_roots(id, vec![root_c]);
448        assert_eq!(forest.get_roots(&id), Some(&vec![root_c]));
449
450        // Discard C -> back to B
451        forest.discard_roots(id);
452        assert_eq!(forest.get_roots(&id), Some(&vec![root_b]));
453        assert!(!root_is_live(&forest, root_c, key_c));
454        assert!(root_is_live(&forest, root_b, key_b));
455        assert!(root_is_live(&forest, root_a, key_a));
456
457        // Discard B -> back to A
458        forest.discard_roots(id);
459        assert_eq!(forest.get_roots(&id), Some(&vec![root_a]));
460        assert!(!root_is_live(&forest, root_b, key_b));
461        assert!(root_is_live(&forest, root_a, key_a));
462    }
463
464    #[test]
465    fn multiple_stages_commit_releases_all_old() {
466        let mut forest = AccountSmtForest::new();
467        let id = account_a();
468
469        let key_a: Word = [ONE, ZERO, ZERO, ZERO].into();
470        let key_b: Word = [ZERO, ONE, ZERO, ZERO].into();
471        let key_c: Word = [ZERO, ZERO, ONE, ZERO].into();
472        let val: Word = [ONE, ONE, ONE, ONE].into();
473
474        let root_a = insert_map(&mut forest, key_a, val);
475        let root_b = insert_map(&mut forest, key_b, val);
476        let root_c = insert_map(&mut forest, key_c, val);
477
478        // A -> B -> C, then commit
479        forest.replace_roots(id, vec![root_a]);
480        forest.stage_roots(id, vec![root_b]);
481        forest.stage_roots(id, vec![root_c]);
482        forest.commit_roots(id);
483
484        // Only C survives
485        assert_eq!(forest.get_roots(&id), Some(&vec![root_c]));
486        assert!(!root_is_live(&forest, root_a, key_a));
487        assert!(!root_is_live(&forest, root_b, key_b));
488        assert!(root_is_live(&forest, root_c, key_c));
489    }
490
491    #[test]
492    fn unchanged_root_survives_stage_commit() {
493        let mut forest = AccountSmtForest::new();
494        let id = account_a();
495
496        let key1: Word = [ONE, ZERO, ZERO, ZERO].into();
497        let key2: Word = [ZERO, ONE, ZERO, ZERO].into();
498        let val: Word = [ONE, ONE, ONE, ONE].into();
499
500        let shared_root = insert_map(&mut forest, key1, val);
501        let changing_root = insert_map(&mut forest, key2, val);
502
503        // Initial: [shared, changing]
504        forest.replace_roots(id, vec![shared_root, changing_root]);
505
506        // Delta only changes the second root; shared_root stays
507        let key3: Word = [ZERO, ZERO, ONE, ZERO].into();
508        let new_root = insert_map(&mut forest, key3, val);
509        forest.stage_roots(id, vec![shared_root, new_root]);
510        forest.commit_roots(id);
511
512        // shared_root must survive (it's in both old and new)
513        assert!(root_is_live(&forest, shared_root, key1));
514        // changing_root should be popped
515        assert!(!root_is_live(&forest, changing_root, key2));
516        // new_root should be alive
517        assert!(root_is_live(&forest, new_root, key3));
518    }
519}