samaharam 0.2.0

Scalable heterogeneous zero-knowledge proof aggregation for EVM chains
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
//! Main aggregator for heterogeneous proof aggregation.

use std::collections::HashMap;
use std::sync::Arc;

use crate::config::{AggregatorBuilder, Srs};
use crate::error::Error;
use crate::proof::{Aggregated, Batched, Proof, Verified};
use crate::registry::{VkId, VkRegistry};
use crate::traits::PairingEngine;
use ff::PrimeField;

/// Handle for a submitted proof in the aggregation queue.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ProofHandle(usize);

/// Heterogeneous proof aggregator.
///
/// Aggregates proofs from different circuits (with different VKs)
/// into a single aggregated proof.
///
/// # Example
///
/// ```rust,ignore
/// let aggregator = Aggregator::<Bn254>::builder()
///     .with_srs(srs)
///     .build()?;
///
/// // Register different circuit types
/// let transfer_vk = aggregator.register_circuit("transfer", vk_data, 5);
/// let deposit_vk = aggregator.register_circuit("deposit", vk_data, 3);
///
/// // Submit proofs from different circuits
/// aggregator.submit(verified_transfer_proof)?;
/// aggregator.submit(verified_deposit_proof)?;
///
/// // Aggregate all pending proofs
/// let aggregated = aggregator.aggregate()?;
/// ```
pub struct Aggregator<E: PairingEngine> {
    #[allow(dead_code)] // Used in actual aggregation implementation
    srs: Arc<Srs<E>>,
    registry: VkRegistry<E>,
    queue: Vec<Proof<E, Batched>>,
    seen_proofs: std::collections::HashSet<[u8; 32]>,
    max_batch_size: usize,
    parallel: bool,
    next_handle: usize,
    /// Accumulator for external proofs (Groth16, gnark, etc.)
    external_accumulator: crate::crypto::ProofAccumulator<E>,
}

impl<E: PairingEngine> Aggregator<E> {
    /// Create a builder for configuration.
    pub fn builder() -> AggregatorBuilder<E> {
        AggregatorBuilder::new()
    }

    /// Create a new aggregator (called by builder).
    pub(crate) fn new(srs: Arc<Srs<E>>, max_batch_size: usize, parallel: bool) -> Self {
        Self {
            srs,
            registry: VkRegistry::new(),
            queue: Vec::new(),
            seen_proofs: std::collections::HashSet::new(),
            max_batch_size,
            parallel,
            next_handle: 0,
            external_accumulator: crate::crypto::ProofAccumulator::new("external_proofs"),
        }
    }

    /// Get the maximum batch size.
    pub fn max_batch_size(&self) -> usize {
        self.max_batch_size
    }

    /// Get a reference to the VK registry.
    pub fn registry(&self) -> &VkRegistry<E> {
        &self.registry
    }

    /// Get a reference to the SRS.
    pub fn srs(&self) -> &Srs<E> {
        &self.srs
    }

    /// Get mutable access to the external proof accumulator.
    ///
    /// Use this to add instances from external proof systems (Groth16, gnark).
    pub fn accumulator_mut(&mut self) -> &mut crate::crypto::ProofAccumulator<E> {
        &mut self.external_accumulator
    }

    /// Get the number of external instances accumulated.
    pub fn external_count(&self) -> usize {
        self.external_accumulator.len()
    }

    /// Register a circuit's verification key.
    ///
    /// # Arguments
    ///
    /// * `name` - Human-readable circuit name
    /// * `vk` - The typed verification key
    ///
    /// # Returns
    ///
    /// A unique ID for this verification key.
    pub fn register_circuit(&self, name: &str, vk: crate::crypto::VerificationKey<E>) -> VkId {
        self.registry.register(name, vk)
    }

    /// Get the number of proofs in the queue.
    pub fn queue_len(&self) -> usize {
        self.queue.len()
    }

    /// Check if the queue is empty.
    pub fn queue_is_empty(&self) -> bool {
        self.queue.is_empty()
    }

    /// Submit a verified proof for aggregation.
    ///
    /// # Arguments
    ///
    /// * `proof` - A verified proof to aggregate
    ///
    /// # Returns
    ///
    /// A handle that can be used to track the proof.
    ///
    /// # Errors
    ///
    /// Returns `Error::BatchTooLarge` if the queue is full.
    pub fn submit(&mut self, proof: Proof<E, Verified>) -> Result<ProofHandle, Error> {
        if self.queue.len() >= self.max_batch_size {
            return Err(Error::BatchTooLarge {
                got: self.queue.len() + 1,
                max: self.max_batch_size,
            });
        }

        // Verify the VK exists
        self.registry.require(proof.vk_id())?;

        // Check for duplicates
        use sha2::{Digest, Sha256};
        let mut hasher = Sha256::new();
        hasher.update(proof.data());
        for input in proof.public_inputs() {
            hasher.update(input.to_repr());
        }
        let hash: [u8; 32] = hasher.finalize().into();

        if self.seen_proofs.contains(&hash) {
            return Err(Error::VerificationFailed("Duplicate proof in batch".to_string()));
        }
        self.seen_proofs.insert(hash);

        let handle = ProofHandle(self.next_handle);
        self.next_handle += 1;

        self.queue.push(proof.submit());

        Ok(handle)
    }

