axiom_query/subquery_aggregation/
circuit.rs

1use std::collections::HashMap;
2
3use anyhow::{bail, Result};
4use axiom_eth::{
5    halo2_base::gates::{circuit::CircuitBuilderStage, GateChip},
6    halo2_proofs::poly::kzg::commitment::ParamsKZG,
7    halo2curves::bn256::Bn256,
8    snark_verifier_sdk::{halo2::aggregation::AggregationCircuit, SHPLONK},
9    utils::{
10        build_utils::pinning::aggregation::AggregationCircuitPinning,
11        component::{
12            promise_loader::multi::ComponentTypeList, types::ComponentPublicInstances,
13            utils::create_hasher, ComponentType,
14        },
15        snark_verifier::{
16            create_universal_aggregation_circuit, AggregationCircuitParams, NUM_FE_ACCUMULATOR,
17        },
18    },
19};
20use itertools::{zip_eq, Itertools};
21
22use crate::components::{
23    results::{circuit::SubqueryDependencies, types::LogicalPublicInstanceResultsRoot},
24    subqueries::{
25        account::types::ComponentTypeAccountSubquery,
26        block_header::types::{ComponentTypeHeaderSubquery, LogicalPublicInstanceHeader},
27        receipt::types::ComponentTypeReceiptSubquery,
28        solidity_mappings::types::ComponentTypeSolidityNestedMappingSubquery,
29        storage::types::ComponentTypeStorageSubquery,
30        transaction::types::ComponentTypeTxSubquery,
31    },
32};
33
34use super::types::{InputSubqueryAggregation, LogicalPublicInstanceSubqueryAgg, F};
35
36impl InputSubqueryAggregation {
37    /// Builds general circuit
38    ///
39    /// Warning: this MUST return a circuit implementing `CircuitExt` with accumulator indices provided.
40    /// In particular, do not return `BaseCircuitBuilder`.
41    pub fn build(
42        self,
43        stage: CircuitBuilderStage,
44        circuit_params: AggregationCircuitParams,
45        kzg_params: &ParamsKZG<Bn256>,
46    ) -> Result<AggregationCircuit> {
47        // dependency checks
48        if self.snark_storage.is_some() && self.snark_account.is_none() {
49            bail!("Storage snark requires Account snark");
50        }
51        if self.snark_solidity_mapping.is_some() && self.snark_storage.is_none() {
52            bail!("SolidityMapping snark requires Storage snark");
53        }
54        const NUM_SNARKS: usize = 7;
55        let snarks = vec![
56            Some(self.snark_header),
57            self.snark_account,
58            self.snark_storage,
59            self.snark_tx,
60            self.snark_receipt,
61            self.snark_solidity_mapping,
62            Some(self.snark_results_root),
63        ];
64        let snarks_enabled = snarks.iter().map(|s| s.is_some()).collect_vec();
65        let subquery_type_ids = [
66            ComponentTypeHeaderSubquery::<F>::get_type_id(),
67            ComponentTypeAccountSubquery::<F>::get_type_id(),
68            ComponentTypeStorageSubquery::<F>::get_type_id(),
69            ComponentTypeTxSubquery::<F>::get_type_id(),
70            ComponentTypeReceiptSubquery::<F>::get_type_id(),
71            ComponentTypeSolidityNestedMappingSubquery::<F>::get_type_id(),
72        ];
73        if snarks.iter().flatten().any(|s| s.agg_vk_hash_idx.is_some()) {
74            bail!("[SubqueryAggregation] No snark should be universal.");
75        }
76        let snarks = snarks.into_iter().flatten().map(|s| s.inner).collect_vec();
77        let agg_vkey_hash_indices = vec![None; snarks.len()];
78        let (mut circuit, previous_instances, agg_vkey_hash) =
79            create_universal_aggregation_circuit::<SHPLONK>(
80                stage,
81                circuit_params,
82                kzg_params,
83                snarks,
84                agg_vkey_hash_indices,
85            );
86
87        let builder = &mut circuit.builder;
88        let ctx = builder.main(0);
89
90        // Parse aggregated component public instances
91        let mut previous_instances = previous_instances.into_iter();
92        let mut get_next_pis =
93            || ComponentPublicInstances::try_from(previous_instances.next().unwrap());
94        let mut pis = Vec::with_capacity(NUM_SNARKS);
95        for snark_enabled in snarks_enabled {
96            if snark_enabled {
97                pis.push(Some(get_next_pis()?));
98            } else {
99                pis.push(None);
100            }
101        }
102        let pis_header = pis[0].clone().unwrap();
103        let pis_results = pis.pop().unwrap().unwrap();
104
105        // Load promise commit keccak as a public input
106        let promise_keccak = ctx.load_witness(self.promise_commit_keccak);
107        // ======== Create Poseidon hasher ===========
108        let gate = GateChip::default();
109        let mut hasher = create_hasher();
110        hasher.initialize_consts(ctx, &gate);
111        // Insert subquery output commits
112        // Unclear if this is a necessary precaution, but we store based on `subquery_type_ids` so the order does not depend on ordering in other modules
113        let mut subquery_commits = HashMap::new();
114        // Insert subquery promise commits
115        let mut subquery_promises = HashMap::new();
116        for (type_id, pi) in zip_eq(subquery_type_ids, &pis) {
117            if let Some(pi) = pi {
118                subquery_commits.insert(type_id.clone(), pi.output_commit);
119                subquery_promises.insert(type_id, pi.promise_result_commit);
120            }
121        }
122        // Hash each subquery output commit with the `promise_commit_keccak`, to be compared with subquery promises later.
123        // This matches the promise public output computation in `ComponentCircuitImpl::generate_public_instances`.
124        // The dependencies of a non-Header subquery circuit are always [Keccak, <Single Subquery Type>]
125        // We only need to calculate the hash for components that are called: Header, Account, Storage. Currently Tx, Receipt, SolidityNestedMapping are not called.
126        let mut hashed_commits = HashMap::new();
127        for type_id in [
128            ComponentTypeHeaderSubquery::<F>::get_type_id(),
129            ComponentTypeAccountSubquery::<F>::get_type_id(),
130            ComponentTypeStorageSubquery::<F>::get_type_id(),
131        ] {
132            if let Some(output_commit) = subquery_commits.get(&type_id) {
133                hashed_commits.insert(
134                    type_id,
135                    hasher.hash_fix_len_array(ctx, &gate, &[promise_keccak, *output_commit]),
136                );
137            }
138        }
139
140        // ======== Manually check all promise calls between subqueries: =======
141        // Header calls Keccak
142        {
143            let hashed_commit_keccak = hasher.hash_fix_len_array(ctx, &gate, &[promise_keccak]);
144            let header_promise_commit =
145                subquery_promises[&ComponentTypeHeaderSubquery::<F>::get_type_id()];
146            log::debug!("hash(promise_keccak): {:?}", hashed_commit_keccak.value());
147            log::debug!("header_promise_commit: {:?}", header_promise_commit.value());
148            ctx.constrain_equal(&hashed_commit_keccak, &header_promise_commit);
149        }
150        // Below when we say promise_header and commit_header, we actually mean promise_keccak_header and commit_keccak_header because both have been hashed with a promise_keccak.
151        // Account calls Keccak & Header
152        if let Some(promise_header) =
153            subquery_promises.get(&ComponentTypeAccountSubquery::<F>::get_type_id())
154        {
155            let commit_header = hashed_commits[&ComponentTypeHeaderSubquery::<F>::get_type_id()];
156            log::debug!("account:commit_header: {:?}", commit_header.value());
157            log::debug!("account:promise_header: {:?}", promise_header.value());
158            ctx.constrain_equal(&commit_header, promise_header);
159        }
160        // Storage calls Keccak & Account
161        if let Some(promise_account) =
162            subquery_promises.get(&ComponentTypeStorageSubquery::<F>::get_type_id())
163        {
164            let commit_account = hashed_commits[&ComponentTypeAccountSubquery::<F>::get_type_id()];
165            log::debug!("storage:commit_account: {:?}", commit_account.value());
166            log::debug!("storage:promise_account: {:?}", promise_account.value());
167            ctx.constrain_equal(&commit_account, promise_account);
168        }
169        // Tx calls Keccak & Header
170        if let Some(promise_header) =
171            subquery_promises.get(&ComponentTypeTxSubquery::<F>::get_type_id())
172        {
173            let commit_header = hashed_commits[&ComponentTypeHeaderSubquery::<F>::get_type_id()];
174            log::debug!("tx:commit_header: {:?}", commit_header.value());
175            log::debug!("tx:promise_header: {:?}", promise_header.value());
176            ctx.constrain_equal(&commit_header, promise_header);
177        }
178        // Receipt calls Keccak & Header
179        if let Some(promise_header) =
180            subquery_promises.get(&ComponentTypeReceiptSubquery::<F>::get_type_id())
181        {
182            let commit_header = hashed_commits[&ComponentTypeHeaderSubquery::<F>::get_type_id()];
183            log::debug!("receipt:commit_header: {:?}", commit_header.value());
184            log::debug!("receipt:promise_header: {:?}", promise_header.value());
185            ctx.constrain_equal(&commit_header, promise_header);
186        }
187        // SolidityNestedMapping calls Keccak & Storage
188        if let Some(promise_storage) =
189            subquery_promises.get(&ComponentTypeSolidityNestedMappingSubquery::<F>::get_type_id())
190        {
191            let commit_storage = hashed_commits[&ComponentTypeStorageSubquery::<F>::get_type_id()];
192            log::debug!("solidity_nested_mapping:commit_storage: {:?}", commit_storage.value());
193            log::debug!("solidity_nested_mapping:promise_storage: {:?}", promise_storage.value());
194            ctx.constrain_equal(&commit_storage, promise_storage);
195        }
196
197        // Get keccakPacked(blockhashMmr)
198        let LogicalPublicInstanceHeader { mmr_keccak } = pis_header.other.try_into()?;
199
200        // ======== results root =========
201        // MUST match order in `InputResultsRootShard::build`
202        let type_ids = SubqueryDependencies::<F>::get_component_type_ids();
203        // We now collect the promises from snarks in the order they were commited to in ResultsRoot
204        let mut results_deps_commits = Vec::new();
205        results_deps_commits.push(promise_keccak);
206        for t_id in &type_ids {
207            if let Some(commit) = subquery_commits.get(t_id) {
208                results_deps_commits.push(*commit);
209            }
210        }
211
212        let results_promise_commit = hasher.hash_fix_len_array(ctx, &gate, &results_deps_commits);
213
214        log::debug!("results_promise_commit: {:?}", results_promise_commit.value());
215        log::debug!("promise_result_commit: {:?}", pis_results.promise_result_commit.value());
216        ctx.constrain_equal(&results_promise_commit, &pis_results.promise_result_commit);
217
218        // We have implicitly checked all Components use the same `promise_keccak` above.
219
220        let LogicalPublicInstanceResultsRoot { results_root_poseidon, commit_subquery_hashes } =
221            pis_results.other.try_into().unwrap();
222
223        let logical_pis = LogicalPublicInstanceSubqueryAgg {
224            promise_keccak,
225            agg_vkey_hash,
226            results_root_poseidon,
227            commit_subquery_hashes,
228            mmr_keccak,
229        };
230        if builder.assigned_instances.len() != 1 {
231            bail!("should only have 1 instance column");
232        }
233        assert_eq!(builder.assigned_instances[0].len(), NUM_FE_ACCUMULATOR);
234        builder.assigned_instances[0].extend(logical_pis.flatten());
235
236        Ok(circuit)
237    }
238
239    pub fn prover_circuit(
240        self,
241        pinning: AggregationCircuitPinning,
242        kzg_params: &ParamsKZG<Bn256>,
243    ) -> Result<AggregationCircuit> {
244        Ok(self
245            .build(CircuitBuilderStage::Prover, pinning.params, kzg_params)?
246            .use_break_points(pinning.break_points))
247    }
248}