1use alloc::collections::BTreeMap;
2use alloc::vec::Vec;
3
4use miden_protocol::account::{
5 AccountId,
6 AccountStorage,
7 StorageMap,
8 StorageMapKey,
9 StorageMapWitness,
10 StorageSlotContent,
11};
12use miden_protocol::asset::{Asset, AssetVault, AssetVaultKey, AssetWitness};
13use miden_protocol::crypto::merkle::smt::{SMT_DEPTH, Smt, SmtForest};
14use miden_protocol::crypto::merkle::{EmptySubtreeRoots, MerkleError};
15use miden_protocol::{EMPTY_WORD, Word};
16
17use super::StoreError;
18
19#[derive(Debug, Default, Clone, Eq, PartialEq)]
25pub struct AccountSmtForest {
26 forest: SmtForest,
27 account_roots: BTreeMap<AccountId, Vec<Word>>,
29 pending_old_roots: BTreeMap<AccountId, Vec<Vec<Word>>>,
31 root_refcounts: BTreeMap<Word, usize>,
33}
34
35impl AccountSmtForest {
36 pub fn new() -> Self {
37 Self::default()
38 }
39
40 pub fn get_roots(&self, account_id: &AccountId) -> Option<&Vec<Word>> {
45 self.account_roots.get(account_id)
46 }
47
48 pub fn get_asset_and_witness(
50 &self,
51 vault_root: Word,
52 vault_key: AssetVaultKey,
53 ) -> Result<(Asset, AssetWitness), StoreError> {
54 let vault_key_word = vault_key.into();
55 let proof = self.forest.open(vault_root, vault_key_word)?;
56 let asset_word =
57 proof.get(&vault_key_word).ok_or(MerkleError::UntrackedKey(vault_key_word))?;
58 if asset_word == EMPTY_WORD {
59 return Err(MerkleError::UntrackedKey(vault_key_word).into());
60 }
61
62 let asset = Asset::from_key_value_words(vault_key_word, asset_word)?;
63 let witness = AssetWitness::new(proof)?;
64 Ok((asset, witness))
65 }
66
67 pub fn get_storage_map_item_witness(
69 &self,
70 map_root: Word,
71 key: StorageMapKey,
72 ) -> Result<StorageMapWitness, StoreError> {
73 let hashed_key = key.hash().as_word();
74 let proof = self.forest.open(map_root, hashed_key).map_err(StoreError::from)?;
75 Ok(StorageMapWitness::new(proof, [key])?)
76 }
77
78 pub fn stage_roots(&mut self, account_id: AccountId, new_roots: Vec<Word>) {
87 increment_refcounts(&mut self.root_refcounts, &new_roots);
88 if let Some(old_roots) = self.account_roots.insert(account_id, new_roots) {
89 self.pending_old_roots.entry(account_id).or_default().push(old_roots);
90 }
91 }
92
93 pub fn commit_roots(&mut self, account_id: AccountId) {
95 if let Some(old_roots_stack) = self.pending_old_roots.remove(&account_id) {
96 for old_roots in old_roots_stack {
97 let to_pop = decrement_refcounts(&mut self.root_refcounts, &old_roots);
98 self.safe_pop_smts(to_pop);
99 }
100 }
101 }
102
103 pub fn discard_roots(&mut self, account_id: AccountId) {
109 let old_roots = self.pending_old_roots.get_mut(&account_id).and_then(Vec::pop);
110
111 let new_roots = match old_roots {
113 Some(old_roots) => self.account_roots.insert(account_id, old_roots),
114 None => self.account_roots.remove(&account_id),
115 };
116
117 if let Some(new_roots) = new_roots {
118 let to_pop = decrement_refcounts(&mut self.root_refcounts, &new_roots);
119 self.safe_pop_smts(to_pop);
120 }
121
122 if self.pending_old_roots.get(&account_id).is_some_and(Vec::is_empty) {
124 self.pending_old_roots.remove(&account_id);
125 }
126 }
127
128 pub fn replace_roots(&mut self, account_id: AccountId, new_roots: Vec<Word>) {
137 assert!(
138 !self.pending_old_roots.contains_key(&account_id),
139 "cannot replace roots while staged changes are pending for account {account_id}"
140 );
141 increment_refcounts(&mut self.root_refcounts, &new_roots);
142 if let Some(old_roots) = self.account_roots.insert(account_id, new_roots) {
143 let to_pop = decrement_refcounts(&mut self.root_refcounts, &old_roots);
144 self.safe_pop_smts(to_pop);
145 }
146 }
147
148 pub fn update_asset_nodes(
153 &mut self,
154 root: Word,
155 new_assets: impl Iterator<Item = Asset>,
156 removed_vault_keys: impl Iterator<Item = AssetVaultKey>,
157 ) -> Result<Word, StoreError> {
158 let entries: Vec<(Word, Word)> = new_assets
159 .map(|asset| {
160 let key: Word = asset.vault_key().into();
161 let value = asset.to_value_word();
162 (key, value)
163 })
164 .chain(removed_vault_keys.map(|vault_key| (vault_key.into(), EMPTY_WORD)))
165 .collect();
166
167 if entries.is_empty() {
168 return Ok(root);
169 }
170
171 let new_root = self.forest.batch_insert(root, entries).map_err(StoreError::from)?;
172 Ok(new_root)
173 }
174
175 pub fn update_storage_map_nodes(
177 &mut self,
178 root: Word,
179 entries: impl Iterator<Item = (StorageMapKey, Word)>,
180 ) -> Result<Word, StoreError> {
181 let entries: Vec<(Word, Word)> =
182 entries.map(|(key, value)| (key.hash().as_word(), value)).collect();
183
184 if entries.is_empty() {
185 return Ok(root);
186 }
187
188 let new_root = self.forest.batch_insert(root, entries).map_err(StoreError::from)?;
189 Ok(new_root)
190 }
191
192 pub fn insert_asset_nodes(&mut self, vault: &AssetVault) -> Result<(), StoreError> {
194 let smt = Smt::with_entries(vault.assets().map(|asset| {
195 let key: Word = asset.vault_key().into();
196 let value = asset.to_value_word();
197 (key, value)
198 }))
199 .map_err(StoreError::from)?;
200
201 let empty_root = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
202 let entries: Vec<(Word, Word)> = smt.entries().map(|(k, v)| (*k, *v)).collect();
203 if entries.is_empty() {
204 return Ok(());
205 }
206 let new_root = self.forest.batch_insert(empty_root, entries).map_err(StoreError::from)?;
207 debug_assert_eq!(new_root, smt.root());
208 Ok(())
209 }
210
211 pub fn insert_storage_map_nodes(&mut self, storage: &AccountStorage) -> Result<(), StoreError> {
213 let maps = storage.slots().iter().filter_map(|slot| match slot.content() {
214 StorageSlotContent::Map(map) => Some(map),
215 StorageSlotContent::Value(_) => None,
216 });
217
218 for map in maps {
219 self.insert_storage_map_nodes_for_map(map)?;
220 }
221 Ok(())
222 }
223
224 pub fn insert_account_state(
227 &mut self,
228 vault: &AssetVault,
229 storage: &AccountStorage,
230 ) -> Result<(), StoreError> {
231 self.insert_storage_map_nodes(storage)?;
232 self.insert_asset_nodes(vault)?;
233 Ok(())
234 }
235
236 pub fn insert_and_stage_account_state(
239 &mut self,
240 account_id: AccountId,
241 vault: &AssetVault,
242 storage: &AccountStorage,
243 ) -> Result<(), StoreError> {
244 self.insert_account_state(vault, storage)?;
245 let roots = Self::collect_account_roots(vault, storage);
246 self.stage_roots(account_id, roots);
247 Ok(())
248 }
249
250 pub fn insert_and_register_account_state(
253 &mut self,
254 account_id: AccountId,
255 vault: &AssetVault,
256 storage: &AccountStorage,
257 ) -> Result<(), StoreError> {
258 self.insert_account_state(vault, storage)?;
259 let roots = Self::collect_account_roots(vault, storage);
260 self.replace_roots(account_id, roots);
261 Ok(())
262 }
263
264 pub fn insert_storage_map_nodes_for_map(&mut self, map: &StorageMap) -> Result<(), StoreError> {
266 let empty_root = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
267 let entries: Vec<(Word, Word)> =
268 map.entries().map(|(k, v)| (k.hash().as_word(), *v)).collect();
269 if entries.is_empty() {
270 return Ok(());
271 }
272 self.forest.batch_insert(empty_root, entries).map_err(StoreError::from)?;
273 Ok(())
274 }
275
276 fn collect_account_roots(vault: &AssetVault, storage: &AccountStorage) -> Vec<Word> {
281 let mut roots = vec![vault.root()];
282 for slot in storage.slots() {
283 if let StorageSlotContent::Map(map) = slot.content() {
284 roots.push(map.root());
285 }
286 }
287 roots
288 }
289
290 fn safe_pop_smts(&mut self, roots: impl IntoIterator<Item = Word>) {
292 self.forest.pop_smts(roots);
293 }
294}
295
296fn increment_refcounts(refcounts: &mut BTreeMap<Word, usize>, roots: &[Word]) {
297 for root in roots {
298 *refcounts.entry(*root).or_insert(0) += 1;
299 }
300}
301
302fn decrement_refcounts(refcounts: &mut BTreeMap<Word, usize>, roots: &[Word]) -> Vec<Word> {
304 let mut to_pop = Vec::new();
305 for root in roots {
306 if let Some(count) = refcounts.get_mut(root) {
307 *count -= 1;
308 if *count == 0 {
309 refcounts.remove(root);
310 to_pop.push(*root);
311 }
312 }
313 }
314 to_pop
315}
316
317#[cfg(test)]
318mod tests {
319 use miden_protocol::testing::account_id::{
320 ACCOUNT_ID_PUBLIC_FUNGIBLE_FAUCET,
321 ACCOUNT_ID_PUBLIC_NON_FUNGIBLE_FAUCET,
322 };
323 use miden_protocol::{ONE, ZERO};
324
325 use super::*;
326
327 fn account_a() -> AccountId {
328 AccountId::try_from(ACCOUNT_ID_PUBLIC_FUNGIBLE_FAUCET).unwrap()
329 }
330
331 fn account_b() -> AccountId {
332 AccountId::try_from(ACCOUNT_ID_PUBLIC_NON_FUNGIBLE_FAUCET).unwrap()
333 }
334
335 fn insert_map(forest: &mut AccountSmtForest, key: Word, value: Word) -> Word {
338 let mut map = StorageMap::new();
339 map.insert(StorageMapKey::new(key), value).unwrap();
340 forest.insert_storage_map_nodes_for_map(&map).unwrap();
341 map.root()
342 }
343
344 fn root_is_live(forest: &AccountSmtForest, root: Word, key: Word) -> bool {
346 forest.get_storage_map_item_witness(root, StorageMapKey::new(key)).is_ok()
347 }
348
349 #[test]
350 fn stage_then_commit_releases_old_roots() {
351 let mut forest = AccountSmtForest::new();
352 let id = account_a();
353
354 let key1: Word = [ONE, ZERO, ZERO, ZERO].into();
355 let key2: Word = [ZERO, ONE, ZERO, ZERO].into();
356 let val: Word = [ONE, ONE, ONE, ONE].into();
357
358 let root1 = insert_map(&mut forest, key1, val);
359 let root2 = insert_map(&mut forest, key2, val);
360
361 forest.replace_roots(id, vec![root1]);
363 assert_eq!(forest.get_roots(&id), Some(&vec![root1]));
364
365 forest.stage_roots(id, vec![root2]);
367 assert_eq!(forest.get_roots(&id), Some(&vec![root2]));
368
369 assert!(root_is_live(&forest, root1, key1));
371 assert!(root_is_live(&forest, root2, key2));
372
373 forest.commit_roots(id);
375 assert_eq!(forest.get_roots(&id), Some(&vec![root2]));
376 assert!(!root_is_live(&forest, root1, key1));
377 assert!(root_is_live(&forest, root2, key2));
378 }
379
380 #[test]
381 fn stage_then_discard_restores_old_roots() {
382 let mut forest = AccountSmtForest::new();
383 let id = account_a();
384
385 let key1: Word = [ONE, ZERO, ZERO, ZERO].into();
386 let key2: Word = [ZERO, ONE, ZERO, ZERO].into();
387 let val: Word = [ONE, ONE, ONE, ONE].into();
388
389 let root1 = insert_map(&mut forest, key1, val);
390 let root2 = insert_map(&mut forest, key2, val);
391
392 forest.replace_roots(id, vec![root1]);
393
394 forest.stage_roots(id, vec![root2]);
396 forest.discard_roots(id);
397
398 assert_eq!(forest.get_roots(&id), Some(&vec![root1]));
399 assert!(root_is_live(&forest, root1, key1));
400 assert!(!root_is_live(&forest, root2, key2));
401 }
402
403 #[test]
404 fn shared_root_survives_single_account_replacement() {
405 let mut forest = AccountSmtForest::new();
406 let id1 = account_a();
407 let id2 = account_b();
408
409 let key: Word = [ONE, ZERO, ZERO, ZERO].into();
410 let val: Word = [ONE, ONE, ONE, ONE].into();
411 let shared_root = insert_map(&mut forest, key, val);
412
413 forest.replace_roots(id1, vec![shared_root]);
415 forest.replace_roots(id2, vec![shared_root]);
416
417 let key2: Word = [ZERO, ONE, ZERO, ZERO].into();
419 let other_root = insert_map(&mut forest, key2, val);
420 forest.replace_roots(id1, vec![other_root]);
421
422 assert!(root_is_live(&forest, shared_root, key));
424
425 forest.replace_roots(id2, vec![other_root]);
427 assert!(!root_is_live(&forest, shared_root, key));
428 }
429
430 #[test]
431 fn multiple_stages_discard_one_at_a_time() {
432 let mut forest = AccountSmtForest::new();
433 let id = account_a();
434
435 let key_a: Word = [ONE, ZERO, ZERO, ZERO].into();
436 let key_b: Word = [ZERO, ONE, ZERO, ZERO].into();
437 let key_c: Word = [ZERO, ZERO, ONE, ZERO].into();
438 let val: Word = [ONE, ONE, ONE, ONE].into();
439
440 let root_a = insert_map(&mut forest, key_a, val);
441 let root_b = insert_map(&mut forest, key_b, val);
442 let root_c = insert_map(&mut forest, key_c, val);
443
444 forest.replace_roots(id, vec![root_a]);
446 forest.stage_roots(id, vec![root_b]);
447 forest.stage_roots(id, vec![root_c]);
448 assert_eq!(forest.get_roots(&id), Some(&vec![root_c]));
449
450 forest.discard_roots(id);
452 assert_eq!(forest.get_roots(&id), Some(&vec![root_b]));
453 assert!(!root_is_live(&forest, root_c, key_c));
454 assert!(root_is_live(&forest, root_b, key_b));
455 assert!(root_is_live(&forest, root_a, key_a));
456
457 forest.discard_roots(id);
459 assert_eq!(forest.get_roots(&id), Some(&vec![root_a]));
460 assert!(!root_is_live(&forest, root_b, key_b));
461 assert!(root_is_live(&forest, root_a, key_a));
462 }
463
464 #[test]
465 fn multiple_stages_commit_releases_all_old() {
466 let mut forest = AccountSmtForest::new();
467 let id = account_a();
468
469 let key_a: Word = [ONE, ZERO, ZERO, ZERO].into();
470 let key_b: Word = [ZERO, ONE, ZERO, ZERO].into();
471 let key_c: Word = [ZERO, ZERO, ONE, ZERO].into();
472 let val: Word = [ONE, ONE, ONE, ONE].into();
473
474 let root_a = insert_map(&mut forest, key_a, val);
475 let root_b = insert_map(&mut forest, key_b, val);
476 let root_c = insert_map(&mut forest, key_c, val);
477
478 forest.replace_roots(id, vec![root_a]);
480 forest.stage_roots(id, vec![root_b]);
481 forest.stage_roots(id, vec![root_c]);
482 forest.commit_roots(id);
483
484 assert_eq!(forest.get_roots(&id), Some(&vec![root_c]));
486 assert!(!root_is_live(&forest, root_a, key_a));
487 assert!(!root_is_live(&forest, root_b, key_b));
488 assert!(root_is_live(&forest, root_c, key_c));
489 }
490
491 #[test]
492 fn unchanged_root_survives_stage_commit() {
493 let mut forest = AccountSmtForest::new();
494 let id = account_a();
495
496 let key1: Word = [ONE, ZERO, ZERO, ZERO].into();
497 let key2: Word = [ZERO, ONE, ZERO, ZERO].into();
498 let val: Word = [ONE, ONE, ONE, ONE].into();
499
500 let shared_root = insert_map(&mut forest, key1, val);
501 let changing_root = insert_map(&mut forest, key2, val);
502
503 forest.replace_roots(id, vec![shared_root, changing_root]);
505
506 let key3: Word = [ZERO, ZERO, ONE, ZERO].into();
508 let new_root = insert_map(&mut forest, key3, val);
509 forest.stage_roots(id, vec![shared_root, new_root]);
510 forest.commit_roots(id);
511
512 assert!(root_is_live(&forest, shared_root, key1));
514 assert!(!root_is_live(&forest, changing_root, key2));
516 assert!(root_is_live(&forest, new_root, key3));
518 }
519}