1mod constraint_system;
14pub(crate) mod density_masks;
15mod gpu_key;
16mod h_poly;
17mod msm;
18mod prepared_key;
19
20use anyhow::Result;
21use ff::{Field, PrimeField};
22use rand_core::RngCore;
23
24use self::constraint_system::GpuConstraintSystem;
25use self::density_masks::dense_assignment_from_masks;
26pub use self::gpu_key::{GpuProvingKey, prepare_gpu_proving_key};
27pub use self::h_poly::compute_h_poly;
28use self::h_poly::{read_h_poly_result, submit_h_poly};
29use self::msm::{MsmBases, enqueue_msm, readback_msms};
30pub use self::msm::{gpu_msm_batch, gpu_msm_g1};
31pub use self::prepared_key::{PreparedProvingKey, prepare_proving_key};
32use crate::bellman;
33use crate::bucket::{
34 compute_bucket_sorting_with_width, compute_glv_bucket_data,
35 compute_glv_bucket_sorting, optimal_glv_c,
36};
37use crate::gpu::GpuContext;
38use crate::gpu::curve::GpuCurve;
39
40#[derive(Copy, Clone)]
42pub enum ProvingKey<'key, G: GpuCurve> {
43 Uploaded(&'key GpuProvingKey<G>),
45 Serialized(&'key PreparedProvingKey<G>),
47}
48
49fn marshal_scalars<G: GpuCurve>(scalars: &[G::Scalar]) -> Vec<u8> {
50 let mut buffer = Vec::with_capacity(scalars.len() * 32);
51 for s in scalars {
52 buffer.extend_from_slice(&G::serialize_scalar(s));
53 }
54 buffer
55}
56
57fn eval_lc<S: PrimeField>(
58 lc: &[(bellman::Variable, S)],
59 inputs: &[S],
60 aux: &[S],
61) -> S {
62 let mut res = S::ZERO;
63 for &(var, coeff) in lc {
64 let val = match var.get_unchecked() {
65 bellman::Index::Input(i) => inputs[i],
66 bellman::Index::Aux(i) => aux[i],
67 };
68 let mut term = val;
69 term.mul_assign(&coeff);
70 res.add_assign(&term);
71 }
72 res
73}
74
75async fn create_proof_with_fixed_randomness<E, G, C>(
84 circuit: C,
85 pk: ProvingKey<'_, G>,
86 gpu: &GpuContext<G>,
87 r: G::Scalar,
88 s: G::Scalar,
89) -> Result<bellman::groth16::Proof<E>>
90where
91 E: pairing::MultiMillerLoop,
92 C: bellman::Circuit<G::Scalar>,
93 G: GpuCurve<
94 Engine = E,
95 Scalar = E::Fr,
96 G1 = E::G1,
97 G2 = E::G2,
98 G1Affine = E::G1Affine,
99 G2Affine = E::G2Affine,
100 > + Send,
101{
102 #[cfg(feature = "timing")]
103 let t_phase = std::time::Instant::now();
104 let mut cs = GpuConstraintSystem::<G>::new();
105 circuit
106 .synthesize(&mut cs)
107 .map_err(|e| anyhow::anyhow!("circuit synthesis failed: {:?}", e))?;
108
109 for i in 0..cs.inputs.len() {
112 cs.a_lcs.push(vec![(
113 bellman::Variable::new_unchecked(bellman::Index::Input(i)),
114 G::Scalar::ONE,
115 )]);
116 cs.b_lcs.push(Vec::new());
117 cs.c_lcs.push(Vec::new());
118 }
119
120 let num_constraints = cs.a_lcs.len();
121 let n = num_constraints.next_power_of_two();
122 #[cfg(feature = "timing")]
123 eprintln!(
124 "[proof] synthesis: {:?} (constraints={num_constraints}, n={n}, \
125 inputs={}, aux={})",
126 t_phase.elapsed(),
127 cs.inputs.len(),
128 cs.aux.len()
129 );
130
131 #[cfg(feature = "timing")]
133 let t_phase = std::time::Instant::now();
134 let mut a_values = vec![G::Scalar::ZERO; n];
135 let mut b_values = vec![G::Scalar::ZERO; n];
136 let mut c_values = vec![G::Scalar::ZERO; n];
137
138 for i in 0..num_constraints {
139 a_values[i] = eval_lc(&cs.a_lcs[i], &cs.inputs, &cs.aux);
140 b_values[i] = eval_lc(&cs.b_lcs[i], &cs.inputs, &cs.aux);
141 c_values[i] = eval_lc(&cs.c_lcs[i], &cs.inputs, &cs.aux);
142 }
143 #[cfg(feature = "timing")]
144 eprintln!("[proof] eval_lc: {:?}", t_phase.elapsed());
145
146 #[cfg(feature = "timing")]
149 let t_phase = std::time::Instant::now();
150 let mut a_assignment = cs.inputs.clone();
151 for (i, v) in cs.aux.iter().enumerate() {
152 if cs.a_aux_density.is_set(i) {
153 a_assignment.push(*v);
154 }
155 }
156 let b_assignment = dense_assignment_from_masks(
157 &cs.inputs,
158 &cs.aux,
159 &cs.b_input_density,
160 &cs.b_aux_density,
161 );
162 #[cfg(feature = "timing")]
163 eprintln!(
164 "[proof] assignments: {:?} (a_assign={}, b_assign={})",
165 t_phase.elapsed(),
166 a_assignment.len(),
167 b_assignment.len()
168 );
169
170 #[cfg(feature = "timing")]
171 let t_phase = std::time::Instant::now();
172 let h_pending = submit_h_poly::<G>(gpu, &a_values, &b_values, &c_values)?;
174 #[cfg(feature = "timing")]
175 eprintln!("[proof] h_poly submit: {:?}", t_phase.elapsed());
176
177 #[cfg(feature = "timing")]
181 let t_phase = std::time::Instant::now();
182 let a_c = optimal_glv_c::<G>(a_assignment.len());
184 let b1_c = optimal_glv_c::<G>(b_assignment.len());
185 let l_c = optimal_glv_c::<G>(cs.aux.len());
186
187 let a_bd;
191 let b1_bd;
192 let l_bd;
193 let b2_bd;
194 let a_glv_bytes;
196 let b1_glv_bytes;
197 let l_glv_bytes;
198 match pk {
199 ProvingKey::Uploaded(_) => {
200 a_bd = compute_glv_bucket_data::<G>(&a_assignment, a_c);
201 b1_bd = compute_glv_bucket_data::<G>(&b_assignment, b1_c);
202 l_bd = compute_glv_bucket_data::<G>(&cs.aux, l_c);
203 b2_bd = compute_bucket_sorting_with_width::<G>(
204 &b_assignment,
205 G::g2_bucket_width(),
206 );
207 a_glv_bytes = Vec::new();
208 b1_glv_bytes = Vec::new();
209 l_glv_bytes = Vec::new();
210 }
211 ProvingKey::Serialized(ppk) => {
212 let (a_bytes, a_bd_tmp) = compute_glv_bucket_sorting::<G>(
213 &a_assignment,
214 &ppk.a_bytes,
215 ppk.a_phi_bytes.as_deref().unwrap_or(&[]),
216 a_c,
217 );
218 let (b1_bytes, b1_bd_tmp) = compute_glv_bucket_sorting::<G>(
219 &b_assignment,
220 &ppk.b_g1_bytes,
221 ppk.b_g1_phi_bytes.as_deref().unwrap_or(&[]),
222 b1_c,
223 );
224 let (l_bytes, l_bd_tmp) = compute_glv_bucket_sorting::<G>(
225 &cs.aux,
226 &ppk.l_bytes,
227 ppk.l_phi_bytes.as_deref().unwrap_or(&[]),
228 l_c,
229 );
230 a_bd = a_bd_tmp;
231 b1_bd = b1_bd_tmp;
232 l_bd = l_bd_tmp;
233 b2_bd = compute_bucket_sorting_with_width::<G>(
234 &b_assignment,
235 G::g2_bucket_width(),
236 );
237 a_glv_bytes = a_bytes;
238 b1_glv_bytes = b1_bytes;
239 l_glv_bytes = l_bytes;
240 }
241 }
242
243 #[cfg(feature = "timing")]
244 {
245 eprintln!(
246 "[proof] bucket sorting (4x GLV): {:?} (c: a={}, b1={}, l={})",
247 t_phase.elapsed(),
248 a_c,
249 b1_c,
250 l_c
251 );
252 a_bd.print_distribution_stats("a_g1_glv");
253 b1_bd.print_distribution_stats("b1_g1_glv");
254 l_bd.print_distribution_stats("l_g1_glv");
255 b2_bd.print_distribution_stats("b2_g2");
256 }
257
258 #[cfg(feature = "timing")]
259 let t_phase = std::time::Instant::now();
260 let h_coeffs = read_h_poly_result::<G>(gpu, h_pending).await?;
262 #[cfg(feature = "timing")]
263 eprintln!("[proof] h_poly read: {:?}", t_phase.elapsed());
264
265 #[cfg(feature = "timing")]
269 let t_phase = std::time::Instant::now();
270 let (a_job, b1_job, l_job, b2_job);
271 match pk {
272 ProvingKey::Uploaded(gpk) => {
273 a_job = enqueue_msm::<G>(
274 gpu,
275 "a",
276 MsmBases::Persistent(&gpk.a_bases_buf),
277 a_bd,
278 false,
279 )?;
280 b1_job = enqueue_msm::<G>(
281 gpu,
282 "b1",
283 MsmBases::Persistent(&gpk.b_g1_bases_buf),
284 b1_bd,
285 false,
286 )?;
287 l_job = enqueue_msm::<G>(
288 gpu,
289 "l",
290 MsmBases::Persistent(&gpk.l_bases_buf),
291 l_bd,
292 false,
293 )?;
294 b2_job = enqueue_msm::<G>(
295 gpu,
296 "b2",
297 MsmBases::Persistent(&gpk.b_g2_bases_buf),
298 b2_bd,
299 true,
300 )?;
301 }
302 ProvingKey::Serialized(ppk) => {
303 a_job = enqueue_msm::<G>(
304 gpu,
305 "a",
306 MsmBases::Bytes(&a_glv_bytes),
307 a_bd,
308 false,
309 )?;
310 b1_job = enqueue_msm::<G>(
311 gpu,
312 "b1",
313 MsmBases::Bytes(&b1_glv_bytes),
314 b1_bd,
315 false,
316 )?;
317 l_job = enqueue_msm::<G>(
318 gpu,
319 "l",
320 MsmBases::Bytes(&l_glv_bytes),
321 l_bd,
322 false,
323 )?;
324 b2_job = enqueue_msm::<G>(
325 gpu,
326 "b2",
327 MsmBases::Bytes(&ppk.b_g2_bytes),
328 b2_bd,
329 true,
330 )?;
331 }
332 }
333 #[cfg(feature = "timing")]
334 eprintln!("[proof] msm enqueue a/b1/l/b2: {:?}", t_phase.elapsed());
335
336 #[cfg(feature = "timing")]
339 let t_phase = std::time::Instant::now();
340 let h_job = match pk {
341 ProvingKey::Uploaded(gpu_pk) => {
342 let h_c = optimal_glv_c::<G>(gpu_pk.h_len);
343 let h_bd =
344 compute_glv_bucket_data::<G>(&h_coeffs[..gpu_pk.h_len], h_c);
345 #[cfg(feature = "timing")]
346 {
347 eprintln!(
348 "[proof] h bucket sorting (GLV): {:?} (c={})",
349 t_phase.elapsed(),
350 h_c
351 );
352 h_bd.print_distribution_stats("h_g1_glv");
353 }
354 #[cfg(feature = "timing")]
355 let t_phase = std::time::Instant::now();
356 let h_job = enqueue_msm::<G>(
357 gpu,
358 "h",
359 MsmBases::Persistent(&gpu_pk.h_bases_buf),
360 h_bd,
361 false,
362 )?;
363 #[cfg(feature = "timing")]
364 eprintln!("[proof] msm enqueue h: {:?}", t_phase.elapsed());
365 h_job
366 }
367 ProvingKey::Serialized(ppk) => {
368 let h_c = optimal_glv_c::<G>(ppk.h_len);
369 let (h_glv_bytes, h_bd) = compute_glv_bucket_sorting::<G>(
370 &h_coeffs[..ppk.h_len],
371 &ppk.h_bytes,
372 ppk.h_phi_bytes.as_deref().unwrap_or(&[]),
373 h_c,
374 );
375 #[cfg(feature = "timing")]
376 {
377 eprintln!(
378 "[proof] h bucket sorting (GLV): {:?} (c={})",
379 t_phase.elapsed(),
380 h_c
381 );
382 h_bd.print_distribution_stats("h_g1_glv");
383 }
384 #[cfg(feature = "timing")]
385 let t_phase = std::time::Instant::now();
386 let h_job = enqueue_msm::<G>(
387 gpu,
388 "h",
389 MsmBases::Bytes(&h_glv_bytes),
390 h_bd,
391 false,
392 )?;
393 #[cfg(feature = "timing")]
394 eprintln!("[proof] msm enqueue h: {:?}", t_phase.elapsed());
395 h_job
396 }
397 };
398
399 #[cfg(feature = "timing")]
400 let t_phase = std::time::Instant::now();
401 let (a_msm, b_g1_msm, l_msm, h_msm, b_g2_msm) =
402 readback_msms::<G>(gpu, a_job, b1_job, l_job, h_job, b2_job).await?;
403 #[cfg(feature = "timing")]
404 eprintln!("[proof] msm readback: {:?}", t_phase.elapsed());
405
406 #[cfg(feature = "timing")]
414 let t_phase = std::time::Instant::now();
415
416 let (alpha_g1, beta_g1, beta_g2, delta_g1, delta_g2) = match pk {
417 ProvingKey::Uploaded(k) => (
418 &k.alpha_g1,
419 &k.beta_g1,
420 &k.beta_g2,
421 &k.delta_g1,
422 &k.delta_g2,
423 ),
424 ProvingKey::Serialized(k) => (
425 &k.alpha_g1,
426 &k.beta_g1,
427 &k.beta_g2,
428 &k.delta_g1,
429 &k.delta_g2,
430 ),
431 };
432
433 let mut proof_a = G::add_g1_proj(&G::affine_to_proj_g1(alpha_g1), &a_msm);
435 proof_a = G::add_g1_proj(&proof_a, &G::mul_g1_scalar(delta_g1, &r));
436
437 let mut proof_b = G::add_g2_proj(&G::affine_to_proj_g2(beta_g2), &b_g2_msm);
439 proof_b = G::add_g2_proj(&proof_b, &G::mul_g2_scalar(delta_g2, &s));
440
441 let mut proof_c = G::add_g1_proj(&l_msm, &h_msm);
443 let mut b_g1 = G::add_g1_proj(&G::affine_to_proj_g1(beta_g1), &b_g1_msm);
444 b_g1 = G::add_g1_proj(&b_g1, &G::mul_g1_scalar(delta_g1, &s));
445
446 let c_shift_a = G::mul_g1_proj_scalar(&proof_a, &s);
447 proof_c = G::add_g1_proj(&proof_c, &c_shift_a);
448
449 let c_shift_b = G::mul_g1_proj_scalar(&b_g1, &r);
450 proof_c = G::add_g1_proj(&proof_c, &c_shift_b);
451
452 let mut rs = r;
453 rs *= s;
454 let rs_delta = G::mul_g1_scalar(delta_g1, &rs);
455 proof_c = G::sub_g1_proj(&proof_c, &rs_delta);
456 #[cfg(feature = "timing")]
457 eprintln!("[proof] final assembly: {:?}", t_phase.elapsed());
458
459 Ok(bellman::groth16::Proof {
460 a: G::proj_to_affine_g1(&proof_a),
461 b: G::proj_to_affine_g2(&proof_b),
462 c: G::proj_to_affine_g1(&proof_c),
463 })
464}
465
466pub async fn create_proof<E, G, C, R>(
471 circuit: C,
472 pk: ProvingKey<'_, G>,
473 gpu: &GpuContext<G>,
474 rng: &mut R,
475) -> Result<bellman::groth16::Proof<E>>
476where
477 E: pairing::MultiMillerLoop,
478 C: bellman::Circuit<G::Scalar>,
479 G: GpuCurve<
480 Engine = E,
481 Scalar = E::Fr,
482 G1 = E::G1,
483 G2 = E::G2,
484 G1Affine = E::G1Affine,
485 G2Affine = E::G2Affine,
486 > + Send,
487 R: RngCore,
488{
489 let r = G::Scalar::random(&mut *rng);
490 let s = G::Scalar::random(&mut *rng);
491
492 create_proof_with_fixed_randomness::<E, G, C>(circuit, pk, gpu, r, s).await
493}
494
495#[cfg(test)]
496mod tests;