Skip to main content

axiom_eth/utils/component/promise_loader/
comp_loader.rs

1#![allow(clippy::type_complexity)]
2use std::{iter, marker::PhantomData};
3
4use crate::utils::component::utils::compute_poseidon;
5use crate::Field;
6use crate::{
7    rlc::{chip::RlcChip, circuit::builder::RlcCircuitBuilder, RLC_PHASE},
8    utils::component::{
9        types::FixLenLogical, utils::compute_poseidon_merkle_tree, FlattenVirtualRow,
10        FlattenVirtualTable, LogicalResult, PromiseShardMetadata, SelectedDataShardsInMerkle,
11    },
12};
13use getset::{CopyGetters, Getters};
14use halo2_base::{
15    gates::{circuit::builder::BaseCircuitBuilder, GateInstructions, RangeChip, RangeInstructions},
16    AssignedValue,
17};
18use itertools::Itertools;
19use serde::{Deserialize, Serialize};
20
21use super::{
22    super::{
23        promise_collector::PromiseResultWitness,
24        types::Flatten,
25        utils::{compute_commitment_with_flatten, create_hasher},
26        ComponentPromiseResultsInMerkle, ComponentType, ComponentTypeId,
27    },
28    flatten_witness_to_rlc,
29};
30
31/// To seal SingleComponentLoader so no external implementation is allowed.
32mod private {
33    pub trait Sealed {}
34}
35
36#[derive(Clone, Debug, Hash, Getters, CopyGetters, Serialize, Deserialize, Eq, PartialEq)]
37/// Specify what merkle tree of commits can be loaded.
38pub struct SingleComponentLoaderParams {
39    /// The maximum height of the merkle tree this loader can load.
40    #[getset(get_copy = "pub")]
41    max_height: usize,
42    /// Specify the number of shards to be loaded and the capacity of each shard.
43    #[getset(get = "pub")]
44    shard_caps: Vec<usize>,
45}
46
47impl SingleComponentLoaderParams {
48    /// Create SingleComponentLoaderParams
49    pub fn new(max_height: usize, shard_caps: Vec<usize>) -> Self {
50        // Tip: binary tree with only 1 node has height 0.
51        assert!(shard_caps.len() <= 1 << max_height);
52        Self { max_height, shard_caps }
53    }
54    /// Create SingleComponentLoaderParams for only 1 shard
55    pub fn new_for_one_shard(cap: usize) -> Self {
56        Self { max_height: 0, shard_caps: vec![cap] }
57    }
58}
59
60impl Default for SingleComponentLoaderParams {
61    fn default() -> Self {
62        Self::new(0, vec![0])
63    }
64}
65
66/// Object safe trait for loading promises of a component type.
67pub trait SingleComponentLoader<F: Field>: private::Sealed {
68    /// Get the component type id this loader is for.
69    fn get_component_type_id(&self) -> ComponentTypeId;
70    /// Get ComponentTypeName for logging/debugging.
71    fn get_component_type_name(&self) -> &'static str;
72    fn get_params(&self) -> &SingleComponentLoaderParams;
73    /// Check if promise results are ready.
74    fn promise_results_ready(&self) -> bool;
75    /// Load promise results from promise results getter.
76    fn load_promise_results(&mut self, promise_results: ComponentPromiseResultsInMerkle<F>);
77    /// Load dummy promise results according to the loader params.
78    fn load_dummy_promise_results(&mut self);
79    /// Return (merkle_tree_root, concat_assigned_virtual_tables). Data is preloaded.
80    /// TODO: do we really need  to return assigned virtual table?
81    fn assign_and_compute_commitment(
82        &self,
83        builder: &mut RlcCircuitBuilder<F>,
84    ) -> (AssignedValue<F>, FlattenVirtualTable<AssignedValue<F>>);
85    /// Return (to_lookup, lookup_table)
86    fn generate_lookup_rlc(
87        &self,
88        builder: &mut RlcCircuitBuilder<F>,
89        promise_calls: &[&PromiseResultWitness<F>],
90        promise_results: &[FlattenVirtualRow<AssignedValue<F>>],
91    ) -> (Vec<AssignedValue<F>>, Vec<AssignedValue<F>>);
92}
93
94/// Promise results but each shard is in the virtual table format.
95type PromiseVirtualTableResults<F> = SelectedDataShardsInMerkle<F, FlattenVirtualTable<F>>;
96
97/// Implementation of SingleComponentLoader for a component type.
98pub struct SingleComponentLoaderImpl<F: Field, T: ComponentType<F>> {
99    val_promise_results: Option<PromiseVirtualTableResults<F>>,
100    params: SingleComponentLoaderParams,
101    _phantom: PhantomData<T>,
102}
103
104impl<F: Field, T: ComponentType<F>> SingleComponentLoaderImpl<F, T> {
105    /// Create SingleComponentLoaderImpl for T.
106    pub fn new(params: SingleComponentLoaderParams) -> Self {
107        Self { val_promise_results: None, params, _phantom: PhantomData }
108    }
109    /// Create dummy promise results based on params for CircuitBuilder params calculation.
110    fn create_dummy_promise_result_merkle(&self) -> PromiseVirtualTableResults<F> {
111        let num_shards = self.params.shard_caps.len();
112        let num_leaves = num_shards.next_power_of_two();
113        let mut leaves = Vec::with_capacity(num_leaves);
114        for i in 0..num_leaves {
115            let commit = F::ZERO;
116            leaves.push(PromiseShardMetadata::<F> {
117                commit,
118                capacity: if i < num_shards { self.params.shard_caps[i] } else { 0 },
119            });
120        }
121        let shards = self
122            .params
123            .shard_caps
124            .iter()
125            .copied()
126            .enumerate()
127            .map(|(idx, shard_cap)| {
128                let dummy_input = Flatten::<F> {
129                    fields: vec![F::ZERO; T::InputValue::get_num_fields()],
130                    field_size: T::InputValue::get_field_size(),
131                };
132                let dummy_output = Flatten::<F> {
133                    fields: vec![F::ZERO; T::OutputValue::get_num_fields()],
134                    field_size: T::OutputValue::get_field_size(),
135                };
136                let shard = vec![(dummy_input, dummy_output); shard_cap];
137                (idx, shard)
138            })
139            .collect_vec();
140        PromiseVirtualTableResults::<F>::new(leaves, shards)
141    }
142}
143
144impl<F: Field, T: ComponentType<F>> private::Sealed for SingleComponentLoaderImpl<F, T> {}
145
146impl<F: Field, T: ComponentType<F>> SingleComponentLoader<F> for SingleComponentLoaderImpl<F, T> {
147    fn get_component_type_id(&self) -> ComponentTypeId {
148        T::get_type_id()
149    }
150    fn get_component_type_name(&self) -> &'static str {
151        T::get_type_name()
152    }
153    fn get_params(&self) -> &SingleComponentLoaderParams {
154        &self.params
155    }
156    fn promise_results_ready(&self) -> bool {
157        self.val_promise_results.is_some()
158    }
159    fn load_promise_results(&mut self, promise_results: ComponentPromiseResultsInMerkle<F>) {
160        // Tip: binary tree with only 1 node has height 0.
161        assert!(promise_results.leaves().len() <= 1 << self.params.max_height);
162        assert_eq!(promise_results.shards().len(), self.params.shard_caps().len());
163
164        let merkle_vt = promise_results.map_data(|typeless_prs| {
165            typeless_prs
166                .into_iter()
167                .flat_map(|typeless_prs| {
168                    FlattenVirtualTable::<F>::from(
169                        LogicalResult::<F, T>::try_from(typeless_prs).unwrap(),
170                    )
171                })
172                .collect_vec()
173        });
174
175        for (shard_idx, shard) in merkle_vt.shards() {
176            let shard_capacity = merkle_vt.leaves()[*shard_idx].capacity;
177            assert_eq!(shard_capacity, shard.len());
178        }
179        self.val_promise_results = Some(merkle_vt);
180    }
181
182    fn assign_and_compute_commitment(
183        &self,
184        builder: &mut RlcCircuitBuilder<F>,
185    ) -> (AssignedValue<F>, FlattenVirtualTable<AssignedValue<F>>) {
186        let val_promise_results =
187            if let Some(val_promise_results) = self.val_promise_results.as_ref() {
188                val_promise_results.clone()
189            } else {
190                self.create_dummy_promise_result_merkle()
191            };
192        let leaves_to_load = val_promise_results.leaves();
193
194        let assigned_per_shard = val_promise_results
195            .shards()
196            .iter()
197            .map(|(_, vt)| {
198                let ctx = builder.base.main(0);
199                let witness_vt =
200                    vt.iter().map(|(v_i, v_o)| (v_i.assign(ctx), v_o.assign(ctx))).collect_vec();
201                let commit = T::Commiter::compute_commitment(&mut builder.base, &witness_vt);
202                (commit, witness_vt)
203            })
204            .collect_vec();
205
206        let range_chip = &builder.range_chip();
207        let gate_chip = &range_chip.gate;
208        let ctx = builder.base.main(0);
209        // Indexes of selected shards. The length is deterministic because we had
210        // checked selected_shard.len() == params.shard_caps.len() in load_promise_results.
211        let selected_shards = ctx.assign_witnesses(
212            val_promise_results.shards().iter().map(|(shard_idx, _)| F::from(*shard_idx as u64)),
213        );
214
215        // The circuit has a fixed `self.params.max_height`. This is the maximum height merkle tree supported.
216        // However the private inputs of the circuit will dictate the actual heigh of the merkle tree of shards that this circuit is using.
217        // Example:
218        // We have max height is 3.
219        // However, this circuit will only get `leaves_to_load` for 4 shard commitments: [a, b, c, d]
220        // It will compute the merkle root of [a, b, c, d] where `4` is a private witness.
221        // Then it may have `selected_shards = [0, 2]`, meaning it only de-commits the shards for a, c. It does this by using `select_from_idx` on `[a, b, c, d]` to get `a, c`.
222        // Because we always decommit the leaves, meaning we dictate that the leaves much be flat hashes of virtual tables of fixed size (given by `shard_caps`), the
223        // private witness for the true height (in this example `4`), is commited to by the merkle root we generate.
224        // In other words, our definition of shard commitment provides domain separation for the merkle leaves.
225
226        // The loader's behavior should not depend on inputs. So the loader always computes a merkle tree with a pre-defined height.
227        // Then we put the merkle tree to load in the left-bottom of the pre-defined merkle tree. The rest of the leaves are filled with zeros.
228        // The root of the merkle tree to load will be on the leftmost path of the pre-defined merkle tree. So we can select the root by
229        // the height of the merkle tree to load.
230
231        let num_leaves = 1 << self.params.max_height;
232        let leaves_commits = ctx.assign_witnesses(
233            leaves_to_load.iter().map(|l| l.commit).chain(iter::repeat(F::ZERO)).take(num_leaves),
234        );
235        let mut assigned_vts = Vec::with_capacity(assigned_per_shard.len());
236        for (selected_shard, (shard_commit, assigned_vt)) in
237            selected_shards.into_iter().zip_eq(assigned_per_shard)
238        {
239            range_chip.check_less_than_safe(ctx, selected_shard, num_leaves as u64);
240            let leaf_commit =
241                gate_chip.select_from_idx(ctx, leaves_commits.clone(), selected_shard);
242            ctx.constrain_equal(&leaf_commit, &shard_commit);
243
244            assigned_vts.push(assigned_vt);
245        }
246        let flatten_assigned_vts = assigned_vts.into_iter().flatten().collect_vec();
247
248        // Optimization: if there is only one shard, we don't need to compute the merkle tree so no need to create hasher.
249        if leaves_commits.len() == 1 {
250            return (leaves_commits[0], flatten_assigned_vts);
251        };
252
253        let mut hasher = create_hasher::<F>();
254        hasher.initialize_consts(ctx, gate_chip);
255        let nodes = compute_poseidon_merkle_tree(ctx, gate_chip, &hasher, leaves_commits);
256
257        // Leftmost nodes of the pre-defined merkle tree from bottom to top.
258        let leftmost_nodes =
259            (0..=self.params.max_height).rev().map(|i| nodes[(1 << i) - 1]).collect_vec();
260        // The height of the merkle tree to load.
261        let result_height: AssignedValue<F> =
262            ctx.load_witness(F::from(leaves_to_load.len().ilog2() as u64));
263        range_chip.check_less_than_safe(ctx, result_height, (self.params.max_height + 1) as u64);
264
265        let output_commit = gate_chip.select_from_idx(ctx, leftmost_nodes, result_height);
266
267        (output_commit, flatten_assigned_vts)
268    }
269
270    fn generate_lookup_rlc(
271        &self,
272        builder: &mut RlcCircuitBuilder<F>,
273        promise_calls: &[&PromiseResultWitness<F>],
274        promise_results: &[FlattenVirtualRow<AssignedValue<F>>],
275    ) -> (Vec<AssignedValue<F>>, Vec<AssignedValue<F>>) {
276        let range_chip = &builder.range_chip();
277        let rlc_chip = builder.rlc_chip(&range_chip.gate);
278        generate_lookup_rlcs_impl::<F, T>(
279            builder,
280            range_chip,
281            &rlc_chip,
282            promise_calls,
283            promise_results,
284        )
285    }
286
287    fn load_dummy_promise_results(&mut self) {
288        let vt = self.create_dummy_promise_result_merkle();
289        self.val_promise_results = Some(vt);
290    }
291}
292
293/// Returns `(to_lookup_rlc, lookup_table_rlc)`
294/// where `to_lookup_rlc` corresponds to `promise_calls` and
295/// `lookup_table_rlc` corresponds to `promise_results`.
296///
297/// This should only be called in phase1.
298pub fn generate_lookup_rlcs_impl<F: Field, T: ComponentType<F>>(
299    builder: &mut RlcCircuitBuilder<F>,
300    range_chip: &RangeChip<F>,
301    rlc_chip: &RlcChip<F>,
302    promise_calls: &[&PromiseResultWitness<F>],
303    promise_results: &[(Flatten<AssignedValue<F>>, Flatten<AssignedValue<F>>)],
304) -> (Vec<AssignedValue<F>>, Vec<AssignedValue<F>>) {
305    let gate_ctx = builder.base.main(RLC_PHASE);
306
307    let input_multiplier =
308        rlc_chip.rlc_pow_fixed(gate_ctx, range_chip.gate(), T::OutputValue::get_num_fields());
309
310    let to_lookup_rlc =
311        builder.parallelize_phase1(promise_calls.to_vec(), |(gate_ctx, rlc_ctx), (f_i, f_o)| {
312            let i_rlc = f_i.to_rlc((gate_ctx, rlc_ctx), range_chip, rlc_chip);
313            let o_rlc = flatten_witness_to_rlc(rlc_ctx, rlc_chip, f_o);
314            range_chip.gate.mul_add(gate_ctx, i_rlc, input_multiplier, o_rlc)
315        });
316
317    let (gate_ctx, rlc_ctx) = builder.rlc_ctx_pair();
318
319    let lookup_table_rlc = T::rlc_virtual_rows(
320        (gate_ctx, rlc_ctx),
321        range_chip,
322        rlc_chip,
323        &promise_results
324            .iter()
325            .map(|(f_i, f_o)| {
326                (
327                    T::InputWitness::try_from(f_i.clone()).unwrap(),
328                    T::OutputWitness::try_from(f_o.clone()).unwrap(),
329                )
330            })
331            .collect_vec(),
332    );
333    (to_lookup_rlc, lookup_table_rlc)
334}
335
336/// Trait for computing commit of ONE virtual table.
337pub trait ComponentCommiter<F: Field> {
338    /// Compute the commitment of a virtual table.
339    fn compute_commitment(
340        builder: &mut BaseCircuitBuilder<F>,
341        witness_promise_results: &[(Flatten<AssignedValue<F>>, Flatten<AssignedValue<F>>)],
342    ) -> AssignedValue<F>;
343    /// The implementor **must** enforce that the output of this function
344    /// is the same as the output value of `compute_commitment`.
345    /// We allow a separate implemenation purely for performance, as the native commitmnt
346    /// computation is much faster than doing it in the circuit.
347    fn compute_native_commitment(witness_promise_results: &[(Flatten<F>, Flatten<F>)]) -> F;
348}
349
350/// BasicComponentCommiter simply compute poseidon of all virtual rows.
351pub struct BasicComponentCommiter<F: Field>(PhantomData<F>);
352
353impl<F: Field> ComponentCommiter<F> for BasicComponentCommiter<F> {
354    fn compute_commitment(
355        builder: &mut BaseCircuitBuilder<F>,
356        witness_promise_results: &[(Flatten<AssignedValue<F>>, Flatten<AssignedValue<F>>)],
357    ) -> AssignedValue<F> {
358        let range_chip = &builder.range_chip();
359        let ctx = builder.main(0);
360
361        let mut hasher = create_hasher::<F>();
362        hasher.initialize_consts(ctx, &range_chip.gate);
363        compute_commitment_with_flatten(ctx, &range_chip.gate, &hasher, witness_promise_results)
364    }
365    fn compute_native_commitment(witness_promise_results: &[(Flatten<F>, Flatten<F>)]) -> F {
366        let to_commit = witness_promise_results
367            .iter()
368            .flat_map(|(i, o)| i.fields.iter().chain(o.fields.iter()).copied())
369            .collect_vec();
370        compute_poseidon(&to_commit)
371    }
372}