use std::sync::Arc;
use crate::aggregator::Aggregator;
use crate::backend::bn254::Bn254;
use crate::config::{AggregatorBuilder, Srs};
use crate::error::Error;
use crate::registry::VkId;
use crate::traits::PairingEngine;
use tracing::{debug, info, instrument, warn};
#[cfg(feature = "solidity")]
use crate::solidity::{SolidityConfig, SolidityGenerator, SolidityVk};
#[derive(Debug, Clone)]
pub struct ClientConfig {
pub max_batch_size: usize,
pub parallel: bool,
pub verifier_contract_name: String,
}
impl Default for ClientConfig {
fn default() -> Self {
Self {
max_batch_size: 32,
parallel: true,
verifier_contract_name: "AggregatedVerifier".to_string(),
}
}
}
#[derive(Debug, Clone)]
pub struct ProofRequest {
pub circuit_type: String,
pub proof_data: Vec<u8>,
pub public_inputs: Vec<[u8; 32]>,
}
#[derive(Debug)]
pub struct AggregationResult {
pub aggregated_proof: Vec<u8>,
pub public_inputs: Vec<[u8; 32]>,
pub proof_count: usize,
pub solidity_verifier: Option<String>,
}
pub struct AggregationClient {
aggregator: Aggregator<Bn254>,
#[allow(dead_code)]
config: ClientConfig,
circuit_vk_map: std::collections::HashMap<String, VkId>,
}
impl AggregationClient {
pub fn new(srs: Arc<Srs<Bn254>>) -> Result<Self, Error> {
Self::with_config(srs, ClientConfig::default())
}
pub fn with_config(srs: Arc<Srs<Bn254>>, config: ClientConfig) -> Result<Self, Error> {
let mut builder = AggregatorBuilder::<Bn254>::new()
.with_srs(srs)
.max_batch_size(config.max_batch_size);
if config.parallel {
builder = builder.enable_parallelism();
}
let aggregator = builder.build().map_err(|e| Error::VerificationFailed(e.to_string()))?;
Ok(Self {
aggregator,
config,
circuit_vk_map: std::collections::HashMap::new(),
})
}
#[instrument(skip(self, vk), fields(vk_domain_size = vk.domain_size))]
pub fn register_circuit(
&mut self,
circuit_type: &str,
vk: crate::crypto::VerificationKey<Bn254>,
) -> VkId {
let vk_id = self.aggregator.register_circuit(circuit_type, vk);
self.circuit_vk_map.insert(circuit_type.to_string(), vk_id);
info!(circuit_type, ?vk_id, "Registered new circuit");
vk_id
}
#[instrument(skip(self, request), fields(circuit = %request.circuit_type))]
pub fn submit(&mut self, request: ProofRequest) -> Result<(), Error> {
let vk_id = self.circuit_vk_map.get(&request.circuit_type).ok_or_else(|| {
warn!("Rejected proof for unknown circuit: {}", request.circuit_type);
Error::VerificationFailed(format!("Unknown circuit type: {}", request.circuit_type))
})?;
let public_inputs = self.decode_public_inputs(&request.public_inputs)?;
let proof = crate::proof::Proof::<Bn254, crate::proof::Pending>::new(
request.proof_data,
public_inputs,
*vk_id,
);
let verified = match proof.verify(self.aggregator.registry()) {
Ok(v) => v,
Err(e) => {
warn!(error = ?e, "Proof verification failed");
return Err(e);
}
};
self.aggregator.submit(verified)?;
debug!("Proof submitted for aggregation");
Ok(())
}
#[cfg(any(test, feature = "testing"))]
pub fn submit_unchecked(&mut self, request: ProofRequest) -> Result<(), Error> {
let vk_id = self.circuit_vk_map.get(&request.circuit_type).ok_or_else(|| {
Error::VerificationFailed(format!("Unknown circuit type: {}", request.circuit_type))
})?;
let public_inputs = self.decode_public_inputs(&request.public_inputs)?;
let verified = crate::proof::Proof::<Bn254, crate::proof::Verified>::new_verified(
request.proof_data,
public_inputs,
*vk_id,
);
self.aggregator.submit(verified)?;
Ok(())
}
pub fn pending_count(&self) -> usize {
self.aggregator.queue_len()
}
#[instrument(skip(self))]
pub fn aggregate(&mut self) -> Result<AggregationResult, Error> {
let queue_len = self.aggregator.queue_len();
info!(queue_len, "Starting aggregation batch");
let aggregated = self.aggregator.aggregate()?;
let public_inputs = self.encode_public_inputs(aggregated.public_inputs());
#[cfg(feature = "solidity")]
let solidity_verifier = Some(self.generate_solidity_verifier()?);
#[cfg(not(feature = "solidity"))]
let solidity_verifier = None;
info!(proof_count = queue_len, "Aggregation complete");
Ok(AggregationResult {
aggregated_proof: aggregated.data().to_vec(),
public_inputs,
proof_count: queue_len, solidity_verifier,
})
}
#[cfg(feature = "solidity")]
pub fn generate_solidity_verifier(&self) -> Result<String, Error> {
let config = SolidityConfig::with_name(&self.config.verifier_contract_name);
let generator = SolidityGenerator::<Bn254>::with_config(config);
let vk = SolidityVk {
commitments: vec![],
num_public_inputs: 0,
proof_length: 256,
};
Ok(generator.generate(&vk))
}
fn decode_public_inputs(
&self,
inputs: &[[u8; 32]],
) -> Result<Vec<<Bn254 as PairingEngine>::Fr>, Error> {
use ff::PrimeField;
inputs
.iter()
.map(|bytes| {
let mut repr = <halo2curves::bn256::Fr as PrimeField>::Repr::default();
repr.as_mut().copy_from_slice(bytes);
halo2curves::bn256::Fr::from_repr(repr)
.into_option()
.ok_or_else(|| Error::VerificationFailed("Invalid field element".to_string()))
})
.collect()
}
fn encode_public_inputs(&self, inputs: &[<Bn254 as PairingEngine>::Fr]) -> Vec<[u8; 32]> {
use ff::PrimeField;
inputs
.iter()
.map(|f| {
let repr = f.to_repr();
let mut bytes = [0u8; 32];
bytes.copy_from_slice(repr.as_ref());
bytes
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::crypto::VerificationKey;
use group::{Curve, Group};
use halo2curves::bn256::{G1, G2};
use rand::rngs::OsRng;
fn mock_vk(num_public_inputs: usize) -> VerificationKey<Bn254> {
VerificationKey {
num_public_inputs,
domain_size: 1024,
selector_commitments: vec![
G1::random(OsRng).to_affine(),
G1::random(OsRng).to_affine(),
],
permutation_commitments: vec![G1::random(OsRng).to_affine()],
x_g2: G2::random(OsRng).to_affine(),
g2_generator: G2::generator().to_affine(),
}
}
fn mock_proof_data() -> Vec<u8> {
use crate::crypto::{PlonkProof, ProofEvaluations};
use ff::Field;
use halo2curves::bn256::Fr;
let proof = PlonkProof::<Bn254> {
wire_commitments: [
G1::random(OsRng).to_affine(),
G1::random(OsRng).to_affine(),
G1::random(OsRng).to_affine(),
],
z_commitment: G1::random(OsRng).to_affine(),
t_commitments: vec![
G1::random(OsRng).to_affine(),
G1::random(OsRng).to_affine(),
G1::random(OsRng).to_affine(),
],
opening_proof: G1::random(OsRng).to_affine(),
shifted_opening_proof: G1::random(OsRng).to_affine(),
evaluations: ProofEvaluations {
a_eval: Fr::random(OsRng),
b_eval: Fr::random(OsRng),
c_eval: Fr::random(OsRng),
s1_eval: Fr::random(OsRng),
s2_eval: Fr::random(OsRng),
z_shifted_eval: Fr::random(OsRng),
},
};
proof.to_bytes()
}
fn setup_client() -> AggregationClient {
let srs = Arc::new(Srs::<Bn254>::mock(10));
AggregationClient::new(srs).unwrap()
}
#[test]
fn client_creates_with_default_config() {
let client = setup_client();
assert_eq!(client.pending_count(), 0);
}
#[test]
fn client_registers_circuits() {
let mut client = setup_client();
let vk1 = client.register_circuit("circuit_a", mock_vk(5));
let vk2 = client.register_circuit("circuit_b", mock_vk(3));
assert_ne!(vk1, vk2);
}
#[test]
fn client_rejects_unknown_circuit() {
let mut client = setup_client();
let request = ProofRequest {
circuit_type: "unknown".to_string(),
proof_data: vec![],
public_inputs: vec![],
};
let result = client.submit(request);
assert!(result.is_err());
}
#[test]
fn client_submits_proofs() {
let mut client = setup_client();
client.register_circuit("test_circuit", mock_vk(0));
let request = ProofRequest {
circuit_type: "test_circuit".to_string(),
proof_data: mock_proof_data(),
public_inputs: vec![],
};
client.submit(request).unwrap();
assert_eq!(client.pending_count(), 1);
}
#[test]
fn client_aggregates() {
let mut client = setup_client();
client.register_circuit("test_circuit", mock_vk(0));
let request = ProofRequest {
circuit_type: "test_circuit".to_string(),
proof_data: mock_proof_data(),
public_inputs: vec![],
};
client.submit(request).unwrap();
let result = client.aggregate().unwrap();
assert!(result.aggregated_proof.is_empty() || true);
}
}