grumpkin_msm/
lib.rs

1// Copyright Supranational LLC
2// Licensed under the Apache License, Version 2.0, see LICENSE for details.
3// SPDX-License-Identifier: Apache-2.0
4#![allow(improper_ctypes)]
5#![allow(unused)]
6
7pub mod utils;
8
9extern crate blst;
10
11#[cfg(feature = "cuda")]
12sppark::cuda_error!();
13#[cfg(feature = "cuda")]
14extern "C" {
15    pub fn cuda_available() -> bool;
16}
17#[cfg(feature = "cuda")]
18pub static mut CUDA_OFF: bool = false;
19
20use halo2curves::bn256;
21use halo2curves::CurveExt;
22
23extern "C" {
24    fn mult_pippenger_bn254(
25        out: *mut bn256::G1,
26        points: *const bn256::G1Affine,
27        npoints: usize,
28        scalars: *const bn256::Fr,
29    );
30
31}
32
33pub fn bn256(points: &[bn256::G1Affine], scalars: &[bn256::Fr]) -> bn256::G1 {
34    let npoints = points.len();
35    assert!(npoints == scalars.len(), "length mismatch");
36
37    #[cfg(feature = "cuda")]
38    if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } {
39        extern "C" {
40            fn cuda_pippenger_bn254(
41                out: *mut bn256::G1,
42                points: *const bn256::G1Affine,
43                npoints: usize,
44                scalars: *const bn256::Fr,
45            ) -> cuda::Error;
46
47        }
48        let mut ret = bn256::G1::default();
49        let err = unsafe {
50            cuda_pippenger_bn254(&mut ret, &points[0], npoints, &scalars[0])
51        };
52        assert!(err.code == 0, "{}", String::from(err));
53
54        return bn256::G1::new_jacobian(ret.x, ret.y, ret.z).unwrap();
55    }
56    let mut ret = bn256::G1::default();
57    unsafe { mult_pippenger_bn254(&mut ret, &points[0], npoints, &scalars[0]) };
58    bn256::G1::new_jacobian(ret.x, ret.y, ret.z).unwrap()
59}
60
61use halo2curves::grumpkin;
62
63extern "C" {
64    fn mult_pippenger_grumpkin(
65        out: *mut grumpkin::G1,
66        points: *const grumpkin::G1Affine,
67        npoints: usize,
68        scalars: *const grumpkin::Fr,
69    );
70
71}
72
73pub fn grumpkin(
74    points: &[grumpkin::G1Affine],
75    scalars: &[grumpkin::Fr],
76) -> grumpkin::G1 {
77    let npoints = points.len();
78    assert!(npoints == scalars.len(), "length mismatch");
79
80    #[cfg(feature = "cuda")]
81    if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } {
82        extern "C" {
83            fn cuda_pippenger_grumpkin(
84                out: *mut grumpkin::G1,
85                points: *const grumpkin::G1Affine,
86                npoints: usize,
87                scalars: *const grumpkin::Fr,
88            ) -> cuda::Error;
89
90        }
91        let mut ret = grumpkin::G1::default();
92        let err = unsafe {
93            cuda_pippenger_grumpkin(&mut ret, &points[0], npoints, &scalars[0])
94        };
95        assert!(err.code == 0, "{}", String::from(err));
96
97        return grumpkin::G1::new_jacobian(ret.x, ret.y, ret.z).unwrap();
98    }
99    let mut ret = grumpkin::G1::default();
100    unsafe {
101        mult_pippenger_grumpkin(&mut ret, &points[0], npoints, &scalars[0])
102    };
103    grumpkin::G1::new_jacobian(ret.x, ret.y, ret.z).unwrap()
104}
105
106#[cfg(test)]
107mod tests {
108    use halo2curves::group::Curve;
109
110    use crate::utils::{gen_points, gen_scalars, naive_multiscalar_mul};
111
112    #[test]
113    fn it_works() {
114        #[cfg(not(debug_assertions))]
115        const NPOINTS: usize = 128 * 1024;
116        #[cfg(debug_assertions)]
117        const NPOINTS: usize = 8 * 1024;
118
119        let points = gen_points(NPOINTS);
120        let scalars = gen_scalars(NPOINTS);
121
122        let naive = naive_multiscalar_mul(&points, &scalars);
123        println!("{:?}", naive);
124
125        let ret = crate::bn256(&points, &scalars).to_affine();
126        println!("{:?}", ret);
127
128        assert_eq!(ret, naive);
129    }
130}