Skip to main content

webgpu_groth16/
prover.rs

1//! Groth16 proof construction using GPU-accelerated MSM and NTT.
2//!
3//! The main entry point is [`create_proof`].
4//!
5//! Proof construction flow:
6//! 1. Circuit synthesis → constraint system (A, B, C linear combinations)
7//! 2. Witness evaluation → dense A/B/C coefficient vectors
8//! 3. H-polynomial: `H(x) = (A(x)·B(x) − C(x)) / Z(x)` via GPU NTT pipeline
9//! 4. Five MSMs dispatched to GPU: `a` (G1), `b1` (G1), `l` (G1), `h` (G1),
10//!    `b2` (G2)
11//! 5. CPU-side proof assembly with random blinding factors r, s
12
13mod constraint_system;
14pub(crate) mod density_masks;
15mod gpu_key;
16mod h_poly;
17mod msm;
18mod prepared_key;
19
20use anyhow::Result;
21use ff::{Field, PrimeField};
22use rand_core::RngCore;
23
24use self::constraint_system::GpuConstraintSystem;
25use self::density_masks::dense_assignment_from_masks;
26pub use self::gpu_key::{GpuProvingKey, prepare_gpu_proving_key};
27pub use self::h_poly::compute_h_poly;
28use self::h_poly::{read_h_poly_result, submit_h_poly};
29use self::msm::{MsmBases, enqueue_msm, readback_msms};
30pub use self::msm::{gpu_msm_batch, gpu_msm_g1};
31pub use self::prepared_key::{PreparedProvingKey, prepare_proving_key};
32use crate::bellman;
33use crate::bucket::{
34    compute_bucket_sorting_with_width, compute_glv_bucket_data,
35    compute_glv_bucket_sorting, optimal_glv_c,
36};
37use crate::gpu::GpuContext;
38use crate::gpu::curve::GpuCurve;
39
40/// Proving key required to create a new Groth16 proof with [`create_proof`].
41#[derive(Copy, Clone)]
42pub enum ProvingKey<'key, G: GpuCurve> {
43    /// Use a key that has already been uploaded to the GPU.
44    Uploaded(&'key GpuProvingKey<G>),
45    /// Use a key that has yet to be uploaded to the GPU.
46    Serialized(&'key PreparedProvingKey<G>),
47}
48
49fn marshal_scalars<G: GpuCurve>(scalars: &[G::Scalar]) -> Vec<u8> {
50    let mut buffer = Vec::with_capacity(scalars.len() * 32);
51    for s in scalars {
52        buffer.extend_from_slice(&G::serialize_scalar(s));
53    }
54    buffer
55}
56
57fn eval_lc<S: PrimeField>(
58    lc: &[(bellman::Variable, S)],
59    inputs: &[S],
60    aux: &[S],
61) -> S {
62    let mut res = S::ZERO;
63    for &(var, coeff) in lc {
64        let val = match var.get_unchecked() {
65            bellman::Index::Input(i) => inputs[i],
66            bellman::Index::Aux(i) => aux[i],
67        };
68        let mut term = val;
69        term.mul_assign(&coeff);
70        res.add_assign(&term);
71    }
72    res
73}
74
75/// Core proof construction with fixed randomness (deterministic for testing).
76///
77/// Orchestrates the full Groth16 proving pipeline:
78/// 1. Synthesize the circuit into a constraint system
79/// 2. Submit H-polynomial computation to GPU (non-blocking)
80/// 3. Compute GLV bucket sorting on CPU (overlapped with GPU H-poly work)
81/// 4. Enqueue 5 MSMs (a, b1, l, b2, then h after H-poly completes)
82/// 5. Read back MSM results and assemble the final proof (A, B, C)
83async fn create_proof_with_fixed_randomness<E, G, C>(
84    circuit: C,
85    pk: ProvingKey<'_, G>,
86    gpu: &GpuContext<G>,
87    r: G::Scalar,
88    s: G::Scalar,
89) -> Result<bellman::groth16::Proof<E>>
90where
91    E: pairing::MultiMillerLoop,
92    C: bellman::Circuit<G::Scalar>,
93    G: GpuCurve<
94            Engine = E,
95            Scalar = E::Fr,
96            G1 = E::G1,
97            G2 = E::G2,
98            G1Affine = E::G1Affine,
99            G2Affine = E::G2Affine,
100        > + Send,
101{
102    #[cfg(feature = "timing")]
103    let t_phase = std::time::Instant::now();
104    let mut cs = GpuConstraintSystem::<G>::new();
105    circuit
106        .synthesize(&mut cs)
107        .map_err(|e| anyhow::anyhow!("circuit synthesis failed: {:?}", e))?;
108
109    // Append input constraints: for each public input i, add the constraint
110    // (input[i]) · (1) = (0), which encodes the public input identity.
111    for i in 0..cs.inputs.len() {
112        cs.a_lcs.push(vec![(
113            bellman::Variable::new_unchecked(bellman::Index::Input(i)),
114            G::Scalar::ONE,
115        )]);
116        cs.b_lcs.push(Vec::new());
117        cs.c_lcs.push(Vec::new());
118    }
119
120    let num_constraints = cs.a_lcs.len();
121    let n = num_constraints.next_power_of_two();
122    #[cfg(feature = "timing")]
123    eprintln!(
124        "[proof] synthesis: {:?} (constraints={num_constraints}, n={n}, \
125         inputs={}, aux={})",
126        t_phase.elapsed(),
127        cs.inputs.len(),
128        cs.aux.len()
129    );
130
131    // Evaluate all linear combinations at the witness to get A, B, C vectors.
132    #[cfg(feature = "timing")]
133    let t_phase = std::time::Instant::now();
134    let mut a_values = vec![G::Scalar::ZERO; n];
135    let mut b_values = vec![G::Scalar::ZERO; n];
136    let mut c_values = vec![G::Scalar::ZERO; n];
137
138    for i in 0..num_constraints {
139        a_values[i] = eval_lc(&cs.a_lcs[i], &cs.inputs, &cs.aux);
140        b_values[i] = eval_lc(&cs.b_lcs[i], &cs.inputs, &cs.aux);
141        c_values[i] = eval_lc(&cs.c_lcs[i], &cs.inputs, &cs.aux);
142    }
143    #[cfg(feature = "timing")]
144    eprintln!("[proof] eval_lc: {:?}", t_phase.elapsed());
145
146    // Build dense assignments using density masks before submitting H poly,
147    // so we can pre-compute GLV bucket data on CPU while GPU processes H.
148    #[cfg(feature = "timing")]
149    let t_phase = std::time::Instant::now();
150    let mut a_assignment = cs.inputs.clone();
151    for (i, v) in cs.aux.iter().enumerate() {
152        if cs.a_aux_density.is_set(i) {
153            a_assignment.push(*v);
154        }
155    }
156    let b_assignment = dense_assignment_from_masks(
157        &cs.inputs,
158        &cs.aux,
159        &cs.b_input_density,
160        &cs.b_aux_density,
161    );
162    #[cfg(feature = "timing")]
163    eprintln!(
164        "[proof] assignments: {:?} (a_assign={}, b_assign={})",
165        t_phase.elapsed(),
166        a_assignment.len(),
167        b_assignment.len()
168    );
169
170    #[cfg(feature = "timing")]
171    let t_phase = std::time::Instant::now();
172    // Submit H polynomial to GPU (non-blocking — GPU processes asynchronously).
173    let h_pending = submit_h_poly::<G>(gpu, &a_values, &b_values, &c_values)?;
174    #[cfg(feature = "timing")]
175    eprintln!("[proof] h_poly submit: {:?}", t_phase.elapsed());
176
177    // Pre-compute GLV bucket data for non-H G1 MSMs while GPU computes H.
178    // GLV decomposes each scalar k into k1·P + k2·φ(P) with ~128-bit
179    // sub-scalars, halving the number of Pippenger windows.
180    #[cfg(feature = "timing")]
181    let t_phase = std::time::Instant::now();
182    // Adaptive bucket width: choose per-MSM c based on point count.
183    let a_c = optimal_glv_c::<G>(a_assignment.len());
184    let b1_c = optimal_glv_c::<G>(b_assignment.len());
185    let l_c = optimal_glv_c::<G>(cs.aux.len());
186
187    // Bucket sorting: with persistent GPU key, GLV negation is folded into sign
188    // bits and no combined bases buffer is built. Without it, the original
189    // path is used.
190    let a_bd;
191    let b1_bd;
192    let l_bd;
193    let b2_bd;
194    // Only needed for the non-persistent path:
195    let a_glv_bytes;
196    let b1_glv_bytes;
197    let l_glv_bytes;
198    match pk {
199        ProvingKey::Uploaded(_) => {
200            a_bd = compute_glv_bucket_data::<G>(&a_assignment, a_c);
201            b1_bd = compute_glv_bucket_data::<G>(&b_assignment, b1_c);
202            l_bd = compute_glv_bucket_data::<G>(&cs.aux, l_c);
203            b2_bd = compute_bucket_sorting_with_width::<G>(
204                &b_assignment,
205                G::g2_bucket_width(),
206            );
207            a_glv_bytes = Vec::new();
208            b1_glv_bytes = Vec::new();
209            l_glv_bytes = Vec::new();
210        }
211        ProvingKey::Serialized(ppk) => {
212            let (a_bytes, a_bd_tmp) = compute_glv_bucket_sorting::<G>(
213                &a_assignment,
214                &ppk.a_bytes,
215                ppk.a_phi_bytes.as_deref().unwrap_or(&[]),
216                a_c,
217            );
218            let (b1_bytes, b1_bd_tmp) = compute_glv_bucket_sorting::<G>(
219                &b_assignment,
220                &ppk.b_g1_bytes,
221                ppk.b_g1_phi_bytes.as_deref().unwrap_or(&[]),
222                b1_c,
223            );
224            let (l_bytes, l_bd_tmp) = compute_glv_bucket_sorting::<G>(
225                &cs.aux,
226                &ppk.l_bytes,
227                ppk.l_phi_bytes.as_deref().unwrap_or(&[]),
228                l_c,
229            );
230            a_bd = a_bd_tmp;
231            b1_bd = b1_bd_tmp;
232            l_bd = l_bd_tmp;
233            b2_bd = compute_bucket_sorting_with_width::<G>(
234                &b_assignment,
235                G::g2_bucket_width(),
236            );
237            a_glv_bytes = a_bytes;
238            b1_glv_bytes = b1_bytes;
239            l_glv_bytes = l_bytes;
240        }
241    }
242
243    #[cfg(feature = "timing")]
244    {
245        eprintln!(
246            "[proof] bucket sorting (4x GLV): {:?} (c: a={}, b1={}, l={})",
247            t_phase.elapsed(),
248            a_c,
249            b1_c,
250            l_c
251        );
252        a_bd.print_distribution_stats("a_g1_glv");
253        b1_bd.print_distribution_stats("b1_g1_glv");
254        l_bd.print_distribution_stats("l_g1_glv");
255        b2_bd.print_distribution_stats("b2_g2");
256    }
257
258    #[cfg(feature = "timing")]
259    let t_phase = std::time::Instant::now();
260    // Await H result (GPU likely already done by now).
261    let h_coeffs = read_h_poly_result::<G>(gpu, h_pending).await?;
262    #[cfg(feature = "timing")]
263    eprintln!("[proof] h_poly read: {:?}", t_phase.elapsed());
264
265    // Enqueue a/b1/l/b2 MSMs right after h_poly completes — GPU starts
266    // processing them immediately while CPU computes h bucket sorting
267    // below.
268    #[cfg(feature = "timing")]
269    let t_phase = std::time::Instant::now();
270    let (a_job, b1_job, l_job, b2_job);
271    match pk {
272        ProvingKey::Uploaded(gpk) => {
273            a_job = enqueue_msm::<G>(
274                gpu,
275                "a",
276                MsmBases::Persistent(&gpk.a_bases_buf),
277                a_bd,
278                false,
279            )?;
280            b1_job = enqueue_msm::<G>(
281                gpu,
282                "b1",
283                MsmBases::Persistent(&gpk.b_g1_bases_buf),
284                b1_bd,
285                false,
286            )?;
287            l_job = enqueue_msm::<G>(
288                gpu,
289                "l",
290                MsmBases::Persistent(&gpk.l_bases_buf),
291                l_bd,
292                false,
293            )?;
294            b2_job = enqueue_msm::<G>(
295                gpu,
296                "b2",
297                MsmBases::Persistent(&gpk.b_g2_bases_buf),
298                b2_bd,
299                true,
300            )?;
301        }
302        ProvingKey::Serialized(ppk) => {
303            a_job = enqueue_msm::<G>(
304                gpu,
305                "a",
306                MsmBases::Bytes(&a_glv_bytes),
307                a_bd,
308                false,
309            )?;
310            b1_job = enqueue_msm::<G>(
311                gpu,
312                "b1",
313                MsmBases::Bytes(&b1_glv_bytes),
314                b1_bd,
315                false,
316            )?;
317            l_job = enqueue_msm::<G>(
318                gpu,
319                "l",
320                MsmBases::Bytes(&l_glv_bytes),
321                l_bd,
322                false,
323            )?;
324            b2_job = enqueue_msm::<G>(
325                gpu,
326                "b2",
327                MsmBases::Bytes(&ppk.b_g2_bytes),
328                b2_bd,
329                true,
330            )?;
331        }
332    }
333    #[cfg(feature = "timing")]
334    eprintln!("[proof] msm enqueue a/b1/l/b2: {:?}", t_phase.elapsed());
335
336    // H bucket data depends on h_coeffs — also uses GLV.
337    // While CPU computes this, GPU is already processing a/b1/l/b2 MSMs.
338    #[cfg(feature = "timing")]
339    let t_phase = std::time::Instant::now();
340    let h_job = match pk {
341        ProvingKey::Uploaded(gpu_pk) => {
342            let h_c = optimal_glv_c::<G>(gpu_pk.h_len);
343            let h_bd =
344                compute_glv_bucket_data::<G>(&h_coeffs[..gpu_pk.h_len], h_c);
345            #[cfg(feature = "timing")]
346            {
347                eprintln!(
348                    "[proof] h bucket sorting (GLV): {:?} (c={})",
349                    t_phase.elapsed(),
350                    h_c
351                );
352                h_bd.print_distribution_stats("h_g1_glv");
353            }
354            #[cfg(feature = "timing")]
355            let t_phase = std::time::Instant::now();
356            let h_job = enqueue_msm::<G>(
357                gpu,
358                "h",
359                MsmBases::Persistent(&gpu_pk.h_bases_buf),
360                h_bd,
361                false,
362            )?;
363            #[cfg(feature = "timing")]
364            eprintln!("[proof] msm enqueue h: {:?}", t_phase.elapsed());
365            h_job
366        }
367        ProvingKey::Serialized(ppk) => {
368            let h_c = optimal_glv_c::<G>(ppk.h_len);
369            let (h_glv_bytes, h_bd) = compute_glv_bucket_sorting::<G>(
370                &h_coeffs[..ppk.h_len],
371                &ppk.h_bytes,
372                ppk.h_phi_bytes.as_deref().unwrap_or(&[]),
373                h_c,
374            );
375            #[cfg(feature = "timing")]
376            {
377                eprintln!(
378                    "[proof] h bucket sorting (GLV): {:?} (c={})",
379                    t_phase.elapsed(),
380                    h_c
381                );
382                h_bd.print_distribution_stats("h_g1_glv");
383            }
384            #[cfg(feature = "timing")]
385            let t_phase = std::time::Instant::now();
386            let h_job = enqueue_msm::<G>(
387                gpu,
388                "h",
389                MsmBases::Bytes(&h_glv_bytes),
390                h_bd,
391                false,
392            )?;
393            #[cfg(feature = "timing")]
394            eprintln!("[proof] msm enqueue h: {:?}", t_phase.elapsed());
395            h_job
396        }
397    };
398
399    #[cfg(feature = "timing")]
400    let t_phase = std::time::Instant::now();
401    let (a_msm, b_g1_msm, l_msm, h_msm, b_g2_msm) =
402        readback_msms::<G>(gpu, a_job, b1_job, l_job, h_job, b2_job).await?;
403    #[cfg(feature = "timing")]
404    eprintln!("[proof] msm readback: {:?}", t_phase.elapsed());
405
406    // Assemble the final Groth16 proof from MSM results and random blinding
407    // factors.
408    //
409    // Groth16 proof elements:
410    //   A = α + Σᵢ aᵢ·Aᵢ + r·δ
411    //   B = β + Σᵢ bᵢ·Bᵢ + s·δ        (in G2)
412    //   C = Σᵢ (aᵢsᵢ)·Lᵢ + h(x)·H + s·A + r·B_G1 − r·s·δ
413    #[cfg(feature = "timing")]
414    let t_phase = std::time::Instant::now();
415
416    let (alpha_g1, beta_g1, beta_g2, delta_g1, delta_g2) = match pk {
417        ProvingKey::Uploaded(k) => (
418            &k.alpha_g1,
419            &k.beta_g1,
420            &k.beta_g2,
421            &k.delta_g1,
422            &k.delta_g2,
423        ),
424        ProvingKey::Serialized(k) => (
425            &k.alpha_g1,
426            &k.beta_g1,
427            &k.beta_g2,
428            &k.delta_g1,
429            &k.delta_g2,
430        ),
431    };
432
433    // A = α + a_msm + r·δ
434    let mut proof_a = G::add_g1_proj(&G::affine_to_proj_g1(alpha_g1), &a_msm);
435    proof_a = G::add_g1_proj(&proof_a, &G::mul_g1_scalar(delta_g1, &r));
436
437    // B = β + b_g2_msm + s·δ   (in G2)
438    let mut proof_b = G::add_g2_proj(&G::affine_to_proj_g2(beta_g2), &b_g2_msm);
439    proof_b = G::add_g2_proj(&proof_b, &G::mul_g2_scalar(delta_g2, &s));
440
441    // C = l_msm + h_msm + s·A + r·(β + b_g1_msm + s·δ_G1) − r·s·δ
442    let mut proof_c = G::add_g1_proj(&l_msm, &h_msm);
443    let mut b_g1 = G::add_g1_proj(&G::affine_to_proj_g1(beta_g1), &b_g1_msm);
444    b_g1 = G::add_g1_proj(&b_g1, &G::mul_g1_scalar(delta_g1, &s));
445
446    let c_shift_a = G::mul_g1_proj_scalar(&proof_a, &s);
447    proof_c = G::add_g1_proj(&proof_c, &c_shift_a);
448
449    let c_shift_b = G::mul_g1_proj_scalar(&b_g1, &r);
450    proof_c = G::add_g1_proj(&proof_c, &c_shift_b);
451
452    let mut rs = r;
453    rs *= s;
454    let rs_delta = G::mul_g1_scalar(delta_g1, &rs);
455    proof_c = G::sub_g1_proj(&proof_c, &rs_delta);
456    #[cfg(feature = "timing")]
457    eprintln!("[proof] final assembly: {:?}", t_phase.elapsed());
458
459    Ok(bellman::groth16::Proof {
460        a: G::proj_to_affine_g1(&proof_a),
461        b: G::proj_to_affine_g2(&proof_b),
462        c: G::proj_to_affine_g1(&proof_c),
463    })
464}
465
466/// Create a new Groth16 proof.
467///
468/// Uses a [`GpuProvingKey`] to skip per-proof base uploads and Montgomery
469/// conversion, reusing pre-uploaded GPU buffers across proofs.
470pub async fn create_proof<E, G, C, R>(
471    circuit: C,
472    pk: ProvingKey<'_, G>,
473    gpu: &GpuContext<G>,
474    rng: &mut R,
475) -> Result<bellman::groth16::Proof<E>>
476where
477    E: pairing::MultiMillerLoop,
478    C: bellman::Circuit<G::Scalar>,
479    G: GpuCurve<
480            Engine = E,
481            Scalar = E::Fr,
482            G1 = E::G1,
483            G2 = E::G2,
484            G1Affine = E::G1Affine,
485            G2Affine = E::G2Affine,
486        > + Send,
487    R: RngCore,
488{
489    let r = G::Scalar::random(&mut *rng);
490    let s = G::Scalar::random(&mut *rng);
491
492    create_proof_with_fixed_randomness::<E, G, C>(circuit, pk, gpu, r, s).await
493}
494
495#[cfg(test)]
496mod tests;