use std::collections::{BTreeMap, HashMap};
#[cfg(feature = "rocksdb")]
use miden_large_smt_backend_rocksdb::RocksDbStorage;
use miden_protocol::account::{AccountId, AccountIdPrefix};
use miden_protocol::block::BlockNumber;
use miden_protocol::block::account_tree::{AccountMutationSet, AccountTree, AccountWitness};
use miden_protocol::crypto::merkle::smt::{
LargeSmt,
LeafIndex,
MemoryStorage,
NodeMutation,
SMT_DEPTH,
SmtLeaf,
SmtStorage,
};
use miden_protocol::crypto::merkle::{
EmptySubtreeRoots,
MerkleError,
MerklePath,
NodeIndex,
SparseMerklePath,
};
use miden_protocol::errors::AccountTreeError;
use miden_protocol::{EMPTY_WORD, Word};
use tracing::instrument;
use crate::COMPONENT;
#[cfg(test)]
mod tests;
pub type InMemoryAccountTree = AccountTree<LargeSmt<MemoryStorage>>;
#[cfg(feature = "rocksdb")]
pub type PersistentAccountTree = AccountTree<LargeSmt<RocksDbStorage>>;
#[expect(missing_docs)]
#[derive(thiserror::Error, Debug)]
pub enum HistoricalError {
#[error(transparent)]
MerkleError(#[from] MerkleError),
#[error(transparent)]
AccountTreeError(#[from] AccountTreeError),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum HistoricalSelector {
Future,
At(BlockNumber),
Latest,
TooAncient,
}
#[derive(Debug, Clone)]
struct HistoricalOverlay {
block_number: BlockNumber,
root: Word,
node_mutations: HashMap<NodeIndex, Word>,
account_updates: HashMap<LeafIndex<SMT_DEPTH>, (Word, Word)>,
}
impl HistoricalOverlay {
fn new(block_number: BlockNumber, rev_set: AccountMutationSet) -> Self {
let root = rev_set.as_mutation_set().root();
let mut_set = rev_set.into_mutation_set();
let node_mutations =
HashMap::from_iter(mut_set.node_mutations().iter().map(|(node_index, mutation)| {
match mutation {
NodeMutation::Addition(inner_node) => (*node_index, inner_node.hash()),
NodeMutation::Removal => {
let empty_root = *EmptySubtreeRoots::entry(SMT_DEPTH, node_index.depth());
(*node_index, empty_root)
},
}
}));
let account_updates = HashMap::from_iter(
mut_set.new_pairs().iter().map(|(&k, &v)| (LeafIndex::from(k), (k, v))),
);
Self {
block_number,
root,
node_mutations,
account_updates,
}
}
}
#[derive(Debug)]
pub struct AccountTreeWithHistory<S: SmtStorage> {
block_number: BlockNumber,
latest: AccountTree<LargeSmt<S>>,
overlays: BTreeMap<BlockNumber, HistoricalOverlay>,
}
impl<S: SmtStorage> AccountTreeWithHistory<S> {
pub const MAX_HISTORY: usize = 50;
pub fn new(account_tree: AccountTree<LargeSmt<S>>, block_number: BlockNumber) -> Self {
Self {
block_number,
latest: account_tree,
overlays: BTreeMap::new(),
}
}
fn drain_excess(overlays: &mut BTreeMap<BlockNumber, HistoricalOverlay>) {
while overlays.len() > Self::MAX_HISTORY {
overlays.pop_first();
}
}
pub fn block_number_latest(&self) -> BlockNumber {
self.block_number
}
pub fn root_latest(&self) -> Word {
self.latest.root()
}
pub fn root_at(&self, block_number: BlockNumber) -> Option<Word> {
match self.historical_selector(block_number) {
HistoricalSelector::Latest => Some(self.latest.root()),
HistoricalSelector::At(block_number) => {
let overlay = self.overlays.get(&block_number)?;
debug_assert_eq!(overlay.block_number, block_number);
Some(overlay.root)
},
HistoricalSelector::Future | HistoricalSelector::TooAncient => None,
}
}
pub fn num_accounts_latest(&self) -> usize {
self.latest.num_accounts()
}
pub fn history_len(&self) -> usize {
self.overlays.len()
}
#[instrument(target = COMPONENT, skip_all)]
pub fn open_latest(&self, account_id: AccountId) -> AccountWitness {
self.latest.open(account_id)
}
#[instrument(target = COMPONENT, skip_all)]
pub fn open_at(
&self,
account_id: AccountId,
block_number: BlockNumber,
) -> Option<AccountWitness> {
match self.historical_selector(block_number) {
HistoricalSelector::Latest => Some(self.latest.open(account_id)),
HistoricalSelector::At(block_number) => {
self.overlays.get(&block_number)?;
Self::reconstruct_historical_witness(self, account_id, block_number)
},
HistoricalSelector::Future | HistoricalSelector::TooAncient => None,
}
}
pub fn get_latest_commitment(&self, account_id: AccountId) -> Word {
self.latest.get(account_id)
}
pub fn contains_account_id_prefix_in_latest(&self, prefix: AccountIdPrefix) -> bool {
self.latest.contains_account_id_prefix(prefix)
}
fn historical_selector(&self, desired_block_number: BlockNumber) -> HistoricalSelector {
if desired_block_number == self.block_number {
return HistoricalSelector::Latest;
}
if self.block_number.checked_sub(desired_block_number.as_u32()).is_none() {
return HistoricalSelector::Future;
}
if !self.overlays.contains_key(&desired_block_number) {
return HistoricalSelector::TooAncient;
}
HistoricalSelector::At(desired_block_number)
}
#[instrument(target = COMPONENT, skip_all)]
fn reconstruct_historical_witness(
&self,
account_id: AccountId,
block_target: BlockNumber,
) -> Option<AccountWitness> {
let latest_witness = self.latest.open(account_id);
let (latest_path, leaf) = latest_witness.into_proof().into_parts();
let path_nodes = Self::initialize_path_nodes(&latest_path);
let leaf_index = NodeIndex::from(leaf.index());
let (path, leaf) = Self::apply_reversion_overlays(
self.overlays.range(block_target..).rev().map(|(_, overlay)| overlay),
path_nodes,
leaf_index,
leaf,
)?;
let commitment = match leaf {
SmtLeaf::Empty(_) => EMPTY_WORD,
SmtLeaf::Single((_, value)) => value,
SmtLeaf::Multiple(_) => unreachable!("AccountTree uses prefix-free IDs"),
};
AccountWitness::new(account_id, commitment, path).ok()
}
fn initialize_path_nodes(path: &SparseMerklePath) -> [Word; SMT_DEPTH as usize] {
let mut path_nodes: [Word; SMT_DEPTH as usize] = MerklePath::from(path.clone())
.to_vec()
.try_into()
.expect("MerklePath should have exactly SMT_DEPTH nodes");
path_nodes.reverse();
path_nodes
}
#[instrument(target = COMPONENT, skip_all)]
fn apply_reversion_overlays<'a>(
overlays: impl IntoIterator<Item = &'a HistoricalOverlay>,
mut path_nodes: [Word; SMT_DEPTH as usize],
leaf_index: NodeIndex,
mut leaf: SmtLeaf,
) -> Option<(SparseMerklePath, SmtLeaf)> {
for overlay in overlays {
for sibling in leaf_index.proof_indices() {
let height = sibling
.depth()
.checked_sub(1) .expect("proof_indices should not include root")
as usize;
if let Some(hash) = overlay.node_mutations.get(&sibling) {
path_nodes[height] = *hash;
}
}
if let Some(&(key, value)) = overlay.account_updates.get(&leaf.index()) {
leaf = if value == EMPTY_WORD {
SmtLeaf::new_empty(leaf.index())
} else {
SmtLeaf::new_single(key, value)
};
}
}
let dense: Vec<Word> = path_nodes.iter().rev().copied().collect();
let path = MerklePath::new(dense);
let path = SparseMerklePath::try_from(path).ok()?;
Some((path, leaf))
}
pub fn compute_and_apply_mutations(
&mut self,
account_commitments: impl IntoIterator<Item = (AccountId, Word)>,
) -> Result<(), HistoricalError> {
let mutations = self.compute_mutations(account_commitments)?;
self.apply_mutations(mutations)
}
pub fn compute_mutations(
&self,
account_commitments: impl IntoIterator<Item = (AccountId, Word)>,
) -> Result<AccountMutationSet, HistoricalError> {
Ok(self.latest.compute_mutations(account_commitments)?)
}
#[instrument(target = COMPONENT, skip_all)]
pub fn apply_mutations(
&mut self,
mutations: AccountMutationSet,
) -> Result<(), HistoricalError> {
let rev = self.latest.apply_mutations_with_reversion(mutations)?;
let block_num = self.block_number;
let overlay = HistoricalOverlay::new(block_num, rev);
self.overlays.insert(block_num, overlay);
self.block_number = block_num.child();
Self::drain_excess(&mut self.overlays);
Ok(())
}
}