Skip to main content

miden_client/store/
smt_forest.rs

1use alloc::collections::BTreeMap;
2use alloc::vec::Vec;
3
4use miden_protocol::account::{
5    AccountId,
6    AccountStorage,
7    StorageMap,
8    StorageMapKey,
9    StorageMapWitness,
10    StorageSlotContent,
11};
12use miden_protocol::asset::{Asset, AssetVault, AssetVaultKey, AssetWitness};
13use miden_protocol::crypto::merkle::smt::{SMT_DEPTH, Smt, SmtForest};
14use miden_protocol::crypto::merkle::{EmptySubtreeRoots, MerkleError};
15use miden_protocol::{EMPTY_WORD, Word};
16
17use super::StoreError;
18
19/// Thin wrapper around `SmtForest` for account vault/storage proofs and updates.
20///
21/// Tracks current SMT roots per account with reference counting to safely pop
22/// roots from the underlying forest when no account references them anymore.
23/// Supports staged updates for transaction rollback via a pending roots stack.
24#[derive(Debug, Default, Clone, Eq, PartialEq)]
25pub struct AccountSmtForest {
26    forest: SmtForest,
27    /// Current roots per account (vault root + storage map roots).
28    account_roots: BTreeMap<AccountId, Vec<Word>>,
29    /// Stack of old roots saved during staging, awaiting commit or undo.
30    pending_old_roots: BTreeMap<AccountId, Vec<Vec<Word>>>,
31    /// Reference count for each SMT root across all accounts.
32    root_refcounts: BTreeMap<Word, usize>,
33}
34
35impl AccountSmtForest {
36    pub fn new() -> Self {
37        Self::default()
38    }
39
40    // READERS
41    // --------------------------------------------------------------------------------------------
42
43    /// Returns the current roots for an account.
44    pub fn get_roots(&self, account_id: &AccountId) -> Option<&Vec<Word>> {
45        self.account_roots.get(account_id)
46    }
47
48    /// Retrieves the vault asset and its witness for a specific vault key.
49    pub fn get_asset_and_witness(
50        &self,
51        vault_root: Word,
52        vault_key: AssetVaultKey,
53    ) -> Result<(Asset, AssetWitness), StoreError> {
54        let vault_key_word = vault_key.into();
55        let proof = self.forest.open(vault_root, vault_key_word)?;
56        let asset_word =
57            proof.get(&vault_key_word).ok_or(MerkleError::UntrackedKey(vault_key_word))?;
58        if asset_word == EMPTY_WORD {
59            return Err(MerkleError::UntrackedKey(vault_key_word).into());
60        }
61
62        let asset = Asset::try_from(asset_word)?;
63        let witness = AssetWitness::new(proof)?;
64        Ok((asset, witness))
65    }
66
67    /// Retrieves the storage map witness for a specific map item.
68    pub fn get_storage_map_item_witness(
69        &self,
70        map_root: Word,
71        key: StorageMapKey,
72    ) -> Result<StorageMapWitness, StoreError> {
73        let hashed_key = key.hash().as_word();
74        let proof = self.forest.open(map_root, hashed_key).map_err(StoreError::from)?;
75        Ok(StorageMapWitness::new(proof, [key])?)
76    }
77
78    // ROOT LIFECYCLE
79    // --------------------------------------------------------------------------------------------
80
81    /// Stages new roots for an account, saving old roots for potential rollback.
82    ///
83    /// The old roots are pushed onto a pending stack and their refcounts are preserved.
84    /// Call [`Self::commit_roots`] to release old roots or [`Self::discard_roots`] to
85    /// restore them.
86    pub fn stage_roots(&mut self, account_id: AccountId, new_roots: Vec<Word>) {
87        increment_refcounts(&mut self.root_refcounts, &new_roots);
88        if let Some(old_roots) = self.account_roots.insert(account_id, new_roots) {
89            self.pending_old_roots.entry(account_id).or_default().push(old_roots);
90        }
91    }
92
93    /// Commits staged changes: releases all pending old roots for the account.
94    pub fn commit_roots(&mut self, account_id: AccountId) {
95        if let Some(old_roots_stack) = self.pending_old_roots.remove(&account_id) {
96            for old_roots in old_roots_stack {
97                let to_pop = decrement_refcounts(&mut self.root_refcounts, &old_roots);
98                self.safe_pop_smts(to_pop);
99            }
100        }
101    }
102
103    /// Discards the most recent staged change: restores old roots and releases new roots.
104    ///
105    /// If there are old roots to restore, the current roots are replaced with them.
106    /// If there are no old roots (i.e., the account was first staged without prior state),
107    /// the current roots are simply removed.
108    pub fn discard_roots(&mut self, account_id: AccountId) {
109        let old_roots = self.pending_old_roots.get_mut(&account_id).and_then(Vec::pop);
110
111        // Release the current (staged) roots and restore old ones if available
112        let new_roots = match old_roots {
113            Some(old_roots) => self.account_roots.insert(account_id, old_roots),
114            None => self.account_roots.remove(&account_id),
115        };
116
117        if let Some(new_roots) = new_roots {
118            let to_pop = decrement_refcounts(&mut self.root_refcounts, &new_roots);
119            self.safe_pop_smts(to_pop);
120        }
121
122        // Clean up empty stack
123        if self.pending_old_roots.get(&account_id).is_some_and(Vec::is_empty) {
124            self.pending_old_roots.remove(&account_id);
125        }
126    }
127
128    /// Replaces roots atomically: sets new roots and immediately releases old roots.
129    ///
130    /// Use this when no rollback is needed (e.g., initial insert, network updates).
131    ///
132    /// # Panics
133    ///
134    /// Panics if there are pending staged changes for the account. Use
135    /// [`Self::commit_roots`] or [`Self::discard_roots`] first.
136    pub fn replace_roots(&mut self, account_id: AccountId, new_roots: Vec<Word>) {
137        assert!(
138            !self.pending_old_roots.contains_key(&account_id),
139            "cannot replace roots while staged changes are pending for account {account_id}"
140        );
141        increment_refcounts(&mut self.root_refcounts, &new_roots);
142        if let Some(old_roots) = self.account_roots.insert(account_id, new_roots) {
143            let to_pop = decrement_refcounts(&mut self.root_refcounts, &old_roots);
144            self.safe_pop_smts(to_pop);
145        }
146    }
147
148    // TREE MUTATORS
149    // --------------------------------------------------------------------------------------------
150
151    /// Updates the SMT forest with the new asset values.
152    pub fn update_asset_nodes(
153        &mut self,
154        root: Word,
155        new_assets: impl Iterator<Item = Asset>,
156        removed_vault_keys: impl Iterator<Item = AssetVaultKey>,
157    ) -> Result<Word, StoreError> {
158        let entries: Vec<(Word, Word)> = new_assets
159            .map(|asset| {
160                let key: Word = asset.vault_key().into();
161                let value: Word = asset.into();
162                (key, value)
163            })
164            .chain(removed_vault_keys.map(|key| (key.into(), EMPTY_WORD)))
165            .collect();
166
167        if entries.is_empty() {
168            return Ok(root);
169        }
170
171        let new_root = self.forest.batch_insert(root, entries).map_err(StoreError::from)?;
172        Ok(new_root)
173    }
174
175    /// Updates the SMT forest with the new storage map values.
176    pub fn update_storage_map_nodes(
177        &mut self,
178        root: Word,
179        entries: impl Iterator<Item = (StorageMapKey, Word)>,
180    ) -> Result<Word, StoreError> {
181        let entries: Vec<(Word, Word)> =
182            entries.map(|(key, value)| (key.hash().as_word(), value)).collect();
183
184        if entries.is_empty() {
185            return Ok(root);
186        }
187
188        let new_root = self.forest.batch_insert(root, entries).map_err(StoreError::from)?;
189        Ok(new_root)
190    }
191
192    /// Inserts the asset vault SMT nodes to the SMT forest.
193    pub fn insert_asset_nodes(&mut self, vault: &AssetVault) -> Result<(), StoreError> {
194        let smt = Smt::with_entries(vault.assets().map(|asset| {
195            let key: Word = asset.vault_key().into();
196            let value: Word = asset.into();
197            (key, value)
198        }))
199        .map_err(StoreError::from)?;
200
201        let empty_root = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
202        let entries: Vec<(Word, Word)> = smt.entries().map(|(k, v)| (*k, *v)).collect();
203        if entries.is_empty() {
204            return Ok(());
205        }
206        let new_root = self.forest.batch_insert(empty_root, entries).map_err(StoreError::from)?;
207        debug_assert_eq!(new_root, smt.root());
208        Ok(())
209    }
210
211    /// Inserts all storage map SMT nodes to the SMT forest.
212    pub fn insert_storage_map_nodes(&mut self, storage: &AccountStorage) -> Result<(), StoreError> {
213        let maps = storage.slots().iter().filter_map(|slot| match slot.content() {
214            StorageSlotContent::Map(map) => Some(map),
215            StorageSlotContent::Value(_) => None,
216        });
217
218        for map in maps {
219            self.insert_storage_map_nodes_for_map(map)?;
220        }
221        Ok(())
222    }
223
224    /// Inserts the SMT nodes for an account's vault and storage maps into the
225    /// forest, without tracking roots for the account.
226    pub fn insert_account_state(
227        &mut self,
228        vault: &AssetVault,
229        storage: &AccountStorage,
230    ) -> Result<(), StoreError> {
231        self.insert_storage_map_nodes(storage)?;
232        self.insert_asset_nodes(vault)?;
233        Ok(())
234    }
235
236    /// Inserts all SMT nodes for an account's vault and storage, then replaces
237    /// the account's tracked roots atomically.
238    pub fn insert_and_register_account_state(
239        &mut self,
240        account_id: AccountId,
241        vault: &AssetVault,
242        storage: &AccountStorage,
243    ) -> Result<(), StoreError> {
244        self.insert_account_state(vault, storage)?;
245        let roots = Self::collect_account_roots(vault, storage);
246        self.replace_roots(account_id, roots);
247        Ok(())
248    }
249
250    pub fn insert_storage_map_nodes_for_map(&mut self, map: &StorageMap) -> Result<(), StoreError> {
251        let empty_root = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
252        let entries: Vec<(Word, Word)> =
253            map.entries().map(|(k, v)| (k.hash().as_word(), *v)).collect();
254        if entries.is_empty() {
255            return Ok(());
256        }
257        self.forest.batch_insert(empty_root, entries).map_err(StoreError::from)?;
258        Ok(())
259    }
260
261    // HELPERS
262    // --------------------------------------------------------------------------------------------
263
264    /// Collects all SMT roots (vault root + storage map roots) for an account's state.
265    fn collect_account_roots(vault: &AssetVault, storage: &AccountStorage) -> Vec<Word> {
266        let mut roots = vec![vault.root()];
267        for slot in storage.slots() {
268            if let StorageSlotContent::Map(map) = slot.content() {
269                roots.push(map.root());
270            }
271        }
272        roots
273    }
274
275    /// Pops SMT roots from the forest that are no longer referenced by any account.
276    fn safe_pop_smts(&mut self, roots: impl IntoIterator<Item = Word>) {
277        self.forest.pop_smts(roots);
278    }
279}
280
281fn increment_refcounts(refcounts: &mut BTreeMap<Word, usize>, roots: &[Word]) {
282    for root in roots {
283        *refcounts.entry(*root).or_insert(0) += 1;
284    }
285}
286
287/// Decrements refcounts for the given roots, returning those that reached zero.
288fn decrement_refcounts(refcounts: &mut BTreeMap<Word, usize>, roots: &[Word]) -> Vec<Word> {
289    let mut to_pop = Vec::new();
290    for root in roots {
291        if let Some(count) = refcounts.get_mut(root) {
292            *count -= 1;
293            if *count == 0 {
294                refcounts.remove(root);
295                to_pop.push(*root);
296            }
297        }
298    }
299    to_pop
300}
301
302#[cfg(test)]
303mod tests {
304    use miden_protocol::testing::account_id::{
305        ACCOUNT_ID_PUBLIC_FUNGIBLE_FAUCET,
306        ACCOUNT_ID_PUBLIC_NON_FUNGIBLE_FAUCET,
307    };
308    use miden_protocol::{ONE, ZERO};
309
310    use super::*;
311
312    fn account_a() -> AccountId {
313        AccountId::try_from(ACCOUNT_ID_PUBLIC_FUNGIBLE_FAUCET).unwrap()
314    }
315
316    fn account_b() -> AccountId {
317        AccountId::try_from(ACCOUNT_ID_PUBLIC_NON_FUNGIBLE_FAUCET).unwrap()
318    }
319
320    /// Creates a `StorageMap` with a single entry and inserts its nodes into the forest.
321    /// Returns the map's root.
322    fn insert_map(forest: &mut AccountSmtForest, key: Word, value: Word) -> Word {
323        let mut map = StorageMap::new();
324        map.insert(StorageMapKey::new(key), value).unwrap();
325        forest.insert_storage_map_nodes_for_map(&map).unwrap();
326        map.root()
327    }
328
329    /// Returns true if the forest can still serve a proof for the given root.
330    fn root_is_live(forest: &AccountSmtForest, root: Word, key: Word) -> bool {
331        forest.get_storage_map_item_witness(root, StorageMapKey::new(key)).is_ok()
332    }
333
334    #[test]
335    fn stage_then_commit_releases_old_roots() {
336        let mut forest = AccountSmtForest::new();
337        let id = account_a();
338
339        let key1: Word = [ONE, ZERO, ZERO, ZERO].into();
340        let key2: Word = [ZERO, ONE, ZERO, ZERO].into();
341        let val: Word = [ONE, ONE, ONE, ONE].into();
342
343        let root1 = insert_map(&mut forest, key1, val);
344        let root2 = insert_map(&mut forest, key2, val);
345
346        // Initial state
347        forest.replace_roots(id, vec![root1]);
348        assert_eq!(forest.get_roots(&id), Some(&vec![root1]));
349
350        // Stage new roots (apply_delta)
351        forest.stage_roots(id, vec![root2]);
352        assert_eq!(forest.get_roots(&id), Some(&vec![root2]));
353
354        // Both roots alive during staging (old preserved for rollback)
355        assert!(root_is_live(&forest, root1, key1));
356        assert!(root_is_live(&forest, root2, key2));
357
358        // Commit — old roots released
359        forest.commit_roots(id);
360        assert_eq!(forest.get_roots(&id), Some(&vec![root2]));
361        assert!(!root_is_live(&forest, root1, key1));
362        assert!(root_is_live(&forest, root2, key2));
363    }
364
365    #[test]
366    fn stage_then_discard_restores_old_roots() {
367        let mut forest = AccountSmtForest::new();
368        let id = account_a();
369
370        let key1: Word = [ONE, ZERO, ZERO, ZERO].into();
371        let key2: Word = [ZERO, ONE, ZERO, ZERO].into();
372        let val: Word = [ONE, ONE, ONE, ONE].into();
373
374        let root1 = insert_map(&mut forest, key1, val);
375        let root2 = insert_map(&mut forest, key2, val);
376
377        forest.replace_roots(id, vec![root1]);
378
379        // Stage and discard (rollback)
380        forest.stage_roots(id, vec![root2]);
381        forest.discard_roots(id);
382
383        assert_eq!(forest.get_roots(&id), Some(&vec![root1]));
384        assert!(root_is_live(&forest, root1, key1));
385        assert!(!root_is_live(&forest, root2, key2));
386    }
387
388    #[test]
389    fn shared_root_survives_single_account_replacement() {
390        let mut forest = AccountSmtForest::new();
391        let id1 = account_a();
392        let id2 = account_b();
393
394        let key: Word = [ONE, ZERO, ZERO, ZERO].into();
395        let val: Word = [ONE, ONE, ONE, ONE].into();
396        let shared_root = insert_map(&mut forest, key, val);
397
398        // Both accounts reference the same root
399        forest.replace_roots(id1, vec![shared_root]);
400        forest.replace_roots(id2, vec![shared_root]);
401
402        // Replace id1 with a different root
403        let key2: Word = [ZERO, ONE, ZERO, ZERO].into();
404        let other_root = insert_map(&mut forest, key2, val);
405        forest.replace_roots(id1, vec![other_root]);
406
407        // Shared root still alive (id2 still references it)
408        assert!(root_is_live(&forest, shared_root, key));
409
410        // Replace id2 too — now shared root should be popped
411        forest.replace_roots(id2, vec![other_root]);
412        assert!(!root_is_live(&forest, shared_root, key));
413    }
414
415    #[test]
416    fn multiple_stages_discard_one_at_a_time() {
417        let mut forest = AccountSmtForest::new();
418        let id = account_a();
419
420        let key_a: Word = [ONE, ZERO, ZERO, ZERO].into();
421        let key_b: Word = [ZERO, ONE, ZERO, ZERO].into();
422        let key_c: Word = [ZERO, ZERO, ONE, ZERO].into();
423        let val: Word = [ONE, ONE, ONE, ONE].into();
424
425        let root_a = insert_map(&mut forest, key_a, val);
426        let root_b = insert_map(&mut forest, key_b, val);
427        let root_c = insert_map(&mut forest, key_c, val);
428
429        // A -> B -> C
430        forest.replace_roots(id, vec![root_a]);
431        forest.stage_roots(id, vec![root_b]);
432        forest.stage_roots(id, vec![root_c]);
433        assert_eq!(forest.get_roots(&id), Some(&vec![root_c]));
434
435        // Discard C -> back to B
436        forest.discard_roots(id);
437        assert_eq!(forest.get_roots(&id), Some(&vec![root_b]));
438        assert!(!root_is_live(&forest, root_c, key_c));
439        assert!(root_is_live(&forest, root_b, key_b));
440        assert!(root_is_live(&forest, root_a, key_a));
441
442        // Discard B -> back to A
443        forest.discard_roots(id);
444        assert_eq!(forest.get_roots(&id), Some(&vec![root_a]));
445        assert!(!root_is_live(&forest, root_b, key_b));
446        assert!(root_is_live(&forest, root_a, key_a));
447    }
448
449    #[test]
450    fn multiple_stages_commit_releases_all_old() {
451        let mut forest = AccountSmtForest::new();
452        let id = account_a();
453
454        let key_a: Word = [ONE, ZERO, ZERO, ZERO].into();
455        let key_b: Word = [ZERO, ONE, ZERO, ZERO].into();
456        let key_c: Word = [ZERO, ZERO, ONE, ZERO].into();
457        let val: Word = [ONE, ONE, ONE, ONE].into();
458
459        let root_a = insert_map(&mut forest, key_a, val);
460        let root_b = insert_map(&mut forest, key_b, val);
461        let root_c = insert_map(&mut forest, key_c, val);
462
463        // A -> B -> C, then commit
464        forest.replace_roots(id, vec![root_a]);
465        forest.stage_roots(id, vec![root_b]);
466        forest.stage_roots(id, vec![root_c]);
467        forest.commit_roots(id);
468
469        // Only C survives
470        assert_eq!(forest.get_roots(&id), Some(&vec![root_c]));
471        assert!(!root_is_live(&forest, root_a, key_a));
472        assert!(!root_is_live(&forest, root_b, key_b));
473        assert!(root_is_live(&forest, root_c, key_c));
474    }
475
476    #[test]
477    fn unchanged_root_survives_stage_commit() {
478        let mut forest = AccountSmtForest::new();
479        let id = account_a();
480
481        let key1: Word = [ONE, ZERO, ZERO, ZERO].into();
482        let key2: Word = [ZERO, ONE, ZERO, ZERO].into();
483        let val: Word = [ONE, ONE, ONE, ONE].into();
484
485        let shared_root = insert_map(&mut forest, key1, val);
486        let changing_root = insert_map(&mut forest, key2, val);
487
488        // Initial: [shared, changing]
489        forest.replace_roots(id, vec![shared_root, changing_root]);
490
491        // Delta only changes the second root; shared_root stays
492        let key3: Word = [ZERO, ZERO, ONE, ZERO].into();
493        let new_root = insert_map(&mut forest, key3, val);
494        forest.stage_roots(id, vec![shared_root, new_root]);
495        forest.commit_roots(id);
496
497        // shared_root must survive (it's in both old and new)
498        assert!(root_is_live(&forest, shared_root, key1));
499        // changing_root should be popped
500        assert!(!root_is_live(&forest, changing_root, key2));
501        // new_root should be alive
502        assert!(root_is_live(&forest, new_root, key3));
503    }
504}