use std::{hash::Hash, io::Write};
use anyhow::{anyhow, bail};
use axiom_codec::{
constants::{
USER_ADVICE_COLS, USER_FIXED_COLS, USER_INSTANCE_COLS, USER_LOOKUP_ADVICE_COLS,
USER_MAX_OUTPUTS, USER_MAX_SUBQUERIES, USER_RESULT_FIELD_ELEMENTS,
},
decoder::native::decode_compute_snark,
types::{
field_elements::SUBQUERY_RESULT_LEN,
native::{AxiomV2ComputeQuery, AxiomV2ComputeSnark},
},
utils::writer::{write_curve_compressed, write_field_le},
HiLo,
};
use axiom_eth::{
halo2_base::{
gates::circuit::{builder::BaseCircuitBuilder, BaseCircuitParams},
halo2_proofs::{
halo2curves::{
bn256::{Bn256, Fr},
ff::PrimeField,
serde::SerdeObject,
CurveAffine,
},
plonk::VerifyingKey,
},
utils::ScalarField,
},
halo2curves::{bn256::G1Affine, ff::Field as _},
rlc::circuit::RlcCircuitParams,
snark_verifier::{
pcs::kzg::KzgDecidingKey,
system::halo2::transcript_initial_state,
util::arithmetic::fe_to_limbs,
verifier::{
plonk::{PlonkProof, PlonkProtocol},
SnarkVerifier,
},
},
snark_verifier_sdk::{
halo2::{aggregation::AggregationCircuit, PoseidonTranscript, POSEIDON_SPEC},
CircuitExt, NativeLoader, PlonkVerifier, Snark, BITS, LIMBS, SHPLONK,
},
};
use ethers_core::types::H256;
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use crate::{
components::results::types::{CircuitOutputResultsRoot, LogicOutputResultsRoot},
utils::client_circuit::{metadata::AxiomV2CircuitMetadata, vkey::OnchainVerifyingKey},
Field,
};
pub fn get_metadata_from_protocol(
protocol: &PlonkProtocol<G1Affine>,
rlc_params: RlcCircuitParams,
max_outputs: usize,
) -> anyhow::Result<AxiomV2CircuitMetadata> {
let num_advice_per_phase =
rlc_params.base.num_advice_per_phase.iter().map(|x| *x as u16).collect();
let num_lookup_advice_per_phase =
rlc_params.base.num_lookup_advice_per_phase.iter().map(|x| *x as u8).collect();
let num_rlc_columns = rlc_params.num_rlc_columns as u16;
let num_fixed = rlc_params.base.num_fixed as u8;
let mut metadata = AxiomV2CircuitMetadata {
version: 0,
num_advice_per_phase,
num_lookup_advice_per_phase,
num_rlc_columns,
num_fixed,
max_outputs: max_outputs as u16,
..Default::default()
};
if protocol.num_instance.len() != 1 {
bail!("Only one instance column supported right now");
}
metadata.num_instance = protocol.num_instance.iter().map(|&x| x as u32).collect();
let mut num_challenge_incl_system = protocol.num_challenge.clone();
if num_challenge_incl_system.pop() != Some(1) {
bail!("last challenge must be [alpha]");
}
if num_challenge_incl_system.pop() != Some(2) {
bail!("second last challenge must be [beta, gamma]");
}
let last_challenge = num_challenge_incl_system.last_mut();
if last_challenge.is_none() {
bail!("num_challenge must have at least 3 challenges");
}
let last_challenge = last_challenge.unwrap();
if *last_challenge == 0 {
bail!("third last challenge must include theta");
}
*last_challenge -= 1;
let num_challenge: Vec<u8> = num_challenge_incl_system.iter().map(|x| *x as u8).collect();
if num_challenge != vec![0] && num_challenge != vec![1, 0] {
log::debug!("num_challenge: {:?}", num_challenge);
bail!("Only phase0 BaseCircuitBuilder or phase0+1 RlcCircuitBuilder supported right now");
}
metadata.num_challenge = num_challenge;
metadata.is_aggregation = if protocol.accumulator_indices.is_empty() {
false
} else {
if protocol.accumulator_indices.len() != 1
|| protocol.accumulator_indices[0] != AggregationCircuit::accumulator_indices().unwrap()
{
bail!("invalid accumulator indices");
}
true
};
Ok(metadata)
}
pub fn write_onchain_vkey<C>(vkey: &OnchainVerifyingKey<C>) -> anyhow::Result<Vec<H256>>
where
C: CurveAffine + SerdeObject,
C::Scalar: Field + SerdeObject,
{
let metadata = vkey.circuit_metadata.encode()?;
let tmp = C::Repr::default();
let compressed_curve_bytes = tmp.as_ref().len();
let tmp = <C::Scalar as PrimeField>::Repr::default();
let field_bytes = tmp.as_ref().len();
let mut writer =
Vec::with_capacity(field_bytes + vkey.preprocessed.len() * compressed_curve_bytes);
writer.write_all(&metadata.to_fixed_bytes())?;
write_field_le(&mut writer, vkey.transcript_initial_state)?;
for &point in &vkey.preprocessed {
write_curve_compressed(&mut writer, point)?;
}
Ok(writer.chunks_exact(32).map(H256::from_slice).collect())
}
pub fn get_onchain_vk_from_vk<C: CurveAffine>(
vk: &VerifyingKey<C>,
circuit_metadata: AxiomV2CircuitMetadata,
) -> OnchainVerifyingKey<C> {
let preprocessed = vk
.fixed_commitments()
.iter()
.chain(vk.permutation().commitments().iter())
.cloned()
.map(Into::into)
.collect();
let transcript_initial_state = transcript_initial_state(vk);
OnchainVerifyingKey { circuit_metadata, preprocessed, transcript_initial_state }
}
pub fn get_onchain_vk_from_protocol<C: CurveAffine>(
protocol: &PlonkProtocol<C>,
circuit_metadata: AxiomV2CircuitMetadata,
) -> OnchainVerifyingKey<C> {
let preprocessed = protocol.preprocessed.clone();
let transcript_initial_state = protocol.transcript_initial_state.unwrap();
OnchainVerifyingKey { circuit_metadata, preprocessed, transcript_initial_state }
}
pub fn reconstruct_snark_from_compute_query(
subquery_results: LogicOutputResultsRoot,
compute_query: AxiomV2ComputeQuery,
) -> anyhow::Result<(Snark, AxiomV2CircuitMetadata)> {
let subquery_results = CircuitOutputResultsRoot::<Fr>::try_from(subquery_results)?;
let vkey = compute_query.vkey.into_iter().flat_map(|u| u.0).collect_vec();
let mut reader = &vkey[..];
let onchain_vk = OnchainVerifyingKey::<G1Affine>::read(&mut reader)?;
let client_metadata = onchain_vk.circuit_metadata.clone();
let k = compute_query.k as usize;
let protocol = onchain_vk.into_plonk_protocol(k)?;
if client_metadata.num_instance.len() != 1 {
bail!("Only one instance column supported right now");
}
let num_instance = client_metadata.num_instance[0] as usize;
let num_subqueries = subquery_results.num_subqueries;
let result_len = compute_query.result_len as usize;
let max_outputs = client_metadata.max_outputs as usize;
if result_len > max_outputs {
bail!("user_output_len exceeds user max outputs");
}
let mut reader = &compute_query.compute_proof[..];
let AxiomV2ComputeSnark { compute_results, proof_transcript, kzg_accumulator } =
decode_compute_snark(
&mut reader,
compute_query.result_len,
client_metadata.is_aggregation,
)?;
let mut instance = Vec::with_capacity(num_instance);
if let Some((lhs, rhs)) = kzg_accumulator {
instance.extend(
[lhs.x, lhs.y, rhs.x, rhs.y].into_iter().flat_map(fe_to_limbs::<_, Fr, LIMBS, BITS>),
);
}
let mut compute_results =
compute_results.into_iter().flat_map(|out| HiLo::from(out).hi_lo()).collect_vec();
assert_eq!(compute_results.len(), result_len * USER_RESULT_FIELD_ELEMENTS);
compute_results
.resize((client_metadata.max_outputs as usize) * USER_RESULT_FIELD_ELEMENTS, Fr::ZERO);
instance.extend(compute_results);
for result in &subquery_results.results.rows[..num_subqueries] {
instance.extend(result.to_fixed_array());
}
if instance.len() > num_instance {
bail!("Num subqueries exceeds num_instance limit");
}
instance.resize(num_instance, Fr::ZERO);
let snark = Snark::new(protocol, vec![instance], proof_transcript);
Ok((snark, client_metadata))
}
pub fn verify_snark(dk: &KzgDecidingKey<Bn256>, snark: &Snark) -> anyhow::Result<()> {
let mut transcript =
PoseidonTranscript::<NativeLoader, &[u8]>::from_spec(snark.proof(), POSEIDON_SPEC.clone());
let proof: PlonkProof<_, _, SHPLONK> =
PlonkVerifier::read_proof(dk, &snark.protocol, &snark.instances, &mut transcript)
.map_err(|_| anyhow!("Failed to read PlonkProof"))?;
PlonkVerifier::verify(dk, &snark.protocol, &snark.instances, &proof)
.map_err(|_| anyhow!("PlonkVerifier failed"))?;
Ok(())
}
lazy_static::lazy_static! {
pub static ref DEFAULT_CLIENT_METADATA: AxiomV2CircuitMetadata = AxiomV2CircuitMetadata {
version: 0,
num_advice_per_phase: vec![USER_ADVICE_COLS as u16],
num_lookup_advice_per_phase: vec![USER_LOOKUP_ADVICE_COLS as u8],
num_rlc_columns: 0,
num_fixed: USER_FIXED_COLS as u8,
num_instance: vec![
(USER_MAX_OUTPUTS * USER_RESULT_FIELD_ELEMENTS + USER_MAX_SUBQUERIES * SUBQUERY_RESULT_LEN)
as u32,
],
num_challenge: vec![0],
max_outputs: USER_MAX_OUTPUTS as u16,
is_aggregation: false,
};
}
#[derive(Clone, Copy, Debug, Hash, Serialize, Deserialize)]
pub struct UserCircuitParams {
pub num_advice_cols: usize,
pub num_lookup_advice_cols: usize,
pub num_fixed_cols: usize,
pub max_outputs: usize,
pub max_subqueries: usize,
}
impl UserCircuitParams {
pub fn num_instances(&self) -> usize {
self.max_outputs * USER_RESULT_FIELD_ELEMENTS + self.max_subqueries * SUBQUERY_RESULT_LEN
}
pub fn base_circuit_params(&self, k: usize) -> BaseCircuitParams {
BaseCircuitParams {
k,
num_advice_per_phase: vec![self.num_advice_cols],
num_lookup_advice_per_phase: vec![self.num_lookup_advice_cols],
num_fixed: self.num_fixed_cols,
lookup_bits: Some(k - 1),
num_instance_columns: USER_INSTANCE_COLS,
}
}
}
pub const DEFAULT_USER_PARAMS: UserCircuitParams = UserCircuitParams {
num_advice_cols: USER_ADVICE_COLS,
num_lookup_advice_cols: USER_LOOKUP_ADVICE_COLS,
num_fixed_cols: USER_FIXED_COLS,
max_outputs: USER_MAX_OUTPUTS,
max_subqueries: USER_MAX_SUBQUERIES,
};
pub fn dummy_compute_circuit<F: ScalarField>(
user_params: UserCircuitParams,
k: u32,
) -> BaseCircuitBuilder<F> {
let circuit_params = user_params.base_circuit_params(k as usize);
let mut builder = BaseCircuitBuilder::new(false).use_params(circuit_params);
let ctx = builder.main(0);
let dummy_instances = ctx.assign_witnesses(vec![F::ZERO; user_params.num_instances()]);
assert_eq!(builder.assigned_instances.len(), USER_INSTANCE_COLS);
builder.assigned_instances[0] = dummy_instances;
builder
}
pub fn default_compute_circuit(k: u32) -> BaseCircuitBuilder<Fr> {
dummy_compute_circuit(DEFAULT_USER_PARAMS, k)
}