samaharam 0.2.0

Scalable heterogeneous zero-knowledge proof aggregation for EVM chains
Documentation
//! Aggregation circuit for heterogeneous proof verification.
//!
//! This circuit verifies multiple proofs with variable public inputs
//! and potentially different verification keys.

use std::marker::PhantomData;

use crate::registry::VkId;
use crate::traits::PairingEngine;

/// Configuration for the aggregation circuit.
#[derive(Debug, Clone)]
pub struct AggregationConfig {
    /// Maximum number of proofs to aggregate.
    pub max_proofs: usize,

    /// Maximum public inputs per proof.
    pub max_public_inputs_per_proof: usize,

    /// Maximum number of distinct VKs.
    pub max_vks: usize,
}

impl Default for AggregationConfig {
    fn default() -> Self {
        Self {
            max_proofs: 32,
            max_public_inputs_per_proof: 16,
            max_vks: 8,
        }
    }
}

/// A proof prepared for aggregation.
#[derive(Debug, Clone)]
pub struct PreparedProof<E: PairingEngine> {
    /// VK identifier for this proof.
    pub vk_id: VkId,

    /// Public inputs (variable length, up to max).
    pub public_inputs: Vec<E::Fr>,

    /// Actual count of public inputs.
    pub public_input_count: usize,

    /// Serialized proof data.
    pub proof_data: Vec<u8>,
}

/// Batch of proofs grouped by VK for efficient aggregation.
#[derive(Debug)]
pub struct ProofBatch<E: PairingEngine> {
    /// VK identifier for all proofs in this batch.
    pub vk_id: VkId,

    /// All proofs sharing this VK.
    pub proofs: Vec<PreparedProof<E>>,
}

/// The main aggregation circuit.
///
/// This circuit:
/// 1. Verifies each proof natively (in-circuit)
/// 2. Accumulates pairing elements
/// 3. Outputs a single aggregated proof
///
/// # Variable Public Inputs
///
/// Unlike fixed-size circuits, this supports variable public inputs
/// by padding unused slots and tracking actual counts.
///
/// # Heterogeneous VKs
///
/// Proofs are grouped by VK into homogeneous batches, then
/// combined at the aggregation level.
pub struct AggregationCircuit<E: PairingEngine> {
    /// Configuration.
    config: AggregationConfig,

    /// Batches grouped by VK.
    batches: Vec<ProofBatch<E>>,

    /// Total proof count.
    total_proofs: usize,

    _engine: PhantomData<E>,
}

impl<E: PairingEngine> AggregationCircuit<E> {
    /// Create a new aggregation circuit with default config.
    pub fn new() -> Self {
        Self::with_config(AggregationConfig::default())
    }

    /// Create a new aggregation circuit with custom config.
    pub fn with_config(config: AggregationConfig) -> Self {
        Self {
            config,
            batches: Vec::new(),
            total_proofs: 0,
            _engine: PhantomData,
        }
    }

    /// Add a proof to the circuit.
    ///
    /// The proof will be grouped with others sharing the same VK.
    pub fn add_proof(&mut self, prepared: PreparedProof<E>) -> Result<(), String> {
        if self.total_proofs >= self.config.max_proofs {
            return Err(format!(
                "Exceeded max proofs: {} >= {}",
                self.total_proofs, self.config.max_proofs
            ));
        }

        if prepared.public_input_count > self.config.max_public_inputs_per_proof {
            return Err(format!(
                "Too many public inputs: {} > {}",
                prepared.public_input_count, self.config.max_public_inputs_per_proof
            ));
        }

        // Find or create batch for this VK
        let batch = self.batches.iter_mut().find(|b| b.vk_id == prepared.vk_id);

        match batch {
            Some(batch) => {
                batch.proofs.push(prepared);
            }
            None => {
                if self.batches.len() >= self.config.max_vks {
                    return Err(format!(
                        "Exceeded max VKs: {} >= {}",
                        self.batches.len(),
                        self.config.max_vks
                    ));
                }
                self.batches.push(ProofBatch {
                    vk_id: prepared.vk_id,
                    proofs: vec![prepared],
                });
            }
        }

        self.total_proofs += 1;
        Ok(())
    }

    /// Get the number of distinct VKs.
    pub fn vk_count(&self) -> usize {
        self.batches.len()
    }

    /// Get the total number of proofs.
    pub fn proof_count(&self) -> usize {
        self.total_proofs
    }

    /// Get proofs grouped by VK.
    pub fn batches(&self) -> &[ProofBatch<E>] {
        &self.batches
    }

