use std::collections::HashMap;
use std::sync::Arc;
use crate::config::{AggregatorBuilder, Srs};
use crate::error::Error;
use crate::proof::{Aggregated, Batched, Proof, Verified};
use crate::registry::{VkId, VkRegistry};
use crate::traits::PairingEngine;
use ff::PrimeField;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ProofHandle(usize);
pub struct Aggregator<E: PairingEngine> {
#[allow(dead_code)] srs: Arc<Srs<E>>,
registry: VkRegistry<E>,
queue: Vec<Proof<E, Batched>>,
seen_proofs: std::collections::HashSet<[u8; 32]>,
max_batch_size: usize,
parallel: bool,
next_handle: usize,
external_accumulator: crate::crypto::ProofAccumulator<E>,
}
impl<E: PairingEngine> Aggregator<E> {
pub fn builder() -> AggregatorBuilder<E> {
AggregatorBuilder::new()
}
pub(crate) fn new(srs: Arc<Srs<E>>, max_batch_size: usize, parallel: bool) -> Self {
Self {
srs,
registry: VkRegistry::new(),
queue: Vec::new(),
seen_proofs: std::collections::HashSet::new(),
max_batch_size,
parallel,
next_handle: 0,
external_accumulator: crate::crypto::ProofAccumulator::new("external_proofs"),
}
}
pub fn max_batch_size(&self) -> usize {
self.max_batch_size
}
pub fn registry(&self) -> &VkRegistry<E> {
&self.registry
}
pub fn srs(&self) -> &Srs<E> {
&self.srs
}
pub fn accumulator_mut(&mut self) -> &mut crate::crypto::ProofAccumulator<E> {
&mut self.external_accumulator
}
pub fn external_count(&self) -> usize {
self.external_accumulator.len()
}
pub fn register_circuit(&self, name: &str, vk: crate::crypto::VerificationKey<E>) -> VkId {
self.registry.register(name, vk)
}
pub fn queue_len(&self) -> usize {
self.queue.len()
}
pub fn queue_is_empty(&self) -> bool {
self.queue.is_empty()
}
pub fn submit(&mut self, proof: Proof<E, Verified>) -> Result<ProofHandle, Error> {
if self.queue.len() >= self.max_batch_size {
return Err(Error::BatchTooLarge {
got: self.queue.len() + 1,
max: self.max_batch_size,
});
}
self.registry.require(proof.vk_id())?;
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(proof.data());
for input in proof.public_inputs() {
hasher.update(input.to_repr());
}
let hash: [u8; 32] = hasher.finalize().into();
if self.seen_proofs.contains(&hash) {
return Err(Error::VerificationFailed("Duplicate proof in batch".to_string()));
}
self.seen_proofs.insert(hash);
let handle = ProofHandle(self.next_handle);
self.next_handle += 1;
self.queue.push(proof.submit());
Ok(handle)
}
pub fn aggregate(&mut self) -> Result<Proof<E, Aggregated>, Error> {
if self.queue.is_empty() {
return Err(Error::EmptyBatch);
}
let batch: Vec<_> = self.queue.drain(..).collect();
self.seen_proofs.clear();
if self.parallel {
self.aggregate_parallel(batch)
} else {
self.aggregate_sequential(batch)
}
}
fn aggregate_sequential(
&self,
batch: Vec<Proof<E, Batched>>,
) -> Result<Proof<E, Aggregated>, Error> {
use crate::crypto::{AccumulatorInstance, ProofAccumulator};
let grouped = self.group_by_vk(batch);
let mut accumulator = ProofAccumulator::<E>::new("samaharam_aggregation");
let mut all_public_inputs = Vec::new();
for (vk_id, proofs) in grouped {
let registered_vk = self.registry.get(vk_id).ok_or(Error::UnknownVk(vk_id))?;
for proof in proofs {
all_public_inputs.extend(proof.public_inputs().iter().cloned());
if let Ok(plonk_proof) = crate::crypto::PlonkProof::<E>::from_bytes(proof.data()) {
let zeta = Self::compute_evaluation_point(
&plonk_proof,
proof.public_inputs(),
®istered_vk.vk,
);
let instance = AccumulatorInstance {
commitment: plonk_proof.wire_commitments[0],
evaluation: plonk_proof.evaluations.a_eval,
point: zeta,
quotient: plonk_proof.opening_proof,
};
accumulator.add(instance);
}
}
}
let aggregated_data = if accumulator.is_empty() {
vec![]
} else {
let accumulated = accumulator.fold().map_err(Error::VerificationFailed)?;
let mut data = Vec::new();
use group::GroupEncoding;
data.extend_from_slice(accumulated.adjusted_commitment.to_bytes().as_ref());
data.extend_from_slice(accumulated.combined_quotient.to_bytes().as_ref());
data.extend_from_slice(&(accumulated.count as u32).to_le_bytes());
data
};
Ok(Proof::new_aggregated(aggregated_data, all_public_inputs))
}
#[cfg(feature = "parallel")]
fn aggregate_parallel(
&self,
batch: Vec<Proof<E, Batched>>,
) -> Result<Proof<E, Aggregated>, Error> {
use crate::crypto::{AccumulatedProof, AccumulatorInstance, ProofAccumulator};
use group::{Curve, Group, GroupEncoding};
use rayon::prelude::*;
let grouped = self.group_by_vk(batch);
let vk_proofs: Vec<_> = grouped
.into_iter()
.filter_map(|(vk_id, proofs)| {
self.registry.get(vk_id).map(|registered| (registered, proofs))
})
.collect();
#[cfg(feature = "parallel")]
type AggResult<E> = Result<(Vec<<E as PairingEngine>::Fr>, AccumulatedProof<E>), Error>;
#[cfg(feature = "parallel")]
let sub_results: Vec<AggResult<E>> = vk_proofs
.into_par_iter()
.map(|(registered_vk, proofs)| {
let mut accumulator = ProofAccumulator::<E>::new("samaharam_parallel");
let mut public_inputs = Vec::new();
for proof in proofs {
public_inputs.extend(proof.public_inputs().iter().cloned());
if let Ok(plonk_proof) = crate::crypto::PlonkProof::<E>::from_bytes(proof.data()) {
let zeta = Self::compute_evaluation_point(
&plonk_proof,
proof.public_inputs(),
®istered_vk.vk,
);
let instance = AccumulatorInstance {
commitment: plonk_proof.wire_commitments[0],
evaluation: plonk_proof.evaluations.a_eval,
point: zeta,
quotient: plonk_proof.opening_proof,
};
accumulator.add(instance);
}
}
let accumulated = accumulator.fold().map_err(Error::VerificationFailed)?;
Ok((public_inputs, accumulated))
})
.collect();
let mut all_public_inputs = Vec::new();
let mut total_count = 0usize;
let mut combined_adjusted = E::G1::identity();
let mut combined_quotient = E::G1::identity();
for result in sub_results {
let (inputs, acc) = result?;
all_public_inputs.extend(inputs);
total_count += acc.count;
combined_adjusted += acc.adjusted_commitment.into();
combined_quotient += acc.combined_quotient.into();
}
let mut aggregated_data = Vec::new();
aggregated_data.extend_from_slice(combined_adjusted.to_affine().to_bytes().as_ref());
aggregated_data.extend_from_slice(combined_quotient.to_affine().to_bytes().as_ref());
aggregated_data.extend_from_slice(&(total_count as u32).to_le_bytes());
Ok(Proof::new_aggregated(aggregated_data, all_public_inputs))
}
#[cfg(not(feature = "parallel"))]
fn aggregate_parallel(
&self,
batch: Vec<Proof<E, Batched>>,
) -> Result<Proof<E, Aggregated>, Error> {
self.aggregate_sequential(batch)
}
fn group_by_vk(&self, batch: Vec<Proof<E, Batched>>) -> HashMap<VkId, Vec<Proof<E, Batched>>> {
let mut grouped: HashMap<VkId, Vec<Proof<E, Batched>>> = HashMap::new();
for proof in batch {
grouped.entry(proof.vk_id()).or_default().push(proof);
}
grouped
}
fn compute_evaluation_point(
plonk_proof: &crate::crypto::PlonkProof<E>,
public_inputs: &[E::Fr],
vk: &crate::crypto::VerificationKey<E>,
) -> E::Fr {
use crate::crypto::Transcript;
let mut transcript = Transcript::new("PLONK");
for commitment in &vk.selector_commitments {
transcript.append_g1::<E>("selector", commitment);
}
for pi in public_inputs {
transcript.append_scalar::<E>("public_input", pi);
}
for wc in &plonk_proof.wire_commitments {
transcript.append_g1::<E>("wire", wc);
}
let _beta: E::Fr = transcript.challenge_scalar::<E>("beta");
let _gamma: E::Fr = transcript.challenge_scalar::<E>("gamma");
transcript.append_g1::<E>("z", &plonk_proof.z_commitment);
let _alpha: E::Fr = transcript.challenge_scalar::<E>("alpha");
for tc in &plonk_proof.t_commitments {
transcript.append_g1::<E>("t", tc);
}
transcript.challenge_scalar::<E>("zeta")
}
}
impl<E: PairingEngine> std::fmt::Debug for Aggregator<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Aggregator")
.field("max_batch_size", &self.max_batch_size)
.field("queue_len", &self.queue.len())
.field("parallel", &self.parallel)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::bn254::Bn254;
use crate::crypto::VerificationKey;
use crate::proof::Pending;
use group::{Curve, Group};
use halo2curves::bn256::{G1, G2};
use rand::rngs::OsRng;
fn mock_vk(num_public_inputs: usize) -> VerificationKey<Bn254> {
VerificationKey {
num_public_inputs,
domain_size: 1024,
selector_commitments: vec![
G1::random(OsRng).to_affine(),
G1::random(OsRng).to_affine(),
],
permutation_commitments: vec![G1::random(OsRng).to_affine()],
x_g2: G2::random(OsRng).to_affine(),
g2_generator: G2::generator().to_affine(),
}
}
fn setup_aggregator() -> Aggregator<Bn254> {
let srs = Arc::new(Srs::<Bn254>::mock(10));
Aggregator::builder().with_srs(srs).max_batch_size(4).build().unwrap()
}
#[test]
fn aggregator_starts_empty() {
let agg = setup_aggregator();
assert!(agg.queue_is_empty());
assert_eq!(agg.queue_len(), 0);
}
#[test]
fn aggregator_registers_circuits() {
let agg = setup_aggregator();
let vk1 = agg.register_circuit("transfer", mock_vk(5));
let vk2 = agg.register_circuit("deposit", mock_vk(3));
assert_ne!(vk1, vk2);
assert!(agg.registry().contains(vk1));
assert!(agg.registry().contains(vk2));
}
#[test]
fn aggregator_rejects_unknown_vk() {
let mut agg = setup_aggregator();
let unknown_vk = VkId::new(999);
let proof = Proof::<Bn254, Pending>::new(vec![], vec![], unknown_vk);
let verified = unsafe_create_verified_proof(proof);
let result = agg.submit(verified);
assert!(matches!(result, Err(Error::UnknownVk(_))));
}
#[test]
fn aggregator_rejects_when_full() {
let mut agg = setup_aggregator();
let vk = agg.register_circuit("test", mock_vk(1));
for i in 0..4 {
let proof = Proof::<Bn254, Pending>::new(vec![i as u8], vec![], vk);
let verified = unsafe_create_verified_proof(proof);
agg.submit(verified).unwrap();
}
let proof = Proof::<Bn254, Pending>::new(vec![99], vec![], vk);
let verified = unsafe_create_verified_proof(proof);
let result = agg.submit(verified);
assert!(matches!(
result,
Err(Error::BatchTooLarge { got: 5, max: 4 })
));
}
#[test]
fn aggregator_empty_batch_fails() {
let mut agg = setup_aggregator();
let result = agg.aggregate();
assert!(matches!(result, Err(Error::EmptyBatch)));
}
#[test]
fn aggregator_drains_queue_on_aggregate() {
let mut agg = setup_aggregator();
let vk = agg.register_circuit("test", mock_vk(1));
let proof = Proof::<Bn254, Pending>::new(vec![], vec![], vk);
let verified = unsafe_create_verified_proof(proof);
agg.submit(verified).unwrap();
assert_eq!(agg.queue_len(), 1);
let _aggregated = agg.aggregate().unwrap();
assert!(agg.queue_is_empty());
}
fn unsafe_create_verified_proof<E: PairingEngine>(
proof: Proof<E, Pending>,
) -> Proof<E, Verified> {
Proof::new_verified(
proof.data().to_vec(),
proof.public_inputs().to_vec(),
proof.vk_id(),
)
}
}