1#![allow(clippy::type_complexity)]
2use std::{iter, marker::PhantomData};
3
4use crate::utils::component::utils::compute_poseidon;
5use crate::Field;
6use crate::{
7 rlc::{chip::RlcChip, circuit::builder::RlcCircuitBuilder, RLC_PHASE},
8 utils::component::{
9 types::FixLenLogical, utils::compute_poseidon_merkle_tree, FlattenVirtualRow,
10 FlattenVirtualTable, LogicalResult, PromiseShardMetadata, SelectedDataShardsInMerkle,
11 },
12};
13use getset::{CopyGetters, Getters};
14use halo2_base::{
15 gates::{circuit::builder::BaseCircuitBuilder, GateInstructions, RangeChip, RangeInstructions},
16 AssignedValue,
17};
18use itertools::Itertools;
19use serde::{Deserialize, Serialize};
20
21use super::{
22 super::{
23 promise_collector::PromiseResultWitness,
24 types::Flatten,
25 utils::{compute_commitment_with_flatten, create_hasher},
26 ComponentPromiseResultsInMerkle, ComponentType, ComponentTypeId,
27 },
28 flatten_witness_to_rlc,
29};
30
31mod private {
33 pub trait Sealed {}
34}
35
36#[derive(Clone, Debug, Hash, Getters, CopyGetters, Serialize, Deserialize, Eq, PartialEq)]
37pub struct SingleComponentLoaderParams {
39 #[getset(get_copy = "pub")]
41 max_height: usize,
42 #[getset(get = "pub")]
44 shard_caps: Vec<usize>,
45}
46
47impl SingleComponentLoaderParams {
48 pub fn new(max_height: usize, shard_caps: Vec<usize>) -> Self {
50 assert!(shard_caps.len() <= 1 << max_height);
52 Self { max_height, shard_caps }
53 }
54 pub fn new_for_one_shard(cap: usize) -> Self {
56 Self { max_height: 0, shard_caps: vec![cap] }
57 }
58}
59
60impl Default for SingleComponentLoaderParams {
61 fn default() -> Self {
62 Self::new(0, vec![0])
63 }
64}
65
66pub trait SingleComponentLoader<F: Field>: private::Sealed {
68 fn get_component_type_id(&self) -> ComponentTypeId;
70 fn get_component_type_name(&self) -> &'static str;
72 fn get_params(&self) -> &SingleComponentLoaderParams;
73 fn promise_results_ready(&self) -> bool;
75 fn load_promise_results(&mut self, promise_results: ComponentPromiseResultsInMerkle<F>);
77 fn load_dummy_promise_results(&mut self);
79 fn assign_and_compute_commitment(
82 &self,
83 builder: &mut RlcCircuitBuilder<F>,
84 ) -> (AssignedValue<F>, FlattenVirtualTable<AssignedValue<F>>);
85 fn generate_lookup_rlc(
87 &self,
88 builder: &mut RlcCircuitBuilder<F>,
89 promise_calls: &[&PromiseResultWitness<F>],
90 promise_results: &[FlattenVirtualRow<AssignedValue<F>>],
91 ) -> (Vec<AssignedValue<F>>, Vec<AssignedValue<F>>);
92}
93
94type PromiseVirtualTableResults<F> = SelectedDataShardsInMerkle<F, FlattenVirtualTable<F>>;
96
97pub struct SingleComponentLoaderImpl<F: Field, T: ComponentType<F>> {
99 val_promise_results: Option<PromiseVirtualTableResults<F>>,
100 params: SingleComponentLoaderParams,
101 _phantom: PhantomData<T>,
102}
103
104impl<F: Field, T: ComponentType<F>> SingleComponentLoaderImpl<F, T> {
105 pub fn new(params: SingleComponentLoaderParams) -> Self {
107 Self { val_promise_results: None, params, _phantom: PhantomData }
108 }
109 fn create_dummy_promise_result_merkle(&self) -> PromiseVirtualTableResults<F> {
111 let num_shards = self.params.shard_caps.len();
112 let num_leaves = num_shards.next_power_of_two();
113 let mut leaves = Vec::with_capacity(num_leaves);
114 for i in 0..num_leaves {
115 let commit = F::ZERO;
116 leaves.push(PromiseShardMetadata::<F> {
117 commit,
118 capacity: if i < num_shards { self.params.shard_caps[i] } else { 0 },
119 });
120 }
121 let shards = self
122 .params
123 .shard_caps
124 .iter()
125 .copied()
126 .enumerate()
127 .map(|(idx, shard_cap)| {
128 let dummy_input = Flatten::<F> {
129 fields: vec![F::ZERO; T::InputValue::get_num_fields()],
130 field_size: T::InputValue::get_field_size(),
131 };
132 let dummy_output = Flatten::<F> {
133 fields: vec![F::ZERO; T::OutputValue::get_num_fields()],
134 field_size: T::OutputValue::get_field_size(),
135 };
136 let shard = vec![(dummy_input, dummy_output); shard_cap];
137 (idx, shard)
138 })
139 .collect_vec();
140 PromiseVirtualTableResults::<F>::new(leaves, shards)
141 }
142}
143
144impl<F: Field, T: ComponentType<F>> private::Sealed for SingleComponentLoaderImpl<F, T> {}
145
146impl<F: Field, T: ComponentType<F>> SingleComponentLoader<F> for SingleComponentLoaderImpl<F, T> {
147 fn get_component_type_id(&self) -> ComponentTypeId {
148 T::get_type_id()
149 }
150 fn get_component_type_name(&self) -> &'static str {
151 T::get_type_name()
152 }
153 fn get_params(&self) -> &SingleComponentLoaderParams {
154 &self.params
155 }
156 fn promise_results_ready(&self) -> bool {
157 self.val_promise_results.is_some()
158 }
159 fn load_promise_results(&mut self, promise_results: ComponentPromiseResultsInMerkle<F>) {
160 assert!(promise_results.leaves().len() <= 1 << self.params.max_height);
162 assert_eq!(promise_results.shards().len(), self.params.shard_caps().len());
163
164 let merkle_vt = promise_results.map_data(|typeless_prs| {
165 typeless_prs
166 .into_iter()
167 .flat_map(|typeless_prs| {
168 FlattenVirtualTable::<F>::from(
169 LogicalResult::<F, T>::try_from(typeless_prs).unwrap(),
170 )
171 })
172 .collect_vec()
173 });
174
175 for (shard_idx, shard) in merkle_vt.shards() {
176 let shard_capacity = merkle_vt.leaves()[*shard_idx].capacity;
177 assert_eq!(shard_capacity, shard.len());
178 }
179 self.val_promise_results = Some(merkle_vt);
180 }
181
182 fn assign_and_compute_commitment(
183 &self,
184 builder: &mut RlcCircuitBuilder<F>,
185 ) -> (AssignedValue<F>, FlattenVirtualTable<AssignedValue<F>>) {
186 let val_promise_results =
187 if let Some(val_promise_results) = self.val_promise_results.as_ref() {
188 val_promise_results.clone()
189 } else {
190 self.create_dummy_promise_result_merkle()
191 };
192 let leaves_to_load = val_promise_results.leaves();
193
194 let assigned_per_shard = val_promise_results
195 .shards()
196 .iter()
197 .map(|(_, vt)| {
198 let ctx = builder.base.main(0);
199 let witness_vt =
200 vt.iter().map(|(v_i, v_o)| (v_i.assign(ctx), v_o.assign(ctx))).collect_vec();
201 let commit = T::Commiter::compute_commitment(&mut builder.base, &witness_vt);
202 (commit, witness_vt)
203 })
204 .collect_vec();
205
206 let range_chip = &builder.range_chip();
207 let gate_chip = &range_chip.gate;
208 let ctx = builder.base.main(0);
209 let selected_shards = ctx.assign_witnesses(
212 val_promise_results.shards().iter().map(|(shard_idx, _)| F::from(*shard_idx as u64)),
213 );
214
215 let num_leaves = 1 << self.params.max_height;
232 let leaves_commits = ctx.assign_witnesses(
233 leaves_to_load.iter().map(|l| l.commit).chain(iter::repeat(F::ZERO)).take(num_leaves),
234 );
235 let mut assigned_vts = Vec::with_capacity(assigned_per_shard.len());
236 for (selected_shard, (shard_commit, assigned_vt)) in
237 selected_shards.into_iter().zip_eq(assigned_per_shard)
238 {
239 range_chip.check_less_than_safe(ctx, selected_shard, num_leaves as u64);
240 let leaf_commit =
241 gate_chip.select_from_idx(ctx, leaves_commits.clone(), selected_shard);
242 ctx.constrain_equal(&leaf_commit, &shard_commit);
243
244 assigned_vts.push(assigned_vt);
245 }
246 let flatten_assigned_vts = assigned_vts.into_iter().flatten().collect_vec();
247
248 if leaves_commits.len() == 1 {
250 return (leaves_commits[0], flatten_assigned_vts);
251 };
252
253 let mut hasher = create_hasher::<F>();
254 hasher.initialize_consts(ctx, gate_chip);
255 let nodes = compute_poseidon_merkle_tree(ctx, gate_chip, &hasher, leaves_commits);
256
257 let leftmost_nodes =
259 (0..=self.params.max_height).rev().map(|i| nodes[(1 << i) - 1]).collect_vec();
260 let result_height: AssignedValue<F> =
262 ctx.load_witness(F::from(leaves_to_load.len().ilog2() as u64));
263 range_chip.check_less_than_safe(ctx, result_height, (self.params.max_height + 1) as u64);
264
265 let output_commit = gate_chip.select_from_idx(ctx, leftmost_nodes, result_height);
266
267 (output_commit, flatten_assigned_vts)
268 }
269
270 fn generate_lookup_rlc(
271 &self,
272 builder: &mut RlcCircuitBuilder<F>,
273 promise_calls: &[&PromiseResultWitness<F>],
274 promise_results: &[FlattenVirtualRow<AssignedValue<F>>],
275 ) -> (Vec<AssignedValue<F>>, Vec<AssignedValue<F>>) {
276 let range_chip = &builder.range_chip();
277 let rlc_chip = builder.rlc_chip(&range_chip.gate);
278 generate_lookup_rlcs_impl::<F, T>(
279 builder,
280 range_chip,
281 &rlc_chip,
282 promise_calls,
283 promise_results,
284 )
285 }
286
287 fn load_dummy_promise_results(&mut self) {
288 let vt = self.create_dummy_promise_result_merkle();
289 self.val_promise_results = Some(vt);
290 }
291}
292
293pub fn generate_lookup_rlcs_impl<F: Field, T: ComponentType<F>>(
299 builder: &mut RlcCircuitBuilder<F>,
300 range_chip: &RangeChip<F>,
301 rlc_chip: &RlcChip<F>,
302 promise_calls: &[&PromiseResultWitness<F>],
303 promise_results: &[(Flatten<AssignedValue<F>>, Flatten<AssignedValue<F>>)],
304) -> (Vec<AssignedValue<F>>, Vec<AssignedValue<F>>) {
305 let gate_ctx = builder.base.main(RLC_PHASE);
306
307 let input_multiplier =
308 rlc_chip.rlc_pow_fixed(gate_ctx, range_chip.gate(), T::OutputValue::get_num_fields());
309
310 let to_lookup_rlc =
311 builder.parallelize_phase1(promise_calls.to_vec(), |(gate_ctx, rlc_ctx), (f_i, f_o)| {
312 let i_rlc = f_i.to_rlc((gate_ctx, rlc_ctx), range_chip, rlc_chip);
313 let o_rlc = flatten_witness_to_rlc(rlc_ctx, rlc_chip, f_o);
314 range_chip.gate.mul_add(gate_ctx, i_rlc, input_multiplier, o_rlc)
315 });
316
317 let (gate_ctx, rlc_ctx) = builder.rlc_ctx_pair();
318
319 let lookup_table_rlc = T::rlc_virtual_rows(
320 (gate_ctx, rlc_ctx),
321 range_chip,
322 rlc_chip,
323 &promise_results
324 .iter()
325 .map(|(f_i, f_o)| {
326 (
327 T::InputWitness::try_from(f_i.clone()).unwrap(),
328 T::OutputWitness::try_from(f_o.clone()).unwrap(),
329 )
330 })
331 .collect_vec(),
332 );
333 (to_lookup_rlc, lookup_table_rlc)
334}
335
336pub trait ComponentCommiter<F: Field> {
338 fn compute_commitment(
340 builder: &mut BaseCircuitBuilder<F>,
341 witness_promise_results: &[(Flatten<AssignedValue<F>>, Flatten<AssignedValue<F>>)],
342 ) -> AssignedValue<F>;
343 fn compute_native_commitment(witness_promise_results: &[(Flatten<F>, Flatten<F>)]) -> F;
348}
349
350pub struct BasicComponentCommiter<F: Field>(PhantomData<F>);
352
353impl<F: Field> ComponentCommiter<F> for BasicComponentCommiter<F> {
354 fn compute_commitment(
355 builder: &mut BaseCircuitBuilder<F>,
356 witness_promise_results: &[(Flatten<AssignedValue<F>>, Flatten<AssignedValue<F>>)],
357 ) -> AssignedValue<F> {
358 let range_chip = &builder.range_chip();
359 let ctx = builder.main(0);
360
361 let mut hasher = create_hasher::<F>();
362 hasher.initialize_consts(ctx, &range_chip.gate);
363 compute_commitment_with_flatten(ctx, &range_chip.gate, &hasher, witness_promise_results)
364 }
365 fn compute_native_commitment(witness_promise_results: &[(Flatten<F>, Flatten<F>)]) -> F {
366 let to_commit = witness_promise_results
367 .iter()
368 .flat_map(|(i, o)| i.fields.iter().chain(o.fields.iter()).copied())
369 .collect_vec();
370 compute_poseidon(&to_commit)
371 }
372}