lurk_pasta_msm/
lib.rs

1// Copyright Supranational LLC
2// Licensed under the Apache License, Version 2.0, see LICENSE for details.
3// SPDX-License-Identifier: Apache-2.0
4
5extern crate semolina;
6
7#[cfg(feature = "cuda")]
8sppark::cuda_error!();
9#[cfg(feature = "cuda")]
10extern "C" {
11    fn cuda_available() -> bool;
12}
13#[cfg(feature = "cuda")]
14pub static mut CUDA_OFF: bool = false;
15
16macro_rules! multi_scalar_mult {
17    (
18        $pasta:ident,
19        $mult:ident,
20        $cuda_mult:ident
21    ) => {
22        use pasta_curves::$pasta;
23
24        extern "C" {
25            fn $mult(
26                out: *mut $pasta::Point,
27                points: *const $pasta::Affine,
28                npoints: usize,
29                scalars: *const $pasta::Scalar,
30                is_mont: bool,
31            );
32        }
33
34        pub fn $pasta(
35            points: &[$pasta::Affine],
36            scalars: &[$pasta::Scalar],
37        ) -> $pasta::Point {
38            let npoints = points.len();
39            if npoints != scalars.len() {
40                panic!("length mismatch")
41            }
42
43            #[cfg(feature = "cuda")]
44            if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } {
45                extern "C" {
46                    fn $cuda_mult(
47                        out: *mut $pasta::Point,
48                        points: *const $pasta::Affine,
49                        npoints: usize,
50                        scalars: *const $pasta::Scalar,
51                        is_mont: bool,
52                    ) -> cuda::Error;
53                }
54                let mut ret = $pasta::Point::default();
55                let err = unsafe {
56                    $cuda_mult(&mut ret, &points[0], npoints, &scalars[0], true)
57                };
58                if err.code != 0 {
59                    panic!("{}", String::from(err));
60                }
61                return ret;
62            }
63            let mut ret = $pasta::Point::default();
64            unsafe { $mult(&mut ret, &points[0], npoints, &scalars[0], true) };
65            ret
66        }
67    };
68}
69
70multi_scalar_mult!(pallas, mult_pippenger_pallas, cuda_pippenger_pallas);
71multi_scalar_mult!(vesta, mult_pippenger_vesta, cuda_pippenger_vesta);
72
73include!("tests.rs");