use crate::curve::{Affine, Curve, Projective};
use crate::error::IcicleResult;
use icicle_cuda_runtime::device_context::DeviceContext;
use icicle_cuda_runtime::memory::HostOrDeviceSlice;
#[cfg(feature = "arkworks")]
#[doc(hidden)]
pub mod tests;
#[repr(C)]
#[derive(Debug, Clone)]
pub struct MSMConfig<'a> {
pub ctx: DeviceContext<'a>,
points_size: i32,
pub precompute_factor: i32,
pub c: i32,
pub bitsize: i32,
pub large_bucket_factor: i32,
batch_size: i32,
are_scalars_on_device: bool,
pub are_scalars_montgomery_form: bool,
are_points_on_device: bool,
pub are_points_montgomery_form: bool,
are_results_on_device: bool,
pub is_big_triangle: bool,
pub is_async: bool,
}
#[doc(hidden)]
pub trait MSM<C: Curve> {
fn msm_unchecked(
scalars: &HostOrDeviceSlice<C::ScalarField>,
points: &HostOrDeviceSlice<Affine<C>>,
cfg: &MSMConfig,
results: &mut HostOrDeviceSlice<Projective<C>>,
) -> IcicleResult<()>;
fn get_default_msm_config() -> MSMConfig<'static>;
}
pub fn msm<C: Curve + MSM<C>>(
scalars: &HostOrDeviceSlice<C::ScalarField>,
points: &HostOrDeviceSlice<Affine<C>>,
cfg: &MSMConfig,
results: &mut HostOrDeviceSlice<Projective<C>>,
) -> IcicleResult<()> {
if scalars.len() % points.len() != 0 {
panic!(
"Number of points {} does not divide the number of scalars {}",
points.len(),
scalars.len()
);
}
if scalars.len() % results.len() != 0 {
panic!(
"Number of results {} does not divide the number of scalars {}",
results.len(),
scalars.len()
);
}
let mut local_cfg = cfg.clone();
local_cfg.points_size = points.len() as i32;
local_cfg.batch_size = results.len() as i32;
local_cfg.are_scalars_on_device = scalars.is_on_device();
local_cfg.are_points_on_device = points.is_on_device();
local_cfg.are_results_on_device = results.is_on_device();
C::msm_unchecked(scalars, points, &local_cfg, results)
}
pub fn get_default_msm_config<C: Curve + MSM<C>>() -> MSMConfig<'static> {
C::get_default_msm_config()
}
#[macro_export]
macro_rules! impl_msm {
(
$curve_prefix:literal,
$curve_prefix_indent:ident,
$curve:ident
) => {
mod $curve_prefix_indent {
use super::{$curve, Affine, CudaError, Curve, MSMConfig, Projective};
extern "C" {
#[link_name = concat!($curve_prefix, "MSMCuda")]
pub(crate) fn msm_cuda(
scalars: *const <$curve as Curve>::ScalarField,
points: *const Affine<$curve>,
count: i32,
config: &MSMConfig,
out: *mut Projective<$curve>,
) -> CudaError;
#[link_name = concat!($curve_prefix, "DefaultMSMConfig")]
pub(crate) fn default_msm_config() -> MSMConfig<'static>;
}
}
impl MSM<$curve> for $curve {
fn msm_unchecked(
scalars: &HostOrDeviceSlice<<$curve as Curve>::ScalarField>,
points: &HostOrDeviceSlice<Affine<$curve>>,
cfg: &MSMConfig,
results: &mut HostOrDeviceSlice<Projective<$curve>>,
) -> IcicleResult<()> {
unsafe {
$curve_prefix_indent::msm_cuda(
scalars.as_ptr(),
points.as_ptr(),
(scalars.len() / results.len()) as i32,
cfg,
results.as_mut_ptr(),
)
.wrap()
}
}
fn get_default_msm_config() -> MSMConfig<'static> {
unsafe { $curve_prefix_indent::default_msm_config() }
}
}
};
}
#[macro_export]
macro_rules! impl_msm_tests {
(
$curve:ident
) => {
#[test]
fn test_msm() {
check_msm::<$curve>()
}
#[test]
fn test_msm_batch() {
check_msm_batch::<$curve>()
}
#[test]
fn test_msm_skewed_distributions() {
check_msm_skewed_distributions::<$curve>()
}
};
}