use core::{
error::Error,
fmt::Display,
ops::{Deref, DerefMut},
};
use primitives::{Address, StorageKey, StorageValue, B256};
use state::{
bal::{alloy::AlloyBal, Bal, BalError, BlockAccessIndex},
Account, AccountId, AccountInfo, Bytecode, EvmState,
};
use std::sync::Arc;
use crate::{DBErrorMarker, Database, DatabaseCommit};
#[derive(Clone, Default, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct BalState {
pub bal: Option<Arc<Bal>>,
pub bal_builder: Option<Bal>,
pub bal_index: BlockAccessIndex,
}
impl BalState {
#[inline]
pub fn new() -> Self {
Self::default()
}
#[inline]
pub const fn reset_bal_index(&mut self) {
self.bal_index = BlockAccessIndex::PRE_EXECUTION;
}
#[inline]
pub const fn bump_bal_index(&mut self) {
self.bal_index.increment();
}
#[inline]
pub const fn bal_index(&self) -> BlockAccessIndex {
self.bal_index
}
#[inline]
pub fn bal(&self) -> Option<Arc<Bal>> {
self.bal.clone()
}
#[inline]
pub fn bal_builder(&self) -> Option<Bal> {
self.bal_builder.clone()
}
#[inline]
pub fn with_bal(mut self, bal: Arc<Bal>) -> Self {
self.bal = Some(bal);
self
}
#[inline]
pub fn with_bal_builder(mut self) -> Self {
self.bal_builder = Some(Bal::new());
self
}
#[inline]
pub const fn take_built_bal(&mut self) -> Option<Bal> {
self.reset_bal_index();
self.bal_builder.take()
}
#[inline]
pub fn take_built_alloy_bal(&mut self) -> Option<AlloyBal> {
self.take_built_bal().map(|bal| bal.into_alloy_bal())
}
#[inline]
pub fn get_account_id(&self, address: &Address) -> Result<Option<AccountId>, BalError> {
self.bal
.as_ref()
.map(|bal| {
bal.accounts
.get_full(address)
.map(|i| AccountId::new(i.0).expect("too many bals"))
.ok_or(BalError::AccountNotFound { address: *address })
})
.transpose()
}
#[inline]
pub fn basic(
&self,
address: Address,
basic: &mut Option<AccountInfo>,
) -> Result<bool, BalError> {
let Some(account_id) = self.get_account_id(&address)? else {
return Ok(false);
};
self.basic_by_account_id(account_id, basic)
}
#[inline]
pub fn basic_by_account_id(
&self,
account_id: AccountId,
basic: &mut Option<AccountInfo>,
) -> Result<bool, BalError> {
let Some(bal) = &self.bal else {
return Ok(false);
};
let is_none = basic.is_none();
let mut bal_basic = core::mem::take(basic).unwrap_or_default();
let changed = bal.populate_account_info(account_id, self.bal_index, &mut bal_basic)?;
if !changed && is_none {
return Ok(true);
}
*basic = Some(bal_basic);
Ok(true)
}
#[inline]
pub fn storage(
&self,
account: &Address,
storage_key: StorageKey,
) -> Result<Option<StorageValue>, BalError> {
let Some(bal) = &self.bal else {
return Ok(None);
};
let Some(bal_account) = bal.accounts.get(account) else {
return Err(BalError::AccountNotFound { address: *account });
};
Ok(bal_account
.storage
.get_bal_writes(account, storage_key)?
.get(self.bal_index))
}
#[inline]
pub fn storage_by_account_id(
&self,
account_id: AccountId,
storage_key: StorageKey,
) -> Result<Option<StorageValue>, BalError> {
let Some(bal) = &self.bal else {
return Ok(None);
};
let Some((address, bal_account)) = bal.accounts.get_index(account_id.get()) else {
return Err(BalError::InvalidAccountId { account_id });
};
Ok(bal_account
.storage
.get_bal_writes(address, storage_key)?
.get(self.bal_index))
}
#[inline]
pub fn commit(&mut self, changes: &EvmState) {
if let Some(bal_builder) = &mut self.bal_builder {
for (address, account) in changes.iter() {
bal_builder.update_account(self.bal_index, *address, account);
}
}
}
#[inline]
pub fn commit_one(&mut self, address: Address, account: &Account) {
if let Some(bal_builder) = &mut self.bal_builder {
bal_builder.update_account(self.bal_index, address, account);
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct BalDatabase<DB> {
pub bal_state: BalState,
pub db: DB,
}
impl<DB> Deref for BalDatabase<DB> {
type Target = DB;
fn deref(&self) -> &Self::Target {
&self.db
}
}
impl<DB> DerefMut for BalDatabase<DB> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.db
}
}
impl<DB> BalDatabase<DB> {
#[inline]
pub fn new(db: DB) -> Self {
Self {
bal_state: BalState::default(),
db,
}
}
#[inline]
pub fn with_bal_option(self, bal: Option<Arc<Bal>>) -> Self {
Self {
bal_state: BalState {
bal,
..self.bal_state
},
..self
}
}
#[inline]
pub fn with_bal_builder(self) -> Self {
Self {
bal_state: self.bal_state.with_bal_builder(),
..self
}
}
#[inline]
pub const fn reset_bal_index(mut self) -> Self {
self.bal_state.reset_bal_index();
self
}
#[inline]
pub const fn bump_bal_index(&mut self) {
self.bal_state.bump_bal_index();
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum EvmDatabaseError<ERROR> {
Bal(BalError),
Database(ERROR),
}
impl<ERROR> From<BalError> for EvmDatabaseError<ERROR> {
fn from(error: BalError) -> Self {
Self::Bal(error)
}
}
impl<ERROR: core::error::Error + Send + Sync + 'static> DBErrorMarker for EvmDatabaseError<ERROR> {}
impl<ERROR: Display> Display for EvmDatabaseError<ERROR> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::Bal(error) => write!(f, "Bal error: {error}"),
Self::Database(error) => write!(f, "Database error: {error}"),
}
}
}
impl<ERROR: Error> Error for EvmDatabaseError<ERROR> {}
impl<ERROR> EvmDatabaseError<ERROR> {
pub fn into_external_error(self) -> ERROR {
match self {
Self::Bal(_) => panic!("Expected database error, got BAL error"),
Self::Database(error) => error,
}
}
}
impl<DB: Database> Database for BalDatabase<DB> {
type Error = EvmDatabaseError<DB::Error>;
#[inline]
fn basic(&mut self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
let account_id = self.bal_state.get_account_id(&address)?;
let mut account = self.db.basic(address).map_err(EvmDatabaseError::Database)?;
if let Some(account_id) = account_id {
self.bal_state
.basic_by_account_id(account_id, &mut account)?;
}
Ok(account)
}
#[inline]
fn code_by_hash(&mut self, code_hash: B256) -> Result<Bytecode, Self::Error> {
self.db
.code_by_hash(code_hash)
.map_err(EvmDatabaseError::Database)
}
#[inline]
fn storage(&mut self, address: Address, key: StorageKey) -> Result<StorageValue, Self::Error> {
if let Some(storage) = self.bal_state.storage(&address, key)? {
return Ok(storage);
}
self.db
.storage(address, key)
.map_err(EvmDatabaseError::Database)
}
#[inline]
fn storage_by_account_id(
&mut self,
address: Address,
account_id: AccountId,
storage_key: StorageKey,
) -> Result<StorageValue, Self::Error> {
if let Some(value) = self
.bal_state
.storage_by_account_id(account_id, storage_key)?
{
return Ok(value);
}
self.db
.storage(address, storage_key)
.map_err(EvmDatabaseError::Database)
}
fn block_hash(&mut self, number: u64) -> Result<B256, Self::Error> {
self.db
.block_hash(number)
.map_err(EvmDatabaseError::Database)
}
}
impl<DB: DatabaseCommit> DatabaseCommit for BalDatabase<DB> {
fn commit(&mut self, changes: EvmState) {
self.bal_state.commit(&changes);
self.db.commit(changes);
}
fn commit_iter(&mut self, changes: &mut dyn Iterator<Item = (Address, Account)>) {
let bal_state = &mut self.bal_state;
let mut changes = changes.map(|(address, account)| {
bal_state.commit_one(address, &account);
(address, account)
});
self.db.commit_iter(&mut changes);
}
}