samaharam 0.2.0

Scalable heterogeneous zero-knowledge proof aggregation for EVM chains
Documentation
//! Parallel aggregation using rayon.

use std::collections::HashMap;
use std::sync::Arc;

use rayon::prelude::*;

use crate::config::Srs;
use crate::error::Error;
use crate::proof::{Aggregated, Batched, Proof};
use crate::registry::{VkId, VkRegistry};
use crate::traits::PairingEngine;

/// Parallel aggregator for high-throughput proof processing.
pub struct ParallelAggregator<E: PairingEngine> {
    #[allow(dead_code)] // Used in actual aggregation implementation
    srs: Arc<Srs<E>>,
    registry: Arc<VkRegistry<E>>,
    thread_pool_size: usize,
}

impl<E: PairingEngine> ParallelAggregator<E> {
    /// Create a new parallel aggregator.
    pub fn new(srs: Arc<Srs<E>>, registry: Arc<VkRegistry<E>>) -> Self {
        Self {
            srs,
            registry,
            thread_pool_size: rayon::current_num_threads(),
        }
    }

    /// Set custom thread pool size.
    pub fn with_threads(mut self, num_threads: usize) -> Self {
        self.thread_pool_size = num_threads;
        self
    }

    /// Aggregate proofs in parallel.
    ///
    /// Proofs are grouped by VK and processed concurrently.
    pub fn aggregate(&self, proofs: Vec<Proof<E, Batched>>) -> Result<Proof<E, Aggregated>, Error> {
        if proofs.is_empty() {
            return Err(Error::EmptyBatch);
        }

        // Group by VK
        let grouped = self.group_by_vk(proofs);

        // Process each VK group in parallel
        let sub_results: Result<Vec<SubAggregateResult<E>>, Error> = grouped
            .into_par_iter()
            .map(|(vk_id, batch)| self.aggregate_homogeneous(vk_id, batch))
            .collect();

        let sub_results = sub_results?;

        // Combine sub-aggregates
        self.combine_sub_aggregates(sub_results)
    }

    /// Group proofs by their VK.
    fn group_by_vk(&self, proofs: Vec<Proof<E, Batched>>) -> HashMap<VkId, Vec<Proof<E, Batched>>> {
        let mut grouped: HashMap<VkId, Vec<Proof<E, Batched>>> = HashMap::new();

        for proof in proofs {
            grouped.entry(proof.vk_id()).or_default().push(proof);
        }

        grouped
    }

    /// Aggregate proofs with the same VK (homogeneous aggregation).
    fn aggregate_homogeneous(
        &self,
        vk_id: VkId,
        proofs: Vec<Proof<E, Batched>>,
    ) -> Result<SubAggregateResult<E>, Error> {
        use crate::crypto::{AccumulatorInstance, ProofAccumulator};
        use group::GroupEncoding;

        // Verify VK exists
        let _vk = self.registry.require(vk_id)?;

        // Create accumulator for this VK group
        let mut accumulator = ProofAccumulator::<E>::new("parallel_sub_aggregate");

        // Collect public inputs and add to accumulator
        let mut public_inputs = Vec::new();

        for proof in &proofs {
            public_inputs.extend(proof.public_inputs().iter().cloned());

            // Parse and accumulate proof
            if let Ok(plonk_proof) = crate::crypto::PlonkProof::<E>::from_bytes(proof.data()) {
                let instance = AccumulatorInstance {
                    commitment: plonk_proof.wire_commitments[0],
                    evaluation: plonk_proof.evaluations.a_eval,
                    point: E::Fr::from(7u64),
                    quotient: plonk_proof.opening_proof,
                };
                accumulator.add(instance);
            }
        }

        // Fold accumulated proofs
        let aggregated_data = if !accumulator.is_empty() {
            let accumulated = accumulator.fold().map_err(Error::VerificationFailed)?;

            let mut data = Vec::new();
            data.extend_from_slice(accumulated.adjusted_commitment.to_bytes().as_ref());
            data.extend_from_slice(accumulated.combined_quotient.to_bytes().as_ref());
            data
        } else {
            vec![]
        };

        Ok(SubAggregateResult {
            vk_id,
            proof_count: proofs.len(),
            public_inputs,
            aggregated_data,
        })
    }

