use revm::{
database::{states::bundle_state::BundleRetention, BundleState, Cache, CacheDB, State},
primitives::B256,
Database,
};
use std::{collections::BTreeMap, convert::Infallible, sync::Arc};
pub trait DbConnect: Sync {
type Database: Database;
type Error: core::error::Error;
fn connect(&self) -> Result<Self::Database, Self::Error>;
}
impl<Db> DbConnect for Db
where
Db: Database + Clone + Sync,
{
type Database = Self;
type Error = Infallible;
fn connect(&self) -> Result<Self::Database, Self::Error> {
Ok(self.clone())
}
}
pub trait StateAcc {
fn set_state_clear_flag(&mut self, flag: bool);
fn merge_transitions(&mut self, retention: BundleRetention);
fn take_bundle(&mut self) -> BundleState;
fn set_block_hashes(&mut self, block_hashes: &BTreeMap<u64, B256>);
}
impl<Db: Database> StateAcc for State<Db> {
fn set_state_clear_flag(&mut self, flag: bool) {
Self::set_state_clear_flag(self, flag)
}
fn merge_transitions(&mut self, retention: BundleRetention) {
Self::merge_transitions(self, retention)
}
fn take_bundle(&mut self) -> BundleState {
Self::take_bundle(self)
}
fn set_block_hashes(&mut self, block_hashes: &BTreeMap<u64, B256>) {
self.block_hashes.extend(block_hashes)
}
}
pub trait TryStateAcc: Sync {
type Error: core::error::Error;
fn try_set_state_clear_flag(&mut self, flag: bool) -> Result<(), Self::Error>;
fn try_merge_transitions(&mut self, retention: BundleRetention) -> Result<(), Self::Error>;
fn try_take_bundle(&mut self) -> Result<BundleState, Self::Error>;
fn try_set_block_hashes(
&mut self,
block_hashes: &BTreeMap<u64, B256>,
) -> Result<(), Self::Error>;
}
impl<Db> TryStateAcc for Db
where
Db: StateAcc + Sync,
{
type Error = Infallible;
fn try_set_state_clear_flag(&mut self, flag: bool) -> Result<(), Infallible> {
self.set_state_clear_flag(flag);
Ok(())
}
fn try_merge_transitions(&mut self, retention: BundleRetention) -> Result<(), Infallible> {
self.merge_transitions(retention);
Ok(())
}
fn try_take_bundle(&mut self) -> Result<BundleState, Infallible> {
Ok(self.take_bundle())
}
fn try_set_block_hashes(
&mut self,
block_hashes: &BTreeMap<u64, B256>,
) -> Result<(), Infallible> {
self.set_block_hashes(block_hashes);
Ok(())
}
}
#[derive(thiserror::Error, Debug, Clone, Copy, PartialEq, Eq)]
pub enum ArcUpgradeError {
#[error("Arc reference is not unique, cannot mutate")]
NotUnique,
}
impl<Db> TryStateAcc for Arc<Db>
where
Db: StateAcc + Sync + Send,
{
type Error = ArcUpgradeError;
fn try_set_state_clear_flag(&mut self, flag: bool) -> Result<(), ArcUpgradeError> {
Self::get_mut(self).ok_or(ArcUpgradeError::NotUnique)?.set_state_clear_flag(flag);
Ok(())
}
fn try_merge_transitions(&mut self, retention: BundleRetention) -> Result<(), ArcUpgradeError> {
Self::get_mut(self).ok_or(ArcUpgradeError::NotUnique)?.merge_transitions(retention);
Ok(())
}
fn try_take_bundle(&mut self) -> Result<BundleState, ArcUpgradeError> {
Ok(Self::get_mut(self).ok_or(ArcUpgradeError::NotUnique)?.take_bundle())
}
fn try_set_block_hashes(
&mut self,
block_hashes: &BTreeMap<u64, B256>,
) -> Result<(), ArcUpgradeError> {
Self::get_mut(self).ok_or(ArcUpgradeError::NotUnique)?.set_block_hashes(block_hashes);
Ok(())
}
}
pub trait CachingDb {
fn cache(&self) -> &Cache;
fn cache_mut(&mut self) -> &mut Cache;
fn into_cache(self) -> Cache;
fn extend_ref(&mut self, cache: &Cache) {
self.cache_mut().accounts.extend(cache.accounts.iter().map(|(k, v)| (*k, v.clone())));
self.cache_mut().contracts.extend(cache.contracts.iter().map(|(k, v)| (*k, v.clone())));
self.cache_mut().logs.extend(cache.logs.iter().cloned());
self.cache_mut().block_hashes.extend(cache.block_hashes.iter().map(|(k, v)| (*k, *v)));
}
fn extend(&mut self, cache: Cache) {
self.cache_mut().accounts.extend(cache.accounts);
self.cache_mut().contracts.extend(cache.contracts);
self.cache_mut().logs.extend(cache.logs);
self.cache_mut().block_hashes.extend(cache.block_hashes);
}
}
impl<Db> CachingDb for CacheDB<Db> {
fn cache(&self) -> &Cache {
&self.cache
}
fn cache_mut(&mut self) -> &mut Cache {
&mut self.cache
}
fn into_cache(self) -> Cache {
self.cache
}
}
pub trait TryCachingDb {
type Error: core::error::Error;
fn cache(&self) -> &Cache;
fn try_cache_mut(&mut self) -> Result<&mut Cache, Self::Error>;
fn try_into_cache(self) -> Result<Cache, Self::Error>;
fn try_extend_ref(&mut self, cache: &Cache) -> Result<(), Self::Error>
where
Self: Sized,
{
let inner_cache = self.try_cache_mut()?;
inner_cache.accounts.extend(cache.accounts.iter().map(|(k, v)| (*k, v.clone())));
inner_cache.contracts.extend(cache.contracts.iter().map(|(k, v)| (*k, v.clone())));
inner_cache.logs.extend(cache.logs.iter().cloned());
inner_cache.block_hashes.extend(cache.block_hashes.iter().map(|(k, v)| (*k, *v)));
Ok(())
}
fn try_extend(&mut self, cache: Cache) -> Result<(), Self::Error>
where
Self: Sized,
{
let inner_cache = self.try_cache_mut()?;
inner_cache.accounts.extend(cache.accounts);
inner_cache.contracts.extend(cache.contracts);
inner_cache.logs.extend(cache.logs);
inner_cache.block_hashes.extend(cache.block_hashes);
Ok(())
}
}
impl<Db> TryCachingDb for Arc<Db>
where
Db: CachingDb,
{
type Error = ArcUpgradeError;
fn cache(&self) -> &Cache {
self.as_ref().cache()
}
fn try_cache_mut(&mut self) -> Result<&mut Cache, Self::Error> {
Self::get_mut(self).ok_or(ArcUpgradeError::NotUnique).map(|db| db.cache_mut())
}
fn try_into_cache(self) -> Result<Cache, Self::Error> {
Self::into_inner(self).ok_or(ArcUpgradeError::NotUnique).map(|db| db.into_cache())
}
}