use std::collections::HashMap;
use std::sync::Arc;
use rayon::prelude::*;
use crate::config::Srs;
use crate::error::Error;
use crate::proof::{Aggregated, Batched, Proof};
use crate::registry::{VkId, VkRegistry};
use crate::traits::PairingEngine;
pub struct ParallelAggregator<E: PairingEngine> {
#[allow(dead_code)] srs: Arc<Srs<E>>,
registry: Arc<VkRegistry<E>>,
thread_pool_size: usize,
}
impl<E: PairingEngine> ParallelAggregator<E> {
pub fn new(srs: Arc<Srs<E>>, registry: Arc<VkRegistry<E>>) -> Self {
Self {
srs,
registry,
thread_pool_size: rayon::current_num_threads(),
}
}
pub fn with_threads(mut self, num_threads: usize) -> Self {
self.thread_pool_size = num_threads;
self
}
pub fn aggregate(&self, proofs: Vec<Proof<E, Batched>>) -> Result<Proof<E, Aggregated>, Error> {
if proofs.is_empty() {
return Err(Error::EmptyBatch);
}
let grouped = self.group_by_vk(proofs);
let sub_results: Result<Vec<SubAggregateResult<E>>, Error> = grouped
.into_par_iter()
.map(|(vk_id, batch)| self.aggregate_homogeneous(vk_id, batch))
.collect();
let sub_results = sub_results?;
self.combine_sub_aggregates(sub_results)
}
fn group_by_vk(&self, proofs: Vec<Proof<E, Batched>>) -> HashMap<VkId, Vec<Proof<E, Batched>>> {
let mut grouped: HashMap<VkId, Vec<Proof<E, Batched>>> = HashMap::new();
for proof in proofs {
grouped.entry(proof.vk_id()).or_default().push(proof);
}
grouped
}
fn aggregate_homogeneous(
&self,
vk_id: VkId,
proofs: Vec<Proof<E, Batched>>,
) -> Result<SubAggregateResult<E>, Error> {
use crate::crypto::{AccumulatorInstance, ProofAccumulator};
use group::GroupEncoding;
let _vk = self.registry.require(vk_id)?;
let mut accumulator = ProofAccumulator::<E>::new("parallel_sub_aggregate");
let mut public_inputs = Vec::new();
for proof in &proofs {
public_inputs.extend(proof.public_inputs().iter().cloned());
if let Ok(plonk_proof) = crate::crypto::PlonkProof::<E>::from_bytes(proof.data()) {
let instance = AccumulatorInstance {
commitment: plonk_proof.wire_commitments[0],
evaluation: plonk_proof.evaluations.a_eval,
point: E::Fr::from(7u64),
quotient: plonk_proof.opening_proof,
};
accumulator.add(instance);
}
}
let aggregated_data = if !accumulator.is_empty() {
let accumulated = accumulator.fold().map_err(Error::VerificationFailed)?;
let mut data = Vec::new();
data.extend_from_slice(accumulated.adjusted_commitment.to_bytes().as_ref());
data.extend_from_slice(accumulated.combined_quotient.to_bytes().as_ref());
data
} else {
vec![]
};
Ok(SubAggregateResult {
vk_id,
proof_count: proofs.len(),
public_inputs,
aggregated_data,
})
}
fn combine_sub_aggregates(
&self,
sub_results: Vec<SubAggregateResult<E>>,
) -> Result<Proof<E, Aggregated>, Error> {
let mut all_public_inputs = Vec::new();
let mut combined_data = Vec::new();
let mut total_count = 0usize;
for result in sub_results {
all_public_inputs.extend(result.public_inputs);
combined_data.extend(result.aggregated_data);
total_count += result.proof_count;
}
combined_data.extend_from_slice(&(total_count as u32).to_le_bytes());
Ok(Proof::new_aggregated(combined_data, all_public_inputs))
}
}
#[allow(dead_code)] struct SubAggregateResult<E: PairingEngine> {
vk_id: VkId,
proof_count: usize,
public_inputs: Vec<E::Fr>,
aggregated_data: Vec<u8>,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::bn254::Bn254;
use crate::crypto::VerificationKey;
use crate::proof::Verified;
use group::{Curve, Group};
use halo2curves::bn256::{G1, G2};
use rand::rngs::OsRng;
fn mock_vk(num_public_inputs: usize) -> VerificationKey<Bn254> {
VerificationKey {
num_public_inputs,
domain_size: 1024,
selector_commitments: vec![
G1::random(OsRng).to_affine(),
G1::random(OsRng).to_affine(),
],
permutation_commitments: vec![G1::random(OsRng).to_affine()],
x_g2: G2::random(OsRng).to_affine(),
g2_generator: G2::generator().to_affine(),
}
}
fn mock_batched_proof(vk_id: VkId) -> Proof<Bn254, Batched> {
Proof::<Bn254, Verified>::new_verified(vec![], vec![], vk_id).submit()
}
fn setup_aggregator() -> ParallelAggregator<Bn254> {
let srs = Arc::new(Srs::<Bn254>::mock(10));
let registry = Arc::new(VkRegistry::<Bn254>::new());
registry.register("test1", mock_vk(5));
registry.register("test2", mock_vk(3));
ParallelAggregator::new(srs, registry)
}
#[test]
fn parallel_aggregator_empty_batch_fails() {
let agg = setup_aggregator();
let result = agg.aggregate(vec![]);
assert!(matches!(result, Err(Error::EmptyBatch)));
}
#[test]
fn parallel_aggregator_groups_by_vk() {
let agg = setup_aggregator();
let proofs = vec![
mock_batched_proof(VkId::new(0)),
mock_batched_proof(VkId::new(1)),
mock_batched_proof(VkId::new(0)),
];
let grouped = agg.group_by_vk(proofs);
assert_eq!(grouped.len(), 2);
assert_eq!(grouped.get(&VkId::new(0)).unwrap().len(), 2);
assert_eq!(grouped.get(&VkId::new(1)).unwrap().len(), 1);
}
#[test]
fn parallel_aggregator_requires_known_vk() {
let agg = setup_aggregator();
let proofs = vec![mock_batched_proof(VkId::new(999))];
let result = agg.aggregate(proofs);
assert!(matches!(result, Err(Error::UnknownVk(_))));
}
#[test]
fn parallel_aggregator_processes_multiple_vks() {
let agg = setup_aggregator();
let proofs = vec![
mock_batched_proof(VkId::new(0)), mock_batched_proof(VkId::new(1)), ];
let result = agg.aggregate(proofs);
assert!(result.is_ok());
}
}