use std::{
cmp,
collections::{BTreeMap, BTreeSet, HashMap},
marker::PhantomData,
mem,
};
use anyhow::{anyhow, bail, ensure, Context as _};
use either::Either;
use ethereum_types::{Address, BigEndianHash as _, U256};
use evm_arithmetization::{
generation::TrieInputs,
proof::{BlockMetadata, TrieRoots},
tries::{MptKey, ReceiptTrie, StateMpt, StorageTrie, TransactionTrie},
world::{Hasher, KeccakHash, PoseidonHash, Type1World, Type2World, World},
GenerationInputs,
};
use itertools::Itertools as _;
use keccak_hash::H256;
use mpt_trie::partial_trie::PartialTrie as _;
use nunny::NonEmpty;
use zk_evm_common::gwei_to_wei;
use crate::observer::{DummyObserver, Observer};
use crate::{
BlockLevelData, BlockTrace, BlockTraceTriePreImages, CombinedPreImages, ContractCodeUsage,
OtherBlockData, SeparateStorageTriesPreImage, SeparateTriePreImage, SeparateTriePreImages,
TxnInfo, TxnMeta, TxnTrace,
};
#[derive(Debug)]
pub enum WireDisposition {
Type1,
Type2,
}
pub fn entrypoint(
trace: BlockTrace,
other: OtherBlockData,
batch_size_hint: usize,
observer: &mut impl Observer<Type1World>,
wire_disposition: WireDisposition,
) -> anyhow::Result<Vec<GenerationInputs>> {
ensure!(batch_size_hint != 0);
let BlockTrace {
trie_pre_images,
code_db,
txn_info,
} = trace;
let fatal_missing_code = match trie_pre_images {
BlockTraceTriePreImages::Separate(_) => FatalMissingCode(true),
BlockTraceTriePreImages::Combined(_) => FatalMissingCode(false),
};
let start = start(trie_pre_images, wire_disposition)?;
let OtherBlockData {
b_data:
BlockLevelData {
b_meta,
b_hashes,
mut withdrawals,
},
checkpoint_state_trie_root,
checkpoint_consolidated_hash,
burn_addr,
ger_data,
} = other;
for (_, amt) in &mut withdrawals {
*amt = gwei_to_wei(*amt)
}
let batches = match start {
Either::Left((type1world, mut code)) => {
code.extend(code_db);
Either::Left(
middle(
type1world,
batch(txn_info, batch_size_hint),
&mut code,
&b_meta,
ger_data,
withdrawals,
fatal_missing_code,
observer,
)?
.into_iter()
.map(|it| it.map(Either::Left)),
)
}
Either::Right((type2world, mut code)) => {
code.extend(code_db);
Either::Right(
middle(
type2world,
batch(txn_info, batch_size_hint),
&mut code,
&b_meta,
ger_data,
withdrawals,
fatal_missing_code,
&mut DummyObserver::new(), )?
.into_iter()
.map(|it| it.map(Either::Right)),
)
}
};
let mut running_gas_used = 0;
Ok(batches
.into_iter()
.map(
|Batch {
first_txn_ix,
gas_used,
contract_code,
byte_code,
before:
IntraBlockTries {
world,
transaction,
receipt,
},
after,
withdrawals,
}| {
let (state, storage) = world
.clone()
.expect_left("TODO(0xaatif): evm_arithemetization accepts an SMT")
.into_state_and_storage();
GenerationInputs {
txn_number_before: first_txn_ix.into(),
gas_used_before: running_gas_used.into(),
gas_used_after: {
running_gas_used += gas_used;
running_gas_used.into()
},
signed_txns: byte_code.into_iter().map(Into::into).collect(),
withdrawals,
ger_data,
tries: TrieInputs {
state_trie: state.into(),
transactions_trie: transaction.into(),
receipts_trie: receipt.into(),
storage_tries: storage.into_iter().map(|(k, v)| (k, v.into())).collect(),
},
trie_roots_after: after,
checkpoint_state_trie_root,
checkpoint_consolidated_hash,
contract_code: contract_code
.into_iter()
.map(|it| match &world {
Either::Left(_type1) => {
(<Type1World as World>::CodeHasher::hash(&it), it)
}
Either::Right(_type2) => {
(<Type2World as World>::CodeHasher::hash(&it), it)
}
})
.collect(),
block_metadata: b_meta.clone(),
block_hashes: b_hashes.clone(),
burn_addr,
}
},
)
.collect())
}
#[allow(clippy::type_complexity)]
fn start(
pre_images: BlockTraceTriePreImages,
wire_disposition: WireDisposition,
) -> anyhow::Result<
Either<(Type1World, Hash2Code<KeccakHash>), (Type2World, Hash2Code<PoseidonHash>)>,
> {
Ok(match pre_images {
BlockTraceTriePreImages::Separate(SeparateTriePreImages {
state: SeparateTriePreImage::Direct(state),
storage: SeparateStorageTriesPreImage::MultipleTries(storage),
}) => {
let state =
state
.items()
.try_fold(StateMpt::new(), |mut acc, (nibbles, hash_or_val)| {
let path = MptKey::from_nibbles(nibbles);
match hash_or_val {
mpt_trie::trie_ops::ValOrHash::Val(bytes) => {
acc.insert(
path.into_hash()
.context("invalid path length in direct state trie")?,
rlp::decode(&bytes)
.context("invalid AccountRlp in direct state trie")?,
)?;
}
mpt_trie::trie_ops::ValOrHash::Hash(h) => {
acc.insert_hash(path, h)?;
}
};
anyhow::Ok(acc)
})?;
let storage = storage
.into_iter()
.map(|(k, SeparateTriePreImage::Direct(v))| {
v.items()
.try_fold(StorageTrie::default(), |mut acc, (nibbles, hash_or_val)| {
let path = MptKey::from_nibbles(nibbles);
match hash_or_val {
mpt_trie::trie_ops::ValOrHash::Val(value) => {
acc.insert(path, value)?;
}
mpt_trie::trie_ops::ValOrHash::Hash(h) => {
acc.insert_hash(path, h)?;
}
};
anyhow::Ok(acc)
})
.map(|v| (k, v))
})
.collect::<Result<_, _>>()?;
Either::Left((Type1World::new(state, storage)?, Hash2Code::new()))
}
BlockTraceTriePreImages::Combined(CombinedPreImages { compact }) => {
let instructions = crate::wire::parse(&compact)
.context("couldn't parse instructions from binary format")?;
match wire_disposition {
WireDisposition::Type1 => {
let crate::type1::Frontend {
state,
storage,
code,
} = crate::type1::frontend(instructions)?;
Either::Left((
Type1World::new(state, storage)?,
Hash2Code::from_iter(code.into_iter().map(NonEmpty::into_vec)),
))
}
WireDisposition::Type2 => {
let crate::type2::Frontend { world: trie, code } =
crate::type2::frontend(instructions)?;
Either::Right((
trie,
Hash2Code::from_iter(code.into_iter().map(NonEmpty::into_vec)),
))
}
}
}
})
}
fn batch(txns: Vec<TxnInfo>, batch_size_hint: usize) -> Vec<Vec<Option<TxnInfo>>> {
let hint = cmp::max(batch_size_hint, 1);
let mut txns = txns.into_iter().map(Some).collect::<Vec<_>>();
let n_batches = txns.iter().chunks(hint).into_iter().count();
match (txns.len(), n_batches) {
(_, 2..) => txns
.into_iter()
.chunks(hint)
.into_iter()
.map(FromIterator::from_iter)
.collect(),
(2.., ..2) => {
let second = txns.split_off(txns.len() / 2);
vec![txns, second]
}
(0 | 1, _) => txns
.into_iter()
.pad_using(2, |_ix| None)
.map(|it| vec![it])
.collect(),
}
}
#[test]
fn test_batch() {
#[track_caller]
fn do_test(n: usize, hint: usize, exp: impl IntoIterator<Item = usize>) {
itertools::assert_equal(
exp,
batch(vec![TxnInfo::default(); n], hint)
.iter()
.map(Vec::len),
)
}
do_test(0, 0, [1, 1]); do_test(1, 0, [1, 1]); do_test(2, 0, [1, 1]); do_test(3, 0, [1, 1, 1]);
do_test(3, 1, [1, 1, 1]);
do_test(3, 2, [2, 1]); do_test(3, 3, [1, 2]); }
#[derive(Debug)]
struct Batch<StateTrieT> {
pub first_txn_ix: usize,
pub gas_used: u64,
pub contract_code: BTreeSet<Vec<u8>>,
pub byte_code: Vec<NonEmpty<Vec<u8>>>,
pub before: IntraBlockTries<StateTrieT>,
pub after: TrieRoots,
pub withdrawals: Vec<(Address, U256)>,
}
impl<T> Batch<T> {
fn map<U>(self, f: impl FnMut(T) -> U) -> Batch<U> {
let Self {
first_txn_ix,
gas_used,
contract_code,
byte_code,
before,
after,
withdrawals,
} = self;
Batch {
first_txn_ix,
gas_used,
contract_code,
byte_code,
before: before.map(f),
after,
withdrawals,
}
}
}
#[derive(Debug)]
pub struct IntraBlockTries<WorldT> {
pub world: WorldT,
pub transaction: TransactionTrie,
pub receipt: ReceiptTrie,
}
impl<T> IntraBlockTries<T> {
fn map<U>(self, f: impl FnOnce(T) -> U) -> IntraBlockTries<U> {
let Self {
world,
transaction,
receipt,
} = self;
IntraBlockTries {
world: f(world),
transaction,
receipt,
}
}
}
#[derive(Copy, Clone)]
pub struct FatalMissingCode(pub bool);
#[allow(clippy::too_many_arguments)]
fn middle<WorldT: World + Clone>(
mut world: WorldT,
batches: Vec<Vec<Option<TxnInfo>>>,
code: &mut Hash2Code<WorldT::CodeHasher>,
block: &BlockMetadata,
ger_data: Option<(H256, H256)>,
mut withdrawals: Vec<(Address, U256)>,
fatal_missing_code: FatalMissingCode,
observer: &mut impl Observer<WorldT>,
) -> anyhow::Result<Vec<Batch<WorldT>>>
where
WorldT::SubtriePath: Ord + From<Address>,
{
let mut transaction_trie = TransactionTrie::new();
let mut receipt_trie = ReceiptTrie::new();
let mut out = vec![];
let mut txn_ix = 0; let mut loop_ix = 0; let loop_len = batches.iter().flatten().count();
for (batch_index, batch) in batches.into_iter().enumerate() {
let batch_first_txn_ix = txn_ix; let mut batch_gas_used = 0;
let mut batch_byte_code = vec![];
let mut batch_contract_code = BTreeSet::from([vec![]]);
let mut before = IntraBlockTries {
world: world.clone(),
transaction: transaction_trie.clone(),
receipt: receipt_trie.clone(),
};
let mut storage_masks = BTreeMap::<_, BTreeSet<MptKey>>::new();
let mut state_mask = BTreeSet::<WorldT::SubtriePath>::new();
if txn_ix == 0 {
do_pre_execution(
block,
ger_data,
&mut storage_masks,
&mut state_mask,
&mut world,
)?;
}
for txn in batch {
let do_increment_txn_ix = txn.is_some();
let TxnInfo {
traces,
meta:
TxnMeta {
byte_code,
new_receipt_trie_node_byte,
gas_used: txn_gas_used,
},
} = txn.unwrap_or_default();
let tx_hash = keccak_hash::keccak(&byte_code);
if let Ok(nonempty) = nunny::Vec::new(byte_code) {
batch_byte_code.push(nonempty.clone());
transaction_trie.insert(txn_ix, nonempty.into())?;
receipt_trie.insert(
txn_ix,
map_receipt_bytes(new_receipt_trie_node_byte.clone())?,
)?;
}
batch_gas_used += txn_gas_used;
for (
addr,
just_access,
TxnTrace {
balance,
nonce,
storage_read,
storage_written,
code_usage,
self_destructed,
},
) in traces
.into_iter()
.map(|(addr, trc)| (addr, trc == TxnTrace::default(), trc))
{
let (_, _, receipt) = evm_arithmetization::generation::mpt::decode_receipt(
&map_receipt_bytes(new_receipt_trie_node_byte.clone())?,
)
.map_err(|e| anyhow!("{e:?}"))
.context(format!("couldn't decode receipt in txn {tx_hash:x}"))?;
let born = !world.contains(addr)?;
if born {
world.create_storage(addr)?
}
let do_writes = !just_access
&& match born {
true => receipt.status,
false => true,
};
let storage_mask = storage_masks.entry(addr).or_default();
storage_mask.extend(
storage_written
.keys()
.chain(&storage_read)
.map(|it| MptKey::from_hash(keccak_hash::keccak(it))),
);
if do_writes {
if let Some(new) = balance {
world.update_balance(addr, |it| *it = new)?
}
if let Some(new) = nonce {
world.update_nonce(addr, |it| *it = new)?
}
if let Some(usage) = code_usage {
match usage {
ContractCodeUsage::Read(hash) => {
match (fatal_missing_code, code.get(hash)) {
(FatalMissingCode(true), None) => {
bail!("no code for hash {hash:x}")
}
(_, Some(byte_code)) => {
world.set_code(addr, Either::Left(&byte_code))?;
batch_contract_code.insert(byte_code);
}
(_, None) => world.set_code(addr, Either::Right(hash))?,
}
}
ContractCodeUsage::Write(bytes) => {
code.insert(bytes.clone());
world.set_code(addr, Either::Left(&bytes))?;
batch_contract_code.insert(bytes);
}
};
}
if !storage_written.is_empty() {
for (k, v) in storage_written {
match v.is_zero() {
true => storage_mask
.extend(world.reporting_destroy_slot(addr, k.into_uint())?),
false => world.store_int(addr, k.into_uint(), v)?,
}
}
}
state_mask.insert(<WorldT::SubtriePath>::from(addr));
} else {
state_mask.insert(<WorldT::SubtriePath>::from(addr));
}
if self_destructed {
world.destroy_storage(addr)?;
state_mask.extend(world.reporting_destroy(addr)?)
}
}
if do_increment_txn_ix {
txn_ix += 1;
}
loop_ix += 1;
}
out.push(Batch {
first_txn_ix: batch_first_txn_ix,
gas_used: batch_gas_used,
contract_code: batch_contract_code,
byte_code: batch_byte_code,
withdrawals: match loop_ix == loop_len {
true => {
for (addr, amt) in &withdrawals {
state_mask.insert(<WorldT::SubtriePath>::from(*addr));
world.update_balance(*addr, |it| *it += *amt)?;
}
mem::take(&mut withdrawals)
}
false => vec![],
},
before: {
before.world.mask(state_mask)?;
before.receipt.mask(batch_first_txn_ix..txn_ix)?;
before.transaction.mask(batch_first_txn_ix..txn_ix)?;
before.world.mask_storage(storage_masks)?;
before
},
after: TrieRoots {
state_root: world.root(),
transactions_root: transaction_trie.root(),
receipts_root: receipt_trie.root(),
},
});
observer.collect_tries(
block.block_number,
batch_index,
&world,
&transaction_trie,
&receipt_trie,
)
}
Ok(out)
}
fn do_pre_execution<WorldT: World + Clone>(
block: &BlockMetadata,
ger_data: Option<(H256, H256)>,
trim_storage: &mut BTreeMap<ethereum_types::H160, BTreeSet<MptKey>>,
trim_state: &mut BTreeSet<WorldT::SubtriePath>,
world: &mut WorldT,
) -> anyhow::Result<()>
where
WorldT::SubtriePath: From<Address> + Ord,
{
if cfg!(feature = "eth_mainnet") {
return do_beacon_hook(
block.block_timestamp,
trim_storage,
block.parent_beacon_block_root,
trim_state,
world,
);
}
if cfg!(feature = "cdk_erigon") {
return do_scalable_hook(block, ger_data, trim_storage, trim_state, world);
}
Ok(())
}
fn do_scalable_hook<WorldT: World + Clone>(
block: &BlockMetadata,
ger_data: Option<(H256, H256)>,
trim_storage: &mut BTreeMap<ethereum_types::H160, BTreeSet<MptKey>>,
trim_state: &mut BTreeSet<WorldT::SubtriePath>,
world: &mut WorldT,
) -> anyhow::Result<()>
where
WorldT::SubtriePath: From<Address> + Ord,
{
use evm_arithmetization::testing_utils::{
ADDRESS_SCALABLE_L2, GLOBAL_EXIT_ROOT_ADDRESS, GLOBAL_EXIT_ROOT_STORAGE_POS,
LAST_BLOCK_STORAGE_POS, STATE_ROOT_STORAGE_POS, TIMESTAMP_STORAGE_POS,
};
if block.block_number.is_zero() {
return Err(anyhow!("Attempted to prove the Genesis block!"));
}
let scalable_trim = trim_storage.entry(ADDRESS_SCALABLE_L2).or_default();
let timestamp = world
.load_int(ADDRESS_SCALABLE_L2, U256::from(TIMESTAMP_STORAGE_POS.1))
.unwrap_or_default();
let timestamp = core::cmp::max(timestamp, block.block_timestamp);
for (ix, u) in [
(U256::from(LAST_BLOCK_STORAGE_POS.1), block.block_number),
(U256::from(TIMESTAMP_STORAGE_POS.1), timestamp),
] {
let slot = MptKey::from_slot_position(ix);
ensure!(!u.is_zero());
world.store_int(ADDRESS_SCALABLE_L2, ix, u)?;
scalable_trim.insert(slot);
}
let prev_block_root_hash = world.root();
let mut arr = [0; 64];
(block.block_number - 1).to_big_endian(&mut arr[0..32]);
U256::from(STATE_ROOT_STORAGE_POS.1).to_big_endian(&mut arr[32..64]);
let slot = MptKey::from_hash(keccak_hash::keccak(arr));
world.store_hash(
ADDRESS_SCALABLE_L2,
keccak_hash::keccak(arr),
prev_block_root_hash,
)?;
scalable_trim.insert(slot);
trim_state.insert(<WorldT::SubtriePath>::from(ADDRESS_SCALABLE_L2));
if let Some((root, l1blockhash)) = ger_data {
let ger_trim = trim_storage.entry(GLOBAL_EXIT_ROOT_ADDRESS).or_default();
let mut arr = [0; 64];
arr[0..32].copy_from_slice(&root.0);
U256::from(GLOBAL_EXIT_ROOT_STORAGE_POS.1).to_big_endian(&mut arr[32..64]);
let slot = MptKey::from_hash(keccak_hash::keccak(arr));
world.store_hash(
GLOBAL_EXIT_ROOT_ADDRESS,
keccak_hash::keccak(arr),
l1blockhash,
)?;
ger_trim.insert(slot);
trim_state.insert(<WorldT::SubtriePath>::from(GLOBAL_EXIT_ROOT_ADDRESS));
}
Ok(())
}
fn do_beacon_hook<WorldT: World + Clone>(
block_timestamp: U256,
trim_storage: &mut BTreeMap<ethereum_types::H160, BTreeSet<MptKey>>,
parent_beacon_block_root: H256,
trim_state: &mut BTreeSet<WorldT::SubtriePath>,
world: &mut WorldT,
) -> anyhow::Result<()>
where
WorldT::SubtriePath: From<Address> + Ord,
{
use evm_arithmetization::testing_utils::{
BEACON_ROOTS_CONTRACT_ADDRESS, HISTORY_BUFFER_LENGTH,
};
let timestamp_idx = block_timestamp % HISTORY_BUFFER_LENGTH.value;
let root_idx = timestamp_idx + HISTORY_BUFFER_LENGTH.value;
let beacon_trim = trim_storage
.entry(BEACON_ROOTS_CONTRACT_ADDRESS)
.or_default();
for (ix, u) in [
(timestamp_idx, block_timestamp),
(
root_idx,
U256::from_big_endian(parent_beacon_block_root.as_bytes()),
),
] {
let slot = MptKey::from_slot_position(ix);
beacon_trim.insert(slot);
match u.is_zero() {
true => {
beacon_trim.extend(world.reporting_destroy_slot(BEACON_ROOTS_CONTRACT_ADDRESS, ix)?)
}
false => {
world.store_int(BEACON_ROOTS_CONTRACT_ADDRESS, ix, u)?;
beacon_trim.insert(slot);
}
}
}
trim_state.insert(<WorldT::SubtriePath>::from(BEACON_ROOTS_CONTRACT_ADDRESS));
Ok(())
}
fn map_receipt_bytes(bytes: Vec<u8>) -> anyhow::Result<Vec<u8>> {
match rlp::decode::<evm_arithmetization::generation::mpt::LegacyReceiptRlp>(&bytes) {
Ok(_) => Ok(bytes),
Err(_) => {
rlp::decode(&bytes).context("couldn't decode receipt as a legacy receipt or raw bytes")
}
}
}
struct Hash2Code<H: Hasher> {
inner: HashMap<H256, Vec<u8>>,
_phantom: PhantomData<H>,
}
impl<H: Hasher> Hash2Code<H> {
pub fn new() -> Self {
let mut this = Self {
inner: HashMap::new(),
_phantom: PhantomData,
};
this.insert(vec![]);
this
}
pub fn get(&mut self, hash: H256) -> Option<Vec<u8>> {
self.inner.get(&hash).cloned()
}
pub fn insert(&mut self, code: Vec<u8>) {
self.inner.insert(H::hash(&code), code);
}
}
impl<H: Hasher> Extend<Vec<u8>> for Hash2Code<H> {
fn extend<II: IntoIterator<Item = Vec<u8>>>(&mut self, iter: II) {
for it in iter {
self.insert(it)
}
}
}
impl<H: Hasher> FromIterator<Vec<u8>> for Hash2Code<H> {
fn from_iter<II: IntoIterator<Item = Vec<u8>>>(iter: II) -> Self {
let mut this = Self::new();
this.extend(iter);
this
}
}