use crate::{errors::DatabaseError, precompiles::PrecompileCache};
use ethrex_common::{
Address, H256, U256,
types::{AccountState, ChainConfig, Code, CodeMetadata},
};
#[cfg(all(feature = "rayon", not(feature = "eip-8025")))]
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
use rustc_hash::FxHashMap;
use std::sync::{Arc, OnceLock, PoisonError, RwLock, RwLockReadGuard, RwLockWriteGuard};
pub mod gen_db;
type AccountCache = FxHashMap<Address, AccountState>;
type StorageCache = FxHashMap<(Address, H256), U256>;
type CodeCache = FxHashMap<H256, Code>;
pub trait Database: Send + Sync {
fn get_account_state(&self, address: Address) -> Result<AccountState, DatabaseError>;
fn get_storage_value(&self, address: Address, key: H256) -> Result<U256, DatabaseError>;
fn get_block_hash(&self, block_number: u64) -> Result<H256, DatabaseError>;
fn get_chain_config(&self) -> Result<ChainConfig, DatabaseError>;
fn get_account_code(&self, code_hash: H256) -> Result<Code, DatabaseError>;
fn get_code_metadata(&self, code_hash: H256) -> Result<CodeMetadata, DatabaseError>;
fn precompile_cache(&self) -> Option<&PrecompileCache> {
None
}
fn prefetch_accounts(&self, addresses: &[Address]) -> Result<(), DatabaseError> {
for &addr in addresses {
self.get_account_state(addr)?;
}
Ok(())
}
fn prefetch_storage(&self, keys: &[(Address, H256)]) -> Result<(), DatabaseError> {
for &(addr, key) in keys {
self.get_storage_value(addr, key)?;
}
Ok(())
}
}
pub struct CachingDatabase {
inner: Arc<dyn Database>,
accounts: RwLock<AccountCache>,
storage: RwLock<StorageCache>,
code: RwLock<CodeCache>,
precompile_cache: Option<PrecompileCache>,
chain_config: OnceLock<ChainConfig>,
}
impl CachingDatabase {
pub fn new(inner: Arc<dyn Database>, precompile_cache_enabled: bool) -> Self {
Self {
inner,
accounts: RwLock::new(FxHashMap::default()),
storage: RwLock::new(FxHashMap::default()),
code: RwLock::new(FxHashMap::default()),
precompile_cache: precompile_cache_enabled.then(PrecompileCache::new),
chain_config: OnceLock::new(),
}
}
fn read_accounts(&self) -> Result<RwLockReadGuard<'_, AccountCache>, DatabaseError> {
self.accounts.read().map_err(poison_error_to_db_error)
}
fn write_accounts(&self) -> Result<RwLockWriteGuard<'_, AccountCache>, DatabaseError> {
self.accounts.write().map_err(poison_error_to_db_error)
}
fn read_storage(&self) -> Result<RwLockReadGuard<'_, StorageCache>, DatabaseError> {
self.storage.read().map_err(poison_error_to_db_error)
}
fn write_storage(&self) -> Result<RwLockWriteGuard<'_, StorageCache>, DatabaseError> {
self.storage.write().map_err(poison_error_to_db_error)
}
fn read_code(&self) -> Result<RwLockReadGuard<'_, CodeCache>, DatabaseError> {
self.code.read().map_err(poison_error_to_db_error)
}
fn write_code(&self) -> Result<RwLockWriteGuard<'_, CodeCache>, DatabaseError> {
self.code.write().map_err(poison_error_to_db_error)
}
}
fn poison_error_to_db_error<T>(err: PoisonError<T>) -> DatabaseError {
DatabaseError::Custom(format!("Cache lock poisoned: {err}"))
}
impl Database for CachingDatabase {
fn get_account_state(&self, address: Address) -> Result<AccountState, DatabaseError> {
if let Some(state) = self.read_accounts()?.get(&address).copied() {
return Ok(state);
}
let state = self.inner.get_account_state(address)?;
self.write_accounts()?.insert(address, state);
Ok(state)
}
fn get_storage_value(&self, address: Address, key: H256) -> Result<U256, DatabaseError> {
if let Some(value) = self.read_storage()?.get(&(address, key)).copied() {
return Ok(value);
}
let value = self.inner.get_storage_value(address, key)?;
self.write_storage()?.insert((address, key), value);
Ok(value)
}
fn get_block_hash(&self, block_number: u64) -> Result<H256, DatabaseError> {
self.inner.get_block_hash(block_number)
}
fn get_chain_config(&self) -> Result<ChainConfig, DatabaseError> {
if let Some(cfg) = self.chain_config.get() {
return Ok(*cfg);
}
let cfg = self.inner.get_chain_config()?;
let _ = self.chain_config.set(cfg);
Ok(*self.chain_config.get().unwrap_or(&cfg))
}
fn get_account_code(&self, code_hash: H256) -> Result<Code, DatabaseError> {
if let Some(code) = self.read_code()?.get(&code_hash).cloned() {
return Ok(code);
}
let code = self.inner.get_account_code(code_hash)?;
self.write_code()?.insert(code_hash, code.clone());
Ok(code)
}
fn get_code_metadata(&self, code_hash: H256) -> Result<CodeMetadata, DatabaseError> {
self.inner.get_code_metadata(code_hash)
}
fn precompile_cache(&self) -> Option<&PrecompileCache> {
self.precompile_cache.as_ref()
}
#[cfg(all(feature = "rayon", not(feature = "eip-8025")))]
fn prefetch_accounts(&self, addresses: &[Address]) -> Result<(), DatabaseError> {
let fetched: Vec<(Address, AccountState)> = addresses
.par_iter()
.map(|&addr| self.inner.get_account_state(addr).map(|s| (addr, s)))
.collect::<Result<_, _>>()?;
let mut cache = self.write_accounts()?;
for (addr, state) in fetched {
cache.entry(addr).or_insert(state);
}
Ok(())
}
#[cfg(all(feature = "rayon", not(feature = "eip-8025")))]
fn prefetch_storage(&self, keys: &[(Address, H256)]) -> Result<(), DatabaseError> {
let fetched: Vec<((Address, H256), U256)> = keys
.par_iter()
.map(|&(addr, key)| {
self.inner
.get_storage_value(addr, key)
.map(|v| ((addr, key), v))
})
.collect::<Result<_, _>>()?;
let mut cache = self.write_storage()?;
for (key, value) in fetched {
cache.entry(key).or_insert(value);
}
Ok(())
}
}