pasta_msm/
lib.rs

1// Copyright Supranational LLC
2// Licensed under the Apache License, Version 2.0, see LICENSE for details.
3// SPDX-License-Identifier: Apache-2.0
4
5#![allow(unexpected_cfgs)]
6
7extern crate semolina;
8
9#[cfg(feature = "cuda")]
10sppark::cuda_error!();
11#[cfg(feature = "cuda")]
12extern "C" {
13    fn cuda_available() -> bool;
14}
15#[cfg(feature = "cuda")]
16pub static mut CUDA_OFF: bool = false;
17
18macro_rules! multi_scalar_mult {
19    (
20        $pasta:ident,
21        $mult:ident,
22        $cuda_mult:ident
23    ) => {
24        use pasta_curves::$pasta;
25
26        extern "C" {
27            fn $mult(
28                out: *mut $pasta::Point,
29                points: *const $pasta::Affine,
30                npoints: usize,
31                scalars: *const $pasta::Scalar,
32                is_mont: bool,
33            );
34        }
35
36        pub fn $pasta(
37            points: &[$pasta::Affine],
38            scalars: &[$pasta::Scalar],
39        ) -> $pasta::Point {
40            let npoints = points.len();
41            if npoints != scalars.len() {
42                panic!("length mismatch")
43            }
44
45            #[cfg(feature = "cuda")]
46            if npoints >= 1 << 16 && unsafe { !CUDA_OFF && cuda_available() } {
47                extern "C" {
48                    fn $cuda_mult(
49                        out: *mut $pasta::Point,
50                        points: *const $pasta::Affine,
51                        npoints: usize,
52                        scalars: *const $pasta::Scalar,
53                        is_mont: bool,
54                    ) -> cuda::Error;
55                }
56                let mut ret = $pasta::Point::default();
57                let err = unsafe {
58                    $cuda_mult(&mut ret, &points[0], npoints, &scalars[0], true)
59                };
60                if err.code != 0 {
61                    panic!("{}", String::from(err));
62                }
63                return ret;
64            }
65            let mut ret = $pasta::Point::default();
66            unsafe { $mult(&mut ret, &points[0], npoints, &scalars[0], true) };
67            ret
68        }
69    };
70}
71
72multi_scalar_mult!(pallas, mult_pippenger_pallas, cuda_pippenger_pallas);
73multi_scalar_mult!(vesta, mult_pippenger_vesta, cuda_pippenger_vesta);
74
75include!("tests.rs");