use crate::db::sync::{ConcurrentCacheState, ConcurrentStateError};
use alloy::primitives::{Address, B256, U256};
use dashmap::mapref::one::RefMut;
use revm::{
database::{
states::{bundle_state::BundleRetention, plain_account::PlainStorage, CacheAccount},
BundleState, State, TransitionAccount, TransitionState,
},
state::{Account, AccountInfo, Bytecode},
Database, DatabaseCommit, DatabaseRef,
};
use std::{
collections::{hash_map, BTreeMap},
sync::{Arc, RwLock},
};
pub type Child<Db> = ConcurrentState<Arc<ConcurrentState<Db>>>;
#[derive(Debug)]
pub struct ConcurrentState<Db> {
database: Db,
pub info: ConcurrentStateInfo,
}
impl<Db> From<State<Db>> for ConcurrentState<Db>
where
Db: DatabaseRef,
{
fn from(value: State<Db>) -> Self {
Self {
database: value.database,
info: ConcurrentStateInfo {
cache: value.cache.into(),
transition_state: value.transition_state,
bundle_state: value.bundle_state,
use_preloaded_bundle: value.use_preloaded_bundle,
block_hashes: value.block_hashes.into(),
},
}
}
}
impl<Db> ConcurrentState<Db> {
pub fn into_parts(self) -> (Db, ConcurrentStateInfo) {
(self.database, self.info)
}
pub fn bundle_size_hint(&self) -> usize {
self.info.bundle_state.size_hint()
}
pub const fn set_state_clear_flag(&mut self, has_state_clear: bool) {
self.info.cache.set_state_clear_flag(has_state_clear);
}
pub fn insert_not_existing(&mut self, address: Address) {
self.info.cache.insert_not_existing(address)
}
pub fn insert_account(&mut self, address: Address, info: AccountInfo) {
self.info.cache.insert_account(address, info)
}
pub fn insert_account_with_storage(
&mut self,
address: Address,
info: AccountInfo,
storage: PlainStorage,
) {
self.info.cache.insert_account_with_storage(address, info, storage)
}
pub fn apply_transition(&mut self, transitions: Vec<(Address, TransitionAccount)>) {
if let Some(s) = self.info.transition_state.as_mut() {
s.add_transitions(transitions)
}
}
pub fn merge_transitions(&mut self, retention: BundleRetention) {
if let Some(transition_state) = self.info.transition_state.take() {
self.info
.bundle_state
.apply_transitions_and_create_reverts(transition_state, retention);
}
}
pub fn take_bundle(&mut self) -> BundleState {
core::mem::take(&mut self.info.bundle_state)
}
}
impl<Db: DatabaseRef + Sync> ConcurrentState<Db> {
pub const fn new(database: Db, info: ConcurrentStateInfo) -> Self {
Self { database, info }
}
pub fn increment_balances(
&mut self,
balances: impl IntoIterator<Item = (Address, u128)>,
) -> Result<(), Db::Error> {
let mut transitions = Vec::new();
for (address, balance) in balances {
if balance == 0 {
continue;
}
let mut original_account = self.load_cache_account_mut(address)?;
transitions.push((
address,
original_account.increment_balance(balance).expect("Balance is not zero"),
))
}
if let Some(s) = self.info.transition_state.as_mut() {
s.add_transitions(transitions)
}
Ok(())
}
pub fn drain_balances(
&mut self,
addresses: impl IntoIterator<Item = Address>,
) -> Result<Vec<u128>, Db::Error> {
let mut transitions = Vec::new();
let mut balances = Vec::new();
for address in addresses {
let mut original_account = self.load_cache_account_mut(address)?;
let (balance, transition) = original_account.drain_balance();
balances.push(balance);
transitions.push((address, transition))
}
if let Some(s) = self.info.transition_state.as_mut() {
s.add_transitions(transitions)
}
Ok(balances)
}
pub fn load_cache_account_mut(
&self,
address: Address,
) -> Result<RefMut<'_, Address, CacheAccount>, Db::Error> {
match self.info.cache.accounts.entry(address) {
dashmap::Entry::Vacant(entry) => {
if self.info.use_preloaded_bundle {
if let Some(account) =
self.info.bundle_state.account(&address).cloned().map(Into::into)
{
return Ok(entry.insert(account));
}
}
let info = self.database.basic_ref(address)?;
let account = match info {
None => CacheAccount::new_loaded_not_existing(),
Some(acc) if acc.is_empty() => {
CacheAccount::new_loaded_empty_eip161(PlainStorage::default())
}
Some(acc) => CacheAccount::new_loaded(acc, PlainStorage::default()),
};
Ok(entry.insert(account))
}
dashmap::Entry::Occupied(entry) => Ok(entry.into_ref()),
}
}
pub fn child(self: &Arc<Self>) -> Child<Db>
where
Db: Send,
{
ConcurrentState::new(self.clone(), Default::default())
}
pub fn merge_child(self: &mut Arc<Self>, child: Child<Db>) -> Result<(), ConcurrentStateError> {
self.can_merge(&child)?;
let (_, info) = child.into_parts();
let this = Arc::get_mut(self).ok_or_else(ConcurrentStateError::not_unique)?;
this.info.cache.absorb(info.cache);
Ok(())
}
pub fn can_merge(self: &Arc<Self>, child: &Child<Db>) -> Result<(), ConcurrentStateError> {
if !self.is_parent(child) {
return Err(ConcurrentStateError::not_parent());
}
if Arc::strong_count(self) != 2 {
return Err(ConcurrentStateError::not_unique());
}
Ok(())
}
pub fn is_parent(self: &Arc<Self>, child: &Child<Db>) -> bool {
Arc::ptr_eq(self, &child.database)
}
}
impl<Db: DatabaseRef + Sync> DatabaseRef for ConcurrentState<Db> {
type Error = Db::Error;
fn basic_ref(&self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
self.load_cache_account_mut(address).map(|a| a.account_info())
}
fn code_by_hash_ref(&self, code_hash: B256) -> Result<Bytecode, Self::Error> {
let res = match self.info.cache.contracts.entry(code_hash) {
dashmap::Entry::Occupied(entry) => Ok(entry.get().clone()),
dashmap::Entry::Vacant(entry) => {
if self.info.use_preloaded_bundle {
if let Some(code) = self.info.bundle_state.contracts.get(&code_hash) {
entry.insert(code.clone());
return Ok(code.clone());
}
}
let code = self.database.code_by_hash_ref(code_hash)?;
entry.insert(code.clone());
Ok(code)
}
};
res
}
fn storage_ref(&self, address: Address, index: U256) -> Result<U256, Self::Error> {
if let Some(mut account) = self.info.cache.accounts.get_mut(&address) {
let is_storage_known = account.status.is_storage_known();
Ok(account
.account
.as_mut()
.map(|account| match account.storage.entry(index) {
hash_map::Entry::Occupied(entry) => Ok(*entry.get()),
hash_map::Entry::Vacant(entry) => {
let value = if is_storage_known {
U256::ZERO
} else {
self.database.storage_ref(address, index)?
};
entry.insert(value);
Ok(value)
}
})
.transpose()?
.unwrap_or_default())
} else {
unreachable!("For accessing any storage account is guaranteed to be loaded beforehand")
}
}
fn block_hash_ref(&self, number: u64) -> Result<B256, Self::Error> {
{
let hashes = self.info.block_hashes.read().unwrap();
if let Some(hash) = hashes.get(&number) {
return Ok(*hash);
}
}
let hash = self.database.block_hash_ref(number)?;
let mut hashes = self.info.block_hashes.write().unwrap();
hashes.insert(number, hash);
let last_block = number.saturating_sub(revm::primitives::BLOCK_HASH_HISTORY);
let mut hashes = self.info.block_hashes.write().unwrap();
let to_retain = hashes.split_off(&last_block);
*hashes = to_retain;
Ok(hash)
}
}
impl<Db: DatabaseRef + Sync> DatabaseCommit for ConcurrentState<Db> {
fn commit(&mut self, evm_state: revm::primitives::HashMap<Address, Account>) {
let transitions = self.info.cache.apply_evm_state(evm_state);
self.apply_transition(transitions);
}
}
impl<Db: DatabaseRef + Sync> Database for ConcurrentState<Db> {
type Error = <Self as DatabaseRef>::Error;
fn basic(&mut self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
self.basic_ref(address)
}
fn code_by_hash(&mut self, code_hash: B256) -> Result<Bytecode, Self::Error> {
self.code_by_hash_ref(code_hash)
}
fn storage(&mut self, address: Address, index: U256) -> Result<U256, Self::Error> {
self.storage_ref(address, index)
}
fn block_hash(&mut self, number: u64) -> Result<B256, Self::Error> {
self.block_hash_ref(number)
}
}
#[derive(Debug, Default)]
pub struct ConcurrentStateInfo {
pub cache: ConcurrentCacheState,
pub transition_state: Option<TransitionState>,
pub bundle_state: BundleState,
pub use_preloaded_bundle: bool,
pub block_hashes: RwLock<BTreeMap<u64, B256>>,
}
#[cfg(test)]
mod test {
use super::*;
use revm::database::EmptyDB;
#[test]
const fn assert_child_trait_impls() {
const fn assert_database_ref<T: DatabaseRef>() {}
const fn assert_database_commit<T: DatabaseCommit>() {}
const fn assert_database<T: Database>() {}
assert_database_ref::<Child<EmptyDB>>();
assert_database_commit::<Child<EmptyDB>>();
assert_database::<Child<EmptyDB>>();
}
#[test]
fn merge_child() {
let addr = Address::repeat_byte(1);
let mut parent = Arc::new(ConcurrentState::new(EmptyDB::new(), Default::default()));
let mut child = parent.child();
child.increment_balances([(addr, 100)]).unwrap();
assert!(parent.load_cache_account_mut(addr).unwrap().value().account_info().is_none());
assert_eq!(
child.load_cache_account_mut(addr).unwrap().value().account_info().unwrap().balance,
U256::from(100)
);
assert_eq!(Arc::strong_count(&parent), 2);
let child_2 = parent.child();
assert_eq!(parent.can_merge(&child_2).unwrap_err(), ConcurrentStateError::not_unique());
assert_eq!(parent.merge_child(child_2).unwrap_err(), ConcurrentStateError::not_unique());
let parent_2 = Arc::new(ConcurrentState::new(EmptyDB::new(), Default::default()));
let child_2 = parent_2.child();
assert_eq!(parent.merge_child(child_2).unwrap_err(), ConcurrentStateError::not_parent());
parent.can_merge(&child).unwrap();
parent.merge_child(child).unwrap();
assert_eq!(
parent.load_cache_account_mut(addr).unwrap().value().account_info().unwrap().balance,
U256::from(100)
);
}
#[test]
fn test_increment_balances() {
let addr = Address::repeat_byte(10);
let mut state = ConcurrentState::new(EmptyDB::new(), Default::default());
state.increment_balances([(addr, 1000)]).unwrap();
assert_eq!(state.basic(addr).unwrap().unwrap().balance, U256::from(1000));
state.increment_balances([(addr, 500)]).unwrap();
assert_eq!(state.basic(addr).unwrap().unwrap().balance, U256::from(1500));
}
#[test]
fn test_drain_balances() {
let addr = Address::repeat_byte(11);
let mut state = ConcurrentState::new(EmptyDB::new(), Default::default());
state.increment_balances([(addr, 1000)]).unwrap();
let drained = state.drain_balances([addr]).unwrap();
assert_eq!(drained, vec![1000]);
assert_eq!(state.basic(addr).unwrap().unwrap().balance, U256::ZERO);
}
#[test]
fn test_insert_account() {
let addr = Address::repeat_byte(12);
let mut state = ConcurrentState::new(EmptyDB::new(), Default::default());
let info = AccountInfo { balance: U256::from(999), nonce: 42, ..Default::default() };
state.insert_account(addr, info);
let loaded = state.basic(addr).unwrap().unwrap();
assert_eq!(loaded.balance, U256::from(999));
assert_eq!(loaded.nonce, 42);
}
#[test]
fn test_child_isolation() {
let addr = Address::repeat_byte(14);
let parent = Arc::new(ConcurrentState::new(EmptyDB::new(), Default::default()));
let mut child = parent.child();
child.increment_balances([(addr, 500)]).unwrap();
assert_eq!(
child.info.cache.accounts.get(&addr).unwrap().account_info().unwrap().balance,
U256::from(500)
);
assert_eq!(
parent.info.cache.accounts.get(&addr).unwrap().account_info(),
None );
let mut parent = parent;
parent.merge_child(child).unwrap();
assert_eq!(parent.basic_ref(addr).unwrap().unwrap().balance, U256::from(500));
}
#[test]
fn test_child_merge_accumulates() {
let addr = Address::repeat_byte(15);
let mut parent = Arc::new(ConcurrentState::new(EmptyDB::new(), Default::default()));
Arc::get_mut(&mut parent).unwrap().increment_balances([(addr, 100)]).unwrap();
let mut child = parent.child();
child.increment_balances([(addr, 50)]).unwrap();
assert_eq!(child.basic(addr).unwrap().unwrap().balance, U256::from(150));
assert_eq!(parent.basic_ref(addr).unwrap().unwrap().balance, U256::from(100));
parent.merge_child(child).unwrap();
assert_eq!(parent.basic_ref(addr).unwrap().unwrap().balance, U256::from(150));
}
}