Skip to main content

axiom_eth/utils/component/promise_loader/
multi.rs

1#![allow(clippy::type_complexity)]
2use std::{collections::HashMap, marker::PhantomData};
3
4use crate::{
5    rlc::{
6        chip::RlcChip,
7        circuit::builder::{RlcCircuitBuilder, RlcContextPair},
8    },
9    utils::component::{
10        circuit::LoaderParamsPerComponentType,
11        promise_loader::comp_loader::SingleComponentLoaderImpl,
12    },
13    Field,
14};
15use getset::{CopyGetters, Setters};
16use halo2_base::{
17    gates::GateInstructions,
18    halo2_proofs::{
19        circuit::Layouter,
20        plonk::{ConstraintSystem, SecondPhase},
21    },
22    virtual_region::{
23        copy_constraints::SharedCopyConstraintManager, lookups::basic::BasicDynLookupConfig,
24    },
25    AssignedValue,
26};
27use itertools::Itertools;
28use serde::{Deserialize, Serialize};
29
30use crate::utils::component::{
31    circuit::{ComponentBuilder, PromiseBuilder},
32    promise_collector::{PromiseCallsGetter, PromiseCommitSetter, PromiseResultsGetter},
33    promise_loader::flatten_witness_to_rlc,
34    types::{FixLenLogical, Flatten, LogicalEmpty},
35    ComponentType, ComponentTypeId,
36};
37
38use super::comp_loader::{SingleComponentLoader, SingleComponentLoaderParams};
39
40pub trait ComponentTypeList<F: Field> {
41    fn get_component_type_ids() -> Vec<ComponentTypeId>;
42    fn build_component_loaders(
43        params_per_component: &HashMap<ComponentTypeId, SingleComponentLoaderParams>,
44    ) -> Vec<Box<dyn SingleComponentLoader<F>>>;
45}
46pub struct ComponentTypeListEnd<F: Field> {
47    _phantom: PhantomData<F>,
48}
49impl<F: Field> ComponentTypeList<F> for ComponentTypeListEnd<F> {
50    fn get_component_type_ids() -> Vec<ComponentTypeId> {
51        vec![]
52    }
53    fn build_component_loaders(
54        _params_per_component: &HashMap<ComponentTypeId, SingleComponentLoaderParams>,
55    ) -> Vec<Box<dyn SingleComponentLoader<F>>> {
56        vec![]
57    }
58}
59pub struct ComponentTypeListImpl<F: Field, HEAD: ComponentType<F>, LATER: ComponentTypeList<F>> {
60    _phantom: PhantomData<(F, HEAD, LATER)>,
61}
62impl<F: Field, HEAD: ComponentType<F>, LATER: ComponentTypeList<F>> ComponentTypeList<F>
63    for ComponentTypeListImpl<F, HEAD, LATER>
64{
65    fn get_component_type_ids() -> Vec<ComponentTypeId> {
66        let mut ret = vec![HEAD::get_type_id()];
67        ret.extend(LATER::get_component_type_ids());
68        ret
69    }
70    fn build_component_loaders(
71        params_per_component: &HashMap<ComponentTypeId, SingleComponentLoaderParams>,
72    ) -> Vec<Box<dyn SingleComponentLoader<F>>> {
73        type Loader<F, HEAD> = SingleComponentLoaderImpl<F, HEAD>;
74        let mut ret = Vec::new();
75        if let Some(params) = params_per_component.get(&HEAD::get_type_id()) {
76            let comp_loader: Box<dyn SingleComponentLoader<F>> =
77                Box::new(Loader::<F, HEAD>::new(params.clone()));
78            ret.push(comp_loader);
79        }
80        ret.extend(LATER::build_component_loaders(params_per_component));
81        ret
82    }
83}
84#[macro_export]
85macro_rules! component_type_list {
86    ($field:ty, $comp_type:ty) => {
87        $crate::utils::component::promise_loader::multi::ComponentTypeListImpl<$field, $comp_type, $crate::utils::component::promise_loader::multi::ComponentTypeListEnd<$field>>
88    };
89    ($field:ty, $comp_type:ty, $($comp_types:ty),+) => {
90        $crate::utils::component::promise_loader::multi::ComponentTypeListImpl<$field, $comp_type, $crate::component_type_list!($field, $($comp_types),+)>
91    }
92}
93
94#[derive(Clone)]
95pub struct MultiPromiseLoaderConfig {
96    pub dyn_lookup_config: BasicDynLookupConfig<1>,
97}
98
99// TODO: this is useless now because comp_loaders already have the information.
100#[derive(Clone, Default, Serialize, Deserialize)]
101pub struct MultiPromiseLoaderParams {
102    pub params_per_component: HashMap<ComponentTypeId, SingleComponentLoaderParams>,
103}
104
105/// Load promises of multiple component types which share the same lookup table.
106/// The size of promise result it receives MUST match its capacity.
107/// VT is a virtual component type which is used to generate lookup table. Its promise
108/// results should not be fulfilled by external.
109/// TODO: Currently we don't support promise calls for virtual component types so we enforce output to be empty.
110/// TODO: remove virtual component type.
111#[derive(CopyGetters, Setters)]
112pub struct MultiPromiseLoader<
113    F: Field,
114    VT: ComponentType<F, OutputValue = LogicalEmpty<F>, OutputWitness = LogicalEmpty<AssignedValue<F>>>,
115    CLIST: ComponentTypeList<F>,
116    A: RlcAdapter<F>,
117> {
118    params: MultiPromiseLoaderParams,
119    // ComponentTypeId -> (input, output)
120    witness_promise_results: Option<
121        HashMap<ComponentTypeId, Vec<(Flatten<AssignedValue<F>>, Flatten<AssignedValue<F>>)>>,
122    >,
123    // (to lookup, lookup table)
124    witness_rlc_lookup: Option<(Vec<AssignedValue<F>>, Vec<AssignedValue<F>>)>,
125    // A bit hacky..
126    witness_gen_only: bool,
127    copy_manager: Option<SharedCopyConstraintManager<F>>,
128    pub(super) comp_loaders: Vec<Box<dyn SingleComponentLoader<F>>>,
129    _phantom: PhantomData<(VT, CLIST, A)>,
130}
131
132pub trait RlcAdapter<F: Field> {
133    fn to_rlc(
134        ctx_pair: RlcContextPair<F>,
135        gate: &impl GateInstructions<F>,
136        rlc: &RlcChip<F>,
137        type_id: &ComponentTypeId,
138        io_pairs: &[(Flatten<AssignedValue<F>>, Flatten<AssignedValue<F>>)],
139    ) -> Vec<AssignedValue<F>>;
140}
141
142impl<
143        F: Field,
144        VT: ComponentType<
145            F,
146            OutputValue = LogicalEmpty<F>,
147            OutputWitness = LogicalEmpty<AssignedValue<F>>,
148        >,
149        CLIST: ComponentTypeList<F>,
150        A: RlcAdapter<F>,
151    > ComponentBuilder<F> for MultiPromiseLoader<F, VT, CLIST, A>
152{
153    type Config = MultiPromiseLoaderConfig;
154    type Params = MultiPromiseLoaderParams;
155
156    /// Create MultiPromiseLoader
157    fn new(params: MultiPromiseLoaderParams) -> Self {
158        let comp_loaders = CLIST::build_component_loaders(&params.params_per_component);
159        Self {
160            params,
161            witness_promise_results: None,
162            witness_rlc_lookup: None,
163            witness_gen_only: false,
164            copy_manager: None,
165            comp_loaders,
166            _phantom: PhantomData,
167        }
168    }
169    fn get_params(&self) -> Self::Params {
170        self.params.clone()
171    }
172
173    fn clear_witnesses(&mut self) {
174        self.witness_promise_results = None;
175        self.witness_rlc_lookup = None;
176        self.copy_manager = None;
177    }
178
179    fn configure_with_params(
180        meta: &mut ConstraintSystem<F>,
181        _params: Self::Params,
182    ) -> Self::Config {
183        // TODO: adjust num of columns based on params.
184        let dyn_lookup_config = BasicDynLookupConfig::new(meta, || SecondPhase, 1);
185        Self::Config { dyn_lookup_config }
186    }
187    fn calculate_params(&mut self) -> Self::Params {
188        self.params.clone()
189    }
190}
191
192impl<
193        F: Field,
194        VT: ComponentType<
195            F,
196            OutputValue = LogicalEmpty<F>,
197            OutputWitness = LogicalEmpty<AssignedValue<F>>,
198        >,
199        CLIST: ComponentTypeList<F>,
200        A: RlcAdapter<F>,
201    > PromiseBuilder<F> for MultiPromiseLoader<F, VT, CLIST, A>
202{
203    // NOTE: the actual dependencies are based on the params.
204    fn get_component_type_dependencies() -> Vec<ComponentTypeId> {
205        CLIST::get_component_type_ids()
206    }
207    fn extract_loader_params_per_component_type(
208        params: &Self::Params,
209    ) -> Vec<LoaderParamsPerComponentType> {
210        let mut ret = Vec::new();
211        for type_id in Self::get_component_type_dependencies() {
212            if let Some(loader_params) = params.params_per_component.get(&type_id) {
213                ret.push(LoaderParamsPerComponentType {
214                    component_type_id: type_id,
215                    loader_params: loader_params.clone(),
216                })
217            }
218        }
219        ret
220    }
221    fn fulfill_promise_results(&mut self, promise_results_getter: &impl PromiseResultsGetter<F>) {
222        assert!(
223            promise_results_getter.get_results_by_component_type_id(&VT::get_type_id()).is_none(),
224            "promise results of the virtual component type should not be fulfilled"
225        );
226        for comp_loader in &mut self.comp_loaders {
227            let component_type_id = comp_loader.get_component_type_id();
228            let promise_results = promise_results_getter
229                .get_results_by_component_type_id(&component_type_id)
230                .unwrap_or_else(|| {
231                    panic!("missing promise results for component type id {:?}", component_type_id)
232                });
233
234            comp_loader.load_promise_results(promise_results.clone());
235        }
236    }
237
238    fn virtual_assign_phase0(
239        &mut self,
240        builder: &mut RlcCircuitBuilder<F>,
241        promise_commit_setter: &mut impl PromiseCommitSetter<F>,
242    ) {
243        assert!(self.witness_promise_results.is_none());
244        self.witness_gen_only = builder.witness_gen_only();
245
246        let mut witness_promise_results = HashMap::new();
247
248        for comp_loader in &self.comp_loaders {
249            // TODO: Multi-thread here?
250            let (commit, witness_promise_results_per_type) =
251                comp_loader.assign_and_compute_commitment(builder);
252            let component_type_id = comp_loader.get_component_type_id();
253            promise_commit_setter
254                .set_commit_by_component_type_id(component_type_id.clone(), commit);
255            witness_promise_results.insert(component_type_id, witness_promise_results_per_type);
256        }
257
258        self.witness_promise_results = Some(witness_promise_results);
259    }
260
261    fn raw_synthesize_phase0(&mut self, _config: &Self::Config, _layouter: &mut impl Layouter<F>) {
262        // Do nothing.
263    }
264
265    fn virtual_assign_phase1(
266        &mut self,
267        builder: &mut RlcCircuitBuilder<F>,
268        promise_calls_getter: &mut impl PromiseCallsGetter<F>,
269    ) {
270        assert!(self.witness_promise_results.is_some());
271        let range_chip = &builder.range_chip();
272        let rlc_chip = builder.rlc_chip(&range_chip.gate);
273        let (gate_ctx, rlc_ctx) = builder.rlc_ctx_pair();
274
275        let input_multiplier =
276            rlc_chip.rlc_pow_fixed(gate_ctx, &range_chip.gate, VT::OutputValue::get_num_fields());
277
278        let component_type_id = VT::get_type_id();
279
280        let calls_per_context =
281            promise_calls_getter.get_calls_by_component_type_id(&component_type_id).unwrap();
282
283        let to_lookup_rlc = calls_per_context
284            .values()
285            .flatten()
286            .map(|(f_i, f_o)| {
287                let i_rlc = f_i.to_rlc((gate_ctx, rlc_ctx), range_chip, &rlc_chip);
288                let o_rlc = flatten_witness_to_rlc(rlc_ctx, &rlc_chip, f_o);
289                range_chip.gate.mul_add(gate_ctx, i_rlc, input_multiplier, o_rlc)
290            })
291            .collect_vec();
292
293        let num_dependencies = self.comp_loaders.len();
294        let mut lookup_table_rlc = Vec::with_capacity(num_dependencies);
295
296        // **Order must be deterministic.**
297        for comp_loader in &self.comp_loaders {
298            let component_type_id = comp_loader.get_component_type_id();
299            let ctx_pair = (&mut *gate_ctx, &mut *rlc_ctx);
300            let lookup_table_rlc_per_type = A::to_rlc(
301                ctx_pair,
302                &range_chip.gate,
303                &rlc_chip,
304                &component_type_id,
305                &self.witness_promise_results.as_ref().unwrap()[&component_type_id],
306            );
307            lookup_table_rlc.push(lookup_table_rlc_per_type);
308        }
309        let lookup_table_rlc = lookup_table_rlc.concat();
310
311        self.witness_rlc_lookup = Some((to_lookup_rlc, lookup_table_rlc));
312        self.copy_manager = Some(builder.copy_manager().clone());
313    }
314
315    fn raw_synthesize_phase1(&mut self, config: &Self::Config, layouter: &mut impl Layouter<F>) {
316        assert!(self.witness_rlc_lookup.is_some());
317
318        let (to_lookup, lookup_table) = self.witness_rlc_lookup.as_ref().unwrap();
319        let dyn_lookup_config = &config.dyn_lookup_config;
320
321        let copy_manager = (!self.witness_gen_only).then(|| self.copy_manager.as_ref().unwrap());
322        dyn_lookup_config.assign_virtual_table_to_raw(
323            layouter.namespace(|| {
324                format!("promise loader adds advice to lookup for {}", VT::get_type_name())
325            }),
326            lookup_table.iter().map(|a| [*a; 1]),
327            copy_manager,
328        );
329
330        dyn_lookup_config.assign_virtual_to_lookup_to_raw(
331            layouter
332                .namespace(|| format!("promise loader loads lookup table {}", VT::get_type_name())),
333            to_lookup.iter().map(|a| [*a; 1]),
334            copy_manager,
335        );
336    }
337}