use std::iter::zip;
use anyhow::bail;
use axiom_codec::{
constants::{MAX_SUBQUERY_OUTPUTS, USER_RESULT_FIELD_ELEMENTS, USER_RESULT_LEN_BYTES},
types::field_elements::SUBQUERY_RESULT_LEN,
HiLo,
};
use axiom_eth::{
halo2_base::{
gates::{GateInstructions, RangeInstructions},
halo2_proofs::halo2curves::bn256::Fr,
safe_types::{SafeBytes32, SafeTypeChip},
AssignedValue,
QuantumCell::Constant,
},
halo2_proofs::plonk::ConstraintSystem,
keccak::{types::ComponentTypeKeccak, KeccakChip},
rlc::{
circuit::builder::RlcCircuitBuilder,
types::{AssignedVarLenVec, ConcatVarFixedArrayWitness},
},
snark_verifier_sdk::{
halo2::aggregation::{
aggregate_snarks, AggregationCircuit, PreprocessedAndDomainAsWitness,
SnarkAggregationOutput, VerifierUniversality,
},
CircuitExt, SHPLONK,
},
utils::{
build_utils::aggregation::CircuitMetadata,
circuit_utils::unsafe_lt_mask,
component::{
circuit::{ComponentBuilder, ComponentCircuitImpl, CoreBuilder, CoreBuilderOutput},
promise_collector::PromiseCaller,
promise_loader::single::PromiseLoader,
types::FixLenLogical,
utils::create_hasher,
NUM_COMPONENT_OWNED_INSTANCES,
},
enforce_conditional_equality,
snark_verifier::NUM_FE_ACCUMULATOR,
uint_to_bytes_be,
},
};
use itertools::{zip_eq, EitherOrBoth, Itertools};
use crate::{
components::results::results_root::get_results_root_poseidon, ff::Field as _,
utils::client_circuit::vkey::OnchainVerifyingKey,
verify_compute::types::LogicalPublicInstanceVerifyCompute,
};
use super::{
query_hash::{
encode_compute_query_phase1, encode_query_schema, get_data_query_hash, get_query_hash_v2,
get_query_schema_hash,
},
types::{CircuitInputVerifyCompute, ComponentTypeVerifyCompute, CoreParamsVerifyCompute},
};
type F = Fr;
pub struct CoreBuilderVerifyCompute {
input: Option<CircuitInputVerifyCompute>,
params: CoreParamsVerifyCompute,
payload: Option<(KeccakChip<F>, ConcatVarFixedArrayWitness<F>)>,
}
pub type PromiseLoaderVerifyCompute = PromiseLoader<F, ComponentTypeKeccak<F>>;
pub type ComponentCircuitVerifyCompute =
ComponentCircuitImpl<F, CoreBuilderVerifyCompute, PromiseLoaderVerifyCompute>;
impl CircuitMetadata for CoreBuilderVerifyCompute {
const HAS_ACCUMULATOR: bool = true;
fn accumulator_indices() -> Option<Vec<(usize, usize)>> {
Some((0..NUM_FE_ACCUMULATOR).map(|i| (0, NUM_COMPONENT_OWNED_INSTANCES + i)).collect())
}
fn num_instance(&self) -> Vec<usize> {
vec![
NUM_COMPONENT_OWNED_INSTANCES + NUM_FE_ACCUMULATOR + super::types::NUM_LOGICAL_INSTANCE,
]
}
}
impl CoreBuilderVerifyCompute {
pub fn client_max_outputs(&self) -> usize {
self.params.client_metadata().max_outputs as usize
}
}
impl ComponentBuilder<F> for CoreBuilderVerifyCompute {
type Params = CoreParamsVerifyCompute;
fn new(params: Self::Params) -> Self {
Self { input: None, params, payload: None }
}
fn get_params(&self) -> Self::Params {
self.params.clone()
}
fn clear_witnesses(&mut self) {
self.payload = None;
}
fn calculate_params(&mut self) -> Self::Params {
self.params.clone()
}
fn configure_with_params(_: &mut ConstraintSystem<F>, _: Self::Params) {}
}
impl CoreBuilder<F> for CoreBuilderVerifyCompute {
type CompType = ComponentTypeVerifyCompute;
type PublicInstanceValue = LogicalPublicInstanceVerifyCompute<F>;
type PublicInstanceWitness = LogicalPublicInstanceVerifyCompute<AssignedValue<F>>;
type CoreInput = CircuitInputVerifyCompute;
fn feed_input(&mut self, mut input: Self::CoreInput) -> anyhow::Result<()> {
let cap = self.params.subquery_results_capacity();
let len = input.subquery_results.results.len();
if cap < len {
bail!("Feed CircuitInputVerifyCompute Error: length of subquery_results {len} is greater than subquery_results_capacity {cap}");
}
input.subquery_results.resize_with_first(cap);
if self.params.preprocessed_len() != input.compute_snark().protocol.preprocessed.len() {
bail!("Feed CircuitInputVerifyCompute Error: preprocessed_len does not match compute_snark");
}
let compute_snark = &mut input.compute_snark;
let client_vk = OnchainVerifyingKey {
circuit_metadata: self.params.client_metadata().clone(),
transcript_initial_state: compute_snark.protocol.transcript_initial_state.unwrap(),
preprocessed: std::mem::take(&mut compute_snark.protocol.preprocessed),
};
compute_snark.protocol = client_vk.into_plonk_protocol(compute_snark.protocol.domain.k)?;
self.input = Some(input);
Ok(())
}
fn virtual_assign_phase0(
&mut self,
builder: &mut RlcCircuitBuilder<F>,
promise_caller: PromiseCaller<F>,
) -> CoreBuilderOutput<F, Self::CompType> {
let keccak_chip =
KeccakChip::new_with_promise_collector(builder.range_chip(), promise_caller.clone());
let keccak = &keccak_chip;
let range = keccak.range();
let gate = range.gate();
let safe = SafeTypeChip::new(range);
let input = self.input.as_ref().unwrap();
let pool = builder.base.pool(0);
let SnarkAggregationOutput {
mut preprocessed,
mut previous_instances,
accumulator,
mut proof_transcripts,
} = aggregate_snarks::<SHPLONK>(
pool,
range,
self.params.svk().into(),
[input.compute_snark().clone()],
VerifierUniversality::Full,
);
let ctx = builder.base.main(0);
let source_chain_id = ctx.load_witness(F::from(input.source_chain_id));
let ne_cq = safe.load_bool(ctx, input.nonempty_compute_query);
let PreprocessedAndDomainAsWitness { mut preprocessed, k } = preprocessed.pop().unwrap();
let transcript_init_state = preprocessed.pop().unwrap();
let compute_proof_transcript = proof_transcripts.pop().unwrap();
let mut compute_instances = previous_instances.pop().unwrap();
let compute_accumulator = {
let acc_indices = &input.compute_snark().protocol.accumulator_indices;
if acc_indices.is_empty() {
None
} else {
assert_eq!(acc_indices.len(), 1);
assert_eq!(&acc_indices[0], &AggregationCircuit::accumulator_indices().unwrap());
Some(compute_instances.drain(0..NUM_FE_ACCUMULATOR).collect_vec())
}
};
let mut compute_results = compute_instances;
let query_instances =
compute_results.split_off(USER_RESULT_FIELD_ELEMENTS * self.client_max_outputs());
let result_len = ctx.load_witness(F::from(input.result_len as u64));
const RESULT_LEN_BITS: usize = 8 * USER_RESULT_LEN_BYTES;
range.range_check(ctx, result_len, RESULT_LEN_BITS);
let compute_results = AssignedVarLenVec { values: compute_results, len: result_len };
let encoded_query_schema = encode_query_schema(
ctx,
range,
k,
result_len,
self.params.client_metadata(),
transcript_init_state,
&preprocessed,
);
let query_schema = get_query_schema_hash(ctx, keccak, &encoded_query_schema, ne_cq);
let subquery_hashes = &input.subquery_results.subquery_hashes;
let (subquery_hashes, subquery_hashes_hilo): (Vec<_>, Vec<_>) = subquery_hashes
.iter()
.map(|subquery_hash| {
let hilo = subquery_hash.hi_lo();
let hilo = hilo.map(|x| ctx.load_witness(x));
let bytes = hilo.map(|x| uint_to_bytes_be(ctx, range, &x, 16)).concat();
(SafeBytes32::try_from(bytes).unwrap(), hilo)
})
.unzip();
let num_subqueries = input.subquery_results.num_subqueries as u64;
let num_subqueries = ctx.load_witness(F::from(num_subqueries));
let total_subquery_capacity = input.subquery_results.results.len() as u64;
range.check_less_than_safe(ctx, num_subqueries, total_subquery_capacity + 1);
let (data_query_hash, encoded_source_chain_id) =
get_data_query_hash(ctx, keccak, source_chain_id, &subquery_hashes, num_subqueries);
let (query_hash, concat_proof_witness) = get_query_hash_v2(
ctx,
keccak,
&encoded_source_chain_id,
&data_query_hash,
&encoded_query_schema,
compute_accumulator,
&compute_results,
compute_proof_transcript,
ne_cq,
);
let max_res_len = gate.select(
ctx,
Constant(F::from(self.client_max_outputs() as u64)),
num_subqueries,
ne_cq,
);
let max_res_len_p1 = gate.inc(ctx, max_res_len);
range.check_less_than(ctx, result_len, max_res_len_p1, RESULT_LEN_BITS);
let table = input.subquery_results.results.assign(ctx);
assert_eq!(query_instances.len() % SUBQUERY_RESULT_LEN, 0);
let user_subqueries = query_instances.chunks_exact(SUBQUERY_RESULT_LEN);
let subquery_mask =
unsafe_lt_mask(ctx, gate, num_subqueries, total_subquery_capacity as usize);
for (i, it) in user_subqueries.zip_longest(&table.rows).enumerate() {
match it {
EitherOrBoth::Both(user, row) => {
let key = &row.key.0;
let out = &row.value.0;
for (&usr, &res) in zip(user, key.iter().chain(out.iter())) {
let res = gate.mul(ctx, res, subquery_mask[i]);
enforce_conditional_equality(ctx, gate, usr, res, ne_cq);
}
}
EitherOrBoth::Left(user) => {
for v in user {
gate.assert_is_const(ctx, v, &Fr::ZERO);
}
}
EitherOrBoth::Right(_) => {
break;
}
}
}
let result_byte_len = gate.mul(ctx, result_len, Constant(F::from(32)));
let mut compute_results_bytes = vec![];
for it in
compute_results.values.chunks_exact(USER_RESULT_FIELD_ELEMENTS).zip_longest(&table.rows)
{
let words = match it {
EitherOrBoth::Both(user, subquery_res) => zip_eq(user, &subquery_res.value.0)
.map(|(&user, &val)| gate.select(ctx, user, val, *ne_cq.as_ref()))
.collect_vec(),
EitherOrBoth::Left(user) => user.to_vec(),
EitherOrBoth::Right(subquery_res) => subquery_res.value.0.to_vec(),
};
for word in &words {
compute_results_bytes.extend(uint_to_bytes_be(ctx, range, word, 16));
}
}
let compute_results_hash = keccak.keccak_var_len(
ctx,
compute_results_bytes.into_iter().map(From::from).collect(),
result_byte_len,
0,
);
let mut poseidon = create_hasher();
poseidon.initialize_consts(ctx, range.gate());
let results_root_poseidon = {
assert_eq!(table.rows.len(), total_subquery_capacity as usize);
get_results_root_poseidon(
ctx,
range,
&poseidon,
&table.rows,
num_subqueries,
&subquery_mask,
)
};
let promise_subquery_hashes = {
let to_commit = subquery_hashes_hilo.into_iter().flatten().collect_vec();
let len = gate.mul(ctx, num_subqueries, Constant(F::from(MAX_SUBQUERY_OUTPUTS as u64)));
poseidon.hash_var_len_array(ctx, range, &to_commit, len)
};
let logical_pis = LogicalPublicInstanceVerifyCompute {
accumulator,
source_chain_id,
compute_results_hash: HiLo::from_hi_lo(compute_results_hash.hi_lo()),
query_hash: HiLo::from_hi_lo(query_hash.hi_lo()),
query_schema: HiLo::from_hi_lo(query_schema.hi_lo()),
results_root_poseidon,
promise_subquery_hashes,
};
self.payload = Some((keccak_chip, concat_proof_witness));
CoreBuilderOutput {
public_instances: logical_pis.into_raw(),
virtual_table: vec![],
logical_results: vec![],
}
}
fn virtual_assign_phase1(&mut self, builder: &mut RlcCircuitBuilder<F>) {
let (keccak, payload) = self.payload.take().unwrap();
let gate = keccak.gate();
let rlc = builder.rlc_chip(gate);
let rlc_pair = builder.rlc_ctx_pair();
encode_compute_query_phase1(rlc_pair, gate, &rlc, payload);
}
}