    /// Aggregate all pending proofs.
    ///
    /// # Returns
    ///
    /// An aggregated proof combining all submitted proofs.
    ///
    /// # Errors
    ///
    /// Returns `Error::EmptyBatch` if no proofs are queued.
    pub fn aggregate(&mut self) -> Result<Proof<E, Aggregated>, Error> {
        if self.queue.is_empty() {
            return Err(Error::EmptyBatch);
        }

        let batch: Vec<_> = self.queue.drain(..).collect();
        self.seen_proofs.clear();

        if self.parallel {
            self.aggregate_parallel(batch)
        } else {
            self.aggregate_sequential(batch)
        }
    }

    fn aggregate_sequential(
        &self,
        batch: Vec<Proof<E, Batched>>,
    ) -> Result<Proof<E, Aggregated>, Error> {
        use crate::crypto::{AccumulatorInstance, ProofAccumulator};

        // Group proofs by VK for homogeneous sub-aggregation
        let grouped = self.group_by_vk(batch);

        // Create accumulator for batched verification
        let mut accumulator = ProofAccumulator::<E>::new("samaharam_aggregation");

        // Collect all public inputs
        let mut all_public_inputs = Vec::new();

        // Process each VK group
        for (vk_id, proofs) in grouped {
            // Get VK from registry for this group (required for transcript binding)
            let registered_vk = self.registry.get(vk_id).ok_or(Error::UnknownVk(vk_id))?;

            for proof in proofs {
                // Collect public inputs
                all_public_inputs.extend(proof.public_inputs().iter().cloned());

                // Create accumulator instance from proof data
                // The proof data contains serialized PlonkProof
                if let Ok(plonk_proof) = crate::crypto::PlonkProof::<E>::from_bytes(proof.data()) {
                    // Compute evaluation point (zeta) using Fiat-Shamir transcript
                    // VK binding ensures cross-circuit replay attacks are prevented
                    let zeta = Self::compute_evaluation_point(
                        &plonk_proof,
                        proof.public_inputs(),
                        &registered_vk.vk,
                    );

                    // Create instance from wire commitment (a(X))
                    let instance = AccumulatorInstance {
                        commitment: plonk_proof.wire_commitments[0],
                        evaluation: plonk_proof.evaluations.a_eval,
                        point: zeta,
                        quotient: plonk_proof.opening_proof,
                    };
                    accumulator.add(instance);
                }
            }
        }



        // Handle case where no valid proof data was parsed
        let aggregated_data = if accumulator.is_empty() {
            // Return empty aggregated data if no instances could be parsed
            vec![]
        } else {
            // Fold all instances into accumulated proof
            let accumulated = accumulator.fold().map_err(Error::VerificationFailed)?;

            // Serialize accumulated proof
            let mut data = Vec::new();

            // Serialize accumulated commitments
            use group::GroupEncoding;
            data.extend_from_slice(accumulated.adjusted_commitment.to_bytes().as_ref());
            data.extend_from_slice(accumulated.combined_quotient.to_bytes().as_ref());
            data.extend_from_slice(&(accumulated.count as u32).to_le_bytes());
            data
        };

        Ok(Proof::new_aggregated(aggregated_data, all_public_inputs))
    }

