Skip to main content

webgpu_groth16/gpu/
msm.rs

1//! MSM (Multi-Scalar Multiplication) GPU pipeline dispatcher.
2//!
3//! Executes the 5-kernel Pippenger MSM pipeline in a single command encoder:
4//!
5//! ```text
6//! bases ──► [to_montgomery] ──► bases(mont)
7//!                                    │
8//! bucket_data ───────────────────────┤
9//!                                    ▼
10//!                            [aggregate_buckets] ──► agg_output
11//!                                                        │
12//!                               (if has_chunks)          ▼
13//!                                              [reduce_sub_buckets] ──► aggregated_buckets
14//!                                                                            │
15//! bucket_values ─────────────────────────────────────────────────────────────┤
16//!                                                                            ▼
17//!                                                                    [weight_buckets]
18//!                                                                            │
19//!                                                                            ▼
20//!                                                                    [subsum_phase1] ──► partial_sums
21//!                                                                                             │
22//!                                                                                             ▼
23//!                                                                                     [subsum_phase2] ──► window_sums
24//! ```
25//!
26//! When sub-bucket chunking is active (`has_chunks`), aggregate writes to an
27//! intermediate buffer and a reduce pass sums sub-bucket partials into the
28//! final per-bucket buffer before weighting.
29
30use wgpu::util::DeviceExt;
31
32use super::curve::GpuCurve;
33use super::{GpuContext, MsmBuffers, compute_pass};
34
35impl<C: GpuCurve> GpuContext<C> {
36    #[allow(clippy::too_many_arguments)]
37    pub fn execute_msm(
38        &self,
39        is_g2: bool,
40        bufs: &MsmBuffers<'_>,
41        num_active_buckets: u32,
42        num_dispatched: u32,
43        has_chunks: bool,
44        num_windows: u32,
45        skip_montgomery: bool,
46    ) {
47        let bases_buf = bufs.bases;
48        let base_indices_buf = bufs.base_indices;
49        let bucket_pointers_buf = bufs.bucket_pointers;
50        let bucket_sizes_buf = bufs.bucket_sizes;
51        let aggregated_buckets_buf = bufs.aggregated_buckets;
52        let bucket_values_buf = bufs.bucket_values;
53        let window_starts_buf = bufs.window_starts;
54        let window_counts_buf = bufs.window_counts;
55        let window_sums_buf = bufs.window_sums;
56
57        let point_gpu_bytes: u64 = if is_g2 {
58            C::G2_GPU_BYTES as u64
59        } else {
60            C::G1_GPU_BYTES as u64
61        };
62
63        // When chunking is active, aggregate writes to a larger intermediate
64        // buffer and a reduce pass sums sub-buckets into the final
65        // aggregated_buckets buffer.
66        let intermediate_buf = if has_chunks {
67            Some(self.device.create_buffer(&wgpu::BufferDescriptor {
68                label: Some("MSM Intermediate Sub-Buckets"),
69                size: num_dispatched as u64 * point_gpu_bytes,
70                usage: wgpu::BufferUsages::STORAGE,
71                mapped_at_creation: false,
72            }))
73        } else {
74            None
75        };
76
77        // The buffer the aggregate kernel writes to: intermediate (chunked) or
78        // final (unchunked).
79        let agg_output_buf =
80            intermediate_buf.as_ref().unwrap_or(aggregated_buckets_buf);
81
82        let agg_bind_group =
83            self.device.create_bind_group(&wgpu::BindGroupDescriptor {
84                label: Some("MSM Agg Bind Group"),
85                layout: &self.msm_agg_bind_group_layout,
86                entries: &[
87                    wgpu::BindGroupEntry {
88                        binding: 0,
89                        resource: bases_buf.as_entire_binding(),
90                    },
91                    wgpu::BindGroupEntry {
92                        binding: 1,
93                        resource: base_indices_buf.as_entire_binding(),
94                    },
95                    wgpu::BindGroupEntry {
96                        binding: 2,
97                        resource: bucket_pointers_buf.as_entire_binding(),
98                    },
99                    wgpu::BindGroupEntry {
100                        binding: 3,
101                        resource: bucket_sizes_buf.as_entire_binding(),
102                    },
103                    wgpu::BindGroupEntry {
104                        binding: 4,
105                        resource: agg_output_buf.as_entire_binding(),
106                    },
107                    wgpu::BindGroupEntry {
108                        binding: 5,
109                        resource: bucket_values_buf.as_entire_binding(),
110                    },
111                ],
112            });
113
114        let mut encoder = self.device.create_command_encoder(
115            &wgpu::CommandEncoderDescriptor {
116                label: Some("MSM Encoder"),
117            },
118        );
119
120        #[cfg(feature = "profiling")]
121        let mut profiler_guard = self.profiler.lock().unwrap();
122        #[cfg(feature = "profiling")]
123        let mut scope = profiler_guard
124            .scope(if is_g2 { "msm_g2" } else { "msm_g1" }, &mut encoder);
125
126        // Pre-pass: convert bases to Montgomery form in-place so aggregate
127        // can skip per-point to_montgomery calls (saves 3 muls/load for G1, 6
128        // for G2). Skipped when using persistent bases that are already
129        // in Montgomery form.
130        if !skip_montgomery {
131            let mont_bind_group =
132                self.device.create_bind_group(&wgpu::BindGroupDescriptor {
133                    label: Some("MSM Bases Mont Bind Group"),
134                    layout: &self.montgomery_bind_group_layout,
135                    entries: &[wgpu::BindGroupEntry {
136                        binding: 0,
137                        resource: bases_buf.as_entire_binding(),
138                    }],
139                });
140            let point_size: u64 = if is_g2 {
141                C::G2_GPU_BYTES as u64
142            } else {
143                C::G1_GPU_BYTES as u64
144            };
145            let num_bases = (bases_buf.size() / point_size) as u32;
146            let mut cpass =
147                compute_pass!(scope, encoder, "to_montgomery_bases");
148            cpass.set_pipeline(if is_g2 {
149                &self.msm_to_mont_g2_pipeline
150            } else {
151                &self.msm_to_mont_g1_pipeline
152            });
153            cpass.set_bind_group(0, &mont_bind_group, &[]);
154            cpass.dispatch_workgroups(
155                num_bases.div_ceil(C::MSM_WORKGROUP_SIZE),
156                1,
157                1,
158            );
159        }
160
161        {
162            let mut cpass = compute_pass!(scope, encoder, "bucket_aggregation");
163            cpass.set_pipeline(if is_g2 {
164                &self.msm_agg_g2_pipeline
165            } else {
166                &self.msm_agg_g1_pipeline
167            });
168            cpass.set_bind_group(0, &agg_bind_group, &[]);
169            cpass.dispatch_workgroups(
170                num_dispatched.div_ceil(C::MSM_WORKGROUP_SIZE).max(1),
171                1,
172                1,
173            );
174        }
175
176        // When sub-bucket chunking is active, reduce sub-bucket partial sums
177        // into the final per-bucket aggregated results.
178        if has_chunks {
179            let reduce_starts_buf = bufs
180                .reduce_starts
181                .expect("reduce_starts required when has_chunks");
182            let reduce_counts_buf = bufs
183                .reduce_counts
184                .expect("reduce_counts required when has_chunks");
185            let reduce_bind_group =
186                self.device.create_bind_group(&wgpu::BindGroupDescriptor {
187                    label: Some("MSM Reduce Sub-Buckets BG"),
188                    layout: &self.msm_reduce_bind_group_layout,
189                    entries: &[
190                        wgpu::BindGroupEntry {
191                            binding: 0,
192                            resource: agg_output_buf.as_entire_binding(),
193                        },
194                        wgpu::BindGroupEntry {
195                            binding: 1,
196                            resource: reduce_starts_buf.as_entire_binding(),
197                        },
198                        wgpu::BindGroupEntry {
199                            binding: 2,
200                            resource: reduce_counts_buf.as_entire_binding(),
201                        },
202                        wgpu::BindGroupEntry {
203                            binding: 3,
204                            resource: aggregated_buckets_buf
205                                .as_entire_binding(),
206                        },
207                    ],
208                });
209            let mut cpass = compute_pass!(scope, encoder, "reduce_sub_buckets");
210            cpass.set_pipeline(if is_g2 {
211                &self.msm_reduce_g2_pipeline
212            } else {
213                &self.msm_reduce_g1_pipeline
214            });
215            cpass.set_bind_group(0, &reduce_bind_group, &[]);
216            cpass.dispatch_workgroups(
217                num_active_buckets.div_ceil(C::MSM_WORKGROUP_SIZE).max(1),
218                1,
219                1,
220            );
221        }
222
223        // Weight each bucket sum by its bucket value in a separate kernel.
224        // When chunking is active, use original bucket values (not sub-bucket
225        // values).
226        let weight_values_buf = if has_chunks {
227            bufs.orig_bucket_values
228                .expect("orig_bucket_values required when has_chunks")
229        } else {
230            bucket_values_buf
231        };
232        {
233            let weight_bind_group =
234                self.device.create_bind_group(&wgpu::BindGroupDescriptor {
235                    label: Some(if is_g2 {
236                        "MSM Weight G2 BG"
237                    } else {
238                        "MSM Weight G1 BG"
239                    }),
240                    layout: if is_g2 {
241                        &self.msm_weight_g2_bind_group_layout
242                    } else {
243                        &self.msm_weight_g1_bind_group_layout
244                    },
245                    entries: &[
246                        wgpu::BindGroupEntry {
247                            binding: 0,
248                            resource: aggregated_buckets_buf
249                                .as_entire_binding(),
250                        },
251                        wgpu::BindGroupEntry {
252                            binding: 1,
253                            resource: weight_values_buf.as_entire_binding(),
254                        },
255                    ],
256                });
257            let mut cpass = compute_pass!(scope, encoder, "bucket_weighting");
258            cpass.set_pipeline(if is_g2 {
259                &self.msm_weight_g2_pipeline
260            } else {
261                &self.msm_weight_g1_pipeline
262            });
263            cpass.set_bind_group(0, &weight_bind_group, &[]);
264            cpass.dispatch_workgroups(
265                num_active_buckets.div_ceil(C::MSM_WORKGROUP_SIZE).max(1),
266                1,
267                1,
268            );
269        }
270
271        // Both G1 and G2: two-pass multi-workgroup tree reduction.
272        // Phase 1: chunks_per_window workgroups per window each sum a
273        // contiguous          slice of weighted buckets → partial_sums
274        // buffer. Phase 2: one workgroup per window reduces
275        // partial_sums → final window_sums.
276        //
277        // When chunking is active, subsum must use original window metadata
278        // (which maps to num_active_buckets layout in aggregated_buckets_buf).
279        {
280            let chunks_per_window = if is_g2 {
281                C::G2_SUBSUM_CHUNKS_PER_WINDOW
282            } else {
283                C::G1_SUBSUM_CHUNKS_PER_WINDOW
284            };
285            let subsum_window_starts = if has_chunks {
286                bufs.orig_window_starts
287                    .expect("orig_window_starts required when has_chunks")
288            } else {
289                window_starts_buf
290            };
291            let subsum_window_counts = if has_chunks {
292                bufs.orig_window_counts
293                    .expect("orig_window_counts required when has_chunks")
294            } else {
295                window_counts_buf
296            };
297
298            let partial_sums_buf =
299                self.device.create_buffer(&wgpu::BufferDescriptor {
300                    label: Some("MSM Partial Sums"),
301                    size: (num_windows * chunks_per_window) as u64
302                        * point_gpu_bytes,
303                    usage: wgpu::BufferUsages::STORAGE,
304                    mapped_at_creation: false,
305                });
306            let subsum_params: [u32; 4] = [chunks_per_window, 0, 0, 0];
307            let subsum_params_buf = self.device.create_buffer_init(
308                &wgpu::util::BufferInitDescriptor {
309                    label: Some("Subsum Params"),
310                    contents: bytemuck::cast_slice(&subsum_params),
311                    usage: wgpu::BufferUsages::UNIFORM,
312                },
313            );
314
315            let phase1_bind_group =
316                self.device.create_bind_group(&wgpu::BindGroupDescriptor {
317                    label: Some("MSM Subsum Phase1 BG"),
318                    layout: &self.msm_subsum_phase1_bind_group_layout,
319                    entries: &[
320                        wgpu::BindGroupEntry {
321                            binding: 0,
322                            resource: aggregated_buckets_buf
323                                .as_entire_binding(),
324                        },
325                        wgpu::BindGroupEntry {
326                            binding: 1,
327                            resource: subsum_window_starts.as_entire_binding(),
328                        },
329                        wgpu::BindGroupEntry {
330                            binding: 2,
331                            resource: subsum_window_counts.as_entire_binding(),
332                        },
333                        wgpu::BindGroupEntry {
334                            binding: 3,
335                            resource: partial_sums_buf.as_entire_binding(),
336                        },
337                        wgpu::BindGroupEntry {
338                            binding: 4,
339                            resource: subsum_params_buf.as_entire_binding(),
340                        },
341                    ],
342                });
343
344            let phase2_bind_group =
345                self.device.create_bind_group(&wgpu::BindGroupDescriptor {
346                    label: Some("MSM Subsum Phase2 BG"),
347                    layout: &self.msm_subsum_phase2_bind_group_layout,
348                    entries: &[
349                        wgpu::BindGroupEntry {
350                            binding: 0,
351                            resource: partial_sums_buf.as_entire_binding(),
352                        },
353                        wgpu::BindGroupEntry {
354                            binding: 1,
355                            resource: window_sums_buf.as_entire_binding(),
356                        },
357                        wgpu::BindGroupEntry {
358                            binding: 2,
359                            resource: subsum_params_buf.as_entire_binding(),
360                        },
361                    ],
362                });
363
364            // Phase 1: many workgroups per window → partial sums.
365            {
366                let mut cpass =
367                    compute_pass!(scope, encoder, "tree_reduction_ph1");
368                cpass.set_pipeline(if is_g2 {
369                    &self.msm_subsum_phase1_g2_pipeline
370                } else {
371                    &self.msm_subsum_phase1_g1_pipeline
372                });
373                cpass.set_bind_group(0, &phase1_bind_group, &[]);
374                cpass.dispatch_workgroups(
375                    num_windows * chunks_per_window,
376                    1,
377                    1,
378                );
379            }
380
381            // Phase 2: reduce partial sums → final window sums.
382            {
383                let mut cpass =
384                    compute_pass!(scope, encoder, "tree_reduction_ph2");
385                cpass.set_pipeline(if is_g2 {
386                    &self.msm_subsum_phase2_g2_pipeline
387                } else {
388                    &self.msm_subsum_phase2_g1_pipeline
389                });
390                cpass.set_bind_group(0, &phase2_bind_group, &[]);
391                cpass.dispatch_workgroups(num_windows, 1, 1);
392            }
393        }
394
395        #[cfg(feature = "profiling")]
396        {
397            drop(scope);
398            profiler_guard.resolve_queries(&mut encoder);
399        }
400
401        self.queue.submit(Some(encoder.finish()));
402    }
403
404    /// Convert a bases buffer to Montgomery form in-place (one-time, for
405    /// persistent bases).
406    pub fn convert_to_montgomery(&self, buf: &wgpu::Buffer, is_g2: bool) {
407        let mont_bind_group =
408            self.device.create_bind_group(&wgpu::BindGroupDescriptor {
409                label: Some("Convert To Montgomery BG"),
410                layout: &self.montgomery_bind_group_layout,
411                entries: &[wgpu::BindGroupEntry {
412                    binding: 0,
413                    resource: buf.as_entire_binding(),
414                }],
415            });
416        let point_size: u64 = if is_g2 {
417            C::G2_GPU_BYTES as u64
418        } else {
419            C::G1_GPU_BYTES as u64
420        };
421        let num_bases = (buf.size() / point_size) as u32;
422        let mut encoder = self.device.create_command_encoder(
423            &wgpu::CommandEncoderDescriptor {
424                label: Some("Convert To Montgomery Encoder"),
425            },
426        );
427        {
428            let mut cpass =
429                encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
430                    label: Some("to_montgomery"),
431                    timestamp_writes: None,
432                });
433            cpass.set_pipeline(if is_g2 {
434                &self.msm_to_mont_g2_pipeline
435            } else {
436                &self.msm_to_mont_g1_pipeline
437            });
438            cpass.set_bind_group(0, &mont_bind_group, &[]);
439            cpass.dispatch_workgroups(
440                num_bases.div_ceil(C::MSM_WORKGROUP_SIZE),
441                1,
442                1,
443            );
444        }
445        self.queue.submit(Some(encoder.finish()));
446    }
447}