msm_webgpu/
lib.rs

1#![allow(clippy::too_many_arguments)]
2
3pub mod cuzk;
4
5use crate::cuzk::msm::compute_msm;
6use ff::PrimeField;
7use group::{Curve, Group};
8use rand::thread_rng;
9
10use halo2curves::{CurveAffine, msm::best_multiexp};
11
12use crate::cuzk::utils::field_to_bytes;
13
14/// Sample random scalars
15pub fn sample_scalars<F: PrimeField>(n: usize) -> Vec<F> {
16    let mut rng = thread_rng();
17    (0..n).map(|_| F::random(&mut rng)).collect::<Vec<_>>()
18}
19
20/// Sample random affine points
21pub fn sample_points<C: CurveAffine>(n: usize) -> Vec<C> {
22    let mut rng = thread_rng();
23    (0..n)
24        .map(|_| C::Curve::random(&mut rng).to_affine())
25        .collect::<Vec<_>>()
26}
27
28/// Run CPU MSM computation
29pub fn cpu_msm<C: CurveAffine>(g: &[C], v: &[C::Scalar]) -> C::Curve {
30    best_multiexp(v, g)
31}
32
33/// Convert scalars to bytes
34pub fn scalars_to_bytes<F: PrimeField>(v: &[F]) -> Vec<u8> {
35    v.iter().flat_map(|x| field_to_bytes(x)).collect::<Vec<_>>()
36}
37
38/// Convert points to bytes as [x0, y0, x1, y1, ...]
39pub fn points_to_bytes<C: CurveAffine>(g: &[C]) -> Vec<u8> {
40    g.iter()
41        .flat_map(|affine| {
42            let coords = affine.coordinates().unwrap();
43            let x = field_to_bytes(coords.x());
44            let y = field_to_bytes(coords.y());
45            [x, y].concat()
46        })
47        .collect::<Vec<_>>()
48}
49
50#[cfg(not(target_arch = "wasm32"))]
51/// Run WebGPU MSM computation synchronously
52pub fn run_webgpu_msm<C: CurveAffine>(g: &[C], v: &[C::Scalar]) -> C::Curve {
53    pollster::block_on(compute_msm(g, v))
54}
55
56#[cfg(target_arch = "wasm32")]
57/// Run WebGPU MSM computation asynchronously
58pub async fn run_webgpu_msm<C: CurveAffine>(g: &[C], v: &[C::Scalar]) -> C::Curve {
59    compute_msm(g, v).await
60}
61
62pub mod tests_wasm_pack {
63    use crate::cuzk::msm::compute_msm;
64
65    use super::*;
66
67    use halo2curves::bn256::{Fr, G1Affine};
68    use wasm_bindgen::prelude::*;
69    use web_sys::console;
70
71    #[wasm_bindgen]
72    extern "C" {
73        #[wasm_bindgen(js_namespace = performance)]
74        fn now() -> f64;
75    }
76
77    pub async fn test_webgpu_msm_cuzk(sample_size: usize) {
78        console::log_1(&format!("Testing with sample size: {sample_size}").into());
79        let points = sample_points::<G1Affine>(sample_size);
80        let scalars = sample_scalars::<Fr>(sample_size);
81
82        let cpu_start = now();
83        let fast = cpu_msm(&points, &scalars);
84        console::log_1(&format!("CPU Elapsed: {} ms", now() - cpu_start).into());
85
86        let result_start = now();
87        let result = compute_msm::<G1Affine>(&points, &scalars).await;
88        console::log_1(&format!("GPU Elapsed: {} ms", now() - result_start).into());
89
90        console::log_1(&format!("Result: {result:?}").into());
91        assert_eq!(fast, result);
92    }
93}