    #[cfg(feature = "parallel")]
    fn aggregate_parallel(
        &self,
        batch: Vec<Proof<E, Batched>>,
    ) -> Result<Proof<E, Aggregated>, Error> {
        use crate::crypto::{AccumulatedProof, AccumulatorInstance, ProofAccumulator};
        use group::{Curve, Group, GroupEncoding};
        use rayon::prelude::*;

        let grouped = self.group_by_vk(batch);

        // Get VKs for parallel processing (need to collect before parallel iteration)
        let vk_proofs: Vec<_> = grouped
            .into_iter()
            .filter_map(|(vk_id, proofs)| {
                self.registry.get(vk_id).map(|registered| (registered, proofs))
            })
            .collect();

        // Parallel processing: group by VK and aggregate in parallel
        #[cfg(feature = "parallel")]
        type AggResult<E> = Result<(Vec<<E as PairingEngine>::Fr>, AccumulatedProof<E>), Error>;
        #[cfg(feature = "parallel")]
        let sub_results: Vec<AggResult<E>> = vk_proofs
            .into_par_iter()
            .map(|(registered_vk, proofs)| {
                let mut accumulator = ProofAccumulator::<E>::new("samaharam_parallel");
                let mut public_inputs = Vec::new();

                for proof in proofs {
                    public_inputs.extend(proof.public_inputs().iter().cloned());

                    if let Ok(plonk_proof) = crate::crypto::PlonkProof::<E>::from_bytes(proof.data()) {
                        // Compute evaluation point using Fiat-Shamir transcript with VK binding
                        let zeta = Self::compute_evaluation_point(
                            &plonk_proof,
                            proof.public_inputs(),
                            &registered_vk.vk,
                        );

                        let instance = AccumulatorInstance {
                            commitment: plonk_proof.wire_commitments[0],
                            evaluation: plonk_proof.evaluations.a_eval,
                            point: zeta,
                            quotient: plonk_proof.opening_proof,
                        };
                        accumulator.add(instance);
                    }
                }

                let accumulated = accumulator.fold().map_err(Error::VerificationFailed)?;

                Ok((public_inputs, accumulated))
            })
            .collect();

        // Combine sub-aggregates
        let mut all_public_inputs = Vec::new();
        let mut total_count = 0usize;
        let mut combined_adjusted = E::G1::identity();
        let mut combined_quotient = E::G1::identity();

        for result in sub_results {
            let (inputs, acc) = result?;
            all_public_inputs.extend(inputs);
            total_count += acc.count;
            combined_adjusted += acc.adjusted_commitment.into();
            combined_quotient += acc.combined_quotient.into();
        }

        // Serialize combined result
        let mut aggregated_data = Vec::new();
        aggregated_data.extend_from_slice(combined_adjusted.to_affine().to_bytes().as_ref());
        aggregated_data.extend_from_slice(combined_quotient.to_affine().to_bytes().as_ref());
        aggregated_data.extend_from_slice(&(total_count as u32).to_le_bytes());

        Ok(Proof::new_aggregated(aggregated_data, all_public_inputs))
    }

    #[cfg(not(feature = "parallel"))]
    fn aggregate_parallel(
        &self,
        batch: Vec<Proof<E, Batched>>,
    ) -> Result<Proof<E, Aggregated>, Error> {
        // Fall back to sequential if parallel feature not enabled
        self.aggregate_sequential(batch)
    }

    fn group_by_vk(&self, batch: Vec<Proof<E, Batched>>) -> HashMap<VkId, Vec<Proof<E, Batched>>> {
        let mut grouped: HashMap<VkId, Vec<Proof<E, Batched>>> = HashMap::new();

        for proof in batch {
            grouped.entry(proof.vk_id()).or_default().push(proof);
        }

        grouped
    }

    /// Compute the evaluation point (zeta) using Fiat-Shamir transcript.
    ///
    /// This matches the challenge derivation in PlonkVerifier::verify.
    /// The zeta challenge is derived after:
    /// 0. Adding verification key selector commitments (VK binding)
    /// 1. Adding public inputs
    /// 2. Adding wire commitments (a, b, c)
    /// 3. Deriving beta, gamma challenges
    /// 4. Adding z commitment
    /// 5. Deriving alpha challenge
    /// 6. Adding t commitments
    /// 7. Finally deriving zeta
    ///
    /// # Security
    /// VK binding prevents cross-circuit replay attacks by ensuring the
    /// transcript is specific to each circuit's verification key.
    fn compute_evaluation_point(
        plonk_proof: &crate::crypto::PlonkProof<E>,
        public_inputs: &[E::Fr],
        vk: &crate::crypto::VerificationKey<E>,
    ) -> E::Fr {
        use crate::crypto::Transcript;

        // Initialize transcript with same domain as PlonkVerifier
        let mut transcript = Transcript::new("PLONK");

        // Add verification key to transcript (VK binding for security)
        // This matches the order in PlonkVerifier::verify
        for commitment in &vk.selector_commitments {
            transcript.append_g1::<E>("selector", commitment);
        }

        // Add public inputs to transcript
        for pi in public_inputs {
            transcript.append_scalar::<E>("public_input", pi);
        }

        // Add wire commitments [a], [b], [c]
        for wc in &plonk_proof.wire_commitments {
            transcript.append_g1::<E>("wire", wc);
        }

        // Get beta and gamma challenges (matching verifier)
        let _beta: E::Fr = transcript.challenge_scalar::<E>("beta");
        let _gamma: E::Fr = transcript.challenge_scalar::<E>("gamma");

        // Add z commitment
        transcript.append_g1::<E>("z", &plonk_proof.z_commitment);

        // Get alpha challenge
        let _alpha: E::Fr = transcript.challenge_scalar::<E>("alpha");

        // Add t commitments
        for tc in &plonk_proof.t_commitments {
            transcript.append_g1::<E>("t", tc);
        }

        // Finally derive zeta - the evaluation point
        transcript.challenge_scalar::<E>("zeta")
    }
}

