1mod buffers;
16pub mod curve;
17mod h_poly;
18mod msm;
19mod ntt;
20
21use std::borrow::Cow;
22use std::marker::PhantomData;
23#[cfg(feature = "profiling")]
24use std::sync::Mutex;
25#[cfg(feature = "timing")]
26use std::time::Instant;
27
28use anyhow::Context;
29use wgpu::util::DeviceExt;
30
31use self::curve::GpuCurve;
32
33macro_rules! compute_pass {
41 ($scope:expr, $encoder:expr, $label:expr) => {{
42 #[cfg(feature = "profiling")]
43 let pass = $scope.scoped_compute_pass($label);
44 #[cfg(not(feature = "profiling"))]
45 let pass = $encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
46 label: Some($label),
47 timestamp_writes: None,
48 });
49 pass
50 }};
51}
52pub(crate) use compute_pass;
53
54enum BufKind {
56 ReadOnly,
57 ReadWrite,
58 Uniform,
59}
60
61fn create_bind_group_layout(
63 device: &wgpu::Device,
64 label: &str,
65 bindings: &[BufKind],
66) -> wgpu::BindGroupLayout {
67 let entries: Vec<wgpu::BindGroupLayoutEntry> = bindings
68 .iter()
69 .enumerate()
70 .map(|(i, kind)| wgpu::BindGroupLayoutEntry {
71 binding: i as u32,
72 visibility: wgpu::ShaderStages::COMPUTE,
73 ty: match kind {
74 BufKind::ReadOnly => wgpu::BindingType::Buffer {
75 ty: wgpu::BufferBindingType::Storage { read_only: true },
76 has_dynamic_offset: false,
77 min_binding_size: None,
78 },
79 BufKind::ReadWrite => wgpu::BindingType::Buffer {
80 ty: wgpu::BufferBindingType::Storage { read_only: false },
81 has_dynamic_offset: false,
82 min_binding_size: None,
83 },
84 BufKind::Uniform => wgpu::BindingType::Buffer {
85 ty: wgpu::BufferBindingType::Uniform,
86 has_dynamic_offset: false,
87 min_binding_size: None,
88 },
89 },
90 count: None,
91 })
92 .collect();
93 device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
94 label: Some(label),
95 entries: &entries,
96 })
97}
98
99fn pipeline_layout(
101 device: &wgpu::Device,
102 bind_group_layouts: &[&wgpu::BindGroupLayout],
103) -> wgpu::PipelineLayout {
104 device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
105 label: None,
106 bind_group_layouts,
107 immediate_size: 0,
108 })
109}
110
111fn create_pipeline(
114 device: &wgpu::Device,
115 label: &str,
116 layout: &wgpu::PipelineLayout,
117 module: &wgpu::ShaderModule,
118 entry_point: &str,
119) -> wgpu::ComputePipeline {
120 device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
121 label: Some(label),
122 layout: Some(layout),
123 module,
124 entry_point: Some(entry_point),
125 compilation_options: Default::default(),
126 cache: None,
127 })
128}
129
130pub struct MsmBuffers<'a> {
132 pub bases: &'a wgpu::Buffer,
133 pub base_indices: &'a wgpu::Buffer,
134 pub bucket_pointers: &'a wgpu::Buffer,
135 pub bucket_sizes: &'a wgpu::Buffer,
136 pub aggregated_buckets: &'a wgpu::Buffer,
137 pub bucket_values: &'a wgpu::Buffer,
138 pub window_starts: &'a wgpu::Buffer,
139 pub window_counts: &'a wgpu::Buffer,
140 pub window_sums: &'a wgpu::Buffer,
141 pub reduce_starts: Option<&'a wgpu::Buffer>,
143 pub reduce_counts: Option<&'a wgpu::Buffer>,
144 pub orig_bucket_values: Option<&'a wgpu::Buffer>,
146 pub orig_window_starts: Option<&'a wgpu::Buffer>,
148 pub orig_window_counts: Option<&'a wgpu::Buffer>,
149}
150
151pub struct HPolyBuffers<'a> {
153 pub a: &'a wgpu::Buffer,
154 pub b: &'a wgpu::Buffer,
155 pub c: &'a wgpu::Buffer,
156 pub h: &'a wgpu::Buffer,
157 pub twiddles_inv: &'a wgpu::Buffer,
158 pub twiddles_fwd: &'a wgpu::Buffer,
159 pub shifts: &'a wgpu::Buffer,
160 pub inv_shifts: &'a wgpu::Buffer,
161 pub z_invs: &'a wgpu::Buffer,
162}
163
164pub struct GpuContext<C> {
165 pub device: wgpu::Device,
166 pub queue: wgpu::Queue,
167
168 pub ntt_pipeline: wgpu::ComputePipeline,
170 pub ntt_fused_pipeline: wgpu::ComputePipeline,
171 pub ntt_tile_dit_no_bitreverse_pipeline: wgpu::ComputePipeline,
172 pub ntt_tile_dif_pipeline: wgpu::ComputePipeline,
173 pub ntt_tile_dif_fused_pipeline: wgpu::ComputePipeline,
174 pub ntt_tile_fused_pointwise_pipeline: wgpu::ComputePipeline,
175 pub ntt_global_stage_pipeline: wgpu::ComputePipeline,
176 pub ntt_global_stage_radix4_pipeline: wgpu::ComputePipeline,
177 pub ntt_global_stage_dif_pipeline: wgpu::ComputePipeline,
178 pub ntt_global_stage_dif_fused_pointwise_pipeline: wgpu::ComputePipeline,
179 pub ntt_bitreverse_pipeline: wgpu::ComputePipeline,
180 pub ntt_bitreverse_fused_pointwise_pipeline: wgpu::ComputePipeline,
181 pub coset_shift_pipeline: wgpu::ComputePipeline,
182 pub pointwise_poly_pipeline: wgpu::ComputePipeline,
183 pub to_montgomery_pipeline: wgpu::ComputePipeline,
184 pub from_montgomery_pipeline: wgpu::ComputePipeline,
185
186 pub msm_agg_g1_pipeline: wgpu::ComputePipeline,
188 pub msm_sum_g1_pipeline: wgpu::ComputePipeline,
189 pub msm_agg_g2_pipeline: wgpu::ComputePipeline,
190 pub msm_sum_g2_pipeline: wgpu::ComputePipeline,
191 pub msm_to_mont_g1_pipeline: wgpu::ComputePipeline,
192 pub msm_to_mont_g2_pipeline: wgpu::ComputePipeline,
193 pub msm_weight_g1_pipeline: wgpu::ComputePipeline,
194 pub msm_subsum_phase1_g1_pipeline: wgpu::ComputePipeline,
195 pub msm_subsum_phase2_g1_pipeline: wgpu::ComputePipeline,
196 pub msm_weight_g2_pipeline: wgpu::ComputePipeline,
197 pub msm_subsum_phase1_g2_pipeline: wgpu::ComputePipeline,
198 pub msm_subsum_phase2_g2_pipeline: wgpu::ComputePipeline,
199 pub msm_reduce_g1_pipeline: wgpu::ComputePipeline,
200 pub msm_reduce_g2_pipeline: wgpu::ComputePipeline,
201
202 pub ntt_bind_group_layout: wgpu::BindGroupLayout,
204 pub ntt_fused_shift_bgl: wgpu::BindGroupLayout,
205 pub ntt_params_bind_group_layout: wgpu::BindGroupLayout,
206 pub coset_shift_bind_group_layout: wgpu::BindGroupLayout,
207 pub pointwise_poly_bind_group_layout: wgpu::BindGroupLayout,
208 pub pointwise_fused_bind_group_layout: wgpu::BindGroupLayout,
209 pub montgomery_bind_group_layout: wgpu::BindGroupLayout,
210 pub msm_agg_bind_group_layout: wgpu::BindGroupLayout,
211 pub msm_sum_bind_group_layout: wgpu::BindGroupLayout,
212 pub msm_weight_g1_bind_group_layout: wgpu::BindGroupLayout,
213 pub msm_weight_g2_bind_group_layout: wgpu::BindGroupLayout,
214 pub msm_subsum_phase1_bind_group_layout: wgpu::BindGroupLayout,
215 pub msm_subsum_phase2_bind_group_layout: wgpu::BindGroupLayout,
216 pub msm_reduce_bind_group_layout: wgpu::BindGroupLayout,
217
218 _marker: PhantomData<C>,
219
220 #[cfg(feature = "profiling")]
221 pub profiler: Mutex<wgpu_profiler::GpuProfiler>,
222}
223
224impl<C: GpuCurve> GpuContext<C> {
225 pub async fn new() -> anyhow::Result<Self> {
226 let instance = wgpu::Instance::default();
227
228 let adapter = instance
229 .request_adapter(&wgpu::RequestAdapterOptions {
230 power_preference: wgpu::PowerPreference::HighPerformance,
231 force_fallback_adapter: false,
232 compatible_surface: None,
233 })
234 .await
235 .context("Failed to find a compatible WebGPU adapter")?;
236
237 #[cfg(feature = "profiling")]
238 let required_features = adapter.features()
239 & wgpu_profiler::GpuProfiler::ALL_WGPU_TIMER_FEATURES;
240 #[cfg(not(feature = "profiling"))]
241 let required_features = wgpu::Features::empty();
242
243 let (device, queue) = adapter
244 .request_device(&wgpu::DeviceDescriptor {
245 label: Some("Groth16 Prover Device"),
246 required_limits: adapter.limits(),
249 required_features,
250 ..Default::default()
251 })
252 .await
253 .context("Failed to request WebGPU device")?;
254
255 macro_rules! timed {
256 ($label:expr, $expr:expr) => {{
257 #[cfg(feature = "timing")]
258 let _t = Instant::now();
259 let _r = $expr;
260 #[cfg(feature = "timing")]
261 eprintln!(
262 "[init] {:<30} {:>8.1}ms",
263 $label,
264 _t.elapsed().as_secs_f64() * 1000.0
265 );
266 _r
267 }};
268 }
269
270 #[cfg(feature = "timing")]
271 let init_start = Instant::now();
272
273 #[cfg(feature = "timing")]
275 let shader_start = Instant::now();
276 let ntt_module = timed!(
277 "shader: NTT",
278 device.create_shader_module(wgpu::ShaderModuleDescriptor {
279 label: Some("NTT Shader"),
280 source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(C::NTT_SOURCE)),
281 })
282 );
283
284 let msm_g1_agg_module = timed!(
285 "shader: MSM G1 Agg",
286 device.create_shader_module(wgpu::ShaderModuleDescriptor {
287 label: Some("MSM G1 Agg Shader"),
288 source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(
289 C::MSM_G1_AGG_SOURCE
290 )),
291 })
292 );
293 let msm_g1_subsum_module = timed!(
294 "shader: MSM G1 Subsum",
295 device.create_shader_module(wgpu::ShaderModuleDescriptor {
296 label: Some("MSM G1 Subsum Shader"),
297 source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(
298 C::MSM_G1_SUBSUM_SOURCE
299 )),
300 })
301 );
302 let msm_g2_agg_module = timed!(
303 "shader: MSM G2 Agg",
304 device.create_shader_module(wgpu::ShaderModuleDescriptor {
305 label: Some("MSM G2 Agg Shader"),
306 source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(
307 C::MSM_G2_AGG_SOURCE
308 )),
309 })
310 );
311 let msm_g2_subsum_module = timed!(
312 "shader: MSM G2 Subsum",
313 device.create_shader_module(wgpu::ShaderModuleDescriptor {
314 label: Some("MSM G2 Subsum Shader"),
315 source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(
316 C::MSM_G2_SUBSUM_SOURCE
317 )),
318 })
319 );
320
321 let poly_ops_module = timed!(
322 "shader: Poly Ops",
323 device.create_shader_module(wgpu::ShaderModuleDescriptor {
324 label: Some("Poly Ops Shader"),
325 source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(
326 C::POLY_OPS_SOURCE
327 )),
328 })
329 );
330
331 let ntt_fused_module = timed!(
332 "shader: NTT Fused",
333 device.create_shader_module(wgpu::ShaderModuleDescriptor {
334 label: Some("NTT Fused Shader"),
335 source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(
336 C::NTT_FUSED_SOURCE
337 )),
338 })
339 );
340 #[cfg(feature = "timing")]
341 let shader_total = shader_start.elapsed();
342
343 #[cfg(feature = "timing")]
345 let layouts_start = Instant::now();
346 use BufKind::{ReadOnly as RO, ReadWrite as RW, Uniform as UF};
347
348 let ntt_bind_group_layout =
349 create_bind_group_layout(&device, "NTT", &[RW, RO]);
350 let ntt_fused_shift_bgl =
351 create_bind_group_layout(&device, "NTT Fused Shift", &[RO]);
352 let ntt_params_bind_group_layout =
353 create_bind_group_layout(&device, "NTT Global", &[RW, RO, UF]);
354 let coset_shift_bind_group_layout =
355 create_bind_group_layout(&device, "Coset Shift", &[RW, RO]);
356 let pointwise_poly_bind_group_layout = create_bind_group_layout(
357 &device,
358 "Pointwise Poly",
359 &[RO, RO, RO, RW, RO],
360 );
361 let pointwise_fused_bind_group_layout = create_bind_group_layout(
362 &device,
363 "Pointwise Fused",
364 &[RO, RO, RO, RO],
365 );
366 let montgomery_bind_group_layout =
367 create_bind_group_layout(&device, "Montgomery", &[RW]);
368 let msm_agg_bind_group_layout = create_bind_group_layout(
369 &device,
370 "MSM Agg",
371 &[RO, RO, RO, RO, RW, RO],
372 );
373 let msm_sum_bind_group_layout =
374 create_bind_group_layout(&device, "MSM Sum", &[RO, RO, RO, RO, RW]);
375 let msm_weight_g1_bind_group_layout =
377 create_bind_group_layout(&device, "MSM Weight G1", &[RW, RO]);
378 let msm_weight_g2_bind_group_layout =
379 create_bind_group_layout(&device, "MSM Weight G2", &[RW, RO]);
380 let msm_subsum_phase1_bind_group_layout = create_bind_group_layout(
383 &device,
384 "MSM Subsum Phase1",
385 &[RO, RO, RO, RW, UF],
386 );
387 let msm_subsum_phase2_bind_group_layout = create_bind_group_layout(
389 &device,
390 "MSM Subsum Phase2",
391 &[RO, RW, UF],
392 );
393 let msm_reduce_bind_group_layout =
396 create_bind_group_layout(&device, "MSM Reduce", &[RO, RO, RO, RW]);
397
398 #[cfg(feature = "timing")]
399 let layouts_total = layouts_start.elapsed();
400 #[cfg(feature = "timing")]
401 eprintln!(
402 "[init] {:<30} {:>8.1}ms",
403 "bind group layouts (total)",
404 layouts_total.as_secs_f64() * 1000.0
405 );
406
407 #[cfg(feature = "timing")]
409 let pipelines_start = Instant::now();
410
411 let ntt_tile_layout =
413 pipeline_layout(&device, &[&ntt_bind_group_layout]);
414 let ntt_global_layout =
415 pipeline_layout(&device, &[&ntt_params_bind_group_layout]);
416 let ntt_pipeline = timed!(
417 "pipeline: NTT Tile",
418 create_pipeline(
419 &device,
420 "NTT Tile",
421 &ntt_tile_layout,
422 &ntt_module,
423 "ntt_tile"
424 )
425 );
426 let ntt_fused_layout = pipeline_layout(
427 &device,
428 &[&ntt_bind_group_layout, &ntt_fused_shift_bgl],
429 );
430 let ntt_fused_pipeline = timed!(
431 "pipeline: NTT Fused",
432 create_pipeline(
433 &device,
434 "NTT Fused",
435 &ntt_fused_layout,
436 &ntt_fused_module,
437 "ntt_tile_with_shift"
438 )
439 );
440 let ntt_tile_dif_pipeline = timed!(
441 "pipeline: NTT Tile DIF",
442 create_pipeline(
443 &device,
444 "NTT Tile DIF",
445 &ntt_fused_layout,
446 &ntt_fused_module,
447 "ntt_tile_dif"
448 )
449 );
450 let ntt_tile_dit_no_bitreverse_pipeline = timed!(
451 "pipeline: NTT Tile DIT NoRev",
452 create_pipeline(
453 &device,
454 "NTT Tile DIT NoRev",
455 &ntt_fused_layout,
456 &ntt_fused_module,
457 "ntt_tile_dit_no_bitreverse"
458 )
459 );
460 let ntt_tile_dif_fused_pipeline = timed!(
461 "pipeline: NTT Tile DIF Fused",
462 create_pipeline(
463 &device,
464 "NTT Tile DIF Fused",
465 &ntt_fused_layout,
466 &ntt_fused_module,
467 "ntt_tile_dif_with_shift"
468 )
469 );
470 let ntt_fused_pointwise_layout = pipeline_layout(
471 &device,
472 &[
473 &ntt_bind_group_layout,
474 &ntt_fused_shift_bgl,
475 &pointwise_fused_bind_group_layout,
476 ],
477 );
478 let ntt_tile_fused_pointwise_pipeline = timed!(
479 "pipeline: NTT Tile Fused Pointwise",
480 create_pipeline(
481 &device,
482 "NTT Tile Fused Pointwise",
483 &ntt_fused_pointwise_layout,
484 &ntt_fused_module,
485 "ntt_tile_fused_pointwise"
486 )
487 );
488 let ntt_global_stage_pipeline = timed!(
489 "pipeline: NTT Global Stage",
490 create_pipeline(
491 &device,
492 "NTT Global Stage",
493 &ntt_global_layout,
494 &ntt_module,
495 "ntt_global_stage"
496 )
497 );
498 let ntt_global_stage_radix4_pipeline = timed!(
499 "pipeline: NTT Global R4",
500 create_pipeline(
501 &device,
502 "NTT Global Stage Radix4",
503 &ntt_global_layout,
504 &ntt_module,
505 "ntt_global_stage_radix4"
506 )
507 );
508 let ntt_global_stage_dif_pipeline = timed!(
509 "pipeline: NTT Global DIF",
510 create_pipeline(
511 &device,
512 "NTT Global Stage DIF",
513 &ntt_global_layout,
514 &ntt_module,
515 "ntt_global_stage_dif"
516 )
517 );
518 let ntt_global_stage_dif_fused_layout = pipeline_layout(
519 &device,
520 &[
521 &ntt_params_bind_group_layout,
522 &ntt_fused_shift_bgl,
523 &pointwise_fused_bind_group_layout,
524 ],
525 );
526 let ntt_global_stage_dif_fused_pointwise_pipeline = timed!(
527 "pipeline: NTT Global DIF Fused",
528 create_pipeline(
529 &device,
530 "NTT Global DIF Fused Pointwise",
531 &ntt_global_stage_dif_fused_layout,
532 &ntt_module,
533 "ntt_global_stage_dif_fused_pointwise"
534 )
535 );
536 let ntt_bitreverse_pipeline = timed!(
537 "pipeline: NTT BitReverse",
538 create_pipeline(
539 &device,
540 "NTT BitReverse",
541 &ntt_global_layout,
542 &ntt_module,
543 "bitreverse_inplace"
544 )
545 );
546 let ntt_bitreverse_fused_layout = pipeline_layout(
547 &device,
548 &[
549 &ntt_params_bind_group_layout,
550 &ntt_fused_shift_bgl,
551 &pointwise_fused_bind_group_layout,
552 ],
553 );
554 let ntt_bitreverse_fused_pointwise_pipeline = timed!(
555 "pipeline: NTT BitReverse Fused",
556 create_pipeline(
557 &device,
558 "NTT BitReverse Fused Pointwise",
559 &ntt_bitreverse_fused_layout,
560 &ntt_module,
561 "bitreverse_fused_pointwise"
562 )
563 );
564
565 let coset_shift_layout =
567 pipeline_layout(&device, &[&coset_shift_bind_group_layout]);
568 let pointwise_layout =
569 pipeline_layout(&device, &[&pointwise_poly_bind_group_layout]);
570 let montgomery_layout =
571 pipeline_layout(&device, &[&montgomery_bind_group_layout]);
572 let coset_shift_pipeline = timed!(
573 "pipeline: Coset Shift",
574 create_pipeline(
575 &device,
576 "Coset Shift",
577 &coset_shift_layout,
578 &poly_ops_module,
579 "coset_shift"
580 )
581 );
582 let pointwise_poly_pipeline = timed!(
583 "pipeline: Pointwise Poly",
584 create_pipeline(
585 &device,
586 "Pointwise Poly",
587 &pointwise_layout,
588 &poly_ops_module,
589 "pointwise_poly"
590 )
591 );
592 let to_montgomery_pipeline = timed!(
593 "pipeline: To Montgomery",
594 create_pipeline(
595 &device,
596 "To Montgomery",
597 &montgomery_layout,
598 &poly_ops_module,
599 "to_montgomery_array"
600 )
601 );
602 let from_montgomery_pipeline = timed!(
603 "pipeline: From Montgomery",
604 create_pipeline(
605 &device,
606 "From Montgomery",
607 &montgomery_layout,
608 &poly_ops_module,
609 "from_montgomery_array"
610 )
611 );
612
613 let msm_agg_layout =
615 pipeline_layout(&device, &[&msm_agg_bind_group_layout]);
616 let msm_sum_layout =
617 pipeline_layout(&device, &[&msm_sum_bind_group_layout]);
618 let msm_weight_g1_layout =
619 pipeline_layout(&device, &[&msm_weight_g1_bind_group_layout]);
620 let msm_subsum_phase1_layout =
621 pipeline_layout(&device, &[&msm_subsum_phase1_bind_group_layout]);
622 let msm_subsum_phase2_layout =
623 pipeline_layout(&device, &[&msm_subsum_phase2_bind_group_layout]);
624
625 let msm_agg_g1_pipeline = timed!(
626 "pipeline: MSM Agg G1",
627 create_pipeline(
628 &device,
629 "MSM Agg G1",
630 &msm_agg_layout,
631 &msm_g1_agg_module,
632 "aggregate_buckets_g1"
633 )
634 );
635 let msm_sum_g1_pipeline = timed!(
636 "pipeline: MSM Sum G1",
637 create_pipeline(
638 &device,
639 "MSM Sum G1",
640 &msm_sum_layout,
641 &msm_g1_subsum_module,
642 "subsum_accumulation_g1"
643 )
644 );
645 let msm_agg_g2_pipeline = timed!(
646 "pipeline: MSM Agg G2",
647 create_pipeline(
648 &device,
649 "MSM Agg G2",
650 &msm_agg_layout,
651 &msm_g2_agg_module,
652 "aggregate_buckets_g2"
653 )
654 );
655 let msm_sum_g2_pipeline = timed!(
656 "pipeline: MSM Sum G2",
657 create_pipeline(
658 &device,
659 "MSM Sum G2",
660 &msm_sum_layout,
661 &msm_g2_subsum_module,
662 "subsum_accumulation_g2"
663 )
664 );
665 let msm_to_mont_g1_pipeline = timed!(
666 "pipeline: MSM To Mont G1",
667 create_pipeline(
668 &device,
669 "MSM To Montgomery G1",
670 &montgomery_layout,
671 &msm_g1_agg_module,
672 "to_montgomery_bases_g1"
673 )
674 );
675 let msm_to_mont_g2_pipeline = timed!(
676 "pipeline: MSM To Mont G2",
677 create_pipeline(
678 &device,
679 "MSM To Montgomery G2",
680 &montgomery_layout,
681 &msm_g2_agg_module,
682 "to_montgomery_bases_g2"
683 )
684 );
685 let msm_weight_g1_pipeline = timed!(
686 "pipeline: MSM Weight G1",
687 create_pipeline(
688 &device,
689 "MSM Weight G1",
690 &msm_weight_g1_layout,
691 &msm_g1_agg_module,
692 "weight_buckets_g1"
693 )
694 );
695 let msm_subsum_phase1_g1_pipeline = timed!(
696 "pipeline: MSM Subsum Ph1 G1",
697 create_pipeline(
698 &device,
699 "MSM Subsum Phase1 G1",
700 &msm_subsum_phase1_layout,
701 &msm_g1_subsum_module,
702 "subsum_phase1_g1"
703 )
704 );
705 let msm_subsum_phase2_g1_pipeline = timed!(
706 "pipeline: MSM Subsum Ph2 G1",
707 create_pipeline(
708 &device,
709 "MSM Subsum Phase2 G1",
710 &msm_subsum_phase2_layout,
711 &msm_g1_subsum_module,
712 "subsum_phase2_g1"
713 )
714 );
715
716 let msm_weight_g2_layout =
717 pipeline_layout(&device, &[&msm_weight_g2_bind_group_layout]);
718 let msm_weight_g2_pipeline = timed!(
719 "pipeline: MSM Weight G2",
720 create_pipeline(
721 &device,
722 "MSM Weight G2",
723 &msm_weight_g2_layout,
724 &msm_g2_agg_module,
725 "weight_buckets_g2"
726 )
727 );
728 let msm_subsum_phase1_g2_pipeline = timed!(
729 "pipeline: MSM Subsum Ph1 G2",
730 create_pipeline(
731 &device,
732 "MSM Subsum Phase1 G2",
733 &msm_subsum_phase1_layout,
734 &msm_g2_subsum_module,
735 "subsum_phase1_g2"
736 )
737 );
738 let msm_subsum_phase2_g2_pipeline = timed!(
739 "pipeline: MSM Subsum Ph2 G2",
740 create_pipeline(
741 &device,
742 "MSM Subsum Phase2 G2",
743 &msm_subsum_phase2_layout,
744 &msm_g2_subsum_module,
745 "subsum_phase2_g2"
746 )
747 );
748
749 let msm_reduce_layout =
750 pipeline_layout(&device, &[&msm_reduce_bind_group_layout]);
751 let msm_reduce_g1_pipeline = timed!(
752 "pipeline: MSM Reduce G1",
753 create_pipeline(
754 &device,
755 "MSM Reduce G1",
756 &msm_reduce_layout,
757 &msm_g1_agg_module,
758 "reduce_sub_buckets_g1"
759 )
760 );
761 let msm_reduce_g2_pipeline = timed!(
762 "pipeline: MSM Reduce G2",
763 create_pipeline(
764 &device,
765 "MSM Reduce G2",
766 &msm_reduce_layout,
767 &msm_g2_agg_module,
768 "reduce_sub_buckets_g2"
769 )
770 );
771 #[cfg(feature = "timing")]
772 {
773 let pipelines_total = pipelines_start.elapsed();
774 eprintln!("\n[init] === GpuContext::new() summary ===");
775 eprintln!(
776 "[init] {:<30} {:>8.1}ms",
777 "shader compilation",
778 shader_total.as_secs_f64() * 1000.0
779 );
780 eprintln!(
781 "[init] {:<30} {:>8.1}ms",
782 "bind group layouts",
783 layouts_total.as_secs_f64() * 1000.0
784 );
785 eprintln!(
786 "[init] {:<30} {:>8.1}ms",
787 "pipeline creation",
788 pipelines_total.as_secs_f64() * 1000.0
789 );
790 eprintln!(
791 "[init] {:<30} {:>8.1}ms",
792 "TOTAL",
793 init_start.elapsed().as_secs_f64() * 1000.0
794 );
795 eprintln!();
796 }
797
798 #[cfg(feature = "profiling")]
799 let profiler = Mutex::new(wgpu_profiler::GpuProfiler::new(
800 &device,
801 wgpu_profiler::GpuProfilerSettings {
802 enable_timer_queries: true,
803 ..Default::default()
804 },
805 )?);
806
807 Ok(Self {
808 device,
809 queue,
810 ntt_pipeline,
811 ntt_fused_pipeline,
812 ntt_tile_dit_no_bitreverse_pipeline,
813 ntt_tile_dif_pipeline,
814 ntt_tile_dif_fused_pipeline,
815 ntt_tile_fused_pointwise_pipeline,
816 ntt_global_stage_pipeline,
817 ntt_global_stage_radix4_pipeline,
818 ntt_global_stage_dif_pipeline,
819 ntt_global_stage_dif_fused_pointwise_pipeline,
820 ntt_bitreverse_pipeline,
821 ntt_bitreverse_fused_pointwise_pipeline,
822 coset_shift_pipeline,
823 pointwise_poly_pipeline,
824 to_montgomery_pipeline,
825 from_montgomery_pipeline,
826 msm_agg_g1_pipeline,
827 msm_sum_g1_pipeline,
828 msm_agg_g2_pipeline,
829 msm_sum_g2_pipeline,
830 msm_to_mont_g1_pipeline,
831 msm_to_mont_g2_pipeline,
832 msm_weight_g1_pipeline,
833 msm_subsum_phase1_g1_pipeline,
834 msm_subsum_phase2_g1_pipeline,
835 msm_weight_g2_pipeline,
836 msm_subsum_phase1_g2_pipeline,
837 msm_subsum_phase2_g2_pipeline,
838 msm_reduce_g1_pipeline,
839 msm_reduce_g2_pipeline,
840 ntt_bind_group_layout,
841 ntt_fused_shift_bgl,
842 ntt_params_bind_group_layout,
843 coset_shift_bind_group_layout,
844 pointwise_poly_bind_group_layout,
845 pointwise_fused_bind_group_layout,
846 montgomery_bind_group_layout,
847 msm_agg_bind_group_layout,
848 msm_sum_bind_group_layout,
849 msm_weight_g1_bind_group_layout,
850 msm_weight_g2_bind_group_layout,
851 msm_subsum_phase1_bind_group_layout,
852 msm_subsum_phase2_bind_group_layout,
853 msm_reduce_bind_group_layout,
854 _marker: PhantomData,
855 #[cfg(feature = "profiling")]
856 profiler,
857 })
858 }
859
860 #[cfg(feature = "profiling")]
861 pub fn end_profiler_frame(&self) {
862 #[cfg(not(target_family = "wasm"))]
866 let _ = self.device.poll(wgpu::PollType::wait_indefinitely());
867 let mut profiler = self.profiler.lock().unwrap();
868 profiler.end_frame().expect("end_frame failed");
869 }
870
871 #[cfg(feature = "profiling")]
872 pub fn process_profiler_results(
873 &self,
874 ) -> Option<Vec<wgpu_profiler::GpuTimerQueryResult>> {
875 let mut profiler = self.profiler.lock().unwrap();
876 profiler.process_finished_frame(self.queue.get_timestamp_period())
877 }
878
879 pub fn create_storage_buffer(
880 &self,
881 label: &str,
882 data: &[u8],
883 ) -> wgpu::Buffer {
884 self.device
885 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
886 label: Some(label),
887 contents: data,
888 usage: wgpu::BufferUsages::STORAGE
889 | wgpu::BufferUsages::COPY_SRC
890 | wgpu::BufferUsages::COPY_DST,
891 })
892 }
893
894 pub fn create_empty_buffer(&self, label: &str, size: u64) -> wgpu::Buffer {
895 self.device.create_buffer(&wgpu::BufferDescriptor {
896 label: Some(label),
897 size,
898 usage: wgpu::BufferUsages::STORAGE
899 | wgpu::BufferUsages::COPY_SRC
900 | wgpu::BufferUsages::COPY_DST,
901 mapped_at_creation: false,
902 })
903 }
904}
905
906#[cfg(test)]
907mod tests;