mod tx;
use super::*;
use crate::{
rlc::circuit::{builder::RlcCircuitBuilder, RlcCircuitParams},
utils::eth_circuit::{
create_circuit, EthCircuitImpl, EthCircuitInstructions, EthCircuitParams,
},
};
use ark_std::{end_timer, start_timer};
use ethers_core::utils::keccak256;
use halo2_base::{
gates::circuit::{BaseCircuitParams, CircuitBuilderStage},
halo2_proofs::{
dev::MockProver,
halo2curves::bn256::Fr,
plonk::{keygen_pk, keygen_vk},
},
utils::{
fs::gen_srs,
testing::{check_proof_with_instances, gen_proof_with_instances},
},
};
use hex::FromHex;
use std::{fs::File, io::Write, marker::PhantomData, path::Path};
use test_case::test_case;
use test_log::test;
const TEST_K: u32 = 15;
#[derive(Clone)]
struct MptTest<F: Field>(MPTInput, PhantomData<F>);
impl<F: Field> EthCircuitInstructions<F> for MptTest<F> {
type FirstPhasePayload = MPTProofWitness<F>;
fn virtual_assign_phase0(
&self,
builder: &mut RlcCircuitBuilder<F>,
mpt: &MPTChip<F>,
) -> Self::FirstPhasePayload {
let ctx = builder.base.main(0);
let mpt_proof = self.0.clone().assign(ctx);
mpt.parse_mpt_inclusion_phase0(ctx, mpt_proof)
}
fn virtual_assign_phase1(
&self,
builder: &mut RlcCircuitBuilder<F>,
mpt: &MPTChip<F>,
mpt_witness: Self::FirstPhasePayload,
) {
let (ctx_gate, ctx_rlc) = builder.rlc_ctx_pair();
mpt.parse_mpt_inclusion_phase1((ctx_gate, ctx_rlc), mpt_witness);
}
}
fn test_mpt_circuit<F: Field>(
stage: CircuitBuilderStage,
params: RlcCircuitParams,
inputs: MPTInput,
) -> EthCircuitImpl<F, MptTest<F>> {
let test = MptTest(inputs, PhantomData);
let mut circuit = create_circuit(stage, params, test);
circuit.mock_fulfill_keccak_promises(None);
if !stage.witness_gen_only() {
circuit.calculate_params();
}
circuit
}
fn from_hex(s: &str) -> Vec<u8> {
let s = if s.len() % 2 == 1 { format!("0{s}") } else { s.to_string() };
Vec::from_hex(s).unwrap()
}
fn mpt_input_storage(
path: impl AsRef<Path>,
slot_is_empty: bool,
max_depth: usize,
max_key_byte_len: usize,
key_byte_len: Option<usize>,
) -> MPTInput {
let pf_str = std::fs::read_to_string(path).unwrap();
let pf: serde_json::Value = serde_json::from_str(pf_str.as_str()).unwrap();
let storage_pf = pf["storageProof"][0].clone();
let key_bytes_str: String = serde_json::from_value(storage_pf["key"].clone()).unwrap();
let path = keccak256(from_hex(&key_bytes_str)).to_vec().into();
let value_bytes_str: String = serde_json::from_value(storage_pf["value"].clone()).unwrap();
let value_bytes_str = if value_bytes_str.len() % 2 == 1 {
format!("0{}", &value_bytes_str[2..])
} else {
value_bytes_str[2..].to_string()
};
let value = ::rlp::encode(&from_hex(&value_bytes_str)).to_vec();
let root_hash_str: String = serde_json::from_value(pf["storageHash"].clone()).unwrap();
let pf_strs: Vec<String> = serde_json::from_value(storage_pf["proof"].clone()).unwrap();
let value_max_byte_len = 33;
let proof = pf_strs.into_iter().map(|pf| from_hex(&pf[2..])).collect();
MPTInput {
path,
value,
root_hash: H256::from_slice(&from_hex(&root_hash_str[2..])),
proof,
slot_is_empty,
value_max_byte_len,
max_depth,
max_key_byte_len,
key_byte_len,
}
}
fn default_input() -> MPTInput {
mpt_input_storage("scripts/input_gen/default_storage_pf.json", false, 8, 32, Some(32))
}
fn default_params() -> RlcCircuitParams {
let mut params = RlcCircuitParams::default();
params.base.num_instance_columns = 1;
params.base.lookup_bits = Some(8);
params.base.k = TEST_K as usize;
params
}
#[test_case("scripts/input_gen/default_storage_pf.json", false, 5, 32, None; "default storage inclusion fixed")]
#[test_case("scripts/input_gen/default_storage_pf.json", false, 5, 32, Some(32); "default storage inclusion var")]
#[test_case("scripts/input_gen/noninclusion_branch_pf.json", true, 5, 32, None; "noninclusion branch fixed")]
#[test_case("scripts/input_gen/noninclusion_branch_pf.json", true, 5, 32, Some(32); "noninclusion branch var")]
#[test_case("scripts/input_gen/noninclusion_extension_pf.json", true, 6, 32, None; "noninclusion extension fixed")]
#[test_case("scripts/input_gen/noninclusion_extension_pf.json", true, 6, 32, Some(32); "noninclusion extension var")]
#[test_case("scripts/input_gen/noninclusion_extension_pf2.json", true, 6, 32, None; "noninclusion branch then extension fixed")]
#[test_case("scripts/input_gen/noninclusion_extension_pf2.json", true, 6, 32, Some(32); "noninclusion branch then extension var")]
#[test_case("scripts/input_gen/empty_storage_pf.json", true, 5, 32, None; "empty root fixed")]
#[test_case("scripts/input_gen/empty_storage_pf.json", true, 5, 32, Some(32); "empty root var")]
pub fn test_mock_mpt(
path: &str,
slot_is_empty: bool,
max_depth: usize,
max_key_byte_len: usize,
key_byte_len: Option<usize>,
) {
let _ = env_logger::builder().is_test(true).try_init();
let input = mpt_input_storage(path, slot_is_empty, max_depth, max_key_byte_len, key_byte_len);
let circuit = test_mpt_circuit::<Fr>(CircuitBuilderStage::Mock, default_params(), input);
let instances = circuit.instances();
assert_eq!(instances.len(), 1);
MockProver::run(TEST_K, &circuit, instances).unwrap().assert_satisfied();
}
#[test]
#[ignore = "bench"]
fn bench_mpt_inclusion_fixed() -> Result<(), Box<dyn std::error::Error>> {
let bench_params_file = File::create("configs/bench/mpt.json").unwrap();
std::fs::create_dir_all("data/bench")?;
let mut fs_results = File::create("data/bench/mpt.csv").unwrap();
writeln!(fs_results, "degree,total_advice,num_rlc_columns,num_advice,num_lookup,num_fixed,proof_time,verify_time")?;
let bench_k = 15..18;
let mut all_bench_params = vec![];
for k in bench_k {
println!("---------------------- degree = {k} ------------------------------",);
let params = gen_srs(k);
let mut bench_params = EthCircuitParams::default().rlc;
bench_params.base.k = k as usize;
let mut circuit =
test_mpt_circuit(CircuitBuilderStage::Keygen, bench_params, default_input());
let bench_params = circuit.calculate_params().rlc;
all_bench_params.push(bench_params.clone());
let vk = keygen_vk(¶ms, &circuit)?;
let pk = keygen_pk(¶ms, vk, &circuit)?;
let break_points = circuit.break_points();
let proof_time = start_timer!(|| "Create proof SHPLONK");
let circuit =
test_mpt_circuit(CircuitBuilderStage::Prover, bench_params.clone(), default_input())
.use_break_points(break_points);
let instances = circuit.instances();
assert_eq!(instances.len(), 1);
let proof = gen_proof_with_instances(¶ms, &pk, circuit, &[&instances[0]]);
end_timer!(proof_time);
let verify_time = start_timer!(|| "Verify time");
check_proof_with_instances(¶ms, pk.get_vk(), &proof, &[&instances[0]], true);
end_timer!(verify_time);
let RlcCircuitParams {
base:
BaseCircuitParams {
k,
num_advice_per_phase,
num_fixed,
num_lookup_advice_per_phase,
..
},
num_rlc_columns,
} = bench_params;
writeln!(
fs_results,
"{},{},{},{:?},{:?},{},{:.2}s,{:?}",
k,
num_rlc_columns
+ num_advice_per_phase.iter().sum::<usize>()
+ num_lookup_advice_per_phase.iter().sum::<usize>(),
num_rlc_columns,
num_advice_per_phase,
num_lookup_advice_per_phase,
num_fixed,
proof_time.time.elapsed().as_secs_f64(),
verify_time.time.elapsed()
)
.unwrap();
}
serde_json::to_writer_pretty(bench_params_file, &all_bench_params).unwrap();
Ok(())
}