impl<E: PairingEngine> std::fmt::Debug for Aggregator<E> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("Aggregator")
            .field("max_batch_size", &self.max_batch_size)
            .field("queue_len", &self.queue.len())
            .field("parallel", &self.parallel)
            .finish()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::backend::bn254::Bn254;
    use crate::crypto::VerificationKey;
    use crate::proof::Pending;
    use group::{Curve, Group};
    use halo2curves::bn256::{G1, G2};
    use rand::rngs::OsRng;

    fn mock_vk(num_public_inputs: usize) -> VerificationKey<Bn254> {
        VerificationKey {
            num_public_inputs,
            domain_size: 1024,
            selector_commitments: vec![
                G1::random(OsRng).to_affine(),
                G1::random(OsRng).to_affine(),
            ],
            permutation_commitments: vec![G1::random(OsRng).to_affine()],
            x_g2: G2::random(OsRng).to_affine(),
            g2_generator: G2::generator().to_affine(),
        }
    }

    fn setup_aggregator() -> Aggregator<Bn254> {
        let srs = Arc::new(Srs::<Bn254>::mock(10));
        Aggregator::builder().with_srs(srs).max_batch_size(4).build().unwrap()
    }

    #[test]
    fn aggregator_starts_empty() {
        let agg = setup_aggregator();
        assert!(agg.queue_is_empty());
        assert_eq!(agg.queue_len(), 0);
    }

    #[test]
    fn aggregator_registers_circuits() {
        let agg = setup_aggregator();

        let vk1 = agg.register_circuit("transfer", mock_vk(5));
        let vk2 = agg.register_circuit("deposit", mock_vk(3));

        assert_ne!(vk1, vk2);
        assert!(agg.registry().contains(vk1));
        assert!(agg.registry().contains(vk2));
    }

    #[test]
    fn aggregator_rejects_unknown_vk() {
        let mut agg = setup_aggregator();

        let unknown_vk = VkId::new(999);
        let proof = Proof::<Bn254, Pending>::new(vec![], vec![], unknown_vk);

        // Create a mock verified proof (bypassing normal verification for test)
        // In real usage, this would go through registry verification
        let verified = unsafe_create_verified_proof(proof);

        let result = agg.submit(verified);
        assert!(matches!(result, Err(Error::UnknownVk(_))));
    }

    #[test]
    fn aggregator_rejects_when_full() {
        let mut agg = setup_aggregator(); // max_batch_size = 4

        let vk = agg.register_circuit("test", mock_vk(1));

        // Fill the queue with unique proofs
        for i in 0..4 {
            let proof =  Proof::<Bn254, Pending>::new(vec![i as u8], vec![], vk);
            let verified = unsafe_create_verified_proof(proof);
            agg.submit(verified).unwrap();
        }

        // Next submission should fail due to batch size limit
        let proof = Proof::<Bn254, Pending>::new(vec![99], vec![], vk);
        let verified = unsafe_create_verified_proof(proof);
        let result = agg.submit(verified);

        assert!(matches!(
            result,
            Err(Error::BatchTooLarge { got: 5, max: 4 })
        ));
    }

    #[test]
    fn aggregator_empty_batch_fails() {
        let mut agg = setup_aggregator();

        let result = agg.aggregate();
        assert!(matches!(result, Err(Error::EmptyBatch)));
    }

    #[test]
    fn aggregator_drains_queue_on_aggregate() {
        let mut agg = setup_aggregator();

        let vk = agg.register_circuit("test", mock_vk(1));

        let proof = Proof::<Bn254, Pending>::new(vec![], vec![], vk);
        let verified = unsafe_create_verified_proof(proof);
        agg.submit(verified).unwrap();

        assert_eq!(agg.queue_len(), 1);

        let _aggregated = agg.aggregate().unwrap();

        assert!(agg.queue_is_empty());
    }

    // Helper for tests - bypasses normal verification
    fn unsafe_create_verified_proof<E: PairingEngine>(
        proof: Proof<E, Pending>,
    ) -> Proof<E, Verified> {
        // This is for testing only - creates a verified proof without going through registry
        Proof::new_verified(
            proof.data().to_vec(),
            proof.public_inputs().to_vec(),
            proof.vk_id(),
        )
    }
}