use crate::db::CachingDb;
use alloy::{
consensus::constants::KECCAK_EMPTY,
primitives::{Address, B256, U256},
};
use revm::{
bytecode::Bytecode,
database::{in_memory_db::Cache, AccountState, DbAccount},
primitives::HashMap,
state::{Account, AccountInfo},
Database, DatabaseCommit, DatabaseRef,
};
use super::TryCachingDb;
#[derive(Debug)]
pub struct CacheOnWrite<Db> {
cache: Cache,
inner: Db,
}
impl<Db> Default for CacheOnWrite<Db>
where
Db: Default,
{
fn default() -> Self {
Self::new(Db::default())
}
}
impl<Db> CacheOnWrite<Db> {
pub fn new(inner: Db) -> Self {
Self { cache: Default::default(), inner }
}
pub const fn new_with_cache(inner: Db, cache: Cache) -> Self {
Self { cache, inner }
}
pub const fn inner(&self) -> &Db {
&self.inner
}
pub const fn inner_mut(&mut self) -> &mut Db {
&mut self.inner
}
pub const fn cache(&self) -> &Cache {
&self.cache
}
pub const fn cache_mut(&mut self) -> &mut Cache {
&mut self.cache
}
pub fn into_parts(self) -> (Db, Cache) {
(self.inner, self.cache)
}
pub fn into_cache(self) -> Cache {
self.cache
}
pub fn nest(self) -> CacheOnWrite<Self> {
CacheOnWrite::new(self)
}
fn insert_contract(&mut self, account: &mut AccountInfo) {
if let Some(code) = &account.code {
if !code.is_empty() {
if account.code_hash == KECCAK_EMPTY {
account.code_hash = code.hash_slow();
}
self.cache.contracts.entry(account.code_hash).or_insert_with(|| code.clone());
}
}
if account.code_hash.is_zero() {
account.code_hash = KECCAK_EMPTY;
}
}
}
impl<Db> CachingDb for CacheOnWrite<Db> {
fn cache(&self) -> &Cache {
&self.cache
}
fn cache_mut(&mut self) -> &mut Cache {
&mut self.cache
}
fn into_cache(self) -> Cache {
self.cache
}
}
impl<Db> CacheOnWrite<Db>
where
Db: CachingDb,
{
pub fn flatten(self) -> Db {
let Self { cache, mut inner } = self;
inner.extend(cache);
inner
}
}
impl<Db> CacheOnWrite<Db>
where
Db: TryCachingDb,
{
pub fn try_flatten(self) -> Result<Db, Db::Error> {
let Self { cache, mut inner } = self;
inner.try_extend(cache)?;
Ok(inner)
}
}
impl<Db: DatabaseRef> Database for CacheOnWrite<Db> {
type Error = Db::Error;
fn basic(&mut self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
if let Some(account) = self.cache.accounts.get(&address).map(DbAccount::info) {
return Ok(account);
}
self.inner.basic_ref(address)
}
fn code_by_hash(&mut self, code_hash: B256) -> Result<Bytecode, Self::Error> {
Self::code_by_hash_ref(self, code_hash)
}
fn storage(&mut self, address: Address, index: U256) -> Result<U256, Self::Error> {
Self::storage_ref(self, address, index)
}
fn block_hash(&mut self, number: u64) -> Result<B256, Self::Error> {
Self::block_hash_ref(self, number)
}
}
impl<Db: DatabaseRef> DatabaseRef for CacheOnWrite<Db> {
type Error = Db::Error;
fn basic_ref(&self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
if let Some(account) = self.cache.accounts.get(&address).map(DbAccount::info) {
return Ok(account);
}
self.inner.basic_ref(address)
}
fn code_by_hash_ref(&self, code_hash: B256) -> Result<Bytecode, Self::Error> {
if let Some(code) = self.cache.contracts.get(&code_hash) {
return Ok(code.clone());
}
self.inner.code_by_hash_ref(code_hash)
}
fn storage_ref(&self, address: Address, index: U256) -> Result<U256, Self::Error> {
if let Some(storage) =
self.cache.accounts.get(&address).and_then(|a| a.storage.get(&index).cloned())
{
return Ok(storage);
}
self.inner.storage_ref(address, index)
}
fn block_hash_ref(&self, number: u64) -> Result<B256, Self::Error> {
if let Some(hash) = self.cache.block_hashes.get(&U256::from(number)) {
return Ok(*hash);
}
self.inner.block_hash_ref(number)
}
}
impl<Db> DatabaseCommit for CacheOnWrite<Db> {
fn commit(&mut self, changes: HashMap<Address, Account>) {
for (address, mut account) in changes {
if !account.is_touched() {
continue;
}
if account.is_selfdestructed() {
let db_account = self.cache.accounts.entry(address).or_default();
db_account.storage.clear();
db_account.account_state = AccountState::NotExisting;
db_account.info = AccountInfo::default();
continue;
}
let is_newly_created = account.is_created();
self.insert_contract(&mut account.info);
let db_account = self.cache.accounts.entry(address).or_default();
db_account.info = account.info;
db_account.account_state = if is_newly_created {
db_account.storage.clear();
AccountState::StorageCleared
} else if db_account.account_state.is_storage_cleared() {
AccountState::StorageCleared
} else {
AccountState::Touched
};
db_account.storage.extend(
account.storage.into_iter().map(|(key, value)| (key, value.present_value())),
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::db::test_utils::DbTestExt;
use revm::{
bytecode::Bytecode,
database::InMemoryDB,
state::{AccountStatus, EvmStorageSlot},
};
#[test]
fn state_isolation_regression() {
let mut mem_db = InMemoryDB::default();
let address = Address::with_last_byte(42);
let original_info = AccountInfo { balance: U256::from(5000), ..Default::default() };
let account = Account {
original_info: Box::new(original_info.clone()),
info: original_info,
storage: [(U256::from(0), EvmStorageSlot::new_changed(U256::ZERO, U256::from(42), 0))]
.into_iter()
.collect(),
transaction_id: 0,
status: AccountStatus::Touched,
};
let mut changes: std::collections::HashMap<
Address,
Account,
revm::primitives::map::DefaultHashBuilder,
> = Default::default();
changes.insert(address, account);
mem_db.commit(changes);
assert_eq!(mem_db.storage(address, U256::from(0)).unwrap(), U256::from(42));
assert_eq!(mem_db.storage(address, U256::from(1)).unwrap(), U256::ZERO);
let mut cow_db = CacheOnWrite::new(mem_db);
assert_eq!(cow_db.storage(address, U256::from(0)).unwrap(), U256::from(42));
assert_eq!(cow_db.storage(address, U256::from(1)).unwrap(), U256::ZERO);
let original_info = AccountInfo { balance: U256::from(5000), ..Default::default() };
let account = Account {
original_info: Box::new(original_info.clone()),
info: original_info,
storage: [(U256::from(1), EvmStorageSlot::new_changed(U256::ZERO, U256::from(42), 0))]
.into_iter()
.collect(),
transaction_id: 0,
status: AccountStatus::Touched,
};
let mut changes: std::collections::HashMap<
Address,
Account,
revm::primitives::map::DefaultHashBuilder,
> = Default::default();
changes.insert(address, account);
cow_db.commit(changes);
assert_eq!(cow_db.storage(address, U256::from(0)).unwrap(), U256::from(42));
assert_eq!(cow_db.storage(address, U256::from(1)).unwrap(), U256::from(42));
}
#[test]
fn test_set_balance() {
let addr = Address::repeat_byte(1);
let mut cow = CacheOnWrite::new(InMemoryDB::default());
cow.test_set_balance(addr, U256::from(1000));
assert_eq!(cow.basic(addr).unwrap().unwrap().balance, U256::from(1000));
cow.test_increase_balance(addr, U256::from(500));
assert_eq!(cow.basic(addr).unwrap().unwrap().balance, U256::from(1500));
cow.test_decrease_balance(addr, U256::from(200));
assert_eq!(cow.basic(addr).unwrap().unwrap().balance, U256::from(1300));
}
#[test]
fn test_set_nonce() {
let addr = Address::repeat_byte(2);
let mut cow = CacheOnWrite::new(InMemoryDB::default());
cow.test_set_nonce(addr, 42);
assert_eq!(cow.basic(addr).unwrap().unwrap().nonce, 42);
}
#[test]
fn test_set_storage() {
let addr = Address::repeat_byte(3);
let mut cow = CacheOnWrite::new(InMemoryDB::default());
cow.test_set_storage(addr, U256::from(10), U256::from(999));
assert_eq!(cow.storage(addr, U256::from(10)).unwrap(), U256::from(999));
cow.test_set_storage(addr, U256::from(20), U256::from(888));
assert_eq!(cow.storage(addr, U256::from(10)).unwrap(), U256::from(999));
assert_eq!(cow.storage(addr, U256::from(20)).unwrap(), U256::from(888));
}
#[test]
fn test_set_bytecode() {
let addr = Address::repeat_byte(4);
let mut cow = CacheOnWrite::new(InMemoryDB::default());
let code =
Bytecode::new_raw(alloy::primitives::Bytes::from_static(&[0x60, 0x00, 0x60, 0x00]));
let code_hash = code.hash_slow();
cow.test_set_bytecode(addr, code.clone());
let info = cow.basic(addr).unwrap().unwrap();
assert_eq!(info.code_hash, code_hash);
assert_eq!(cow.code_by_hash(code_hash).unwrap(), code);
}
#[test]
fn test_cow_isolates_writes_from_inner() {
let addr = Address::repeat_byte(5);
let mut inner = InMemoryDB::default();
inner.test_set_balance(addr, U256::from(100));
let mut cow = CacheOnWrite::new(inner);
cow.test_set_balance(addr, U256::from(500));
assert_eq!(cow.basic(addr).unwrap().unwrap().balance, U256::from(500));
assert_eq!(cow.inner().basic_ref(addr).unwrap().unwrap().balance, U256::from(100));
}
#[test]
fn test_nested_cow_flatten() {
let addr1 = Address::repeat_byte(6);
let addr2 = Address::repeat_byte(7);
let mut inner = InMemoryDB::default();
inner.test_set_balance(addr1, U256::from(100));
let mut cow = CacheOnWrite::new(inner);
cow.test_set_balance(addr2, U256::from(50));
let mut nested = cow.nest();
nested.test_set_balance(addr1, U256::from(200));
let mut cow = nested.flatten();
assert_eq!(cow.basic(addr1).unwrap().unwrap().balance, U256::from(200));
assert_eq!(cow.basic(addr2).unwrap().unwrap().balance, U256::from(50));
let inner = cow.flatten();
assert_eq!(inner.basic_ref(addr1).unwrap().unwrap().balance, U256::from(200));
assert_eq!(inner.basic_ref(addr2).unwrap().unwrap().balance, U256::from(50));
}
}