1use anyhow::{bail, Result};
13use axiom_eth::{
14 halo2_base::{
15 gates::{circuit::CircuitBuilderStage, GateInstructions, RangeChip, RangeInstructions},
16 utils::ScalarField,
17 AssignedValue, Context,
18 QuantumCell::{Constant, Existing, Witness},
19 },
20 halo2_proofs::{
21 halo2curves::bn256::{Bn256, Fr},
22 poly::kzg::commitment::ParamsKZG,
23 },
24 snark_verifier_sdk::{
25 halo2::aggregation::{AggregationCircuit, VerifierUniversality},
26 Snark, SHPLONK,
27 },
28 utils::snark_verifier::{
29 get_accumulator_indices, AggregationCircuitParams, NUM_FE_ACCUMULATOR,
30 },
31};
32use itertools::Itertools;
33
34use crate::Field;
35
36pub struct EthBlockHeaderChainIntermediateAggregationCircuit(pub AggregationCircuit);
38
39impl EthBlockHeaderChainIntermediateAggregationCircuit {
40 pub fn get_num_instance(max_depth: usize, initial_depth: usize) -> usize {
42 assert!(max_depth >= initial_depth);
43 5 + 2 * ((1 << (max_depth - initial_depth)) + initial_depth)
44 }
45}
46
47#[derive(Clone, Debug)]
51pub struct EthBlockHeaderChainIntermediateAggregationInput {
52 pub num_blocks: u32,
54 pub snarks: Vec<Snark>,
60 pub max_depth: usize,
61 pub initial_depth: usize,
62 }
65
66impl EthBlockHeaderChainIntermediateAggregationInput {
67 pub fn new(
73 snarks: Vec<Snark>,
74 num_blocks: u32,
75 max_depth: usize,
76 initial_depth: usize,
77 ) -> Self {
78 assert_ne!(num_blocks, 0);
79 assert_eq!(snarks.len(), 2);
80 assert!(max_depth > initial_depth);
81 assert!(num_blocks <= 1 << max_depth);
82
83 Self { snarks, num_blocks, max_depth, initial_depth }
84 }
85}
86
87impl EthBlockHeaderChainIntermediateAggregationInput {
88 pub fn build(
89 self,
90 stage: CircuitBuilderStage,
91 circuit_params: AggregationCircuitParams,
92 kzg_params: &ParamsKZG<Bn256>,
93 ) -> Result<EthBlockHeaderChainIntermediateAggregationCircuit> {
94 let num_blocks = self.num_blocks;
95 let max_depth = self.max_depth;
96 let initial_depth = self.initial_depth;
97 log::info!(
98 "New EthBlockHeaderChainAggregationCircuit | num_blocks: {num_blocks} | max_depth: {max_depth} | initial_depth: {initial_depth}"
99 );
100 let prev_acc_indices = get_accumulator_indices(&self.snarks);
101 if self.max_depth == self.initial_depth + 1
102 && prev_acc_indices.iter().any(|indices| !indices.is_empty())
103 {
104 bail!("Snarks to be aggregated must not have accumulators: they should come from EthBlockHeaderChainCircuit");
105 }
106 if self.max_depth > self.initial_depth + 1
107 && prev_acc_indices.iter().any(|indices| indices.len() != NUM_FE_ACCUMULATOR)
108 {
109 bail!("Snarks to be aggregated must all be EthBlockHeaderChainIntermediateAggregationCircuits");
110 }
111 let mut circuit = AggregationCircuit::new::<SHPLONK>(
112 stage,
113 circuit_params,
114 kzg_params,
115 self.snarks,
116 VerifierUniversality::None,
117 );
118 let mut prev_instances = circuit.previous_instances().clone();
119 for (prev_instance, acc_indices) in prev_instances.iter_mut().zip_eq(prev_acc_indices) {
121 for i in acc_indices.into_iter().sorted().rev() {
122 prev_instance.remove(i);
123 }
124 }
125
126 let builder = &mut circuit.builder;
127 let range = builder.range_chip();
129 let ctx = builder.main(0);
130 let num_blocks_minus_one = ctx.load_witness(Fr::from(num_blocks as u64 - 1));
131
132 let new_instances = join_previous_instances::<Fr>(
133 ctx,
134 &range,
135 prev_instances.try_into().unwrap(),
136 num_blocks_minus_one,
137 max_depth,
138 initial_depth,
139 );
140 if builder.assigned_instances.len() != 1 {
141 bail!("should only have 1 instance column");
142 }
143 assert_eq!(builder.assigned_instances[0].len(), NUM_FE_ACCUMULATOR);
144 builder.assigned_instances[0].extend(new_instances);
145
146 Ok(EthBlockHeaderChainIntermediateAggregationCircuit(circuit))
147 }
148}
149
150pub fn join_previous_instances<F: Field>(
164 ctx: &mut Context<F>,
165 range: &RangeChip<F>,
166 prev_instances: [Vec<AssignedValue<F>>; 2],
167 num_blocks_minus_one: AssignedValue<F>,
168 max_depth: usize,
169 initial_depth: usize,
170) -> Vec<AssignedValue<F>> {
171 let prev_depth = max_depth - 1;
172 let num_instance = EthBlockHeaderChainIntermediateAggregationCircuit::get_num_instance(
173 prev_depth,
174 initial_depth,
175 );
176 assert_eq!(num_instance, prev_instances[0].len());
177 assert_eq!(num_instance, prev_instances[1].len());
178
179 let [instance0, instance1] = prev_instances;
180 let mountain_selector = range.is_less_than_safe(ctx, num_blocks_minus_one, 1u64 << prev_depth);
181
182 let prev_hash = &instance0[..2];
184 let intermed_hash0 = &instance0[2..4];
185 let intermed_hash1 = &instance1[..2];
186 let end_hash = &instance1[2..4];
187 for (a, b) in intermed_hash0.iter().zip(intermed_hash1.iter()) {
188 let mut eq_check = range.gate().is_equal(ctx, *a, *b);
190 eq_check = range.gate().or(ctx, eq_check, mountain_selector);
191 range.gate().assert_is_const(ctx, &eq_check, &F::ONE);
192 }
193 let end_hash = intermed_hash0
194 .iter()
195 .zip(end_hash.iter())
196 .map(|(a, b)| range.gate().select(ctx, *a, *b, mountain_selector))
197 .collect_vec();
198
199 let (start_block_number, intermed_block_num0) = split_u64_into_u32s(ctx, range, instance0[4]);
201 let (intermed_block_num1, mut end_block_number) = split_u64_into_u32s(ctx, range, instance1[4]);
202 let num_blocks0_minus_one = range.gate().sub(ctx, intermed_block_num0, start_block_number);
203 let num_blocks1_minus_one = range.gate().sub(ctx, end_block_number, intermed_block_num1);
204 range.check_less_than_safe(ctx, num_blocks0_minus_one, 1 << prev_depth);
205 range.check_less_than_safe(ctx, num_blocks1_minus_one, 1 << prev_depth);
206
207 end_block_number =
208 range.gate().select(ctx, intermed_block_num0, end_block_number, mountain_selector);
209 let next_block_num0 = range.gate().add(ctx, intermed_block_num0, Constant(F::ONE));
211 let mut eq_check = range.gate().is_equal(ctx, next_block_num0, intermed_block_num1);
212 eq_check = range.gate().or(ctx, eq_check, mountain_selector);
213 range.gate().assert_is_const(ctx, &eq_check, &F::ONE);
214 let prev_max_blocks = range.gate().pow_of_two()[prev_depth];
216 let is_max_depth0 =
217 range.gate().is_equal(ctx, num_blocks0_minus_one, Constant(prev_max_blocks - F::ONE));
218 eq_check = range.gate().or(ctx, is_max_depth0, mountain_selector);
219 range.gate().assert_is_const(ctx, &eq_check, &F::ONE);
220 let boundary_num_diff = range.gate().sub(ctx, end_block_number, start_block_number);
222 ctx.constrain_equal(&boundary_num_diff, &num_blocks_minus_one);
223 let boundary_block_numbers = range.gate().mul_add(
225 ctx,
226 Constant(range.gate().pow_of_two()[32]),
227 start_block_number,
228 end_block_number,
229 );
230
231 let roots0 = &instance0[5..];
233 let roots1 = &instance1[5..];
234 let cutoff = 2 * (1 << (prev_depth - initial_depth));
235
236 let mut instances = Vec::with_capacity(num_instance + cutoff);
238 instances.extend_from_slice(prev_hash);
239 instances.extend_from_slice(&end_hash);
240 instances.push(boundary_block_numbers);
241 instances.extend_from_slice(&roots0[..cutoff]);
242 instances.extend_from_slice(&roots1[..cutoff]);
243 instances.extend(
244 roots0[cutoff..]
245 .iter()
246 .zip(roots1[cutoff..].iter())
247 .map(|(a, b)| range.gate().select(ctx, *a, *b, mountain_selector)),
248 );
249
250 instances
251}
252
253fn split_u64_into_u32s<F: ScalarField>(
254 ctx: &mut Context<F>,
255 range: &RangeChip<F>,
256 num: AssignedValue<F>,
257) -> (AssignedValue<F>, AssignedValue<F>) {
258 let v = num.value().get_lower_64();
259 let first = F::from(v >> 32);
260 let second = F::from(v & u32::MAX as u64);
261 ctx.assign_region(
262 [Witness(second), Witness(first), Constant(F::from(1u64 << 32)), Existing(num)],
263 [0],
264 );
265 let second = ctx.get(-4);
266 let first = ctx.get(-3);
267 for limb in [first, second] {
268 range.range_check(ctx, limb, 32);
269 }
270 (first, second)
271}