msm_webgpu/cuzk/
msm.rs

1use group::Group;
2use halo2curves::CurveAffine;
3use halo2curves::CurveExt;
4use num_bigint::BigUint;
5use num_traits::Num;
6use once_cell::sync::Lazy;
7use wgpu::{Buffer, CommandEncoder, CommandEncoderDescriptor, Device, Queue};
8
9use crate::cuzk::gpu::{
10    create_and_write_storage_buffer, create_and_write_uniform_buffer, create_bind_group,
11    create_bind_group_layout, create_compute_pipeline, create_storage_buffer, execute_pipeline,
12    get_adapter, get_device, read_from_gpu,
13};
14use crate::cuzk::shader_manager::ShaderManager;
15use crate::cuzk::utils::to_biguint_le;
16use crate::{points_to_bytes, scalars_to_bytes};
17
18use super::utils::bytes_to_field;
19use super::utils::calc_bitwidth;
20use super::utils::{MiscParams, compute_misc_params};
21
22/// Calculate the number of words in the field characteristic
23pub fn calc_num_words(word_size: usize) -> usize {
24    let p_bit_length = calc_bitwidth(&P);
25    let mut num_words = p_bit_length / word_size;
26    while num_words * word_size < p_bit_length {
27        num_words += 1;
28    }
29    num_words
30}
31
32/// 13-bit limbs.
33pub const WORD_SIZE: usize = 13;
34
35/// Field characteristic
36pub static P: Lazy<BigUint> = Lazy::new(|| {
37    BigUint::from_str_radix(
38        "21888242871839275222246405745257275088696311157297823662689037894645226208583",
39        10,
40    )
41    .expect("Invalid modulus")
42});
43
44/// Miscellaneous parameters
45pub static PARAMS: Lazy<MiscParams> = Lazy::new(|| compute_misc_params(&P, WORD_SIZE));
46
47/*
48 * End-to-end implementation of the modified cuZK MSM algorithm by Lu et al,
49 * 2022: https://eprint.iacr.org/2022/1321.pdf
50 */
51pub async fn compute_msm<C: CurveAffine>(points: &[C], scalars: &[C::Scalar]) -> C::Curve {
52    let input_size = scalars.len();
53    let chunk_size = if input_size >= 65536 { 16 } else { 4 };
54    let num_columns = 1 << chunk_size;
55    let num_rows = input_size.div_ceil(num_columns);
56    let num_subtasks = 256_usize.div_ceil(chunk_size);
57    let num_words = PARAMS.num_words;
58    let point_bytes = points_to_bytes(points);
59    let scalar_bytes = scalars_to_bytes(scalars);
60
61    let shader_manager = ShaderManager::new(WORD_SIZE, chunk_size, input_size);
62
63    let adapter = get_adapter().await;
64    let (device, queue) = get_device(&adapter).await;
65    let mut encoder = device.create_command_encoder(&CommandEncoderDescriptor {
66        label: Some("MSM Encoder"),
67    });
68
69    ////////////////////////////////////////////////////////////////////////////////////////////
70    // 1. Decompose scalars into chunk_size windows using signed bucket indices.             /
71    ////////////////////////////////////////////////////////////////////////////////////////////
72
73    // Total thread count = workgroup_size * #x workgroups * #y workgroups * #z workgroups.
74    let mut c_workgroup_size = 64;
75    let mut c_num_x_workgroups = 128;
76    let mut c_num_y_workgroups = input_size / c_workgroup_size / c_num_x_workgroups;
77    let c_num_z_workgroups = 1;
78
79    if input_size <= 256 {
80        c_workgroup_size = input_size;
81        c_num_x_workgroups = 1;
82        c_num_y_workgroups = 1;
83    } else if input_size > 256 && input_size <= 32768 {
84        c_workgroup_size = 64;
85        c_num_x_workgroups = 4;
86        c_num_y_workgroups = input_size / c_workgroup_size / c_num_x_workgroups;
87    } else if input_size > 32768 && input_size <= 131072 {
88        c_workgroup_size = 256;
89        c_num_x_workgroups = 8;
90        c_num_y_workgroups = input_size / c_workgroup_size / c_num_x_workgroups;
91    } else if input_size > 131072 && input_size <= 1048576 {
92        c_workgroup_size = 256;
93        c_num_x_workgroups = 32;
94        c_num_y_workgroups = input_size / c_workgroup_size / c_num_x_workgroups;
95    } 
96
97    let c_shader = shader_manager.gen_decomp_scalars_shader(
98        c_workgroup_size,
99        c_num_y_workgroups,
100        num_subtasks,
101        num_columns,
102    );
103
104    let (point_x_sb, point_y_sb, scalar_chunks_sb) = convert_point_coords_and_decompose_shaders(
105        &c_shader,
106        c_num_x_workgroups,
107        c_num_y_workgroups,
108        c_num_z_workgroups,
109        &device,
110        &queue,
111        &mut encoder,
112        &point_bytes,
113        &scalar_bytes,
114        num_subtasks,
115        chunk_size,
116        num_words,
117    )
118    .await;
119
120    ////////////////////////////////////////////////////////////////////////////////////////////
121    // 2. Sparse Matrix Transposition                                                         /
122    //                                                                                        /
123    // Compute the indices of the points which share the same                                 /
124    // scalar chunks, enabling the parallel accumulation of points                            /
125    // into buckets. Transposing each subtask (CSR sparse matrix)                             /
126    // is a serial computation.                                                               /
127    //                                                                                        /
128    // The transpose step generates the CSR sparse matrix and                                 /
129    // transpoes the matrix simultaneously, resulting in a                                    /
130    // wide and flat matrix where width of the matrix (n) = 2 ^ chunk_size                    /
131    // and height of the matrix (m) = 1.                                                      /
132    ////////////////////////////////////////////////////////////////////////////////////////////
133
134    let t_num_x_workgroups = 1;
135    let t_num_y_workgroups = 1;
136    let t_num_z_workgroups = 1;
137
138    let t_shader = shader_manager.gen_transpose_shader(num_subtasks);
139
140    let (all_csc_col_ptr_sb, all_csc_val_idxs_sb) = transpose_gpu(
141        &t_shader,
142        &device,
143        &queue,
144        &mut encoder,
145        t_num_x_workgroups,
146        t_num_y_workgroups,
147        t_num_z_workgroups,
148        input_size,
149        num_columns,
150        num_rows,
151        num_subtasks,
152        scalar_chunks_sb,
153    )
154    .await;
155
156    ////////////////////////////////////////////////////////////////////////////////////////////
157    // 3. Sparse Matrix Vector Product (SMVP)                                                 /
158    //                                                                                        /
159    // Each thread handles accumulating points in a single bucket.                            /
160    // The workgroup size and number of workgroups are designed around                        /
161    // minimizing shader invocations.                                                         /
162    ////////////////////////////////////////////////////////////////////////////////////////////
163
164    let half_num_columns = num_columns / 2;
165    let mut s_workgroup_size = 256;
166    let mut s_num_x_workgroups = 64;
167    let mut s_num_y_workgroups = (half_num_columns / s_workgroup_size) / s_num_x_workgroups;
168    let mut s_num_z_workgroups = num_subtasks;
169
170    if half_num_columns < 32768 {
171        s_workgroup_size = 32;
172        s_num_x_workgroups = 1;
173        s_num_y_workgroups =
174            (half_num_columns / s_workgroup_size).div_ceil(s_num_x_workgroups);
175    }
176
177    if num_columns < 256 {
178        s_workgroup_size = 1;
179        s_num_x_workgroups = half_num_columns;
180        s_num_y_workgroups = 1;
181        s_num_z_workgroups = 1;
182    }
183
184    // This is a dynamic variable that determines the number of CSR
185    // matrices processed per invocation of the shader. A safe default is 1.
186    let num_subtask_chunk_size = 4;
187
188    // Buffers that store the SMVP result, ie. bucket sums. They are
189    // overwritten per iteration.
190    let bucket_sum_coord_bytelength = (num_columns / 2) * num_words * 4 * num_subtasks;
191    let bucket_sum_x_sb = create_storage_buffer(
192        Some("Bucket sum X buffer"),
193        &device,
194        bucket_sum_coord_bytelength as u64,
195    );
196    let bucket_sum_y_sb = create_storage_buffer(
197        Some("Bucket sum Y buffer"),
198        &device,
199        bucket_sum_coord_bytelength as u64,
200    );
201    let bucket_sum_z_sb = create_storage_buffer(
202        Some("Bucket sum Z buffer"),
203        &device,
204        bucket_sum_coord_bytelength as u64,
205    );
206    let smvp_shader = shader_manager.gen_smvp_shader(s_workgroup_size, num_columns);
207
208    for offset in (0..num_subtasks).step_by(num_subtask_chunk_size) {
209        smvp_gpu(
210            &smvp_shader,
211            s_num_x_workgroups / (num_subtasks / num_subtask_chunk_size),
212            s_num_y_workgroups,
213            s_num_z_workgroups,
214            offset,
215            &device,
216            &queue,
217            &mut encoder,
218            input_size,
219            &all_csc_col_ptr_sb,
220            &point_x_sb,
221            &point_y_sb,
222            &all_csc_val_idxs_sb,
223            &bucket_sum_x_sb,
224            &bucket_sum_y_sb,
225            &bucket_sum_z_sb,
226        )
227        .await;
228    }
229
230    /////////////////////////////////////////////////////////////////////////////////////////////
231    // 4. Bucket Reduction                                                                     /
232    //                                                                                         /
233    // Performs a parallelized running-sum by computing a serieds of point additions,          /
234    // followed by a scalar multiplication (Algorithm 4 of the cuZK paper).                    /
235    /////////////////////////////////////////////////////////////////////////////////////////////
236
237    // This is a dynamic variable that determines the number of CSR
238    // matrices processed per invocation of the BPR shader. A safe default is 1.
239    let num_subtasks_per_bpr_1 = 16;
240
241    let b_num_x_workgroups = num_subtasks_per_bpr_1;
242    let b_num_y_workgroups = 1;
243    let b_num_z_workgroups = 1;
244    let b_workgroup_size = 256;
245
246    // Buffers that store the bucket points reduction (BPR) output.
247    let g_points_coord_bytelength = num_subtasks * b_workgroup_size * num_words * 4;
248    let g_points_x_sb = create_storage_buffer(
249        Some("Bucket points reduction X buffer"),
250        &device,
251        g_points_coord_bytelength as u64,
252    );
253    let g_points_y_sb = create_storage_buffer(
254        Some("Bucket points reduction Y buffer"),
255        &device,
256        g_points_coord_bytelength as u64,
257    );
258    let g_points_z_sb = create_storage_buffer(
259        Some("Bucket points reduction Z buffer"),
260        &device,
261        g_points_coord_bytelength as u64,
262    );
263
264    let bpr_shader = shader_manager.gen_bpr_shader(b_workgroup_size);
265
266    // Stage 1: Bucket points reduction (BPR)
267    for subtask_idx in (0..num_subtasks).step_by(num_subtasks_per_bpr_1) {
268        bpr_1(
269            &bpr_shader,
270            subtask_idx,
271            b_num_x_workgroups,
272            b_num_y_workgroups,
273            b_num_z_workgroups,
274            num_columns,
275            &device,
276            &queue,
277            &mut encoder,
278            &bucket_sum_x_sb,
279            &bucket_sum_y_sb,
280            &bucket_sum_z_sb,
281            &g_points_x_sb,
282            &g_points_y_sb,
283            &g_points_z_sb,
284        )
285        .await;
286    }
287
288    let num_subtasks_per_bpr_2 = 16;
289    let b_2_num_x_workgroups = num_subtasks_per_bpr_2;
290
291    // Stage 2: Bucket points reduction (BPR).
292    for subtask_idx in (0..num_subtasks).step_by(num_subtasks_per_bpr_2) {
293        bpr_2(
294            &bpr_shader,
295            subtask_idx,
296            b_2_num_x_workgroups,
297            1,
298            1,
299            num_columns,
300            &device,
301            &queue,
302            &mut encoder,
303            &bucket_sum_x_sb,
304            &bucket_sum_y_sb,
305            &bucket_sum_z_sb,
306            &g_points_x_sb,
307            &g_points_y_sb,
308            &g_points_z_sb,
309        )
310        .await;
311    }
312
313    // Map results back from GPU to CPU.
314    let data = read_from_gpu(
315        &device,
316        &queue,
317        encoder,
318        vec![g_points_x_sb, g_points_y_sb, g_points_z_sb],
319    )
320    .await;
321
322    // Destroy the GPU device object.
323    device.destroy();
324
325    let mut points = vec![];
326
327    let g_points_x = bytemuck::cast_slice::<u8, u32>(&data[0])
328        .chunks(num_words)
329        .map(|x| {
330            let x_biguint_montgomery = to_biguint_le(x, num_words, WORD_SIZE as u32);
331            let x_biguint = x_biguint_montgomery * &PARAMS.rinv % P.clone();
332            
333            bytes_to_field(&x_biguint.to_bytes_le())
334        })
335        .collect::<Vec<_>>();
336    let g_points_y = bytemuck::cast_slice::<u8, u32>(&data[1])
337        .chunks(num_words)
338        .map(|y| {
339            let y_biguint_montgomery = to_biguint_le(y, num_words, WORD_SIZE as u32);
340            let y_biguint = y_biguint_montgomery * &PARAMS.rinv % P.clone();
341            
342            bytes_to_field(&y_biguint.to_bytes_le())
343        })
344        .collect::<Vec<_>>();
345    let g_points_z = bytemuck::cast_slice::<u8, u32>(&data[2])
346        .chunks(num_words)
347        .map(|z| {
348            let z_biguint_montgomery = to_biguint_le(z, num_words, WORD_SIZE as u32);
349            let z_biguint = z_biguint_montgomery * &PARAMS.rinv % P.clone();
350            
351            bytes_to_field(&z_biguint.to_bytes_le())
352        })
353        .collect::<Vec<_>>();
354
355    // TODO: Use from_montgomery_repr passing a valid R^2 as a parameter for performance
356    // let g_points_x = data[0]
357    //     .chunks(num_words * 4)
358    //     .map(|x| {
359    //         let x_field = u8s_to_field_without_assertion(&x, num_words, WORD_SIZE);
360    //         x_field
361    //     })
362    //     .collect::<Vec<_>>();
363
364    for i in 0..num_subtasks {
365        let mut point = C::Curve::identity();
366        for j in 0..b_workgroup_size {
367            let reduced_point = C::Curve::new_jacobian(
368                g_points_x[i * b_workgroup_size + j],
369                g_points_y[i * b_workgroup_size + j],
370                g_points_z[i * b_workgroup_size + j],
371            )
372            .unwrap();
373            point += reduced_point;
374        }
375        points.push(point);
376    }
377
378    ////////////////////////////////////////////////////////////////////////////////////////////
379    // 5. Horner's Method                                                                     /
380    //                                                                                        /
381    // Calculate the final result using Horner's method (Formula 3 of the cuZK paper)         /
382    ////////////////////////////////////////////////////////////////////////////////////////////
383
384    let m = C::ScalarExt::from(1 << chunk_size);
385    let mut result = points[points.len() - 1];
386    for i in (0..points.len() - 1).rev() {
387        result = result * m + points[i];
388    }
389    result
390}
391
392/****************************************************** WGSL Shader Invocations ******************************************************/
393
394/*
395 * Prepares and executes the shader for decomposing scalars into chunk_size
396 * windows using the signed bucket index technique.
397 *
398 * ASSUMPTION: the vast majority of WebGPU-enabled consumer devices have a
399 * maximum buffer size of at least 268435456 bytes.
400 *
401 * The default maximum buffer size is 268435456 bytes. Since each point
402 * consumes 320 bytes, a maximum of around 2 ** 19 points can be stored in a
403 * single buffer. If, however, we use 2 buffers - one for each point coordinate
404 * X and Y - we can support larger input sizes.
405 * Our implementation, however, will only support up to 2 ** 20 points.
406 *
407 * Furthremore, there is a limit of 8 storage buffers per shader. As such, we
408 * do not calculate the T and Z coordinates in this shader. Rather, we do so in
409 * the SMVP shader.
410 *
411 */
412
413/// Convert point coordinates and decompose shaders
414pub async fn convert_point_coords_and_decompose_shaders(
415    shader_code: &str,
416    num_x_workgroups: usize,
417    num_y_workgroups: usize,
418    num_z_workgroups: usize,
419    device: &Device,
420    queue: &Queue,
421    encoder: &mut CommandEncoder,
422    points_bytes: &[u8],
423    scalars_bytes: &[u8],
424    num_subtasks: usize,
425    chunk_size: usize,
426    num_words: usize,
427) -> (Buffer, Buffer, Buffer) {
428    assert!(num_subtasks * chunk_size == 256);
429    let input_size = scalars_bytes.len() / 32;
430    let points_sb = create_and_write_storage_buffer(Some("Points buffer"), device, points_bytes);
431    let scalars_sb = create_and_write_storage_buffer(Some("Scalars buffer"), device, scalars_bytes);
432
433    let points_x_sb = create_storage_buffer(
434        Some("Point X buffer"),
435        device,
436        (input_size * num_words * 4) as u64,
437    );
438    let points_y_sb = create_storage_buffer(
439        Some("Point Y buffer"),
440        device,
441        (input_size * num_words * 4) as u64,
442    );
443    // Output storage buffers.
444    let scalar_chunks_sb = create_storage_buffer(
445        Some("Scalar chunks buffer"),
446        device,
447        (input_size * num_subtasks * 4) as u64, // TODO: Check this
448    );
449
450    // Uniform storage buffer.
451    let params_bytes = to_u8s_for_gpu([input_size].to_vec());
452    let params_ub =
453        create_and_write_uniform_buffer(Some("Params buffer"), device, queue, &params_bytes);
454
455    let bind_group_layout = create_bind_group_layout(
456        Some("Bind group layout"),
457        device,
458        vec![&points_sb, &scalars_sb],
459        vec![&points_x_sb, &points_y_sb, &scalar_chunks_sb],
460        vec![&params_ub],
461    );
462
463    let bind_group = create_bind_group(
464        Some("Bind group"),
465        device,
466        &bind_group_layout,
467        vec![
468            &points_sb,
469            &scalars_sb,
470            &points_x_sb,
471            &points_y_sb,
472            &scalar_chunks_sb,
473            &params_ub,
474        ],
475    );
476
477    let compute_pipeline = create_compute_pipeline(
478        Some("Convert point coords and decompose shader"),
479        device,
480        &bind_group_layout,
481        shader_code,
482        "main",
483    )
484    .await;
485
486    execute_pipeline(
487        encoder,
488        compute_pipeline,
489        bind_group,
490        num_x_workgroups as u32,
491        num_y_workgroups as u32,
492        num_z_workgroups as u32,
493    )
494    .await;
495
496    (points_x_sb, points_y_sb, scalar_chunks_sb)
497}
498
499/*
500 * Perform a modified version of CSR matrix transposition, which comes before
501 * SMVP. Essentially, this step generates the point indices for each thread in
502 * the SMVP step which corresponds to a particular bucket.
503 */
504pub async fn transpose_gpu(
505    shader_code: &str,
506    device: &Device,
507    queue: &Queue,
508    command_encoder: &mut CommandEncoder,
509    num_x_workgroups: usize,
510    num_y_workgroups: usize,
511    num_z_workgroups: usize,
512    input_size: usize,
513    num_columns: usize,
514    num_rows: usize,
515    num_subtasks: usize,
516    scalar_chunks_sb: Buffer,
517) -> (Buffer, Buffer) {
518    // Input storage buffers.
519    let all_csc_col_ptr_sb = create_storage_buffer(
520        Some("All CSC col"),
521        device,
522        (num_subtasks * (num_columns + 1) * 4) as u64,
523    );
524    let all_csc_val_idxs_sb =
525        create_storage_buffer(Some("All CSC Val Indexes"), device, scalar_chunks_sb.size());
526    let all_curr_sb = create_storage_buffer(
527        Some("All Current"),
528        device,
529        (num_subtasks * num_columns * 4) as u64,
530    );
531
532    // Uniform storage buffer.
533    let params_bytes = to_u8s_for_gpu([num_rows, num_columns, input_size].to_vec());
534    let params_ub = create_and_write_uniform_buffer(
535        Some("Transpose GPU Uniform Params"),
536        device,
537        queue,
538        &params_bytes,
539    );
540
541    let bind_group_layout = create_bind_group_layout(
542        Some("Transpose GPU Bind Group Layout"),
543        device,
544        vec![&scalar_chunks_sb],
545        vec![&all_csc_col_ptr_sb, &all_csc_val_idxs_sb, &all_curr_sb],
546        vec![&params_ub],
547    );
548
549    let bind_group = create_bind_group(
550        Some("Transpose GPU Bind Group"),
551        device,
552        &bind_group_layout,
553        vec![
554            &scalar_chunks_sb,
555            &all_csc_col_ptr_sb,
556            &all_csc_val_idxs_sb,
557            &all_curr_sb,
558            &params_ub,
559        ],
560    );
561
562    let compute_pipeline = create_compute_pipeline(
563        Some("Transpose GPU Compute Pipeline"),
564        device,
565        &bind_group_layout,
566        shader_code,
567        "main",
568    )
569    .await;
570
571    execute_pipeline(
572        command_encoder,
573        compute_pipeline,
574        bind_group,
575        num_x_workgroups as u32,
576        num_y_workgroups as u32,
577        num_z_workgroups as u32,
578    )
579    .await;
580
581    (all_csc_col_ptr_sb, all_csc_val_idxs_sb)
582}
583
584// TODO: Use bytemuck
585pub fn to_u8s_for_gpu(vals: Vec<usize>) -> Vec<u8> {
586    let max: u64 = 1 << 32;
587    let mut buf = vec![];
588    for val in vals {
589        assert!((val as u64) < max);
590        buf.extend_from_slice(&(val as u32).to_le_bytes());
591    }
592    buf
593}
594
595/*
596 * Compute the bucket sums and perform scalar multiplication with the bucket indices.
597 */
598pub async fn smvp_gpu(
599    shader_code: &str,
600    num_x_workgroups: usize,
601    num_y_workgroups: usize,
602    num_z_workgroups: usize,
603    offset: usize,
604    device: &Device,
605    queue: &Queue,
606    command_encoder: &mut CommandEncoder,
607    input_size: usize,
608    all_csc_col_ptr_sb: &Buffer,
609    point_x_sb: &Buffer,
610    point_y_sb: &Buffer,
611    all_csc_val_idxs_sb: &Buffer,
612    bucket_sum_x_sb: &Buffer,
613    bucket_sum_y_sb: &Buffer,
614    bucket_sum_z_sb: &Buffer,
615) {
616    // Uniform Storage Buffer.
617    let params_bytes = to_u8s_for_gpu(vec![input_size, num_y_workgroups, num_z_workgroups, offset]);
618    let params_ub = create_and_write_uniform_buffer(None, device, queue, &params_bytes);
619
620    let bind_group_layout = create_bind_group_layout(
621        Some("Bind group layout"),
622        device,
623        vec![
624            &all_csc_col_ptr_sb,
625            &all_csc_val_idxs_sb,
626            &point_x_sb,
627            &point_y_sb,
628        ],
629        vec![&bucket_sum_x_sb, &bucket_sum_y_sb, &bucket_sum_z_sb],
630        vec![&params_ub],
631    );
632
633    let bind_group = create_bind_group(
634        Some("Bind group"),
635        device,
636        &bind_group_layout,
637        vec![
638            &all_csc_col_ptr_sb,
639            &all_csc_val_idxs_sb,
640            &point_x_sb,
641            &point_y_sb,
642            &bucket_sum_x_sb,
643            &bucket_sum_y_sb,
644            &bucket_sum_z_sb,
645            &params_ub,
646        ],
647    );
648
649    let compute_pipeline = create_compute_pipeline(
650        Some("Compute pipeline"),
651        device,
652        &bind_group_layout,
653        shader_code,
654        "main",
655    )
656    .await;
657
658    execute_pipeline(
659        command_encoder,
660        compute_pipeline,
661        bind_group,
662        num_x_workgroups as u32,
663        num_y_workgroups as u32,
664        num_z_workgroups as u32,
665    )
666    .await;
667}
668
669/// Batch product reduction shader 1
670pub async fn bpr_1(
671    shader_code: &str,
672    subtask_idx: usize,
673    num_x_workgroups: usize,
674    num_y_workgroups: usize,
675    num_z_workgroups: usize,
676    num_columns: usize,
677    device: &Device,
678    queue: &Queue,
679    command_encoder: &mut CommandEncoder,
680    bucket_sum_x_sb: &Buffer,
681    bucket_sum_y_sb: &Buffer,
682    bucket_sum_z_sb: &Buffer,
683    g_points_x_sb: &Buffer,
684    g_points_y_sb: &Buffer,
685    g_points_z_sb: &Buffer,
686) {
687    // Uniform storage buffer.
688    let params_bytes = to_u8s_for_gpu(vec![subtask_idx, num_columns, num_x_workgroups]);
689    let params_ub = create_and_write_uniform_buffer(None, device, queue, &params_bytes);
690
691    let bind_group_layout = create_bind_group_layout(
692        Some("Bind group layout"),
693        device,
694        vec![],
695        vec![
696            &bucket_sum_x_sb,
697            &bucket_sum_y_sb,
698            &bucket_sum_z_sb,
699            &g_points_x_sb,
700            &g_points_y_sb,
701            &g_points_z_sb,
702        ],
703        vec![&params_ub],
704    );
705
706    let bind_group = create_bind_group(
707        Some("Bind group"),
708        device,
709        &bind_group_layout,
710        vec![
711            &bucket_sum_x_sb,
712            &bucket_sum_y_sb,
713            &bucket_sum_z_sb,
714            &g_points_x_sb,
715            &g_points_y_sb,
716            &g_points_z_sb,
717            &params_ub,
718        ],
719    );
720
721    let compute_pipeline = create_compute_pipeline(
722        Some("Compute pipeline"),
723        device,
724        &bind_group_layout,
725        shader_code,
726        "stage_1",
727    )
728    .await;
729
730    execute_pipeline(
731        command_encoder,
732        compute_pipeline,
733        bind_group,
734        num_x_workgroups as u32,
735        num_y_workgroups as u32,
736        num_z_workgroups as u32,
737    )
738    .await;
739}
740
741/// Batch product reduction shader 2
742pub async fn bpr_2(
743    shader_code: &str,
744    subtask_idx: usize,
745    num_x_workgroups: usize,
746    num_y_workgroups: usize,
747    num_z_workgroups: usize,
748    num_columns: usize,
749    device: &Device,
750    queue: &Queue,
751    command_encoder: &mut CommandEncoder,
752    bucket_sum_x_sb: &Buffer,
753    bucket_sum_y_sb: &Buffer,
754    bucket_sum_z_sb: &Buffer,
755    g_points_x_sb: &Buffer,
756    g_points_y_sb: &Buffer,
757    g_points_z_sb: &Buffer,
758) {
759    // Uniform storage buffer.
760    let params_bytes = to_u8s_for_gpu(vec![subtask_idx, num_columns, num_x_workgroups]);
761    let params_ub = create_and_write_uniform_buffer(None, device, queue, &params_bytes);
762
763    let bind_group_layout = create_bind_group_layout(
764        Some("Bind group layout"),
765        device,
766        vec![],
767        vec![
768            &bucket_sum_x_sb,
769            &bucket_sum_y_sb,
770            &bucket_sum_z_sb,
771            &g_points_x_sb,
772            &g_points_y_sb,
773            &g_points_z_sb,
774        ],
775        vec![&params_ub],
776    );
777
778    let bind_group = create_bind_group(
779        Some("Bind group"),
780        device,
781        &bind_group_layout,
782        vec![
783            &bucket_sum_x_sb,
784            &bucket_sum_y_sb,
785            &bucket_sum_z_sb,
786            &g_points_x_sb,
787            &g_points_y_sb,
788            &g_points_z_sb,
789            &params_ub,
790        ],
791    );
792
793    let compute_pipeline = create_compute_pipeline(
794        Some("Compute pipeline"),
795        device,
796        &bind_group_layout,
797        shader_code,
798        "stage_2",
799    )
800    .await;
801
802    execute_pipeline(
803        command_encoder,
804        compute_pipeline,
805        bind_group,
806        num_x_workgroups as u32,
807        num_y_workgroups as u32,
808        num_z_workgroups as u32,
809    )
810    .await;
811}