Skip to main content

axiom_core/aggregation/
intermediate.rs

1//! Intermediate aggregation circuits that aggregate in a binary tree topology:
2//! The leaves of the tree are formed by [crate::header_chain::EthBlockHeaderChainCircuit]s, and intermediate notes
3//! of the tree are formed by [EthBlockHeaderChainIntermediateAggregationCircuit]s.
4//!
5//! An [EthBlockHeaderChainIntermediateAggregationCircuit] can aggregate either:
6//! - two [crate::header_chain::EthBlockHeaderChainCircuit]s or
7//! - two [EthBlockHeaderChainIntermediateAggregationCircuit]s.
8//!
9//! The root of the aggregation tree will be a [super::final_merkle::EthBlockHeaderChainRootAggregationCircuit].
10//! The difference between Intermediate and Root aggregation circuits is that the Intermediate ones
11//! do not have a keccak sub-circuit: all keccaks are delayed until the Root aggregation.
12use anyhow::{bail, Result};
13use axiom_eth::{
14    halo2_base::{
15        gates::{circuit::CircuitBuilderStage, GateInstructions, RangeChip, RangeInstructions},
16        utils::ScalarField,
17        AssignedValue, Context,
18        QuantumCell::{Constant, Existing, Witness},
19    },
20    halo2_proofs::{
21        halo2curves::bn256::{Bn256, Fr},
22        poly::kzg::commitment::ParamsKZG,
23    },
24    snark_verifier_sdk::{
25        halo2::aggregation::{AggregationCircuit, VerifierUniversality},
26        Snark, SHPLONK,
27    },
28    utils::snark_verifier::{
29        get_accumulator_indices, AggregationCircuitParams, NUM_FE_ACCUMULATOR,
30    },
31};
32use itertools::Itertools;
33
34use crate::Field;
35
36/// Newtype to distinguish an aggregation circuit created from [EthBlockHeaderChainIntermediateAggregationInput]
37pub struct EthBlockHeaderChainIntermediateAggregationCircuit(pub AggregationCircuit);
38
39impl EthBlockHeaderChainIntermediateAggregationCircuit {
40    /// The number of instances NOT INCLUDING the accumulator
41    pub fn get_num_instance(max_depth: usize, initial_depth: usize) -> usize {
42        assert!(max_depth >= initial_depth);
43        5 + 2 * ((1 << (max_depth - initial_depth)) + initial_depth)
44    }
45}
46
47/// The input to create an intermediate [AggregationCircuit] that aggregates [crate::header_chain::EthBlockHeaderChainCircuit]s.
48/// These are intemediate aggregations because they do not perform additional keccaks. Therefore the public instance format (after excluding accumulators) is
49/// different from that of the original [crate::header_chain::EthBlockHeaderChainCircuit]s.
50#[derive(Clone, Debug)]
51pub struct EthBlockHeaderChainIntermediateAggregationInput {
52    // aggregation circuit with `instances` the accumulator (two G1 points) for delayed pairing verification
53    pub num_blocks: u32,
54    /// `snarks` should be exactly two snarks of either
55    /// - `EthBlockHeaderChainCircuit` if `max_depth == initial_depth + 1` or
56    /// - `EthBlockHeaderChainIntermediateAggregationCircuit` (this circuit) otherwise
57    ///
58    /// Assumes `num_blocks > 0`.
59    pub snarks: Vec<Snark>,
60    pub max_depth: usize,
61    pub initial_depth: usize,
62    // because the aggregation circuit doesn't have a keccak chip, in the mountain range
63    // vector we will store the `2^{max_depth - initial_depth}` "new roots" as well as the length `initial_depth` mountain range tail, which determines the smallest entries in the mountain range.
64}
65
66impl EthBlockHeaderChainIntermediateAggregationInput {
67    /// `snarks` should be exactly two snarks of either
68    /// - `EthBlockHeaderChainCircuit` if `max_depth == initial_depth + 1` or
69    /// - `EthBlockHeaderChainAggregationCircuit` otherwise
70    ///
71    /// Assumes `num_blocks > 0`.
72    pub fn new(
73        snarks: Vec<Snark>,
74        num_blocks: u32,
75        max_depth: usize,
76        initial_depth: usize,
77    ) -> Self {
78        assert_ne!(num_blocks, 0);
79        assert_eq!(snarks.len(), 2);
80        assert!(max_depth > initial_depth);
81        assert!(num_blocks <= 1 << max_depth);
82
83        Self { snarks, num_blocks, max_depth, initial_depth }
84    }
85}
86
87impl EthBlockHeaderChainIntermediateAggregationInput {
88    pub fn build(
89        self,
90        stage: CircuitBuilderStage,
91        circuit_params: AggregationCircuitParams,
92        kzg_params: &ParamsKZG<Bn256>,
93    ) -> Result<EthBlockHeaderChainIntermediateAggregationCircuit> {
94        let num_blocks = self.num_blocks;
95        let max_depth = self.max_depth;
96        let initial_depth = self.initial_depth;
97        log::info!(
98            "New EthBlockHeaderChainAggregationCircuit | num_blocks: {num_blocks} | max_depth: {max_depth} | initial_depth: {initial_depth}"
99        );
100        let prev_acc_indices = get_accumulator_indices(&self.snarks);
101        if self.max_depth == self.initial_depth + 1
102            && prev_acc_indices.iter().any(|indices| !indices.is_empty())
103        {
104            bail!("Snarks to be aggregated must not have accumulators: they should come from EthBlockHeaderChainCircuit");
105        }
106        if self.max_depth > self.initial_depth + 1
107            && prev_acc_indices.iter().any(|indices| indices.len() != NUM_FE_ACCUMULATOR)
108        {
109            bail!("Snarks to be aggregated must all be EthBlockHeaderChainIntermediateAggregationCircuits");
110        }
111        let mut circuit = AggregationCircuit::new::<SHPLONK>(
112            stage,
113            circuit_params,
114            kzg_params,
115            self.snarks,
116            VerifierUniversality::None,
117        );
118        let mut prev_instances = circuit.previous_instances().clone();
119        // remove old accumulators
120        for (prev_instance, acc_indices) in prev_instances.iter_mut().zip_eq(prev_acc_indices) {
121            for i in acc_indices.into_iter().sorted().rev() {
122                prev_instance.remove(i);
123            }
124        }
125
126        let builder = &mut circuit.builder;
127        // TODO: slight computational overhead from recreating RangeChip; builder should store RangeChip as OnceCell
128        let range = builder.range_chip();
129        let ctx = builder.main(0);
130        let num_blocks_minus_one = ctx.load_witness(Fr::from(num_blocks as u64 - 1));
131
132        let new_instances = join_previous_instances::<Fr>(
133            ctx,
134            &range,
135            prev_instances.try_into().unwrap(),
136            num_blocks_minus_one,
137            max_depth,
138            initial_depth,
139        );
140        if builder.assigned_instances.len() != 1 {
141            bail!("should only have 1 instance column");
142        }
143        assert_eq!(builder.assigned_instances[0].len(), NUM_FE_ACCUMULATOR);
144        builder.assigned_instances[0].extend(new_instances);
145
146        Ok(EthBlockHeaderChainIntermediateAggregationCircuit(circuit))
147    }
148}
149
150/// Takes the concatenated previous instances from two `EthBlockHeaderChainIntermediateAggregationCircuit`s
151/// of max depth `max_depth - 1` and
152/// - checks that they form a chain of `max_depth`
153/// - updates the merkle mountain range:
154///     - stores the latest `2^{max_depth - initial_depth}` roots for keccak later
155///     - selects the correct last `initial_depth` roots for the smallest part of the range
156///
157/// If `max_depth - 1 == initial_depth`, then the previous instances are from two `EthBlockHeaderChainCircuit`s.
158///
159/// Returns the new instances for the depth `max_depth` circuit (without accumulators)
160///
161/// ## Assumptions
162/// - `prev_instances` are the previous instances **with old accumulators removed**.
163pub fn join_previous_instances<F: Field>(
164    ctx: &mut Context<F>,
165    range: &RangeChip<F>,
166    prev_instances: [Vec<AssignedValue<F>>; 2],
167    num_blocks_minus_one: AssignedValue<F>,
168    max_depth: usize,
169    initial_depth: usize,
170) -> Vec<AssignedValue<F>> {
171    let prev_depth = max_depth - 1;
172    let num_instance = EthBlockHeaderChainIntermediateAggregationCircuit::get_num_instance(
173        prev_depth,
174        initial_depth,
175    );
176    assert_eq!(num_instance, prev_instances[0].len());
177    assert_eq!(num_instance, prev_instances[1].len());
178
179    let [instance0, instance1] = prev_instances;
180    let mountain_selector = range.is_less_than_safe(ctx, num_blocks_minus_one, 1u64 << prev_depth);
181
182    // join block hashes
183    let prev_hash = &instance0[..2];
184    let intermed_hash0 = &instance0[2..4];
185    let intermed_hash1 = &instance1[..2];
186    let end_hash = &instance1[2..4];
187    for (a, b) in intermed_hash0.iter().zip(intermed_hash1.iter()) {
188        // a == b || num_blocks <= 2^prev_depth
189        let mut eq_check = range.gate().is_equal(ctx, *a, *b);
190        eq_check = range.gate().or(ctx, eq_check, mountain_selector);
191        range.gate().assert_is_const(ctx, &eq_check, &F::ONE);
192    }
193    let end_hash = intermed_hash0
194        .iter()
195        .zip(end_hash.iter())
196        .map(|(a, b)| range.gate().select(ctx, *a, *b, mountain_selector))
197        .collect_vec();
198
199    // join & sanitize block numbers
200    let (start_block_number, intermed_block_num0) = split_u64_into_u32s(ctx, range, instance0[4]);
201    let (intermed_block_num1, mut end_block_number) = split_u64_into_u32s(ctx, range, instance1[4]);
202    let num_blocks0_minus_one = range.gate().sub(ctx, intermed_block_num0, start_block_number);
203    let num_blocks1_minus_one = range.gate().sub(ctx, end_block_number, intermed_block_num1);
204    range.check_less_than_safe(ctx, num_blocks0_minus_one, 1 << prev_depth);
205    range.check_less_than_safe(ctx, num_blocks1_minus_one, 1 << prev_depth);
206
207    end_block_number =
208        range.gate().select(ctx, intermed_block_num0, end_block_number, mountain_selector);
209    // make sure chains link up
210    let next_block_num0 = range.gate().add(ctx, intermed_block_num0, Constant(F::ONE));
211    let mut eq_check = range.gate().is_equal(ctx, next_block_num0, intermed_block_num1);
212    eq_check = range.gate().or(ctx, eq_check, mountain_selector);
213    range.gate().assert_is_const(ctx, &eq_check, &F::ONE);
214    // if num_blocks > 2^prev_depth, then num_blocks0 must equal 2^prev_depth
215    let prev_max_blocks = range.gate().pow_of_two()[prev_depth];
216    let is_max_depth0 =
217        range.gate().is_equal(ctx, num_blocks0_minus_one, Constant(prev_max_blocks - F::ONE));
218    eq_check = range.gate().or(ctx, is_max_depth0, mountain_selector);
219    range.gate().assert_is_const(ctx, &eq_check, &F::ONE);
220    // check number of blocks is correct
221    let boundary_num_diff = range.gate().sub(ctx, end_block_number, start_block_number);
222    ctx.constrain_equal(&boundary_num_diff, &num_blocks_minus_one);
223    // concatenate block numbers
224    let boundary_block_numbers = range.gate().mul_add(
225        ctx,
226        Constant(range.gate().pow_of_two()[32]),
227        start_block_number,
228        end_block_number,
229    );
230
231    // update merkle roots
232    let roots0 = &instance0[5..];
233    let roots1 = &instance1[5..];
234    let cutoff = 2 * (1 << (prev_depth - initial_depth));
235
236    // join merkle mountain ranges
237    let mut instances = Vec::with_capacity(num_instance + cutoff);
238    instances.extend_from_slice(prev_hash);
239    instances.extend_from_slice(&end_hash);
240    instances.push(boundary_block_numbers);
241    instances.extend_from_slice(&roots0[..cutoff]);
242    instances.extend_from_slice(&roots1[..cutoff]);
243    instances.extend(
244        roots0[cutoff..]
245            .iter()
246            .zip(roots1[cutoff..].iter())
247            .map(|(a, b)| range.gate().select(ctx, *a, *b, mountain_selector)),
248    );
249
250    instances
251}
252
253fn split_u64_into_u32s<F: ScalarField>(
254    ctx: &mut Context<F>,
255    range: &RangeChip<F>,
256    num: AssignedValue<F>,
257) -> (AssignedValue<F>, AssignedValue<F>) {
258    let v = num.value().get_lower_64();
259    let first = F::from(v >> 32);
260    let second = F::from(v & u32::MAX as u64);
261    ctx.assign_region(
262        [Witness(second), Witness(first), Constant(F::from(1u64 << 32)), Existing(num)],
263        [0],
264    );
265    let second = ctx.get(-4);
266    let first = ctx.get(-3);
267    for limb in [first, second] {
268        range.range_check(ctx, limb, 32);
269    }
270    (first, second)
271}