use std::marker::PhantomData;
use crate::registry::VkId;
use crate::traits::PairingEngine;
#[derive(Debug, Clone)]
pub struct AggregationConfig {
pub max_proofs: usize,
pub max_public_inputs_per_proof: usize,
pub max_vks: usize,
}
impl Default for AggregationConfig {
fn default() -> Self {
Self {
max_proofs: 32,
max_public_inputs_per_proof: 16,
max_vks: 8,
}
}
}
#[derive(Debug, Clone)]
pub struct PreparedProof<E: PairingEngine> {
pub vk_id: VkId,
pub public_inputs: Vec<E::Fr>,
pub public_input_count: usize,
pub proof_data: Vec<u8>,
}
#[derive(Debug)]
pub struct ProofBatch<E: PairingEngine> {
pub vk_id: VkId,
pub proofs: Vec<PreparedProof<E>>,
}
pub struct AggregationCircuit<E: PairingEngine> {
config: AggregationConfig,
batches: Vec<ProofBatch<E>>,
total_proofs: usize,
_engine: PhantomData<E>,
}
impl<E: PairingEngine> AggregationCircuit<E> {
pub fn new() -> Self {
Self::with_config(AggregationConfig::default())
}
pub fn with_config(config: AggregationConfig) -> Self {
Self {
config,
batches: Vec::new(),
total_proofs: 0,
_engine: PhantomData,
}
}
pub fn add_proof(&mut self, prepared: PreparedProof<E>) -> Result<(), String> {
if self.total_proofs >= self.config.max_proofs {
return Err(format!(
"Exceeded max proofs: {} >= {}",
self.total_proofs, self.config.max_proofs
));
}
if prepared.public_input_count > self.config.max_public_inputs_per_proof {
return Err(format!(
"Too many public inputs: {} > {}",
prepared.public_input_count, self.config.max_public_inputs_per_proof
));
}
let batch = self.batches.iter_mut().find(|b| b.vk_id == prepared.vk_id);
match batch {
Some(batch) => {
batch.proofs.push(prepared);
}
None => {
if self.batches.len() >= self.config.max_vks {
return Err(format!(
"Exceeded max VKs: {} >= {}",
self.batches.len(),
self.config.max_vks
));
}
self.batches.push(ProofBatch {
vk_id: prepared.vk_id,
proofs: vec![prepared],
});
}
}
self.total_proofs += 1;
Ok(())
}
pub fn vk_count(&self) -> usize {
self.batches.len()
}
pub fn proof_count(&self) -> usize {
self.total_proofs
}
pub fn batches(&self) -> &[ProofBatch<E>] {
&self.batches
}
pub fn synthesize_summary(&self) -> String {
let mut summary = format!(
"AggregationCircuit: {} proofs across {} VKs\n",
self.total_proofs,
self.batches.len()
);
for batch in &self.batches {
summary.push_str(&format!(
" - VK {:?}: {} proofs\n",
batch.vk_id,
batch.proofs.len()
));
}
summary
}
}
impl<E: PairingEngine> Default for AggregationCircuit<E> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::bn254::Bn254;
use crate::registry::VkId;
use ff::Field;
use halo2curves::bn256::Fr;
use rand::rngs::OsRng;
fn mock_proof(vk_id: VkId, num_inputs: usize) -> PreparedProof<Bn254> {
let public_inputs: Vec<_> = (0..num_inputs).map(|_| Fr::random(OsRng)).collect();
PreparedProof {
vk_id,
public_inputs: public_inputs.clone(),
public_input_count: num_inputs,
proof_data: vec![0u8; 32],
}
}
#[test]
fn circuit_starts_empty() {
let circuit = AggregationCircuit::<Bn254>::new();
assert_eq!(circuit.proof_count(), 0);
assert_eq!(circuit.vk_count(), 0);
}
#[test]
fn circuit_adds_proofs() {
let mut circuit = AggregationCircuit::<Bn254>::new();
let proof = mock_proof(VkId::new(1), 3);
circuit.add_proof(proof).unwrap();
assert_eq!(circuit.proof_count(), 1);
assert_eq!(circuit.vk_count(), 1);
}
#[test]
fn circuit_groups_by_vk() {
let mut circuit = AggregationCircuit::<Bn254>::new();
circuit.add_proof(mock_proof(VkId::new(1), 2)).unwrap();
circuit.add_proof(mock_proof(VkId::new(1), 3)).unwrap();
circuit.add_proof(mock_proof(VkId::new(2), 4)).unwrap();
assert_eq!(circuit.proof_count(), 3);
assert_eq!(circuit.vk_count(), 2);
let batches = circuit.batches();
let vk1_batch = batches.iter().find(|b| b.vk_id == VkId::new(1)).unwrap();
let vk2_batch = batches.iter().find(|b| b.vk_id == VkId::new(2)).unwrap();
assert_eq!(vk1_batch.proofs.len(), 2);
assert_eq!(vk2_batch.proofs.len(), 1);
}
#[test]
fn circuit_respects_max_proofs() {
let config = AggregationConfig {
max_proofs: 2,
..Default::default()
};
let mut circuit = AggregationCircuit::<Bn254>::with_config(config);
circuit.add_proof(mock_proof(VkId::new(1), 1)).unwrap();
circuit.add_proof(mock_proof(VkId::new(1), 1)).unwrap();
let result = circuit.add_proof(mock_proof(VkId::new(1), 1));
assert!(result.is_err());
assert!(result.unwrap_err().contains("max proofs"));
}
#[test]
fn circuit_respects_max_vks() {
let config = AggregationConfig {
max_vks: 2,
..Default::default()
};
let mut circuit = AggregationCircuit::<Bn254>::with_config(config);
circuit.add_proof(mock_proof(VkId::new(1), 1)).unwrap();
circuit.add_proof(mock_proof(VkId::new(2), 1)).unwrap();
let result = circuit.add_proof(mock_proof(VkId::new(3), 1));
assert!(result.is_err());
assert!(result.unwrap_err().contains("max VKs"));
}
#[test]
fn circuit_respects_max_public_inputs() {
let config = AggregationConfig {
max_public_inputs_per_proof: 2,
..Default::default()
};
let mut circuit = AggregationCircuit::<Bn254>::with_config(config);
let result = circuit.add_proof(mock_proof(VkId::new(1), 5));
assert!(result.is_err());
assert!(result.unwrap_err().contains("public inputs"));
}
#[test]
fn circuit_synthesize_summary() {
let mut circuit = AggregationCircuit::<Bn254>::new();
circuit.add_proof(mock_proof(VkId::new(1), 2)).unwrap();
circuit.add_proof(mock_proof(VkId::new(1), 3)).unwrap();
circuit.add_proof(mock_proof(VkId::new(2), 4)).unwrap();
let summary = circuit.synthesize_summary();
assert!(summary.contains("3 proofs"));
assert!(summary.contains("2 VKs"));
}
}