1use core::{
3 error::Error,
4 fmt::Display,
5 ops::{Deref, DerefMut},
6};
7use primitives::{Address, StorageKey, StorageValue, B256};
8use state::{
9 bal::{alloy::AlloyBal, Bal, BalError, BlockAccessIndex},
10 Account, AccountId, AccountInfo, Bytecode, EvmState,
11};
12use std::sync::Arc;
13
14use crate::{DBErrorMarker, Database, DatabaseCommit};
15
16#[derive(Clone, Default, Debug, PartialEq, Eq)]
18#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
19pub struct BalState {
20 pub bal: Option<Arc<Bal>>,
22 pub bal_builder: Option<Bal>,
25 pub bal_index: BlockAccessIndex,
28}
29
30impl BalState {
31 #[inline]
33 pub fn new() -> Self {
34 Self::default()
35 }
36
37 #[inline]
39 pub const fn reset_bal_index(&mut self) {
40 self.bal_index = BlockAccessIndex::PRE_EXECUTION;
41 }
42
43 #[inline]
45 pub const fn bump_bal_index(&mut self) {
46 self.bal_index.increment();
47 }
48
49 #[inline]
51 pub const fn bal_index(&self) -> BlockAccessIndex {
52 self.bal_index
53 }
54
55 #[inline]
57 pub fn bal(&self) -> Option<Arc<Bal>> {
58 self.bal.clone()
59 }
60
61 #[inline]
63 pub fn bal_builder(&self) -> Option<Bal> {
64 self.bal_builder.clone()
65 }
66
67 #[inline]
69 pub fn with_bal(mut self, bal: Arc<Bal>) -> Self {
70 self.bal = Some(bal);
71 self
72 }
73
74 #[inline]
76 pub fn with_bal_builder(mut self) -> Self {
77 self.bal_builder = Some(Bal::new());
78 self
79 }
80
81 #[inline]
83 pub const fn take_built_bal(&mut self) -> Option<Bal> {
84 self.reset_bal_index();
85 self.bal_builder.take()
86 }
87
88 #[inline]
90 pub fn take_built_alloy_bal(&mut self) -> Option<AlloyBal> {
91 self.take_built_bal().map(|bal| bal.into_alloy_bal())
92 }
93
94 #[inline]
98 pub fn get_account_id(&self, address: &Address) -> Result<Option<AccountId>, BalError> {
99 self.bal
100 .as_ref()
101 .map(|bal| {
102 bal.accounts
103 .get_full(address)
104 .map(|i| AccountId::new(i.0).expect("too many bals"))
105 .ok_or(BalError::AccountNotFound { address: *address })
106 })
107 .transpose()
108 }
109
110 #[inline]
116 pub fn basic(
117 &self,
118 address: Address,
119 basic: &mut Option<AccountInfo>,
120 ) -> Result<bool, BalError> {
121 let Some(account_id) = self.get_account_id(&address)? else {
122 return Ok(false);
123 };
124 self.basic_by_account_id(account_id, basic)
125 }
126
127 #[inline]
129 pub fn basic_by_account_id(
130 &self,
131 account_id: AccountId,
132 basic: &mut Option<AccountInfo>,
133 ) -> Result<bool, BalError> {
134 let Some(bal) = &self.bal else {
135 return Ok(false);
136 };
137 let is_none = basic.is_none();
138 let mut bal_basic = core::mem::take(basic).unwrap_or_default();
139 let changed = bal.populate_account_info(account_id, self.bal_index, &mut bal_basic)?;
140
141 if !changed && is_none {
143 return Ok(true);
144 }
145
146 *basic = Some(bal_basic);
147 Ok(true)
148 }
149
150 #[inline]
154 pub fn storage(
155 &self,
156 account: &Address,
157 storage_key: StorageKey,
158 ) -> Result<Option<StorageValue>, BalError> {
159 let Some(bal) = &self.bal else {
160 return Ok(None);
161 };
162
163 let Some(bal_account) = bal.accounts.get(account) else {
164 return Err(BalError::AccountNotFound { address: *account });
165 };
166
167 Ok(bal_account
168 .storage
169 .get_bal_writes(account, storage_key)?
170 .get(self.bal_index))
171 }
172
173 #[inline]
177 pub fn storage_by_account_id(
178 &self,
179 account_id: AccountId,
180 storage_key: StorageKey,
181 ) -> Result<Option<StorageValue>, BalError> {
182 let Some(bal) = &self.bal else {
183 return Ok(None);
184 };
185
186 let Some((address, bal_account)) = bal.accounts.get_index(account_id.get()) else {
187 return Err(BalError::InvalidAccountId { account_id });
188 };
189
190 Ok(bal_account
191 .storage
192 .get_bal_writes(address, storage_key)?
193 .get(self.bal_index))
194 }
195
196 #[inline]
198 pub fn commit(&mut self, changes: &EvmState) {
199 if let Some(bal_builder) = &mut self.bal_builder {
200 for (address, account) in changes.iter() {
201 bal_builder.update_account(self.bal_index, *address, account);
202 }
203 }
204 }
205
206 #[inline]
208 pub fn commit_one(&mut self, address: Address, account: &Account) {
209 if let Some(bal_builder) = &mut self.bal_builder {
210 bal_builder.update_account(self.bal_index, address, account);
211 }
212 }
213}
214
215#[derive(Clone, Debug, PartialEq, Eq)]
217#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
218pub struct BalDatabase<DB> {
219 pub bal_state: BalState,
221 pub db: DB,
223}
224
225impl<DB> Deref for BalDatabase<DB> {
226 type Target = DB;
227
228 fn deref(&self) -> &Self::Target {
229 &self.db
230 }
231}
232
233impl<DB> DerefMut for BalDatabase<DB> {
234 fn deref_mut(&mut self) -> &mut Self::Target {
235 &mut self.db
236 }
237}
238
239impl<DB> BalDatabase<DB> {
240 #[inline]
242 pub fn new(db: DB) -> Self {
243 Self {
244 bal_state: BalState::default(),
245 db,
246 }
247 }
248
249 #[inline]
251 pub fn with_bal_option(self, bal: Option<Arc<Bal>>) -> Self {
252 Self {
253 bal_state: BalState {
254 bal,
255 ..self.bal_state
256 },
257 ..self
258 }
259 }
260
261 #[inline]
263 pub fn with_bal_builder(self) -> Self {
264 Self {
265 bal_state: self.bal_state.with_bal_builder(),
266 ..self
267 }
268 }
269
270 #[inline]
272 pub const fn reset_bal_index(mut self) -> Self {
273 self.bal_state.reset_bal_index();
274 self
275 }
276
277 #[inline]
279 pub const fn bump_bal_index(&mut self) {
280 self.bal_state.bump_bal_index();
281 }
282}
283
284#[derive(Clone, Debug, PartialEq, Eq)]
286#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
287pub enum EvmDatabaseError<ERROR> {
288 Bal(BalError),
290 Database(ERROR),
292}
293
294impl<ERROR> From<BalError> for EvmDatabaseError<ERROR> {
295 fn from(error: BalError) -> Self {
296 Self::Bal(error)
297 }
298}
299
300impl<ERROR: core::error::Error + Send + Sync + 'static> DBErrorMarker for EvmDatabaseError<ERROR> {
301 fn is_fatal(&self) -> bool {
302 match self {
303 Self::Bal(_) => false,
304 Self::Database(_) => true,
305 }
306 }
307}
308
309impl<ERROR: Display> Display for EvmDatabaseError<ERROR> {
310 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
311 match self {
312 Self::Bal(error) => write!(f, "Bal error: {error}"),
313 Self::Database(error) => write!(f, "Database error: {error}"),
314 }
315 }
316}
317
318impl<ERROR: Error> Error for EvmDatabaseError<ERROR> {}
319
320impl<ERROR> EvmDatabaseError<ERROR> {
321 pub fn into_external_error(self) -> ERROR {
325 match self {
326 Self::Bal(_) => panic!("Expected database error, got BAL error"),
327 Self::Database(error) => error,
328 }
329 }
330}
331
332impl<DB: Database> Database for BalDatabase<DB> {
333 type Error = EvmDatabaseError<DB::Error>;
334
335 #[inline]
336 fn basic(&mut self, address: Address) -> Result<Option<AccountInfo>, Self::Error> {
337 let account_id = self.bal_state.get_account_id(&address)?;
338
339 let mut account = self.db.basic(address).map_err(EvmDatabaseError::Database)?;
340
341 if let Some(account_id) = account_id {
342 self.bal_state
343 .basic_by_account_id(account_id, &mut account)?;
344 }
345
346 Ok(account)
347 }
348
349 #[inline]
350 fn code_by_hash(&mut self, code_hash: B256) -> Result<Bytecode, Self::Error> {
351 self.db
352 .code_by_hash(code_hash)
353 .map_err(EvmDatabaseError::Database)
354 }
355
356 #[inline]
357 fn storage(&mut self, address: Address, key: StorageKey) -> Result<StorageValue, Self::Error> {
358 if let Some(storage) = self.bal_state.storage(&address, key)? {
359 return Ok(storage);
360 }
361
362 self.db
363 .storage(address, key)
364 .map_err(EvmDatabaseError::Database)
365 }
366
367 #[inline]
368 fn storage_by_account_id(
369 &mut self,
370 address: Address,
371 account_id: AccountId,
372 storage_key: StorageKey,
373 ) -> Result<StorageValue, Self::Error> {
374 if let Some(value) = self
375 .bal_state
376 .storage_by_account_id(account_id, storage_key)?
377 {
378 return Ok(value);
379 }
380
381 self.db
382 .storage(address, storage_key)
383 .map_err(EvmDatabaseError::Database)
384 }
385
386 fn block_hash(&mut self, number: u64) -> Result<B256, Self::Error> {
387 self.db
388 .block_hash(number)
389 .map_err(EvmDatabaseError::Database)
390 }
391}
392
393impl<DB: DatabaseCommit> DatabaseCommit for BalDatabase<DB> {
394 fn commit(&mut self, changes: EvmState) {
395 self.bal_state.commit(&changes);
396 self.db.commit(changes);
397 }
398
399 fn commit_iter(&mut self, changes: &mut dyn Iterator<Item = (Address, Account)>) {
400 let bal_state = &mut self.bal_state;
401 let mut changes = changes.map(|(address, account)| {
402 bal_state.commit_one(address, &account);
403 (address, account)
404 });
405 self.db.commit_iter(&mut changes);
406 }
407}