axiom-eth 0.4.0

This crate is the main library for building ZK circuits that prove data about the Ethereum virtual machine (EVM).
Documentation
/// Tests using cita-trie to generate random tx tries
mod tx;

use super::*;
use crate::{
    rlc::circuit::{builder::RlcCircuitBuilder, RlcCircuitParams},
    utils::eth_circuit::{
        create_circuit, EthCircuitImpl, EthCircuitInstructions, EthCircuitParams,
    },
};
use ark_std::{end_timer, start_timer};
use ethers_core::utils::keccak256;
use halo2_base::{
    gates::circuit::{BaseCircuitParams, CircuitBuilderStage},
    halo2_proofs::{
        dev::MockProver,
        halo2curves::bn256::Fr,
        plonk::{keygen_pk, keygen_vk},
    },
    utils::{
        fs::gen_srs,
        testing::{check_proof_with_instances, gen_proof_with_instances},
    },
};
use hex::FromHex;
use std::{fs::File, io::Write, marker::PhantomData, path::Path};
use test_case::test_case;
use test_log::test;

const TEST_K: u32 = 15;

#[derive(Clone)]
struct MptTest<F: Field>(MPTInput, PhantomData<F>);

impl<F: Field> EthCircuitInstructions<F> for MptTest<F> {
    type FirstPhasePayload = MPTProofWitness<F>;
    fn virtual_assign_phase0(
        &self,
        builder: &mut RlcCircuitBuilder<F>,
        mpt: &MPTChip<F>,
    ) -> Self::FirstPhasePayload {
        let ctx = builder.base.main(0);
        let mpt_proof = self.0.clone().assign(ctx);
        mpt.parse_mpt_inclusion_phase0(ctx, mpt_proof)
    }

    fn virtual_assign_phase1(
        &self,
        builder: &mut RlcCircuitBuilder<F>,
        mpt: &MPTChip<F>,
        mpt_witness: Self::FirstPhasePayload,
    ) {
        let (ctx_gate, ctx_rlc) = builder.rlc_ctx_pair();
        mpt.parse_mpt_inclusion_phase1((ctx_gate, ctx_rlc), mpt_witness);
    }
}

fn test_mpt_circuit<F: Field>(
    stage: CircuitBuilderStage,
    params: RlcCircuitParams,
    inputs: MPTInput,
) -> EthCircuitImpl<F, MptTest<F>> {
    let test = MptTest(inputs, PhantomData);
    let mut circuit = create_circuit(stage, params, test);
    circuit.mock_fulfill_keccak_promises(None);
    if !stage.witness_gen_only() {
        circuit.calculate_params();
    }
    circuit
}

/// Assumes string does **not** start with `0x`
fn from_hex(s: &str) -> Vec<u8> {
    let s = if s.len() % 2 == 1 { format!("0{s}") } else { s.to_string() };
    Vec::from_hex(s).unwrap()
}

// The input file is generated by running `query_test.sh` in the `scripts/input_gen` directory of this repo
fn mpt_input_storage(
    path: impl AsRef<Path>,
    slot_is_empty: bool,
    max_depth: usize,
    max_key_byte_len: usize,
    key_byte_len: Option<usize>,
) -> MPTInput {
    /*let block: serde_json::Value =
    serde_json::from_reader(File::open("scripts/input_gen/block.json").unwrap()).unwrap();*/

    let pf_str = std::fs::read_to_string(path).unwrap();
    let pf: serde_json::Value = serde_json::from_str(pf_str.as_str()).unwrap();
    // let acct_pf = pf["accountProof"].clone();
    let storage_pf = pf["storageProof"][0].clone();
    // println!("acct_pf {:?}", acct_pf);
    // println!("storage_root {:?}", pf["storageHash"]);
    // println!("storage_pf {:?}", storage_pf);

    let key_bytes_str: String = serde_json::from_value(storage_pf["key"].clone()).unwrap();
    let path = keccak256(from_hex(&key_bytes_str)).to_vec().into();
    // let path = keccak256(from_hex(&key_bytes_str));
    let value_bytes_str: String = serde_json::from_value(storage_pf["value"].clone()).unwrap();
    let value_bytes_str = if value_bytes_str.len() % 2 == 1 {
        format!("0{}", &value_bytes_str[2..])
    } else {
        value_bytes_str[2..].to_string()
    };
    let value = ::rlp::encode(&from_hex(&value_bytes_str)).to_vec();
    let root_hash_str: String = serde_json::from_value(pf["storageHash"].clone()).unwrap();
    let pf_strs: Vec<String> = serde_json::from_value(storage_pf["proof"].clone()).unwrap();

    let value_max_byte_len = 33;
    let proof = pf_strs.into_iter().map(|pf| from_hex(&pf[2..])).collect();

    MPTInput {
        path,
        value,
        root_hash: H256::from_slice(&from_hex(&root_hash_str[2..])),
        proof,
        slot_is_empty,
        value_max_byte_len,
        max_depth,
        max_key_byte_len,
        key_byte_len,
    }
}

