1use alloc::collections::BTreeMap;
2use alloc::vec::Vec;
3
4use miden_protocol::account::{
5 AccountId,
6 AccountStorage,
7 StorageMap,
8 StorageMapKey,
9 StorageMapWitness,
10 StorageSlotContent,
11};
12use miden_protocol::asset::{Asset, AssetVault, AssetVaultKey, AssetWitness};
13use miden_protocol::crypto::merkle::smt::{SMT_DEPTH, Smt, SmtForest};
14use miden_protocol::crypto::merkle::{EmptySubtreeRoots, MerkleError};
15use miden_protocol::{EMPTY_WORD, Word};
16
17use super::StoreError;
18
19#[derive(Debug, Default, Clone, Eq, PartialEq)]
25pub struct AccountSmtForest {
26 forest: SmtForest,
27 account_roots: BTreeMap<AccountId, Vec<Word>>,
29 pending_old_roots: BTreeMap<AccountId, Vec<Vec<Word>>>,
31 root_refcounts: BTreeMap<Word, usize>,
33}
34
35impl AccountSmtForest {
36 pub fn new() -> Self {
37 Self::default()
38 }
39
40 pub fn get_roots(&self, account_id: &AccountId) -> Option<&Vec<Word>> {
45 self.account_roots.get(account_id)
46 }
47
48 pub fn get_asset_and_witness(
50 &self,
51 vault_root: Word,
52 vault_key: AssetVaultKey,
53 ) -> Result<(Asset, AssetWitness), StoreError> {
54 let vault_key_word = vault_key.into();
55 let proof = self.forest.open(vault_root, vault_key_word)?;
56 let asset_word =
57 proof.get(&vault_key_word).ok_or(MerkleError::UntrackedKey(vault_key_word))?;
58 if asset_word == EMPTY_WORD {
59 return Err(MerkleError::UntrackedKey(vault_key_word).into());
60 }
61
62 let asset = Asset::try_from(asset_word)?;
63 let witness = AssetWitness::new(proof)?;
64 Ok((asset, witness))
65 }
66
67 pub fn get_storage_map_item_witness(
69 &self,
70 map_root: Word,
71 key: StorageMapKey,
72 ) -> Result<StorageMapWitness, StoreError> {
73 let hashed_key = key.hash().as_word();
74 let proof = self.forest.open(map_root, hashed_key).map_err(StoreError::from)?;
75 Ok(StorageMapWitness::new(proof, [key])?)
76 }
77
78 pub fn stage_roots(&mut self, account_id: AccountId, new_roots: Vec<Word>) {
87 increment_refcounts(&mut self.root_refcounts, &new_roots);
88 if let Some(old_roots) = self.account_roots.insert(account_id, new_roots) {
89 self.pending_old_roots.entry(account_id).or_default().push(old_roots);
90 }
91 }
92
93 pub fn commit_roots(&mut self, account_id: AccountId) {
95 if let Some(old_roots_stack) = self.pending_old_roots.remove(&account_id) {
96 for old_roots in old_roots_stack {
97 let to_pop = decrement_refcounts(&mut self.root_refcounts, &old_roots);
98 self.safe_pop_smts(to_pop);
99 }
100 }
101 }
102
103 pub fn discard_roots(&mut self, account_id: AccountId) {
109 let old_roots = self.pending_old_roots.get_mut(&account_id).and_then(Vec::pop);
110
111 let new_roots = match old_roots {
113 Some(old_roots) => self.account_roots.insert(account_id, old_roots),
114 None => self.account_roots.remove(&account_id),
115 };
116
117 if let Some(new_roots) = new_roots {
118 let to_pop = decrement_refcounts(&mut self.root_refcounts, &new_roots);
119 self.safe_pop_smts(to_pop);
120 }
121
122 if self.pending_old_roots.get(&account_id).is_some_and(Vec::is_empty) {
124 self.pending_old_roots.remove(&account_id);
125 }
126 }
127
128 pub fn replace_roots(&mut self, account_id: AccountId, new_roots: Vec<Word>) {
137 assert!(
138 !self.pending_old_roots.contains_key(&account_id),
139 "cannot replace roots while staged changes are pending for account {account_id}"
140 );
141 increment_refcounts(&mut self.root_refcounts, &new_roots);
142 if let Some(old_roots) = self.account_roots.insert(account_id, new_roots) {
143 let to_pop = decrement_refcounts(&mut self.root_refcounts, &old_roots);
144 self.safe_pop_smts(to_pop);
145 }
146 }
147
148 pub fn update_asset_nodes(
153 &mut self,
154 root: Word,
155 new_assets: impl Iterator<Item = Asset>,
156 removed_vault_keys: impl Iterator<Item = AssetVaultKey>,
157 ) -> Result<Word, StoreError> {
158 let entries: Vec<(Word, Word)> = new_assets
159 .map(|asset| {
160 let key: Word = asset.vault_key().into();
161 let value: Word = asset.into();
162 (key, value)
163 })
164 .chain(removed_vault_keys.map(|key| (key.into(), EMPTY_WORD)))
165 .collect();
166
167 if entries.is_empty() {
168 return Ok(root);
169 }
170
171 let new_root = self.forest.batch_insert(root, entries).map_err(StoreError::from)?;
172 Ok(new_root)
173 }
174
175 pub fn update_storage_map_nodes(
177 &mut self,
178 root: Word,
179 entries: impl Iterator<Item = (StorageMapKey, Word)>,
180 ) -> Result<Word, StoreError> {
181 let entries: Vec<(Word, Word)> =
182 entries.map(|(key, value)| (key.hash().as_word(), value)).collect();
183
184 if entries.is_empty() {
185 return Ok(root);
186 }
187
188 let new_root = self.forest.batch_insert(root, entries).map_err(StoreError::from)?;
189 Ok(new_root)
190 }
191
192 pub fn insert_asset_nodes(&mut self, vault: &AssetVault) -> Result<(), StoreError> {
194 let smt = Smt::with_entries(vault.assets().map(|asset| {
195 let key: Word = asset.vault_key().into();
196 let value: Word = asset.into();
197 (key, value)
198 }))
199 .map_err(StoreError::from)?;
200
201 let empty_root = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
202 let entries: Vec<(Word, Word)> = smt.entries().map(|(k, v)| (*k, *v)).collect();
203 if entries.is_empty() {
204 return Ok(());
205 }
206 let new_root = self.forest.batch_insert(empty_root, entries).map_err(StoreError::from)?;
207 debug_assert_eq!(new_root, smt.root());
208 Ok(())
209 }
210
211 pub fn insert_storage_map_nodes(&mut self, storage: &AccountStorage) -> Result<(), StoreError> {
213 let maps = storage.slots().iter().filter_map(|slot| match slot.content() {
214 StorageSlotContent::Map(map) => Some(map),
215 StorageSlotContent::Value(_) => None,
216 });
217
218 for map in maps {
219 self.insert_storage_map_nodes_for_map(map)?;
220 }
221 Ok(())
222 }
223
224 pub fn insert_account_state(
227 &mut self,
228 vault: &AssetVault,
229 storage: &AccountStorage,
230 ) -> Result<(), StoreError> {
231 self.insert_storage_map_nodes(storage)?;
232 self.insert_asset_nodes(vault)?;
233 Ok(())
234 }
235
236 pub fn insert_and_register_account_state(
239 &mut self,
240 account_id: AccountId,
241 vault: &AssetVault,
242 storage: &AccountStorage,
243 ) -> Result<(), StoreError> {
244 self.insert_account_state(vault, storage)?;
245 let roots = Self::collect_account_roots(vault, storage);
246 self.replace_roots(account_id, roots);
247 Ok(())
248 }
249
250 pub fn insert_storage_map_nodes_for_map(&mut self, map: &StorageMap) -> Result<(), StoreError> {
251 let empty_root = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
252 let entries: Vec<(Word, Word)> =
253 map.entries().map(|(k, v)| (k.hash().as_word(), *v)).collect();
254 if entries.is_empty() {
255 return Ok(());
256 }
257 self.forest.batch_insert(empty_root, entries).map_err(StoreError::from)?;
258 Ok(())
259 }
260
261 fn collect_account_roots(vault: &AssetVault, storage: &AccountStorage) -> Vec<Word> {
266 let mut roots = vec![vault.root()];
267 for slot in storage.slots() {
268 if let StorageSlotContent::Map(map) = slot.content() {
269 roots.push(map.root());
270 }
271 }
272 roots
273 }
274
275 fn safe_pop_smts(&mut self, roots: impl IntoIterator<Item = Word>) {
277 self.forest.pop_smts(roots);
278 }
279}
280
281fn increment_refcounts(refcounts: &mut BTreeMap<Word, usize>, roots: &[Word]) {
282 for root in roots {
283 *refcounts.entry(*root).or_insert(0) += 1;
284 }
285}
286
287fn decrement_refcounts(refcounts: &mut BTreeMap<Word, usize>, roots: &[Word]) -> Vec<Word> {
289 let mut to_pop = Vec::new();
290 for root in roots {
291 if let Some(count) = refcounts.get_mut(root) {
292 *count -= 1;
293 if *count == 0 {
294 refcounts.remove(root);
295 to_pop.push(*root);
296 }
297 }
298 }
299 to_pop
300}
301
302#[cfg(test)]
303mod tests {
304 use miden_protocol::testing::account_id::{
305 ACCOUNT_ID_PUBLIC_FUNGIBLE_FAUCET,
306 ACCOUNT_ID_PUBLIC_NON_FUNGIBLE_FAUCET,
307 };
308 use miden_protocol::{ONE, ZERO};
309
310 use super::*;
311
312 fn account_a() -> AccountId {
313 AccountId::try_from(ACCOUNT_ID_PUBLIC_FUNGIBLE_FAUCET).unwrap()
314 }
315
316 fn account_b() -> AccountId {
317 AccountId::try_from(ACCOUNT_ID_PUBLIC_NON_FUNGIBLE_FAUCET).unwrap()
318 }
319
320 fn insert_map(forest: &mut AccountSmtForest, key: Word, value: Word) -> Word {
323 let mut map = StorageMap::new();
324 map.insert(StorageMapKey::new(key), value).unwrap();
325 forest.insert_storage_map_nodes_for_map(&map).unwrap();
326 map.root()
327 }
328
329 fn root_is_live(forest: &AccountSmtForest, root: Word, key: Word) -> bool {
331 forest.get_storage_map_item_witness(root, StorageMapKey::new(key)).is_ok()
332 }
333
334 #[test]
335 fn stage_then_commit_releases_old_roots() {
336 let mut forest = AccountSmtForest::new();
337 let id = account_a();
338
339 let key1: Word = [ONE, ZERO, ZERO, ZERO].into();
340 let key2: Word = [ZERO, ONE, ZERO, ZERO].into();
341 let val: Word = [ONE, ONE, ONE, ONE].into();
342
343 let root1 = insert_map(&mut forest, key1, val);
344 let root2 = insert_map(&mut forest, key2, val);
345
346 forest.replace_roots(id, vec![root1]);
348 assert_eq!(forest.get_roots(&id), Some(&vec![root1]));
349
350 forest.stage_roots(id, vec![root2]);
352 assert_eq!(forest.get_roots(&id), Some(&vec![root2]));
353
354 assert!(root_is_live(&forest, root1, key1));
356 assert!(root_is_live(&forest, root2, key2));
357
358 forest.commit_roots(id);
360 assert_eq!(forest.get_roots(&id), Some(&vec![root2]));
361 assert!(!root_is_live(&forest, root1, key1));
362 assert!(root_is_live(&forest, root2, key2));
363 }
364
365 #[test]
366 fn stage_then_discard_restores_old_roots() {
367 let mut forest = AccountSmtForest::new();
368 let id = account_a();
369
370 let key1: Word = [ONE, ZERO, ZERO, ZERO].into();
371 let key2: Word = [ZERO, ONE, ZERO, ZERO].into();
372 let val: Word = [ONE, ONE, ONE, ONE].into();
373
374 let root1 = insert_map(&mut forest, key1, val);
375 let root2 = insert_map(&mut forest, key2, val);
376
377 forest.replace_roots(id, vec![root1]);
378
379 forest.stage_roots(id, vec![root2]);
381 forest.discard_roots(id);
382
383 assert_eq!(forest.get_roots(&id), Some(&vec![root1]));
384 assert!(root_is_live(&forest, root1, key1));
385 assert!(!root_is_live(&forest, root2, key2));
386 }
387
388 #[test]
389 fn shared_root_survives_single_account_replacement() {
390 let mut forest = AccountSmtForest::new();
391 let id1 = account_a();
392 let id2 = account_b();
393
394 let key: Word = [ONE, ZERO, ZERO, ZERO].into();
395 let val: Word = [ONE, ONE, ONE, ONE].into();
396 let shared_root = insert_map(&mut forest, key, val);
397
398 forest.replace_roots(id1, vec![shared_root]);
400 forest.replace_roots(id2, vec![shared_root]);
401
402 let key2: Word = [ZERO, ONE, ZERO, ZERO].into();
404 let other_root = insert_map(&mut forest, key2, val);
405 forest.replace_roots(id1, vec![other_root]);
406
407 assert!(root_is_live(&forest, shared_root, key));
409
410 forest.replace_roots(id2, vec![other_root]);
412 assert!(!root_is_live(&forest, shared_root, key));
413 }
414
415 #[test]
416 fn multiple_stages_discard_one_at_a_time() {
417 let mut forest = AccountSmtForest::new();
418 let id = account_a();
419
420 let key_a: Word = [ONE, ZERO, ZERO, ZERO].into();
421 let key_b: Word = [ZERO, ONE, ZERO, ZERO].into();
422 let key_c: Word = [ZERO, ZERO, ONE, ZERO].into();
423 let val: Word = [ONE, ONE, ONE, ONE].into();
424
425 let root_a = insert_map(&mut forest, key_a, val);
426 let root_b = insert_map(&mut forest, key_b, val);
427 let root_c = insert_map(&mut forest, key_c, val);
428
429 forest.replace_roots(id, vec![root_a]);
431 forest.stage_roots(id, vec![root_b]);
432 forest.stage_roots(id, vec![root_c]);
433 assert_eq!(forest.get_roots(&id), Some(&vec![root_c]));
434
435 forest.discard_roots(id);
437 assert_eq!(forest.get_roots(&id), Some(&vec![root_b]));
438 assert!(!root_is_live(&forest, root_c, key_c));
439 assert!(root_is_live(&forest, root_b, key_b));
440 assert!(root_is_live(&forest, root_a, key_a));
441
442 forest.discard_roots(id);
444 assert_eq!(forest.get_roots(&id), Some(&vec![root_a]));
445 assert!(!root_is_live(&forest, root_b, key_b));
446 assert!(root_is_live(&forest, root_a, key_a));
447 }
448
449 #[test]
450 fn multiple_stages_commit_releases_all_old() {
451 let mut forest = AccountSmtForest::new();
452 let id = account_a();
453
454 let key_a: Word = [ONE, ZERO, ZERO, ZERO].into();
455 let key_b: Word = [ZERO, ONE, ZERO, ZERO].into();
456 let key_c: Word = [ZERO, ZERO, ONE, ZERO].into();
457 let val: Word = [ONE, ONE, ONE, ONE].into();
458
459 let root_a = insert_map(&mut forest, key_a, val);
460 let root_b = insert_map(&mut forest, key_b, val);
461 let root_c = insert_map(&mut forest, key_c, val);
462
463 forest.replace_roots(id, vec![root_a]);
465 forest.stage_roots(id, vec![root_b]);
466 forest.stage_roots(id, vec![root_c]);
467 forest.commit_roots(id);
468
469 assert_eq!(forest.get_roots(&id), Some(&vec![root_c]));
471 assert!(!root_is_live(&forest, root_a, key_a));
472 assert!(!root_is_live(&forest, root_b, key_b));
473 assert!(root_is_live(&forest, root_c, key_c));
474 }
475
476 #[test]
477 fn unchanged_root_survives_stage_commit() {
478 let mut forest = AccountSmtForest::new();
479 let id = account_a();
480
481 let key1: Word = [ONE, ZERO, ZERO, ZERO].into();
482 let key2: Word = [ZERO, ONE, ZERO, ZERO].into();
483 let val: Word = [ONE, ONE, ONE, ONE].into();
484
485 let shared_root = insert_map(&mut forest, key1, val);
486 let changing_root = insert_map(&mut forest, key2, val);
487
488 forest.replace_roots(id, vec![shared_root, changing_root]);
490
491 let key3: Word = [ZERO, ZERO, ONE, ZERO].into();
493 let new_root = insert_map(&mut forest, key3, val);
494 forest.stage_roots(id, vec![shared_root, new_root]);
495 forest.commit_roots(id);
496
497 assert!(root_is_live(&forest, shared_root, key1));
499 assert!(!root_is_live(&forest, changing_root, key2));
501 assert!(root_is_live(&forest, new_root, key3));
503 }
504}