    /// Combine sub-aggregates into final proof.
    fn combine_sub_aggregates(
        &self,
        sub_results: Vec<SubAggregateResult<E>>,
    ) -> Result<Proof<E, Aggregated>, Error> {
        // Combine all public inputs and aggregated data
        let mut all_public_inputs = Vec::new();
        let mut combined_data = Vec::new();
        let mut total_count = 0usize;

        for result in sub_results {
            all_public_inputs.extend(result.public_inputs);
            combined_data.extend(result.aggregated_data);
            total_count += result.proof_count;
        }

        // Append proof count
        combined_data.extend_from_slice(&(total_count as u32).to_le_bytes());

        Ok(Proof::new_aggregated(combined_data, all_public_inputs))
    }
}

/// Result of aggregating proofs with the same VK.
#[allow(dead_code)] // Used in full aggregation implementation
struct SubAggregateResult<E: PairingEngine> {
    vk_id: VkId,
    proof_count: usize,
    public_inputs: Vec<E::Fr>,
    aggregated_data: Vec<u8>,
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::backend::bn254::Bn254;
    use crate::crypto::VerificationKey;
    use crate::proof::Verified;
    use group::{Curve, Group};
    use halo2curves::bn256::{G1, G2};
    use rand::rngs::OsRng;

    fn mock_vk(num_public_inputs: usize) -> VerificationKey<Bn254> {
        VerificationKey {
            num_public_inputs,
            domain_size: 1024,
            selector_commitments: vec![
                G1::random(OsRng).to_affine(),
                G1::random(OsRng).to_affine(),
            ],
            permutation_commitments: vec![G1::random(OsRng).to_affine()],
            x_g2: G2::random(OsRng).to_affine(),
            g2_generator: G2::generator().to_affine(),
        }
    }

    fn mock_batched_proof(vk_id: VkId) -> Proof<Bn254, Batched> {
        Proof::<Bn254, Verified>::new_verified(vec![], vec![], vk_id).submit()
    }

    fn setup_aggregator() -> ParallelAggregator<Bn254> {
        let srs = Arc::new(Srs::<Bn254>::mock(10));
        let registry = Arc::new(VkRegistry::<Bn254>::new());

        // Register test VKs
        registry.register("test1", mock_vk(5));
        registry.register("test2", mock_vk(3));

        ParallelAggregator::new(srs, registry)
    }

    #[test]
    fn parallel_aggregator_empty_batch_fails() {
        let agg = setup_aggregator();
        let result = agg.aggregate(vec![]);
        assert!(matches!(result, Err(Error::EmptyBatch)));
    }

    #[test]
    fn parallel_aggregator_groups_by_vk() {
        let agg = setup_aggregator();

        let proofs = vec![
            mock_batched_proof(VkId::new(0)),
            mock_batched_proof(VkId::new(1)),
            mock_batched_proof(VkId::new(0)),
        ];

        let grouped = agg.group_by_vk(proofs);

        assert_eq!(grouped.len(), 2);
        assert_eq!(grouped.get(&VkId::new(0)).unwrap().len(), 2);
        assert_eq!(grouped.get(&VkId::new(1)).unwrap().len(), 1);
    }

    #[test]
    fn parallel_aggregator_requires_known_vk() {
        let agg = setup_aggregator();

        // Use unknown VK
        let proofs = vec![mock_batched_proof(VkId::new(999))];

        let result = agg.aggregate(proofs);
        assert!(matches!(result, Err(Error::UnknownVk(_))));
    }

    #[test]
    fn parallel_aggregator_processes_multiple_vks() {
        let agg = setup_aggregator();

        let proofs = vec![
            mock_batched_proof(VkId::new(0)), // test1
            mock_batched_proof(VkId::new(1)), // test2
        ];

        let result = agg.aggregate(proofs);
        assert!(result.is_ok());
    }
}