fn default_input() -> MPTInput {
    mpt_input_storage("scripts/input_gen/default_storage_pf.json", false, 8, 32, Some(32))
}

fn default_params() -> RlcCircuitParams {
    let mut params = RlcCircuitParams::default();
    params.base.num_instance_columns = 1;
    params.base.lookup_bits = Some(8);
    params.base.k = TEST_K as usize;
    params
}

#[test_case("scripts/input_gen/default_storage_pf.json", false, 5, 32, None; "default storage inclusion fixed")]
#[test_case("scripts/input_gen/default_storage_pf.json", false, 5, 32, Some(32); "default storage inclusion var")]
#[test_case("scripts/input_gen/noninclusion_branch_pf.json", true, 5, 32, None; "noninclusion branch fixed")]
#[test_case("scripts/input_gen/noninclusion_branch_pf.json", true, 5, 32, Some(32); "noninclusion branch var")]
#[test_case("scripts/input_gen/noninclusion_extension_pf.json", true, 6, 32, None; "noninclusion extension fixed")]
#[test_case("scripts/input_gen/noninclusion_extension_pf.json", true, 6, 32, Some(32); "noninclusion extension var")]
#[test_case("scripts/input_gen/noninclusion_extension_pf2.json", true, 6, 32, None; "noninclusion branch then extension fixed")]
#[test_case("scripts/input_gen/noninclusion_extension_pf2.json", true, 6, 32, Some(32); "noninclusion branch then extension var")]
#[test_case("scripts/input_gen/empty_storage_pf.json", true, 5, 32, None; "empty root fixed")]
#[test_case("scripts/input_gen/empty_storage_pf.json", true, 5, 32, Some(32); "empty root var")]
pub fn test_mock_mpt(
    path: &str,
    slot_is_empty: bool,
    max_depth: usize,
    max_key_byte_len: usize,
    key_byte_len: Option<usize>,
) {
    let _ = env_logger::builder().is_test(true).try_init();
    let input = mpt_input_storage(path, slot_is_empty, max_depth, max_key_byte_len, key_byte_len);
    let circuit = test_mpt_circuit::<Fr>(CircuitBuilderStage::Mock, default_params(), input);
    let instances = circuit.instances();
    assert_eq!(instances.len(), 1);
    MockProver::run(TEST_K, &circuit, instances).unwrap().assert_satisfied();
}

#[test]
#[ignore = "bench"]
fn bench_mpt_inclusion_fixed() -> Result<(), Box<dyn std::error::Error>> {
    let bench_params_file = File::create("configs/bench/mpt.json").unwrap();
    std::fs::create_dir_all("data/bench")?;
    let mut fs_results = File::create("data/bench/mpt.csv").unwrap();
    writeln!(fs_results, "degree,total_advice,num_rlc_columns,num_advice,num_lookup,num_fixed,proof_time,verify_time")?;

    let bench_k = 15..18;
    let mut all_bench_params = vec![];
    // let bench_params: Vec<EthConfigParams> = serde_json::from_reader(bench_params_file).unwrap();
    for k in bench_k {
        println!("---------------------- degree = {k} ------------------------------",);
        let params = gen_srs(k);
        let mut bench_params = EthCircuitParams::default().rlc;
        bench_params.base.k = k as usize;

        let mut circuit =
            test_mpt_circuit(CircuitBuilderStage::Keygen, bench_params, default_input());
        let bench_params = circuit.calculate_params().rlc;
        all_bench_params.push(bench_params.clone());
        let vk = keygen_vk(&params, &circuit)?;
        let pk = keygen_pk(&params, vk, &circuit)?;
        let break_points = circuit.break_points();

        // create a proof
        let proof_time = start_timer!(|| "Create proof SHPLONK");
        let circuit =
            test_mpt_circuit(CircuitBuilderStage::Prover, bench_params.clone(), default_input())
                .use_break_points(break_points);
        let instances = circuit.instances();
        assert_eq!(instances.len(), 1);
        let proof = gen_proof_with_instances(&params, &pk, circuit, &[&instances[0]]);
        end_timer!(proof_time);

        let verify_time = start_timer!(|| "Verify time");
        check_proof_with_instances(&params, pk.get_vk(), &proof, &[&instances[0]], true);
        end_timer!(verify_time);

        let RlcCircuitParams {
            base:
                BaseCircuitParams {
                    k,
                    num_advice_per_phase,
                    num_fixed,
                    num_lookup_advice_per_phase,
                    ..
                },
            num_rlc_columns,
        } = bench_params;
        writeln!(
            fs_results,
            "{},{},{},{:?},{:?},{},{:.2}s,{:?}",
            k,
            num_rlc_columns
                + num_advice_per_phase.iter().sum::<usize>()
                + num_lookup_advice_per_phase.iter().sum::<usize>(),
            num_rlc_columns,
            num_advice_per_phase,
            num_lookup_advice_per_phase,
            num_fixed,
            proof_time.time.elapsed().as_secs_f64(),
            verify_time.time.elapsed()
        )
        .unwrap();
    }
    serde_json::to_writer_pretty(bench_params_file, &all_bench_params).unwrap();
    Ok(())
}