1use group::Group;
2use halo2curves::CurveAffine;
3use halo2curves::CurveExt;
4use num_bigint::BigUint;
5use num_traits::Num;
6use once_cell::sync::Lazy;
7use wgpu::{Buffer, CommandEncoder, CommandEncoderDescriptor, Device, Queue};
8
9use crate::cuzk::gpu::{
10 create_and_write_storage_buffer, create_and_write_uniform_buffer, create_bind_group,
11 create_bind_group_layout, create_compute_pipeline, create_storage_buffer, execute_pipeline,
12 get_adapter, get_device, read_from_gpu,
13};
14use crate::cuzk::shader_manager::ShaderManager;
15use crate::cuzk::utils::to_biguint_le;
16use crate::{points_to_bytes, scalars_to_bytes};
17
18use super::utils::bytes_to_field;
19use super::utils::calc_bitwidth;
20use super::utils::{MiscParams, compute_misc_params};
21
22pub fn calc_num_words(word_size: usize) -> usize {
24 let p_bit_length = calc_bitwidth(&P);
25 let mut num_words = p_bit_length / word_size;
26 while num_words * word_size < p_bit_length {
27 num_words += 1;
28 }
29 num_words
30}
31
32pub const WORD_SIZE: usize = 13;
34
35pub static P: Lazy<BigUint> = Lazy::new(|| {
37 BigUint::from_str_radix(
38 "21888242871839275222246405745257275088696311157297823662689037894645226208583",
39 10,
40 )
41 .expect("Invalid modulus")
42});
43
44pub static PARAMS: Lazy<MiscParams> = Lazy::new(|| compute_misc_params(&P, WORD_SIZE));
46
47pub async fn compute_msm<C: CurveAffine>(points: &[C], scalars: &[C::Scalar]) -> C::Curve {
52 let input_size = scalars.len();
53 let chunk_size = if input_size >= 65536 { 16 } else { 4 };
54 let num_columns = 1 << chunk_size;
55 let num_rows = input_size.div_ceil(num_columns);
56 let num_subtasks = 256_usize.div_ceil(chunk_size);
57 let num_words = PARAMS.num_words;
58 let point_bytes = points_to_bytes(points);
59 let scalar_bytes = scalars_to_bytes(scalars);
60
61 let shader_manager = ShaderManager::new(WORD_SIZE, chunk_size, input_size);
62
63 let adapter = get_adapter().await;
64 let (device, queue) = get_device(&adapter).await;
65 let mut encoder = device.create_command_encoder(&CommandEncoderDescriptor {
66 label: Some("MSM Encoder"),
67 });
68
69 let mut c_workgroup_size = 64;
75 let mut c_num_x_workgroups = 128;
76 let mut c_num_y_workgroups = input_size / c_workgroup_size / c_num_x_workgroups;
77 let c_num_z_workgroups = 1;
78
79 if input_size <= 256 {
80 c_workgroup_size = input_size;
81 c_num_x_workgroups = 1;
82 c_num_y_workgroups = 1;
83 } else if input_size > 256 && input_size <= 32768 {
84 c_workgroup_size = 64;
85 c_num_x_workgroups = 4;
86 c_num_y_workgroups = input_size / c_workgroup_size / c_num_x_workgroups;
87 } else if input_size > 32768 && input_size <= 131072 {
88 c_workgroup_size = 256;
89 c_num_x_workgroups = 8;
90 c_num_y_workgroups = input_size / c_workgroup_size / c_num_x_workgroups;
91 } else if input_size > 131072 && input_size <= 1048576 {
92 c_workgroup_size = 256;
93 c_num_x_workgroups = 32;
94 c_num_y_workgroups = input_size / c_workgroup_size / c_num_x_workgroups;
95 }
96
97 let c_shader = shader_manager.gen_decomp_scalars_shader(
98 c_workgroup_size,
99 c_num_y_workgroups,
100 num_subtasks,
101 num_columns,
102 );
103
104 let (point_x_sb, point_y_sb, scalar_chunks_sb) = convert_point_coords_and_decompose_shaders(
105 &c_shader,
106 c_num_x_workgroups,
107 c_num_y_workgroups,
108 c_num_z_workgroups,
109 &device,
110 &queue,
111 &mut encoder,
112 &point_bytes,
113 &scalar_bytes,
114 num_subtasks,
115 chunk_size,
116 num_words,
117 )
118 .await;
119
120 let t_num_x_workgroups = 1;
135 let t_num_y_workgroups = 1;
136 let t_num_z_workgroups = 1;
137
138 let t_shader = shader_manager.gen_transpose_shader(num_subtasks);
139
140 let (all_csc_col_ptr_sb, all_csc_val_idxs_sb) = transpose_gpu(
141 &t_shader,
142 &device,
143 &queue,
144 &mut encoder,
145 t_num_x_workgroups,
146 t_num_y_workgroups,
147 t_num_z_workgroups,
148 input_size,
149 num_columns,
150 num_rows,
151 num_subtasks,
152 scalar_chunks_sb,
153 )
154 .await;
155
156 let half_num_columns = num_columns / 2;
165 let mut s_workgroup_size = 256;
166 let mut s_num_x_workgroups = 64;
167 let mut s_num_y_workgroups = (half_num_columns / s_workgroup_size) / s_num_x_workgroups;
168 let mut s_num_z_workgroups = num_subtasks;
169
170 if half_num_columns < 32768 {
171 s_workgroup_size = 32;
172 s_num_x_workgroups = 1;
173 s_num_y_workgroups =
174 (half_num_columns / s_workgroup_size).div_ceil(s_num_x_workgroups);
175 }
176
177 if num_columns < 256 {
178 s_workgroup_size = 1;
179 s_num_x_workgroups = half_num_columns;
180 s_num_y_workgroups = 1;
181 s_num_z_workgroups = 1;
182 }
183
184 let num_subtask_chunk_size = 4;
187
188 let bucket_sum_coord_bytelength = (num_columns / 2) * num_words * 4 * num_subtasks;
191 let bucket_sum_x_sb = create_storage_buffer(
192 Some("Bucket sum X buffer"),
193 &device,
194 bucket_sum_coord_bytelength as u64,
195 );
196 let bucket_sum_y_sb = create_storage_buffer(
197 Some("Bucket sum Y buffer"),
198 &device,
199 bucket_sum_coord_bytelength as u64,
200 );
201 let bucket_sum_z_sb = create_storage_buffer(
202 Some("Bucket sum Z buffer"),
203 &device,
204 bucket_sum_coord_bytelength as u64,
205 );
206 let smvp_shader = shader_manager.gen_smvp_shader(s_workgroup_size, num_columns);
207
208 for offset in (0..num_subtasks).step_by(num_subtask_chunk_size) {
209 smvp_gpu(
210 &smvp_shader,
211 s_num_x_workgroups / (num_subtasks / num_subtask_chunk_size),
212 s_num_y_workgroups,
213 s_num_z_workgroups,
214 offset,
215 &device,
216 &queue,
217 &mut encoder,
218 input_size,
219 &all_csc_col_ptr_sb,
220 &point_x_sb,
221 &point_y_sb,
222 &all_csc_val_idxs_sb,
223 &bucket_sum_x_sb,
224 &bucket_sum_y_sb,
225 &bucket_sum_z_sb,
226 )
227 .await;
228 }
229
230 let num_subtasks_per_bpr_1 = 16;
240
241 let b_num_x_workgroups = num_subtasks_per_bpr_1;
242 let b_num_y_workgroups = 1;
243 let b_num_z_workgroups = 1;
244 let b_workgroup_size = 256;
245
246 let g_points_coord_bytelength = num_subtasks * b_workgroup_size * num_words * 4;
248 let g_points_x_sb = create_storage_buffer(
249 Some("Bucket points reduction X buffer"),
250 &device,
251 g_points_coord_bytelength as u64,
252 );
253 let g_points_y_sb = create_storage_buffer(
254 Some("Bucket points reduction Y buffer"),
255 &device,
256 g_points_coord_bytelength as u64,
257 );
258 let g_points_z_sb = create_storage_buffer(
259 Some("Bucket points reduction Z buffer"),
260 &device,
261 g_points_coord_bytelength as u64,
262 );
263
264 let bpr_shader = shader_manager.gen_bpr_shader(b_workgroup_size);
265
266 for subtask_idx in (0..num_subtasks).step_by(num_subtasks_per_bpr_1) {
268 bpr_1(
269 &bpr_shader,
270 subtask_idx,
271 b_num_x_workgroups,
272 b_num_y_workgroups,
273 b_num_z_workgroups,
274 num_columns,
275 &device,
276 &queue,
277 &mut encoder,
278 &bucket_sum_x_sb,
279 &bucket_sum_y_sb,
280 &bucket_sum_z_sb,
281 &g_points_x_sb,
282 &g_points_y_sb,
283 &g_points_z_sb,
284 )
285 .await;
286 }
287
288 let num_subtasks_per_bpr_2 = 16;
289 let b_2_num_x_workgroups = num_subtasks_per_bpr_2;
290
291 for subtask_idx in (0..num_subtasks).step_by(num_subtasks_per_bpr_2) {
293 bpr_2(
294 &bpr_shader,
295 subtask_idx,
296 b_2_num_x_workgroups,
297 1,
298 1,
299 num_columns,
300 &device,
301 &queue,
302 &mut encoder,
303 &bucket_sum_x_sb,
304 &bucket_sum_y_sb,
305 &bucket_sum_z_sb,
306 &g_points_x_sb,
307 &g_points_y_sb,
308 &g_points_z_sb,
309 )
310 .await;
311 }
312
313 let data = read_from_gpu(
315 &device,
316 &queue,
317 encoder,
318 vec![g_points_x_sb, g_points_y_sb, g_points_z_sb],
319 )
320 .await;
321
322 device.destroy();
324
325 let mut points = vec![];
326
327 let g_points_x = bytemuck::cast_slice::<u8, u32>(&data[0])
328 .chunks(num_words)
329 .map(|x| {
330 let x_biguint_montgomery = to_biguint_le(x, num_words, WORD_SIZE as u32);
331 let x_biguint = x_biguint_montgomery * &PARAMS.rinv % P.clone();
332
333 bytes_to_field(&x_biguint.to_bytes_le())
334 })
335 .collect::<Vec<_>>();
336 let g_points_y = bytemuck::cast_slice::<u8, u32>(&data[1])
337 .chunks(num_words)
338 .map(|y| {
339 let y_biguint_montgomery = to_biguint_le(y, num_words, WORD_SIZE as u32);
340 let y_biguint = y_biguint_montgomery * &PARAMS.rinv % P.clone();
341
342 bytes_to_field(&y_biguint.to_bytes_le())
343 })
344 .collect::<Vec<_>>();
345 let g_points_z = bytemuck::cast_slice::<u8, u32>(&data[2])
346 .chunks(num_words)
347 .map(|z| {
348 let z_biguint_montgomery = to_biguint_le(z, num_words, WORD_SIZE as u32);
349 let z_biguint = z_biguint_montgomery * &PARAMS.rinv % P.clone();
350
351 bytes_to_field(&z_biguint.to_bytes_le())
352 })
353 .collect::<Vec<_>>();
354
355 for i in 0..num_subtasks {
365 let mut point = C::Curve::identity();
366 for j in 0..b_workgroup_size {
367 let reduced_point = C::Curve::new_jacobian(
368 g_points_x[i * b_workgroup_size + j],
369 g_points_y[i * b_workgroup_size + j],
370 g_points_z[i * b_workgroup_size + j],
371 )
372 .unwrap();
373 point += reduced_point;
374 }
375 points.push(point);
376 }
377
378 let m = C::ScalarExt::from(1 << chunk_size);
385 let mut result = points[points.len() - 1];
386 for i in (0..points.len() - 1).rev() {
387 result = result * m + points[i];
388 }
389 result
390}
391
392pub async fn convert_point_coords_and_decompose_shaders(
415 shader_code: &str,
416 num_x_workgroups: usize,
417 num_y_workgroups: usize,
418 num_z_workgroups: usize,
419 device: &Device,
420 queue: &Queue,
421 encoder: &mut CommandEncoder,
422 points_bytes: &[u8],
423 scalars_bytes: &[u8],
424 num_subtasks: usize,
425 chunk_size: usize,
426 num_words: usize,
427) -> (Buffer, Buffer, Buffer) {
428 assert!(num_subtasks * chunk_size == 256);
429 let input_size = scalars_bytes.len() / 32;
430 let points_sb = create_and_write_storage_buffer(Some("Points buffer"), device, points_bytes);
431 let scalars_sb = create_and_write_storage_buffer(Some("Scalars buffer"), device, scalars_bytes);
432
433 let points_x_sb = create_storage_buffer(
434 Some("Point X buffer"),
435 device,
436 (input_size * num_words * 4) as u64,
437 );
438 let points_y_sb = create_storage_buffer(
439 Some("Point Y buffer"),
440 device,
441 (input_size * num_words * 4) as u64,
442 );
443 let scalar_chunks_sb = create_storage_buffer(
445 Some("Scalar chunks buffer"),
446 device,
447 (input_size * num_subtasks * 4) as u64, );
449
450 let params_bytes = to_u8s_for_gpu([input_size].to_vec());
452 let params_ub =
453 create_and_write_uniform_buffer(Some("Params buffer"), device, queue, ¶ms_bytes);
454
455 let bind_group_layout = create_bind_group_layout(
456 Some("Bind group layout"),
457 device,
458 vec![&points_sb, &scalars_sb],
459 vec![&points_x_sb, &points_y_sb, &scalar_chunks_sb],
460 vec![¶ms_ub],
461 );
462
463 let bind_group = create_bind_group(
464 Some("Bind group"),
465 device,
466 &bind_group_layout,
467 vec![
468 &points_sb,
469 &scalars_sb,
470 &points_x_sb,
471 &points_y_sb,
472 &scalar_chunks_sb,
473 ¶ms_ub,
474 ],
475 );
476
477 let compute_pipeline = create_compute_pipeline(
478 Some("Convert point coords and decompose shader"),
479 device,
480 &bind_group_layout,
481 shader_code,
482 "main",
483 )
484 .await;
485
486 execute_pipeline(
487 encoder,
488 compute_pipeline,
489 bind_group,
490 num_x_workgroups as u32,
491 num_y_workgroups as u32,
492 num_z_workgroups as u32,
493 )
494 .await;
495
496 (points_x_sb, points_y_sb, scalar_chunks_sb)
497}
498
499pub async fn transpose_gpu(
505 shader_code: &str,
506 device: &Device,
507 queue: &Queue,
508 command_encoder: &mut CommandEncoder,
509 num_x_workgroups: usize,
510 num_y_workgroups: usize,
511 num_z_workgroups: usize,
512 input_size: usize,
513 num_columns: usize,
514 num_rows: usize,
515 num_subtasks: usize,
516 scalar_chunks_sb: Buffer,
517) -> (Buffer, Buffer) {
518 let all_csc_col_ptr_sb = create_storage_buffer(
520 Some("All CSC col"),
521 device,
522 (num_subtasks * (num_columns + 1) * 4) as u64,
523 );
524 let all_csc_val_idxs_sb =
525 create_storage_buffer(Some("All CSC Val Indexes"), device, scalar_chunks_sb.size());
526 let all_curr_sb = create_storage_buffer(
527 Some("All Current"),
528 device,
529 (num_subtasks * num_columns * 4) as u64,
530 );
531
532 let params_bytes = to_u8s_for_gpu([num_rows, num_columns, input_size].to_vec());
534 let params_ub = create_and_write_uniform_buffer(
535 Some("Transpose GPU Uniform Params"),
536 device,
537 queue,
538 ¶ms_bytes,
539 );
540
541 let bind_group_layout = create_bind_group_layout(
542 Some("Transpose GPU Bind Group Layout"),
543 device,
544 vec![&scalar_chunks_sb],
545 vec![&all_csc_col_ptr_sb, &all_csc_val_idxs_sb, &all_curr_sb],
546 vec![¶ms_ub],
547 );
548
549 let bind_group = create_bind_group(
550 Some("Transpose GPU Bind Group"),
551 device,
552 &bind_group_layout,
553 vec![
554 &scalar_chunks_sb,
555 &all_csc_col_ptr_sb,
556 &all_csc_val_idxs_sb,
557 &all_curr_sb,
558 ¶ms_ub,
559 ],
560 );
561
562 let compute_pipeline = create_compute_pipeline(
563 Some("Transpose GPU Compute Pipeline"),
564 device,
565 &bind_group_layout,
566 shader_code,
567 "main",
568 )
569 .await;
570
571 execute_pipeline(
572 command_encoder,
573 compute_pipeline,
574 bind_group,
575 num_x_workgroups as u32,
576 num_y_workgroups as u32,
577 num_z_workgroups as u32,
578 )
579 .await;
580
581 (all_csc_col_ptr_sb, all_csc_val_idxs_sb)
582}
583
584pub fn to_u8s_for_gpu(vals: Vec<usize>) -> Vec<u8> {
586 let max: u64 = 1 << 32;
587 let mut buf = vec![];
588 for val in vals {
589 assert!((val as u64) < max);
590 buf.extend_from_slice(&(val as u32).to_le_bytes());
591 }
592 buf
593}
594
595pub async fn smvp_gpu(
599 shader_code: &str,
600 num_x_workgroups: usize,
601 num_y_workgroups: usize,
602 num_z_workgroups: usize,
603 offset: usize,
604 device: &Device,
605 queue: &Queue,
606 command_encoder: &mut CommandEncoder,
607 input_size: usize,
608 all_csc_col_ptr_sb: &Buffer,
609 point_x_sb: &Buffer,
610 point_y_sb: &Buffer,
611 all_csc_val_idxs_sb: &Buffer,
612 bucket_sum_x_sb: &Buffer,
613 bucket_sum_y_sb: &Buffer,
614 bucket_sum_z_sb: &Buffer,
615) {
616 let params_bytes = to_u8s_for_gpu(vec![input_size, num_y_workgroups, num_z_workgroups, offset]);
618 let params_ub = create_and_write_uniform_buffer(None, device, queue, ¶ms_bytes);
619
620 let bind_group_layout = create_bind_group_layout(
621 Some("Bind group layout"),
622 device,
623 vec![
624 &all_csc_col_ptr_sb,
625 &all_csc_val_idxs_sb,
626 &point_x_sb,
627 &point_y_sb,
628 ],
629 vec![&bucket_sum_x_sb, &bucket_sum_y_sb, &bucket_sum_z_sb],
630 vec![¶ms_ub],
631 );
632
633 let bind_group = create_bind_group(
634 Some("Bind group"),
635 device,
636 &bind_group_layout,
637 vec![
638 &all_csc_col_ptr_sb,
639 &all_csc_val_idxs_sb,
640 &point_x_sb,
641 &point_y_sb,
642 &bucket_sum_x_sb,
643 &bucket_sum_y_sb,
644 &bucket_sum_z_sb,
645 ¶ms_ub,
646 ],
647 );
648
649 let compute_pipeline = create_compute_pipeline(
650 Some("Compute pipeline"),
651 device,
652 &bind_group_layout,
653 shader_code,
654 "main",
655 )
656 .await;
657
658 execute_pipeline(
659 command_encoder,
660 compute_pipeline,
661 bind_group,
662 num_x_workgroups as u32,
663 num_y_workgroups as u32,
664 num_z_workgroups as u32,
665 )
666 .await;
667}
668
669pub async fn bpr_1(
671 shader_code: &str,
672 subtask_idx: usize,
673 num_x_workgroups: usize,
674 num_y_workgroups: usize,
675 num_z_workgroups: usize,
676 num_columns: usize,
677 device: &Device,
678 queue: &Queue,
679 command_encoder: &mut CommandEncoder,
680 bucket_sum_x_sb: &Buffer,
681 bucket_sum_y_sb: &Buffer,
682 bucket_sum_z_sb: &Buffer,
683 g_points_x_sb: &Buffer,
684 g_points_y_sb: &Buffer,
685 g_points_z_sb: &Buffer,
686) {
687 let params_bytes = to_u8s_for_gpu(vec![subtask_idx, num_columns, num_x_workgroups]);
689 let params_ub = create_and_write_uniform_buffer(None, device, queue, ¶ms_bytes);
690
691 let bind_group_layout = create_bind_group_layout(
692 Some("Bind group layout"),
693 device,
694 vec![],
695 vec![
696 &bucket_sum_x_sb,
697 &bucket_sum_y_sb,
698 &bucket_sum_z_sb,
699 &g_points_x_sb,
700 &g_points_y_sb,
701 &g_points_z_sb,
702 ],
703 vec![¶ms_ub],
704 );
705
706 let bind_group = create_bind_group(
707 Some("Bind group"),
708 device,
709 &bind_group_layout,
710 vec![
711 &bucket_sum_x_sb,
712 &bucket_sum_y_sb,
713 &bucket_sum_z_sb,
714 &g_points_x_sb,
715 &g_points_y_sb,
716 &g_points_z_sb,
717 ¶ms_ub,
718 ],
719 );
720
721 let compute_pipeline = create_compute_pipeline(
722 Some("Compute pipeline"),
723 device,
724 &bind_group_layout,
725 shader_code,
726 "stage_1",
727 )
728 .await;
729
730 execute_pipeline(
731 command_encoder,
732 compute_pipeline,
733 bind_group,
734 num_x_workgroups as u32,
735 num_y_workgroups as u32,
736 num_z_workgroups as u32,
737 )
738 .await;
739}
740
741pub async fn bpr_2(
743 shader_code: &str,
744 subtask_idx: usize,
745 num_x_workgroups: usize,
746 num_y_workgroups: usize,
747 num_z_workgroups: usize,
748 num_columns: usize,
749 device: &Device,
750 queue: &Queue,
751 command_encoder: &mut CommandEncoder,
752 bucket_sum_x_sb: &Buffer,
753 bucket_sum_y_sb: &Buffer,
754 bucket_sum_z_sb: &Buffer,
755 g_points_x_sb: &Buffer,
756 g_points_y_sb: &Buffer,
757 g_points_z_sb: &Buffer,
758) {
759 let params_bytes = to_u8s_for_gpu(vec![subtask_idx, num_columns, num_x_workgroups]);
761 let params_ub = create_and_write_uniform_buffer(None, device, queue, ¶ms_bytes);
762
763 let bind_group_layout = create_bind_group_layout(
764 Some("Bind group layout"),
765 device,
766 vec![],
767 vec![
768 &bucket_sum_x_sb,
769 &bucket_sum_y_sb,
770 &bucket_sum_z_sb,
771 &g_points_x_sb,
772 &g_points_y_sb,
773 &g_points_z_sb,
774 ],
775 vec![¶ms_ub],
776 );
777
778 let bind_group = create_bind_group(
779 Some("Bind group"),
780 device,
781 &bind_group_layout,
782 vec![
783 &bucket_sum_x_sb,
784 &bucket_sum_y_sb,
785 &bucket_sum_z_sb,
786 &g_points_x_sb,
787 &g_points_y_sb,
788 &g_points_z_sb,
789 ¶ms_ub,
790 ],
791 );
792
793 let compute_pipeline = create_compute_pipeline(
794 Some("Compute pipeline"),
795 device,
796 &bind_group_layout,
797 shader_code,
798 "stage_2",
799 )
800 .await;
801
802 execute_pipeline(
803 command_encoder,
804 compute_pipeline,
805 bind_group,
806 num_x_workgroups as u32,
807 num_y_workgroups as u32,
808 num_z_workgroups as u32,
809 )
810 .await;
811}