    /// Synthesize the aggregation circuit.
    ///
    /// This would generate the actual circuit constraints.
    /// For now, returns a summary of what would be synthesized.
    pub fn synthesize_summary(&self) -> String {
        let mut summary = format!(
            "AggregationCircuit: {} proofs across {} VKs\n",
            self.total_proofs,
            self.batches.len()
        );

        for batch in &self.batches {
            summary.push_str(&format!(
                "  - VK {:?}: {} proofs\n",
                batch.vk_id,
                batch.proofs.len()
            ));
        }

        summary
    }
}

impl<E: PairingEngine> Default for AggregationCircuit<E> {
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::backend::bn254::Bn254;
    use crate::registry::VkId;
    use ff::Field;
    use halo2curves::bn256::Fr;
    use rand::rngs::OsRng;

    fn mock_proof(vk_id: VkId, num_inputs: usize) -> PreparedProof<Bn254> {
        let public_inputs: Vec<_> = (0..num_inputs).map(|_| Fr::random(OsRng)).collect();

        PreparedProof {
            vk_id,
            public_inputs: public_inputs.clone(),
            public_input_count: num_inputs,
            proof_data: vec![0u8; 32],
        }
    }

    #[test]
    fn circuit_starts_empty() {
        let circuit = AggregationCircuit::<Bn254>::new();
        assert_eq!(circuit.proof_count(), 0);
        assert_eq!(circuit.vk_count(), 0);
    }

    #[test]
    fn circuit_adds_proofs() {
        let mut circuit = AggregationCircuit::<Bn254>::new();

        let proof = mock_proof(VkId::new(1), 3);
        circuit.add_proof(proof).unwrap();

        assert_eq!(circuit.proof_count(), 1);
        assert_eq!(circuit.vk_count(), 1);
    }

    #[test]
    fn circuit_groups_by_vk() {
        let mut circuit = AggregationCircuit::<Bn254>::new();

        // Add proofs from same VK
        circuit.add_proof(mock_proof(VkId::new(1), 2)).unwrap();
        circuit.add_proof(mock_proof(VkId::new(1), 3)).unwrap();

        // Add proof from different VK
        circuit.add_proof(mock_proof(VkId::new(2), 4)).unwrap();

        assert_eq!(circuit.proof_count(), 3);
        assert_eq!(circuit.vk_count(), 2);

        // Check batch sizes
        let batches = circuit.batches();
        let vk1_batch = batches.iter().find(|b| b.vk_id == VkId::new(1)).unwrap();
        let vk2_batch = batches.iter().find(|b| b.vk_id == VkId::new(2)).unwrap();

        assert_eq!(vk1_batch.proofs.len(), 2);
        assert_eq!(vk2_batch.proofs.len(), 1);
    }

    #[test]
    fn circuit_respects_max_proofs() {
        let config = AggregationConfig {
            max_proofs: 2,
            ..Default::default()
        };
        let mut circuit = AggregationCircuit::<Bn254>::with_config(config);

        circuit.add_proof(mock_proof(VkId::new(1), 1)).unwrap();
        circuit.add_proof(mock_proof(VkId::new(1), 1)).unwrap();

        let result = circuit.add_proof(mock_proof(VkId::new(1), 1));
        assert!(result.is_err());
        assert!(result.unwrap_err().contains("max proofs"));
    }

    #[test]
    fn circuit_respects_max_vks() {
        let config = AggregationConfig {
            max_vks: 2,
            ..Default::default()
        };
        let mut circuit = AggregationCircuit::<Bn254>::with_config(config);

        circuit.add_proof(mock_proof(VkId::new(1), 1)).unwrap();
        circuit.add_proof(mock_proof(VkId::new(2), 1)).unwrap();

        let result = circuit.add_proof(mock_proof(VkId::new(3), 1));
        assert!(result.is_err());
        assert!(result.unwrap_err().contains("max VKs"));
    }

    #[test]
    fn circuit_respects_max_public_inputs() {
        let config = AggregationConfig {
            max_public_inputs_per_proof: 2,
            ..Default::default()
        };
        let mut circuit = AggregationCircuit::<Bn254>::with_config(config);

        let result = circuit.add_proof(mock_proof(VkId::new(1), 5));
        assert!(result.is_err());
        assert!(result.unwrap_err().contains("public inputs"));
    }

    #[test]
    fn circuit_synthesize_summary() {
        let mut circuit = AggregationCircuit::<Bn254>::new();

        circuit.add_proof(mock_proof(VkId::new(1), 2)).unwrap();
        circuit.add_proof(mock_proof(VkId::new(1), 3)).unwrap();
        circuit.add_proof(mock_proof(VkId::new(2), 4)).unwrap();

        let summary = circuit.synthesize_summary();
        assert!(summary.contains("3 proofs"));
        assert!(summary.contains("2 VKs"));
    }
}