axiom_circuit/
scaffold.rs

1use std::{
2    borrow::BorrowMut,
3    cell::RefCell,
4    fmt::Debug,
5    mem,
6    ops::DerefMut,
7    sync::{Arc, Mutex},
8};
9
10use axiom_codec::{
11    constants::{USER_MAX_OUTPUTS, USER_MAX_SUBQUERIES, USER_RESULT_FIELD_ELEMENTS},
12    types::field_elements::SUBQUERY_RESULT_LEN,
13    utils::native::decode_hilo_to_h256,
14    HiLo,
15};
16use axiom_query::axiom_eth::{
17    halo2_base::{
18        gates::{
19            circuit::{BaseConfig, CircuitBuilderStage},
20            RangeChip,
21        },
22        safe_types::SafeTypeChip,
23        virtual_region::manager::VirtualRegionManager,
24        AssignedValue,
25    },
26    halo2_proofs::{
27        circuit::{Layouter, SimpleFloorPlanner},
28        plonk::{Circuit, ConstraintSystem, Error},
29    },
30    rlc::{
31        circuit::{builder::RlcCircuitBuilder, RlcConfig},
32        virtual_region::RlcThreadBreakPoints,
33    },
34    snark_verifier_sdk::CircuitExt,
35    utils::{
36        keccak::decorator::{KeccakCallCollector, RlcKeccakCircuitParams, RlcKeccakConfig},
37        DEFAULT_RLC_CACHE_BITS,
38    },
39    zkevm_hashes::keccak::{
40        component::circuit::shard::LoadedKeccakF,
41        vanilla::{
42            keccak_packed_multi::get_num_keccak_f,
43            param::{NUM_ROUNDS, NUM_WORDS_TO_ABSORB},
44            KeccakCircuitConfig,
45        },
46    },
47    Field,
48};
49use ethers::providers::{JsonRpcClient, Provider};
50use itertools::Itertools;
51
52use crate::{
53    input::flatten::InputFlatten,
54    subquery::caller::SubqueryCaller,
55    types::{AxiomCircuitConfig, AxiomCircuitParams, AxiomCircuitPinning, AxiomV2DataAndResults},
56};
57
58pub trait AxiomCircuitScaffold<P: JsonRpcClient, F: Field>: Default + Clone + Debug {
59    type InputValue: Clone + Debug + Default + InputFlatten<F>;
60    type InputWitness: Clone + Debug + InputFlatten<AssignedValue<F>>;
61    type FirstPhasePayload: Clone = ();
62
63    fn virtual_assign_phase0(
64        builder: &mut RlcCircuitBuilder<F>,
65        range: &RangeChip<F>,
66        subquery_caller: Arc<Mutex<SubqueryCaller<P, F>>>,
67        callback: &mut Vec<HiLo<AssignedValue<F>>>,
68        assigned_inputs: Self::InputWitness,
69    ) -> Self::FirstPhasePayload;
70
71    /// Most people should not use this
72    #[allow(unused_variables)]
73    fn virtual_assign_phase1(
74        builder: &mut RlcCircuitBuilder<F>,
75        range: &RangeChip<F>,
76        payload: Self::FirstPhasePayload,
77    ) {
78    }
79}
80
81#[derive(Clone, Debug)]
82pub struct AxiomCircuit<F: Field, P: JsonRpcClient, A: AxiomCircuitScaffold<P, F>> {
83    pub builder: RefCell<RlcCircuitBuilder<F>>,
84    pub inputs: Option<A::InputValue>,
85    pub provider: Provider<P>,
86    range: RangeChip<F>,
87    payload: RefCell<Option<A::FirstPhasePayload>>,
88    output: RefCell<AxiomV2DataAndResults>,
89    keccak_call_collector: RefCell<KeccakCallCollector<F>>,
90    keccak_rows_per_round: usize,
91    max_user_outputs: usize,
92    max_user_subqueries: usize,
93}
94
95impl<F: Field, P: JsonRpcClient + Clone, A: AxiomCircuitScaffold<P, F>> AxiomCircuit<F, P, A> {
96    pub fn new(provider: Provider<P>, circuit_params: AxiomCircuitParams) -> Self {
97        Self::from_stage(provider, circuit_params, CircuitBuilderStage::Mock)
98    }
99
100    pub fn prover(provider: Provider<P>, pinning: AxiomCircuitPinning) -> Self {
101        let mut circuit = Self::from_stage(provider, pinning.params, CircuitBuilderStage::Prover);
102        circuit.set_break_points(pinning.break_points);
103        circuit
104    }
105
106    pub fn from_stage(
107        provider: Provider<P>,
108        circuit_params: AxiomCircuitParams,
109        stage: CircuitBuilderStage,
110    ) -> Self {
111        let params = RlcKeccakCircuitParams::from(circuit_params);
112        let rlc_bits = if params.rlc.num_rlc_columns > 0 {
113            DEFAULT_RLC_CACHE_BITS
114        } else {
115            0
116        };
117        let builder =
118            RlcCircuitBuilder::<F>::from_stage(stage, rlc_bits).use_params(params.rlc.clone());
119        let range = RangeChip::new(
120            params.rlc.base.lookup_bits.unwrap(),
121            builder.base.lookup_manager().clone(),
122        );
123        Self {
124            builder: RefCell::new(builder),
125            range,
126            inputs: None,
127            provider,
128            payload: RefCell::new(None),
129            keccak_rows_per_round: params.keccak_rows_per_round,
130            output: Default::default(),
131            keccak_call_collector: RefCell::new(Default::default()),
132            max_user_outputs: USER_MAX_OUTPUTS,
133            max_user_subqueries: USER_MAX_SUBQUERIES,
134        }
135    }
136
137    pub fn set_max_user_outputs(&mut self, max_user_outputs: usize) {
138        self.max_user_outputs = max_user_outputs;
139    }
140
141    pub fn use_max_user_outputs(mut self, max_user_outputs: usize) -> Self {
142        self.set_max_user_outputs(max_user_outputs);
143        self
144    }
145
146    pub fn set_max_user_subqueries(&mut self, max_user_subqueries: usize) {
147        self.max_user_subqueries = max_user_subqueries;
148    }
149
150    pub fn use_max_user_subqueries(mut self, max_user_subqueries: usize) -> Self {
151        self.set_max_user_subqueries(max_user_subqueries);
152        self
153    }
154
155    pub fn set_inputs(&mut self, inputs: Option<A::InputValue>) {
156        self.inputs = inputs;
157    }
158
159    pub fn use_inputs(mut self, inputs: Option<A::InputValue>) -> Self {
160        self.set_inputs(inputs);
161        self
162    }
163
164    pub fn set_break_points(&mut self, break_points: RlcThreadBreakPoints) {
165        self.builder.borrow_mut().set_break_points(break_points);
166    }
167
168    pub fn use_break_points(mut self, break_points: RlcThreadBreakPoints) -> Self {
169        self.set_break_points(break_points);
170        self
171    }
172
173    pub fn set_params(&mut self, params: AxiomCircuitParams) {
174        let params = RlcKeccakCircuitParams::from(params);
175        let mut builder = self.builder.borrow_mut();
176        builder.set_params(params.rlc.clone());
177        self.keccak_rows_per_round = params.keccak_rows_per_round;
178    }
179
180    pub fn use_params(mut self, params: AxiomCircuitParams) -> Self {
181        self.set_params(params);
182        self
183    }
184
185    pub fn set_pinning(&mut self, pinning: AxiomCircuitPinning) {
186        self.set_params(pinning.params);
187        self.set_break_points(pinning.break_points);
188    }
189
190    pub fn use_pinning(mut self, pinning: AxiomCircuitPinning) -> Self {
191        self.set_pinning(pinning);
192        self
193    }
194
195    pub fn set_provider(&mut self, provider: Provider<P>) {
196        self.provider = provider;
197    }
198
199    pub fn use_provider(mut self, provider: Provider<P>) -> Self {
200        self.set_provider(provider);
201        self
202    }
203
204    pub fn break_points(&self) -> RlcThreadBreakPoints {
205        let rlc_params = self.builder.borrow().params();
206        if rlc_params.num_rlc_columns == 0 {
207            let break_points = self.builder.borrow().base.break_points();
208            RlcThreadBreakPoints {
209                base: break_points,
210                rlc: vec![],
211            }
212        } else {
213            self.builder.borrow().break_points()
214        }
215    }
216
217    pub fn pinning(&self) -> AxiomCircuitPinning {
218        AxiomCircuitPinning {
219            params: self.params(),
220            break_points: self.break_points(),
221        }
222    }
223
224    pub fn k(&self) -> usize {
225        self.builder.borrow().params().base.k
226    }
227
228    pub fn output_num_instances(&self) -> usize {
229        self.max_user_outputs * USER_RESULT_FIELD_ELEMENTS
230    }
231
232    pub fn subquery_num_instances(&self) -> usize {
233        self.max_user_subqueries * SUBQUERY_RESULT_LEN
234    }
235
236    fn virtual_assign_phase0(&self) {
237        if self.payload.borrow().is_some() {
238            return;
239        }
240        let is_inputs = self.inputs.is_none();
241        let flattened_inputs = self.inputs.clone().unwrap_or_default().flatten_vec();
242        let assigned_input_vec = self
243            .builder
244            .borrow_mut()
245            .base
246            .main(0)
247            .assign_witnesses(flattened_inputs);
248        let assigned_inputs = A::InputWitness::unflatten(assigned_input_vec).unwrap();
249
250        let subquery_caller = Arc::new(Mutex::new(SubqueryCaller::new(
251            self.provider.clone(),
252            is_inputs,
253        )));
254        let mut callback = Vec::new();
255        let payload = A::virtual_assign_phase0(
256            &mut self.builder.borrow_mut(),
257            &self.range,
258            subquery_caller.clone(),
259            &mut callback,
260            assigned_inputs,
261        );
262        self.payload.borrow_mut().replace(payload);
263
264        let mut flattened_callback = callback
265            .clone()
266            .into_iter()
267            .flat_map(|hilo| hilo.flatten())
268            .collect::<Vec<_>>();
269        flattened_callback.resize_with(self.output_num_instances(), || {
270            self.builder
271                .borrow_mut()
272                .base
273                .main(0)
274                .load_constant(F::ZERO)
275        });
276
277        let mut subquery_instances = subquery_caller.lock().unwrap().instances().clone();
278        subquery_instances.resize_with(self.subquery_num_instances(), || {
279            self.builder
280                .borrow_mut()
281                .base
282                .main(0)
283                .load_constant(F::ZERO)
284        });
285
286        flattened_callback.extend(subquery_instances);
287        let instances = vec![flattened_callback];
288        self.builder.borrow_mut().base.assigned_instances = instances.clone();
289
290        let circuit_output = callback
291            .iter()
292            .map(|hilo| decode_hilo_to_h256(HiLo::from_hi_lo(hilo.hi_lo().map(|x| *x.value()))))
293            .collect_vec();
294        self.output.replace(AxiomV2DataAndResults {
295            data_query: subquery_caller.lock().unwrap().data_query(),
296            compute_results: circuit_output,
297        });
298
299        self.keccak_call_collector.borrow_mut().var_len_calls =
300            subquery_caller.lock().unwrap().keccak_var_len_calls.clone();
301        self.keccak_call_collector.borrow_mut().fix_len_calls =
302            subquery_caller.lock().unwrap().keccak_fix_len_calls.clone();
303    }
304
305    fn virtual_assign_phase1(&self) {
306        let payload = self
307            .payload
308            .borrow_mut()
309            .take()
310            .expect("FirstPhase witness generation was not run");
311        self.builder.borrow_mut().base.main(1);
312        A::virtual_assign_phase1(&mut self.builder.borrow_mut(), &self.range, payload);
313    }
314
315    fn synthesize_without_rlc(
316        &self,
317        config: BaseConfig<F>,
318        mut layouter: impl Layouter<F>,
319    ) -> Result<(), Error> {
320        self.virtual_assign_phase0();
321        if !self.keccak_call_collector.borrow().fix_len_calls.is_empty()
322            || !self.keccak_call_collector.borrow().var_len_calls.is_empty()
323        {
324            panic!("Keccak calls made but keccak_rows_per_round is None");
325        }
326        self.builder.borrow_mut().base.synthesize(
327            config,
328            layouter.namespace(|| "BaseCircuitBuilder raw synthesize phase0"),
329        )
330    }
331
332    fn synthesize_with_rlc_and_keccak(
333        &self,
334        config: RlcConfig<F>,
335        keccak_config: Option<KeccakCircuitConfig<F>>,
336        mut layouter: impl Layouter<F>,
337    ) -> Result<(), Error> {
338        config.base.initialize(&mut layouter);
339        let k = self.builder.borrow().params().base.k;
340        self.virtual_assign_phase0();
341        if let Some(keccak_config) = keccak_config {
342            keccak_config.load_aux_tables(&mut layouter, k as u32)?;
343            let keccak_calls = mem::take(self.keccak_call_collector.borrow_mut().deref_mut());
344
345            keccak_calls.assign_raw_and_constrain(
346                keccak_config.parameters,
347                &keccak_config,
348                &mut layouter.namespace(|| "keccak sub-circuit"),
349                self.builder.borrow_mut().base.pool(0),
350                &self.range,
351            )?;
352        }
353        {
354            let rlc_builder = self.builder.borrow_mut();
355
356            let phase0_layouter = layouter.namespace(|| "RlcCircuitBuilder raw synthesize phase0");
357            rlc_builder.raw_synthesize_phase0(&config, phase0_layouter);
358        }
359
360        layouter.next_phase();
361        self.builder
362            .borrow_mut()
363            .load_challenge(&config, layouter.namespace(|| "load challenges"));
364
365        self.virtual_assign_phase1();
366        {
367            let rlc_builder = self.builder.borrow();
368
369            let phase1_layouter = layouter.namespace(|| "RlcCircuitBuilder raw synthesize phase1");
370            rlc_builder.raw_synthesize_phase1(&config, phase1_layouter, false);
371        }
372
373        let rlc_builder = self.builder.borrow();
374        if !rlc_builder.witness_gen_only() {
375            layouter.assign_region(
376                || "copy constraints",
377                |mut region| {
378                    let constant_cols = config.base.constants();
379                    rlc_builder
380                        .copy_manager()
381                        .assign_raw(constant_cols, &mut region);
382                    Ok(())
383                },
384            )?;
385        }
386        drop(rlc_builder);
387
388        self.builder.borrow_mut().clear();
389        Ok(())
390    }
391
392    fn clear(&self) {
393        self.builder.borrow_mut().clear();
394        self.payload.borrow_mut().take();
395        self.keccak_call_collector.borrow_mut().clear();
396        self.output.borrow_mut().compute_results.clear();
397        self.output.borrow_mut().data_query.clear();
398    }
399
400    pub fn calculate_params(&mut self) {
401        self.virtual_assign_phase0();
402        let keccak_calls = mem::take(self.keccak_call_collector.borrow_mut().deref_mut());
403        let mut capacity = 0;
404        for (call, _) in keccak_calls.fix_len_calls.iter() {
405            capacity += get_num_keccak_f(call.bytes().len());
406        }
407        for (call, _) in keccak_calls.var_len_calls.iter() {
408            capacity += get_num_keccak_f(call.bytes().max_len());
409        }
410        // make mock loaded_keccak_fs just to simulate
411        let copy_manager_ref = self.builder.borrow().copy_manager().clone();
412        let mut copy_manager = copy_manager_ref.lock().unwrap();
413        let virtual_keccak_fs = (0..capacity)
414            .map(|_| {
415                LoadedKeccakF::new(
416                    copy_manager.mock_external_assigned(F::ZERO),
417                    core::array::from_fn(|_| copy_manager.mock_external_assigned(F::ZERO)),
418                    SafeTypeChip::unsafe_to_bool(copy_manager.mock_external_assigned(F::ZERO)),
419                    copy_manager.mock_external_assigned(F::ZERO),
420                    copy_manager.mock_external_assigned(F::ZERO),
421                )
422            })
423            .collect_vec();
424        drop(copy_manager);
425        keccak_calls.pack_and_constrain(
426            virtual_keccak_fs,
427            self.builder.borrow_mut().base.pool(0),
428            self.range.borrow_mut(),
429        );
430        self.virtual_assign_phase1();
431
432        // TMP: use empirical constants
433        let unusable_rows = 109;
434        self.builder
435            .borrow_mut()
436            .calculate_params(Some(unusable_rows));
437        let usable_rows = (1 << self.builder.borrow().base.config_params.k) - unusable_rows;
438        // This is the inverse of [zkevm_hashes::keccak::vanilla::keccak_packed_multi::get_keccak_capacity].
439        let rows_per_round = usable_rows / (capacity * (NUM_ROUNDS + 1) + 1 + NUM_WORDS_TO_ABSORB);
440        // log::info!("RlcKeccakCircuit used capacity: {capacity}");
441        // log::info!("RlcKeccakCircuit optimal rows_per_round : {rows_per_round}");
442        // Empirically more than 50 rows makes the rotations inhibit performance.
443        self.keccak_rows_per_round = rows_per_round.min(50);
444
445        self.clear();
446    }
447
448    pub fn instances(&self) -> Vec<Vec<F>> {
449        self.virtual_assign_phase0();
450        let builder = self.builder.borrow();
451        builder
452            .base
453            .assigned_instances
454            .iter()
455            .map(|instance| instance.iter().map(|x| *x.value()).collect())
456            .collect()
457    }
458
459    pub fn scaffold_output(&self) -> AxiomV2DataAndResults {
460        self.virtual_assign_phase0();
461        self.output.borrow().clone()
462    }
463}
464
465impl<F: Field, P: JsonRpcClient + Clone, A: AxiomCircuitScaffold<P, F>> Circuit<F>
466    for AxiomCircuit<F, P, A>
467{
468    type Config = AxiomCircuitConfig<F>;
469    type Params = AxiomCircuitParams;
470    type FloorPlanner = SimpleFloorPlanner;
471
472    fn without_witnesses(&self) -> Self {
473        unimplemented!()
474    }
475
476    fn params(&self) -> Self::Params {
477        let rlc_params = self.builder.borrow().params();
478        if rlc_params.num_rlc_columns == 0 && self.keccak_rows_per_round == 0 {
479            AxiomCircuitParams::Base(rlc_params.base)
480        } else if self.keccak_rows_per_round == 0 {
481            return AxiomCircuitParams::Rlc(rlc_params);
482        } else {
483            return AxiomCircuitParams::Keccak(RlcKeccakCircuitParams {
484                rlc: rlc_params,
485                keccak_rows_per_round: self.keccak_rows_per_round,
486            });
487        }
488    }
489
490    fn configure_with_params(meta: &mut ConstraintSystem<F>, params: Self::Params) -> Self::Config {
491        match params {
492            AxiomCircuitParams::Rlc(params) => {
493                let k = params.base.k;
494                let mut rlc_config = RlcConfig::configure(meta, params);
495                let usable_rows = (1 << k) - meta.minimum_rows();
496                rlc_config.set_usable_rows(usable_rows);
497                AxiomCircuitConfig::Rlc(rlc_config)
498            }
499            AxiomCircuitParams::Base(params) => {
500                AxiomCircuitConfig::Base(BaseConfig::configure(meta, params))
501            }
502            AxiomCircuitParams::Keccak(params) => {
503                AxiomCircuitConfig::Keccak(RlcKeccakConfig::configure(meta, params))
504            }
505        }
506    }
507
508    fn configure(_: &mut ConstraintSystem<F>) -> Self::Config {
509        unreachable!()
510    }
511
512    fn synthesize(&self, config: Self::Config, layouter: impl Layouter<F>) -> Result<(), Error> {
513        match config {
514            AxiomCircuitConfig::Rlc(config) => {
515                self.synthesize_with_rlc_and_keccak(config, None, layouter)
516            }
517            AxiomCircuitConfig::Base(config) => self.synthesize_without_rlc(config, layouter),
518            AxiomCircuitConfig::Keccak(config) => {
519                self.synthesize_with_rlc_and_keccak(config.rlc, Some(config.keccak), layouter)
520            }
521        }
522    }
523}
524
525impl<F: Field, P: JsonRpcClient + Clone, A: AxiomCircuitScaffold<P, F> + Default> CircuitExt<F>
526    for AxiomCircuit<F, P, A>
527{
528    fn num_instance(&self) -> Vec<usize> {
529        vec![self.output_num_instances() + self.subquery_num_instances()]
530    }
531
532    fn instances(&self) -> Vec<Vec<F>> {
533        self.instances()
534    }
535}