use icicle_cuda_runtime::device_context::{get_default_device_context, DeviceContext};
use icicle_cuda_runtime::memory::HostOrDeviceSlice;
use crate::{error::IcicleResult, traits::FieldImpl};
#[cfg(feature = "arkworks")]
#[doc(hidden)]
pub mod tests;
#[allow(non_camel_case_types)]
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum NTTDir {
kForward,
kInverse,
}
#[allow(non_camel_case_types)]
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Ordering {
kNN,
kNR,
kRN,
kRR,
}
#[repr(C)]
#[derive(Debug, Clone)]
pub struct NTTConfig<'a, S> {
pub ctx: DeviceContext<'a>,
pub coset_gen: S,
pub batch_size: i32,
pub ordering: Ordering,
are_inputs_on_device: bool,
are_outputs_on_device: bool,
pub is_async: bool,
pub is_force_radix2: bool,
}
impl<'a, S: FieldImpl> NTTConfig<'a, S> {
pub fn default_config() -> Self {
let ctx = get_default_device_context();
NTTConfig {
ctx,
coset_gen: S::one(),
batch_size: 1,
ordering: Ordering::kNN,
are_inputs_on_device: false,
are_outputs_on_device: false,
is_async: false,
is_force_radix2: false,
}
}
}
#[doc(hidden)]
pub trait NTT<F: FieldImpl> {
fn ntt_unchecked(
input: &HostOrDeviceSlice<F>,
dir: NTTDir,
cfg: &NTTConfig<F>,
output: &mut HostOrDeviceSlice<F>,
) -> IcicleResult<()>;
fn initialize_domain(primitive_root: F, ctx: &DeviceContext) -> IcicleResult<()>;
fn get_default_ntt_config() -> NTTConfig<'static, F>;
}
pub fn ntt<F>(
input: &HostOrDeviceSlice<F>,
dir: NTTDir,
cfg: &NTTConfig<F>,
output: &mut HostOrDeviceSlice<F>,
) -> IcicleResult<()>
where
F: FieldImpl,
<F as FieldImpl>::Config: NTT<F>,
{
if input.len() != output.len() {
panic!(
"input and output lengths {}; {} do not match",
input.len(),
output.len()
);
}
let mut local_cfg = cfg.clone();
local_cfg.are_inputs_on_device = input.is_on_device();
local_cfg.are_outputs_on_device = output.is_on_device();
<<F as FieldImpl>::Config as NTT<F>>::ntt_unchecked(input, dir, &local_cfg, output)
}
pub fn initialize_domain<F>(primitive_root: F, ctx: &DeviceContext) -> IcicleResult<()>
where
F: FieldImpl,
<F as FieldImpl>::Config: NTT<F>,
{
<<F as FieldImpl>::Config as NTT<F>>::initialize_domain(primitive_root, ctx)
}
pub fn get_default_ntt_config<F>() -> NTTConfig<'static, F>
where
F: FieldImpl,
<F as FieldImpl>::Config: NTT<F>,
{
<<F as FieldImpl>::Config as NTT<F>>::get_default_ntt_config()
}
#[macro_export]
macro_rules! impl_ntt {
(
$field_prefix:literal,
$field_prefix_ident:ident,
$field:ident,
$field_config:ident
) => {
mod $field_prefix_ident {
use crate::ntt::{$field, $field_config, CudaError, DeviceContext, NTTConfig, NTTDir};
extern "C" {
#[link_name = concat!($field_prefix, "NTTCuda")]
pub(crate) fn ntt_cuda(
input: *const $field,
size: i32,
dir: NTTDir,
config: &NTTConfig<$field>,
output: *mut $field,
) -> CudaError;
#[link_name = concat!($field_prefix, "InitializeDomain")]
pub(crate) fn initialize_ntt_domain(primitive_root: $field, ctx: &DeviceContext) -> CudaError;
}
}
impl NTT<$field> for $field_config {
fn ntt_unchecked(
input: &HostOrDeviceSlice<$field>,
dir: NTTDir,
cfg: &NTTConfig<$field>,
output: &mut HostOrDeviceSlice<$field>,
) -> IcicleResult<()> {
unsafe {
$field_prefix_ident::ntt_cuda(
input.as_ptr(),
(input.len() / (cfg.batch_size as usize)) as i32,
dir,
cfg,
output.as_mut_ptr(),
)
.wrap()
}
}
fn initialize_domain(primitive_root: $field, ctx: &DeviceContext) -> IcicleResult<()> {
unsafe { $field_prefix_ident::initialize_ntt_domain(primitive_root, ctx).wrap() }
}
fn get_default_ntt_config() -> NTTConfig<'static, $field> {
NTTConfig::<$field>::default_config()
}
}
};
}
#[macro_export]
macro_rules! impl_ntt_tests {
(
$field:ident
) => {
const MAX_SIZE: u64 = 1 << 17;
static INIT: OnceLock<()> = OnceLock::new();
#[test]
fn test_ntt() {
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE));
check_ntt::<$field>()
}
#[test]
fn test_ntt_coset_from_subgroup() {
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE));
check_ntt_coset_from_subgroup::<$field>()
}
#[test]
fn test_ntt_arbitrary_coset() {
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE));
check_ntt_arbitrary_coset::<$field>()
}
#[test]
fn test_ntt_batch() {
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE));
check_ntt_batch::<$field>()
}
#[test]
fn test_ntt_device_async() {
INIT.get_or_init(move || init_domain::<$field>(MAX_SIZE));
check_ntt_device_async::<$field>()
}
};
}