Skip to main content

webgpu_groth16/
gpu.rs

1//! GPU context, compute pipeline management, and kernel dispatch.
2//!
3//! [`GpuContext`] owns the wgpu device/queue and all pre-compiled compute
4//! pipelines needed for MSM, NTT, and polynomial operations.
5
6// Implementation details:
7//
8// - [`msm`] — MSM 5-kernel Pippenger pipeline (to_montgomery, aggregate,
9//   reduce, weight, subsum)
10// - [`ntt`] — NTT dispatchers (tile-local and multi-stage global), Montgomery
11//   conversion, coset shift, pointwise polynomial evaluation
12// - [`h_poly`] — H-polynomial pipeline (fused NTT+shift → pointwise → iNTT)
13// - [`curve`] — curve-agnostic trait + curve-specific implementations
14
15mod buffers;
16pub mod curve;
17mod h_poly;
18mod msm;
19mod ntt;
20
21use std::borrow::Cow;
22use std::marker::PhantomData;
23#[cfg(feature = "profiling")]
24use std::sync::Mutex;
25#[cfg(feature = "timing")]
26use std::time::Instant;
27
28use anyhow::Context;
29use wgpu::util::DeviceExt;
30
31use self::curve::GpuCurve;
32
33/// Creates a compute pass, optionally wrapped in a GPU profiling scope.
34///
35/// With the `profiling` feature enabled, wraps the pass in
36/// `scope.scoped_compute_pass()` for per-kernel GPU timing via
37/// [`wgpu_profiler`]. Without it, creates a plain compute pass.
38///
39/// Usage: `let mut cpass = compute_pass!(scope, encoder, "kernel_label");`
40macro_rules! compute_pass {
41    ($scope:expr, $encoder:expr, $label:expr) => {{
42        #[cfg(feature = "profiling")]
43        let pass = $scope.scoped_compute_pass($label);
44        #[cfg(not(feature = "profiling"))]
45        let pass = $encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
46            label: Some($label),
47            timestamp_writes: None,
48        });
49        pass
50    }};
51}
52pub(crate) use compute_pass;
53
54/// Shorthand for buffer binding types in compute bind group layouts.
55enum BufKind {
56    ReadOnly,
57    ReadWrite,
58    Uniform,
59}
60
61/// Creates a compute bind group layout with sequentially-numbered bindings.
62fn create_bind_group_layout(
63    device: &wgpu::Device,
64    label: &str,
65    bindings: &[BufKind],
66) -> wgpu::BindGroupLayout {
67    let entries: Vec<wgpu::BindGroupLayoutEntry> = bindings
68        .iter()
69        .enumerate()
70        .map(|(i, kind)| wgpu::BindGroupLayoutEntry {
71            binding: i as u32,
72            visibility: wgpu::ShaderStages::COMPUTE,
73            ty: match kind {
74                BufKind::ReadOnly => wgpu::BindingType::Buffer {
75                    ty: wgpu::BufferBindingType::Storage { read_only: true },
76                    has_dynamic_offset: false,
77                    min_binding_size: None,
78                },
79                BufKind::ReadWrite => wgpu::BindingType::Buffer {
80                    ty: wgpu::BufferBindingType::Storage { read_only: false },
81                    has_dynamic_offset: false,
82                    min_binding_size: None,
83                },
84                BufKind::Uniform => wgpu::BindingType::Buffer {
85                    ty: wgpu::BufferBindingType::Uniform,
86                    has_dynamic_offset: false,
87                    min_binding_size: None,
88                },
89            },
90            count: None,
91        })
92        .collect();
93    device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
94        label: Some(label),
95        entries: &entries,
96    })
97}
98
99/// Creates a pipeline layout from one or more bind group layouts.
100fn pipeline_layout(
101    device: &wgpu::Device,
102    bind_group_layouts: &[&wgpu::BindGroupLayout],
103) -> wgpu::PipelineLayout {
104    device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
105        label: None,
106        bind_group_layouts,
107        immediate_size: 0,
108    })
109}
110
111/// Creates a compute pipeline with the given layout, shader module, and entry
112/// point.
113fn create_pipeline(
114    device: &wgpu::Device,
115    label: &str,
116    layout: &wgpu::PipelineLayout,
117    module: &wgpu::ShaderModule,
118    entry_point: &str,
119) -> wgpu::ComputePipeline {
120    device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
121        label: Some(label),
122        layout: Some(layout),
123        module,
124        entry_point: Some(entry_point),
125        compilation_options: Default::default(),
126        cache: None,
127    })
128}
129
130/// GPU buffers required for an MSM dispatch.
131pub struct MsmBuffers<'a> {
132    pub bases: &'a wgpu::Buffer,
133    pub base_indices: &'a wgpu::Buffer,
134    pub bucket_pointers: &'a wgpu::Buffer,
135    pub bucket_sizes: &'a wgpu::Buffer,
136    pub aggregated_buckets: &'a wgpu::Buffer,
137    pub bucket_values: &'a wgpu::Buffer,
138    pub window_starts: &'a wgpu::Buffer,
139    pub window_counts: &'a wgpu::Buffer,
140    pub window_sums: &'a wgpu::Buffer,
141    /// Sub-bucket reduce buffers (only used when has_chunks is true).
142    pub reduce_starts: Option<&'a wgpu::Buffer>,
143    pub reduce_counts: Option<&'a wgpu::Buffer>,
144    /// Original (pre-chunking) bucket values for weight pass.
145    pub orig_bucket_values: Option<&'a wgpu::Buffer>,
146    /// Original window starts/counts for subsum passes.
147    pub orig_window_starts: Option<&'a wgpu::Buffer>,
148    pub orig_window_counts: Option<&'a wgpu::Buffer>,
149}
150
151/// GPU buffers required for the H polynomial pipeline.
152pub struct HPolyBuffers<'a> {
153    pub a: &'a wgpu::Buffer,
154    pub b: &'a wgpu::Buffer,
155    pub c: &'a wgpu::Buffer,
156    pub h: &'a wgpu::Buffer,
157    pub twiddles_inv: &'a wgpu::Buffer,
158    pub twiddles_fwd: &'a wgpu::Buffer,
159    pub shifts: &'a wgpu::Buffer,
160    pub inv_shifts: &'a wgpu::Buffer,
161    pub z_invs: &'a wgpu::Buffer,
162}
163
164pub struct GpuContext<C> {
165    pub device: wgpu::Device,
166    pub queue: wgpu::Queue,
167
168    // Polynomial Pipelines
169    pub ntt_pipeline: wgpu::ComputePipeline,
170    pub ntt_fused_pipeline: wgpu::ComputePipeline,
171    pub ntt_tile_dit_no_bitreverse_pipeline: wgpu::ComputePipeline,
172    pub ntt_tile_dif_pipeline: wgpu::ComputePipeline,
173    pub ntt_tile_dif_fused_pipeline: wgpu::ComputePipeline,
174    pub ntt_tile_fused_pointwise_pipeline: wgpu::ComputePipeline,
175    pub ntt_global_stage_pipeline: wgpu::ComputePipeline,
176    pub ntt_global_stage_radix4_pipeline: wgpu::ComputePipeline,
177    pub ntt_global_stage_dif_pipeline: wgpu::ComputePipeline,
178    pub ntt_global_stage_dif_fused_pointwise_pipeline: wgpu::ComputePipeline,
179    pub ntt_bitreverse_pipeline: wgpu::ComputePipeline,
180    pub ntt_bitreverse_fused_pointwise_pipeline: wgpu::ComputePipeline,
181    pub coset_shift_pipeline: wgpu::ComputePipeline,
182    pub pointwise_poly_pipeline: wgpu::ComputePipeline,
183    pub to_montgomery_pipeline: wgpu::ComputePipeline,
184    pub from_montgomery_pipeline: wgpu::ComputePipeline,
185
186    // MSM 2-Stage Pipelines
187    pub msm_agg_g1_pipeline: wgpu::ComputePipeline,
188    pub msm_sum_g1_pipeline: wgpu::ComputePipeline,
189    pub msm_agg_g2_pipeline: wgpu::ComputePipeline,
190    pub msm_sum_g2_pipeline: wgpu::ComputePipeline,
191    pub msm_to_mont_g1_pipeline: wgpu::ComputePipeline,
192    pub msm_to_mont_g2_pipeline: wgpu::ComputePipeline,
193    pub msm_weight_g1_pipeline: wgpu::ComputePipeline,
194    pub msm_subsum_phase1_g1_pipeline: wgpu::ComputePipeline,
195    pub msm_subsum_phase2_g1_pipeline: wgpu::ComputePipeline,
196    pub msm_weight_g2_pipeline: wgpu::ComputePipeline,
197    pub msm_subsum_phase1_g2_pipeline: wgpu::ComputePipeline,
198    pub msm_subsum_phase2_g2_pipeline: wgpu::ComputePipeline,
199    pub msm_reduce_g1_pipeline: wgpu::ComputePipeline,
200    pub msm_reduce_g2_pipeline: wgpu::ComputePipeline,
201
202    // Bind Group Layouts
203    pub ntt_bind_group_layout: wgpu::BindGroupLayout,
204    pub ntt_fused_shift_bgl: wgpu::BindGroupLayout,
205    pub ntt_params_bind_group_layout: wgpu::BindGroupLayout,
206    pub coset_shift_bind_group_layout: wgpu::BindGroupLayout,
207    pub pointwise_poly_bind_group_layout: wgpu::BindGroupLayout,
208    pub pointwise_fused_bind_group_layout: wgpu::BindGroupLayout,
209    pub montgomery_bind_group_layout: wgpu::BindGroupLayout,
210    pub msm_agg_bind_group_layout: wgpu::BindGroupLayout,
211    pub msm_sum_bind_group_layout: wgpu::BindGroupLayout,
212    pub msm_weight_g1_bind_group_layout: wgpu::BindGroupLayout,
213    pub msm_weight_g2_bind_group_layout: wgpu::BindGroupLayout,
214    pub msm_subsum_phase1_bind_group_layout: wgpu::BindGroupLayout,
215    pub msm_subsum_phase2_bind_group_layout: wgpu::BindGroupLayout,
216    pub msm_reduce_bind_group_layout: wgpu::BindGroupLayout,
217
218    _marker: PhantomData<C>,
219
220    #[cfg(feature = "profiling")]
221    pub profiler: Mutex<wgpu_profiler::GpuProfiler>,
222}
223
224impl<C: GpuCurve> GpuContext<C> {
225    pub async fn new() -> anyhow::Result<Self> {
226        let instance = wgpu::Instance::default();
227
228        let adapter = instance
229            .request_adapter(&wgpu::RequestAdapterOptions {
230                power_preference: wgpu::PowerPreference::HighPerformance,
231                force_fallback_adapter: false,
232                compatible_surface: None,
233            })
234            .await
235            .context("Failed to find a compatible WebGPU adapter")?;
236
237        #[cfg(feature = "profiling")]
238        let required_features = adapter.features()
239            & wgpu_profiler::GpuProfiler::ALL_WGPU_TIMER_FEATURES;
240        #[cfg(not(feature = "profiling"))]
241        let required_features = wgpu::Features::empty();
242
243        let (device, queue) = adapter
244            .request_device(&wgpu::DeviceDescriptor {
245                label: Some("Groth16 Prover Device"),
246                // Use adapter.limits() directly to support large buffers
247                // (>128MB) for WebAssembly
248                required_limits: adapter.limits(),
249                required_features,
250                ..Default::default()
251            })
252            .await
253            .context("Failed to request WebGPU device")?;
254
255        macro_rules! timed {
256            ($label:expr, $expr:expr) => {{
257                #[cfg(feature = "timing")]
258                let _t = Instant::now();
259                let _r = $expr;
260                #[cfg(feature = "timing")]
261                eprintln!(
262                    "[init] {:<30} {:>8.1}ms",
263                    $label,
264                    _t.elapsed().as_secs_f64() * 1000.0
265                );
266                _r
267            }};
268        }
269
270        #[cfg(feature = "timing")]
271        let init_start = Instant::now();
272
273        // 1. Compile Shader Modules
274        #[cfg(feature = "timing")]
275        let shader_start = Instant::now();
276        let ntt_module = timed!(
277            "shader: NTT",
278            device.create_shader_module(wgpu::ShaderModuleDescriptor {
279                label: Some("NTT Shader"),
280                source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(C::NTT_SOURCE)),
281            })
282        );
283
284        let msm_g1_agg_module = timed!(
285            "shader: MSM G1 Agg",
286            device.create_shader_module(wgpu::ShaderModuleDescriptor {
287                label: Some("MSM G1 Agg Shader"),
288                source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(
289                    C::MSM_G1_AGG_SOURCE
290                )),
291            })
292        );
293        let msm_g1_subsum_module = timed!(
294            "shader: MSM G1 Subsum",
295            device.create_shader_module(wgpu::ShaderModuleDescriptor {
296                label: Some("MSM G1 Subsum Shader"),
297                source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(
298                    C::MSM_G1_SUBSUM_SOURCE
299                )),
300            })
301        );
302        let msm_g2_agg_module = timed!(
303            "shader: MSM G2 Agg",
304            device.create_shader_module(wgpu::ShaderModuleDescriptor {
305                label: Some("MSM G2 Agg Shader"),
306                source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(
307                    C::MSM_G2_AGG_SOURCE
308                )),
309            })
310        );
311        let msm_g2_subsum_module = timed!(
312            "shader: MSM G2 Subsum",
313            device.create_shader_module(wgpu::ShaderModuleDescriptor {
314                label: Some("MSM G2 Subsum Shader"),
315                source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(
316                    C::MSM_G2_SUBSUM_SOURCE
317                )),
318            })
319        );
320
321        let poly_ops_module = timed!(
322            "shader: Poly Ops",
323            device.create_shader_module(wgpu::ShaderModuleDescriptor {
324                label: Some("Poly Ops Shader"),
325                source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(
326                    C::POLY_OPS_SOURCE
327                )),
328            })
329        );
330
331        let ntt_fused_module = timed!(
332            "shader: NTT Fused",
333            device.create_shader_module(wgpu::ShaderModuleDescriptor {
334                label: Some("NTT Fused Shader"),
335                source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(
336                    C::NTT_FUSED_SOURCE
337                )),
338            })
339        );
340        #[cfg(feature = "timing")]
341        let shader_total = shader_start.elapsed();
342
343        // 2. Define Bind Group Layouts
344        #[cfg(feature = "timing")]
345        let layouts_start = Instant::now();
346        use BufKind::{ReadOnly as RO, ReadWrite as RW, Uniform as UF};
347
348        let ntt_bind_group_layout =
349            create_bind_group_layout(&device, "NTT", &[RW, RO]);
350        let ntt_fused_shift_bgl =
351            create_bind_group_layout(&device, "NTT Fused Shift", &[RO]);
352        let ntt_params_bind_group_layout =
353            create_bind_group_layout(&device, "NTT Global", &[RW, RO, UF]);
354        let coset_shift_bind_group_layout =
355            create_bind_group_layout(&device, "Coset Shift", &[RW, RO]);
356        let pointwise_poly_bind_group_layout = create_bind_group_layout(
357            &device,
358            "Pointwise Poly",
359            &[RO, RO, RO, RW, RO],
360        );
361        let pointwise_fused_bind_group_layout = create_bind_group_layout(
362            &device,
363            "Pointwise Fused",
364            &[RO, RO, RO, RO],
365        );
366        let montgomery_bind_group_layout =
367            create_bind_group_layout(&device, "Montgomery", &[RW]);
368        let msm_agg_bind_group_layout = create_bind_group_layout(
369            &device,
370            "MSM Agg",
371            &[RO, RO, RO, RO, RW, RO],
372        );
373        let msm_sum_bind_group_layout =
374            create_bind_group_layout(&device, "MSM Sum", &[RO, RO, RO, RO, RW]);
375        // Weight buckets: [data(rw), bucket_values(read)]
376        let msm_weight_g1_bind_group_layout =
377            create_bind_group_layout(&device, "MSM Weight G1", &[RW, RO]);
378        let msm_weight_g2_bind_group_layout =
379            create_bind_group_layout(&device, "MSM Weight G2", &[RW, RO]);
380        // Phase1: [agg_buckets(read), window_starts(read), window_counts(read),
381        //          partial_sums(rw), subsum_params(uniform)]
382        let msm_subsum_phase1_bind_group_layout = create_bind_group_layout(
383            &device,
384            "MSM Subsum Phase1",
385            &[RO, RO, RO, RW, UF],
386        );
387        // Phase2: [partial_sums(read), window_sums(rw), subsum_params(uniform)]
388        let msm_subsum_phase2_bind_group_layout = create_bind_group_layout(
389            &device,
390            "MSM Subsum Phase2",
391            &[RO, RW, UF],
392        );
393        // Reduce sub-buckets: [input(read), starts(read), counts(read),
394        // output(rw)]
395        let msm_reduce_bind_group_layout =
396            create_bind_group_layout(&device, "MSM Reduce", &[RO, RO, RO, RW]);
397
398        #[cfg(feature = "timing")]
399        let layouts_total = layouts_start.elapsed();
400        #[cfg(feature = "timing")]
401        eprintln!(
402            "[init] {:<30} {:>8.1}ms",
403            "bind group layouts (total)",
404            layouts_total.as_secs_f64() * 1000.0
405        );
406
407        // 3. Create Compute Pipelines
408        #[cfg(feature = "timing")]
409        let pipelines_start = Instant::now();
410
411        // NTT pipelines
412        let ntt_tile_layout =
413            pipeline_layout(&device, &[&ntt_bind_group_layout]);
414        let ntt_global_layout =
415            pipeline_layout(&device, &[&ntt_params_bind_group_layout]);
416        let ntt_pipeline = timed!(
417            "pipeline: NTT Tile",
418            create_pipeline(
419                &device,
420                "NTT Tile",
421                &ntt_tile_layout,
422                &ntt_module,
423                "ntt_tile"
424            )
425        );
426        let ntt_fused_layout = pipeline_layout(
427            &device,
428            &[&ntt_bind_group_layout, &ntt_fused_shift_bgl],
429        );
430        let ntt_fused_pipeline = timed!(
431            "pipeline: NTT Fused",
432            create_pipeline(
433                &device,
434                "NTT Fused",
435                &ntt_fused_layout,
436                &ntt_fused_module,
437                "ntt_tile_with_shift"
438            )
439        );
440        let ntt_tile_dif_pipeline = timed!(
441            "pipeline: NTT Tile DIF",
442            create_pipeline(
443                &device,
444                "NTT Tile DIF",
445                &ntt_fused_layout,
446                &ntt_fused_module,
447                "ntt_tile_dif"
448            )
449        );
450        let ntt_tile_dit_no_bitreverse_pipeline = timed!(
451            "pipeline: NTT Tile DIT NoRev",
452            create_pipeline(
453                &device,
454                "NTT Tile DIT NoRev",
455                &ntt_fused_layout,
456                &ntt_fused_module,
457                "ntt_tile_dit_no_bitreverse"
458            )
459        );
460        let ntt_tile_dif_fused_pipeline = timed!(
461            "pipeline: NTT Tile DIF Fused",
462            create_pipeline(
463                &device,
464                "NTT Tile DIF Fused",
465                &ntt_fused_layout,
466                &ntt_fused_module,
467                "ntt_tile_dif_with_shift"
468            )
469        );
470        let ntt_fused_pointwise_layout = pipeline_layout(
471            &device,
472            &[
473                &ntt_bind_group_layout,
474                &ntt_fused_shift_bgl,
475                &pointwise_fused_bind_group_layout,
476            ],
477        );
478        let ntt_tile_fused_pointwise_pipeline = timed!(
479            "pipeline: NTT Tile Fused Pointwise",
480            create_pipeline(
481                &device,
482                "NTT Tile Fused Pointwise",
483                &ntt_fused_pointwise_layout,
484                &ntt_fused_module,
485                "ntt_tile_fused_pointwise"
486            )
487        );
488        let ntt_global_stage_pipeline = timed!(
489            "pipeline: NTT Global Stage",
490            create_pipeline(
491                &device,
492                "NTT Global Stage",
493                &ntt_global_layout,
494                &ntt_module,
495                "ntt_global_stage"
496            )
497        );
498        let ntt_global_stage_radix4_pipeline = timed!(
499            "pipeline: NTT Global R4",
500            create_pipeline(
501                &device,
502                "NTT Global Stage Radix4",
503                &ntt_global_layout,
504                &ntt_module,
505                "ntt_global_stage_radix4"
506            )
507        );
508        let ntt_global_stage_dif_pipeline = timed!(
509            "pipeline: NTT Global DIF",
510            create_pipeline(
511                &device,
512                "NTT Global Stage DIF",
513                &ntt_global_layout,
514                &ntt_module,
515                "ntt_global_stage_dif"
516            )
517        );
518        let ntt_global_stage_dif_fused_layout = pipeline_layout(
519            &device,
520            &[
521                &ntt_params_bind_group_layout,
522                &ntt_fused_shift_bgl,
523                &pointwise_fused_bind_group_layout,
524            ],
525        );
526        let ntt_global_stage_dif_fused_pointwise_pipeline = timed!(
527            "pipeline: NTT Global DIF Fused",
528            create_pipeline(
529                &device,
530                "NTT Global DIF Fused Pointwise",
531                &ntt_global_stage_dif_fused_layout,
532                &ntt_module,
533                "ntt_global_stage_dif_fused_pointwise"
534            )
535        );
536        let ntt_bitreverse_pipeline = timed!(
537            "pipeline: NTT BitReverse",
538            create_pipeline(
539                &device,
540                "NTT BitReverse",
541                &ntt_global_layout,
542                &ntt_module,
543                "bitreverse_inplace"
544            )
545        );
546        let ntt_bitreverse_fused_layout = pipeline_layout(
547            &device,
548            &[
549                &ntt_params_bind_group_layout,
550                &ntt_fused_shift_bgl,
551                &pointwise_fused_bind_group_layout,
552            ],
553        );
554        let ntt_bitreverse_fused_pointwise_pipeline = timed!(
555            "pipeline: NTT BitReverse Fused",
556            create_pipeline(
557                &device,
558                "NTT BitReverse Fused Pointwise",
559                &ntt_bitreverse_fused_layout,
560                &ntt_module,
561                "bitreverse_fused_pointwise"
562            )
563        );
564
565        // Polynomial pipelines
566        let coset_shift_layout =
567            pipeline_layout(&device, &[&coset_shift_bind_group_layout]);
568        let pointwise_layout =
569            pipeline_layout(&device, &[&pointwise_poly_bind_group_layout]);
570        let montgomery_layout =
571            pipeline_layout(&device, &[&montgomery_bind_group_layout]);
572        let coset_shift_pipeline = timed!(
573            "pipeline: Coset Shift",
574            create_pipeline(
575                &device,
576                "Coset Shift",
577                &coset_shift_layout,
578                &poly_ops_module,
579                "coset_shift"
580            )
581        );
582        let pointwise_poly_pipeline = timed!(
583            "pipeline: Pointwise Poly",
584            create_pipeline(
585                &device,
586                "Pointwise Poly",
587                &pointwise_layout,
588                &poly_ops_module,
589                "pointwise_poly"
590            )
591        );
592        let to_montgomery_pipeline = timed!(
593            "pipeline: To Montgomery",
594            create_pipeline(
595                &device,
596                "To Montgomery",
597                &montgomery_layout,
598                &poly_ops_module,
599                "to_montgomery_array"
600            )
601        );
602        let from_montgomery_pipeline = timed!(
603            "pipeline: From Montgomery",
604            create_pipeline(
605                &device,
606                "From Montgomery",
607                &montgomery_layout,
608                &poly_ops_module,
609                "from_montgomery_array"
610            )
611        );
612
613        // MSM pipelines
614        let msm_agg_layout =
615            pipeline_layout(&device, &[&msm_agg_bind_group_layout]);
616        let msm_sum_layout =
617            pipeline_layout(&device, &[&msm_sum_bind_group_layout]);
618        let msm_weight_g1_layout =
619            pipeline_layout(&device, &[&msm_weight_g1_bind_group_layout]);
620        let msm_subsum_phase1_layout =
621            pipeline_layout(&device, &[&msm_subsum_phase1_bind_group_layout]);
622        let msm_subsum_phase2_layout =
623            pipeline_layout(&device, &[&msm_subsum_phase2_bind_group_layout]);
624
625        let msm_agg_g1_pipeline = timed!(
626            "pipeline: MSM Agg G1",
627            create_pipeline(
628                &device,
629                "MSM Agg G1",
630                &msm_agg_layout,
631                &msm_g1_agg_module,
632                "aggregate_buckets_g1"
633            )
634        );
635        let msm_sum_g1_pipeline = timed!(
636            "pipeline: MSM Sum G1",
637            create_pipeline(
638                &device,
639                "MSM Sum G1",
640                &msm_sum_layout,
641                &msm_g1_subsum_module,
642                "subsum_accumulation_g1"
643            )
644        );
645        let msm_agg_g2_pipeline = timed!(
646            "pipeline: MSM Agg G2",
647            create_pipeline(
648                &device,
649                "MSM Agg G2",
650                &msm_agg_layout,
651                &msm_g2_agg_module,
652                "aggregate_buckets_g2"
653            )
654        );
655        let msm_sum_g2_pipeline = timed!(
656            "pipeline: MSM Sum G2",
657            create_pipeline(
658                &device,
659                "MSM Sum G2",
660                &msm_sum_layout,
661                &msm_g2_subsum_module,
662                "subsum_accumulation_g2"
663            )
664        );
665        let msm_to_mont_g1_pipeline = timed!(
666            "pipeline: MSM To Mont G1",
667            create_pipeline(
668                &device,
669                "MSM To Montgomery G1",
670                &montgomery_layout,
671                &msm_g1_agg_module,
672                "to_montgomery_bases_g1"
673            )
674        );
675        let msm_to_mont_g2_pipeline = timed!(
676            "pipeline: MSM To Mont G2",
677            create_pipeline(
678                &device,
679                "MSM To Montgomery G2",
680                &montgomery_layout,
681                &msm_g2_agg_module,
682                "to_montgomery_bases_g2"
683            )
684        );
685        let msm_weight_g1_pipeline = timed!(
686            "pipeline: MSM Weight G1",
687            create_pipeline(
688                &device,
689                "MSM Weight G1",
690                &msm_weight_g1_layout,
691                &msm_g1_agg_module,
692                "weight_buckets_g1"
693            )
694        );
695        let msm_subsum_phase1_g1_pipeline = timed!(
696            "pipeline: MSM Subsum Ph1 G1",
697            create_pipeline(
698                &device,
699                "MSM Subsum Phase1 G1",
700                &msm_subsum_phase1_layout,
701                &msm_g1_subsum_module,
702                "subsum_phase1_g1"
703            )
704        );
705        let msm_subsum_phase2_g1_pipeline = timed!(
706            "pipeline: MSM Subsum Ph2 G1",
707            create_pipeline(
708                &device,
709                "MSM Subsum Phase2 G1",
710                &msm_subsum_phase2_layout,
711                &msm_g1_subsum_module,
712                "subsum_phase2_g1"
713            )
714        );
715
716        let msm_weight_g2_layout =
717            pipeline_layout(&device, &[&msm_weight_g2_bind_group_layout]);
718        let msm_weight_g2_pipeline = timed!(
719            "pipeline: MSM Weight G2",
720            create_pipeline(
721                &device,
722                "MSM Weight G2",
723                &msm_weight_g2_layout,
724                &msm_g2_agg_module,
725                "weight_buckets_g2"
726            )
727        );
728        let msm_subsum_phase1_g2_pipeline = timed!(
729            "pipeline: MSM Subsum Ph1 G2",
730            create_pipeline(
731                &device,
732                "MSM Subsum Phase1 G2",
733                &msm_subsum_phase1_layout,
734                &msm_g2_subsum_module,
735                "subsum_phase1_g2"
736            )
737        );
738        let msm_subsum_phase2_g2_pipeline = timed!(
739            "pipeline: MSM Subsum Ph2 G2",
740            create_pipeline(
741                &device,
742                "MSM Subsum Phase2 G2",
743                &msm_subsum_phase2_layout,
744                &msm_g2_subsum_module,
745                "subsum_phase2_g2"
746            )
747        );
748
749        let msm_reduce_layout =
750            pipeline_layout(&device, &[&msm_reduce_bind_group_layout]);
751        let msm_reduce_g1_pipeline = timed!(
752            "pipeline: MSM Reduce G1",
753            create_pipeline(
754                &device,
755                "MSM Reduce G1",
756                &msm_reduce_layout,
757                &msm_g1_agg_module,
758                "reduce_sub_buckets_g1"
759            )
760        );
761        let msm_reduce_g2_pipeline = timed!(
762            "pipeline: MSM Reduce G2",
763            create_pipeline(
764                &device,
765                "MSM Reduce G2",
766                &msm_reduce_layout,
767                &msm_g2_agg_module,
768                "reduce_sub_buckets_g2"
769            )
770        );
771        #[cfg(feature = "timing")]
772        {
773            let pipelines_total = pipelines_start.elapsed();
774            eprintln!("\n[init] === GpuContext::new() summary ===");
775            eprintln!(
776                "[init] {:<30} {:>8.1}ms",
777                "shader compilation",
778                shader_total.as_secs_f64() * 1000.0
779            );
780            eprintln!(
781                "[init] {:<30} {:>8.1}ms",
782                "bind group layouts",
783                layouts_total.as_secs_f64() * 1000.0
784            );
785            eprintln!(
786                "[init] {:<30} {:>8.1}ms",
787                "pipeline creation",
788                pipelines_total.as_secs_f64() * 1000.0
789            );
790            eprintln!(
791                "[init] {:<30} {:>8.1}ms",
792                "TOTAL",
793                init_start.elapsed().as_secs_f64() * 1000.0
794            );
795            eprintln!();
796        }
797
798        #[cfg(feature = "profiling")]
799        let profiler = Mutex::new(wgpu_profiler::GpuProfiler::new(
800            &device,
801            wgpu_profiler::GpuProfilerSettings {
802                enable_timer_queries: true,
803                ..Default::default()
804            },
805        )?);
806
807        Ok(Self {
808            device,
809            queue,
810            ntt_pipeline,
811            ntt_fused_pipeline,
812            ntt_tile_dit_no_bitreverse_pipeline,
813            ntt_tile_dif_pipeline,
814            ntt_tile_dif_fused_pipeline,
815            ntt_tile_fused_pointwise_pipeline,
816            ntt_global_stage_pipeline,
817            ntt_global_stage_radix4_pipeline,
818            ntt_global_stage_dif_pipeline,
819            ntt_global_stage_dif_fused_pointwise_pipeline,
820            ntt_bitreverse_pipeline,
821            ntt_bitreverse_fused_pointwise_pipeline,
822            coset_shift_pipeline,
823            pointwise_poly_pipeline,
824            to_montgomery_pipeline,
825            from_montgomery_pipeline,
826            msm_agg_g1_pipeline,
827            msm_sum_g1_pipeline,
828            msm_agg_g2_pipeline,
829            msm_sum_g2_pipeline,
830            msm_to_mont_g1_pipeline,
831            msm_to_mont_g2_pipeline,
832            msm_weight_g1_pipeline,
833            msm_subsum_phase1_g1_pipeline,
834            msm_subsum_phase2_g1_pipeline,
835            msm_weight_g2_pipeline,
836            msm_subsum_phase1_g2_pipeline,
837            msm_subsum_phase2_g2_pipeline,
838            msm_reduce_g1_pipeline,
839            msm_reduce_g2_pipeline,
840            ntt_bind_group_layout,
841            ntt_fused_shift_bgl,
842            ntt_params_bind_group_layout,
843            coset_shift_bind_group_layout,
844            pointwise_poly_bind_group_layout,
845            pointwise_fused_bind_group_layout,
846            montgomery_bind_group_layout,
847            msm_agg_bind_group_layout,
848            msm_sum_bind_group_layout,
849            msm_weight_g1_bind_group_layout,
850            msm_weight_g2_bind_group_layout,
851            msm_subsum_phase1_bind_group_layout,
852            msm_subsum_phase2_bind_group_layout,
853            msm_reduce_bind_group_layout,
854            _marker: PhantomData,
855            #[cfg(feature = "profiling")]
856            profiler,
857        })
858    }
859
860    #[cfg(feature = "profiling")]
861    pub fn end_profiler_frame(&self) {
862        // Ensure all GPU work and timestamp query readbacks are complete
863        // before ending the frame. Without this, Metal may return stale or
864        // uninitialized timestamp data (causing negative durations).
865        #[cfg(not(target_family = "wasm"))]
866        let _ = self.device.poll(wgpu::PollType::wait_indefinitely());
867        let mut profiler = self.profiler.lock().unwrap();
868        profiler.end_frame().expect("end_frame failed");
869    }
870
871    #[cfg(feature = "profiling")]
872    pub fn process_profiler_results(
873        &self,
874    ) -> Option<Vec<wgpu_profiler::GpuTimerQueryResult>> {
875        let mut profiler = self.profiler.lock().unwrap();
876        profiler.process_finished_frame(self.queue.get_timestamp_period())
877    }
878
879    pub fn create_storage_buffer(
880        &self,
881        label: &str,
882        data: &[u8],
883    ) -> wgpu::Buffer {
884        self.device
885            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
886                label: Some(label),
887                contents: data,
888                usage: wgpu::BufferUsages::STORAGE
889                    | wgpu::BufferUsages::COPY_SRC
890                    | wgpu::BufferUsages::COPY_DST,
891            })
892    }
893
894    pub fn create_empty_buffer(&self, label: &str, size: u64) -> wgpu::Buffer {
895        self.device.create_buffer(&wgpu::BufferDescriptor {
896            label: Some(label),
897            size,
898            usage: wgpu::BufferUsages::STORAGE
899                | wgpu::BufferUsages::COPY_SRC
900                | wgpu::BufferUsages::COPY_DST,
901            mapped_at_creation: false,
902        })
903    }
904}
905
906#[cfg(test)]
907mod tests;