#![allow(non_camel_case_types)]
#![allow(clippy::too_many_arguments)]
use core::ffi::c_void;
use super::{
baracuda_kernels_scale_inplace_c32_run, baracuda_kernels_scale_inplace_c64_run,
baracuda_kernels_scale_inplace_real_f32_run, baracuda_kernels_scale_inplace_real_f64_run,
cufftComplex, cufftDestroy, cufftDoubleComplex, cufftExecC2C, cufftExecC2R, cufftExecD2Z,
cufftExecR2C, cufftExecZ2D, cufftExecZ2Z, cufftHandle, cufftPlan1d, cufftPlanMany,
cufftSetStream, CUFFT_C2C, CUFFT_C2R, CUFFT_D2Z, CUFFT_FORWARD, CUFFT_INVERSE, CUFFT_R2C,
CUFFT_Z2D, CUFFT_Z2Z,
};
const OK: i32 = 0;
const INVALID: i32 = 2;
const INTERNAL: i32 = 5;
#[inline]
fn map_cufft(status: i32) -> i32 {
if status == 0 { OK } else { INTERNAL }
}
const HANDLE_UNINIT: cufftHandle = -1;
struct CufftPlan {
h: cufftHandle,
}
impl CufftPlan {
#[inline]
fn new() -> Self {
Self { h: HANDLE_UNINIT }
}
}
impl Drop for CufftPlan {
fn drop(&mut self) {
if self.h != HANDLE_UNINIT {
unsafe {
let _ = cufftDestroy(self.h);
}
}
}
}
#[inline]
unsafe fn bind_stream(plan: cufftHandle, stream: *mut c_void) -> i32 {
let s = unsafe { cufftSetStream(plan, stream) };
if s != 0 { INTERNAL } else { OK }
}
macro_rules! fft_1d_pair {
($run:ident, $ws:ident, $cufft_type:expr, $exec:ident, $T:ty, $scale_inplace:ident, $cell:ty) => {
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $ws(_n: i32, _batch: i32, out_bytes: *mut usize) -> i32 {
if out_bytes.is_null() {
return INVALID;
}
unsafe { *out_bytes = 0 };
OK
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $run(
n: i32,
batch: i32,
inverse: i32,
x: *mut c_void,
y: *mut c_void,
_workspace: *mut c_void,
_workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
if n <= 0 || batch <= 0 || x.is_null() || y.is_null() {
return INVALID;
}
let mut plan = CufftPlan::new();
let st =
unsafe { cufftPlan1d(&mut plan.h as *mut _, n, $cufft_type, batch) };
if st != 0 {
return INTERNAL;
}
let s = unsafe { bind_stream(plan.h, stream) };
if s != OK {
return s;
}
let direction = if inverse != 0 { CUFFT_INVERSE } else { CUFFT_FORWARD };
let st = unsafe {
$exec(plan.h, x as *mut $cell, y as *mut $cell, direction)
};
if st != 0 {
return INTERNAL;
}
if inverse != 0 {
let numel = (batch as i64) * (n as i64);
let scale = 1.0 as $T / (n as $T);
let s = unsafe {
$scale_inplace(numel, scale, y, core::ptr::null_mut(), 0, stream)
};
if s != OK {
return s;
}
}
OK
}
};
}
fft_1d_pair!(
baracuda_kernels_fft_1d_c32_run,
baracuda_kernels_fft_1d_c32_workspace_size,
CUFFT_C2C,
cufftExecC2C,
f32,
baracuda_kernels_scale_inplace_c32_run,
cufftComplex
);
fft_1d_pair!(
baracuda_kernels_fft_1d_c64_run,
baracuda_kernels_fft_1d_c64_workspace_size,
CUFFT_Z2Z,
cufftExecZ2Z,
f64,
baracuda_kernels_scale_inplace_c64_run,
cufftDoubleComplex
);
macro_rules! rfft_1d {
($run:ident, $ws:ident, $cufft_type:expr, $exec:ident, $T:ty, $cell:ty) => {
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $ws(_n: i32, _batch: i32, out_bytes: *mut usize) -> i32 {
if out_bytes.is_null() {
return INVALID;
}
unsafe { *out_bytes = 0 };
OK
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $run(
n: i32,
batch: i32,
x: *mut c_void,
y: *mut c_void,
_workspace: *mut c_void,
_workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
if n <= 0 || batch <= 0 || x.is_null() || y.is_null() {
return INVALID;
}
let mut plan = CufftPlan::new();
let st =
unsafe { cufftPlan1d(&mut plan.h as *mut _, n, $cufft_type, batch) };
if st != 0 {
return INTERNAL;
}
let s = unsafe { bind_stream(plan.h, stream) };
if s != OK {
return s;
}
let st = unsafe { $exec(plan.h, x as *mut $T, y as *mut $cell) };
map_cufft(st)
}
};
}
rfft_1d!(
baracuda_kernels_rfft_1d_f32_run,
baracuda_kernels_rfft_1d_f32_workspace_size,
CUFFT_R2C,
cufftExecR2C,
f32,
cufftComplex
);
rfft_1d!(
baracuda_kernels_rfft_1d_f64_run,
baracuda_kernels_rfft_1d_f64_workspace_size,
CUFFT_D2Z,
cufftExecD2Z,
f64,
cufftDoubleComplex
);
macro_rules! irfft_1d {
($run:ident, $ws:ident, $cufft_type:expr, $exec:ident, $T:ty, $cell:ty, $scale_inplace:ident) => {
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $ws(_n: i32, _batch: i32, out_bytes: *mut usize) -> i32 {
if out_bytes.is_null() {
return INVALID;
}
unsafe { *out_bytes = 0 };
OK
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $run(
n: i32,
batch: i32,
x: *mut c_void,
y: *mut c_void,
_workspace: *mut c_void,
_workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
if n <= 0 || batch <= 0 || x.is_null() || y.is_null() {
return INVALID;
}
let mut plan = CufftPlan::new();
let st =
unsafe { cufftPlan1d(&mut plan.h as *mut _, n, $cufft_type, batch) };
if st != 0 {
return INTERNAL;
}
let s = unsafe { bind_stream(plan.h, stream) };
if s != OK {
return s;
}
let st = unsafe { $exec(plan.h, x as *mut $cell, y as *mut $T) };
if st != 0 {
return INTERNAL;
}
let numel = (batch as i64) * (n as i64);
let scale = 1.0 as $T / (n as $T);
let s = unsafe {
$scale_inplace(numel, scale, y, core::ptr::null_mut(), 0, stream)
};
if s != OK {
return s;
}
OK
}
};
}
irfft_1d!(
baracuda_kernels_irfft_1d_f32_run,
baracuda_kernels_irfft_1d_f32_workspace_size,
CUFFT_C2R,
cufftExecC2R,
f32,
cufftComplex,
baracuda_kernels_scale_inplace_real_f32_run
);
irfft_1d!(
baracuda_kernels_irfft_1d_f64_run,
baracuda_kernels_irfft_1d_f64_workspace_size,
CUFFT_Z2D,
cufftExecZ2D,
f64,
cufftDoubleComplex,
baracuda_kernels_scale_inplace_real_f64_run
);
const FFT_ND_MAX_RANK: i32 = 3;
macro_rules! fft_nd_pair {
($run:ident, $ws:ident, $cufft_type:expr, $exec:ident, $T:ty, $scale_inplace:ident, $cell:ty) => {
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $ws(
_rank: i32,
_dims: *const i32,
_batch: i32,
out_bytes: *mut usize,
) -> i32 {
if out_bytes.is_null() {
return INVALID;
}
unsafe { *out_bytes = 0 };
OK
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $run(
rank: i32,
dims: *const i32,
batch: i32,
inverse: i32,
x: *mut c_void,
y: *mut c_void,
_workspace: *mut c_void,
_workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
if !(1..=FFT_ND_MAX_RANK).contains(&rank)
|| batch <= 0
|| dims.is_null()
|| x.is_null()
|| y.is_null()
{
return INVALID;
}
let mut n_arr = [0i32; FFT_ND_MAX_RANK as usize];
let mut total: i64 = 1;
for i in 0..rank as usize {
let d = unsafe { *dims.add(i) };
if d <= 0 {
return INVALID;
}
n_arr[i] = d;
total = total.saturating_mul(d as i64);
}
let dist = total as i32;
let mut plan = CufftPlan::new();
let st = unsafe {
cufftPlanMany(
&mut plan.h as *mut _,
rank,
n_arr.as_mut_ptr(),
core::ptr::null_mut(),
1,
dist,
core::ptr::null_mut(),
1,
dist,
$cufft_type,
batch,
)
};
if st != 0 {
return INTERNAL;
}
let s = unsafe { bind_stream(plan.h, stream) };
if s != OK {
return s;
}
let direction = if inverse != 0 { CUFFT_INVERSE } else { CUFFT_FORWARD };
let st = unsafe {
$exec(plan.h, x as *mut $cell, y as *mut $cell, direction)
};
if st != 0 {
return INTERNAL;
}
if inverse != 0 {
let total_with_batch = total.saturating_mul(batch as i64);
let scale = 1.0 as $T / (total as $T);
let s = unsafe {
$scale_inplace(total_with_batch, scale, y, core::ptr::null_mut(), 0, stream)
};
if s != OK {
return s;
}
}
OK
}
};
}
fft_nd_pair!(
baracuda_kernels_fft_nd_c32_run,
baracuda_kernels_fft_nd_c32_workspace_size,
CUFFT_C2C,
cufftExecC2C,
f32,
baracuda_kernels_scale_inplace_c32_run,
cufftComplex
);
fft_nd_pair!(
baracuda_kernels_fft_nd_c64_run,
baracuda_kernels_fft_nd_c64_workspace_size,
CUFFT_Z2Z,
cufftExecZ2Z,
f64,
baracuda_kernels_scale_inplace_c64_run,
cufftDoubleComplex
);
macro_rules! rfft_nd {
($run:ident, $ws:ident, $cufft_type:expr, $exec:ident, $T:ty, $cell:ty) => {
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $ws(
_rank: i32,
_dims: *const i32,
_batch: i32,
out_bytes: *mut usize,
) -> i32 {
if out_bytes.is_null() {
return INVALID;
}
unsafe { *out_bytes = 0 };
OK
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $run(
rank: i32,
dims: *const i32,
batch: i32,
x: *mut c_void,
y: *mut c_void,
_workspace: *mut c_void,
_workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
if !(1..=FFT_ND_MAX_RANK).contains(&rank)
|| batch <= 0
|| dims.is_null()
|| x.is_null()
|| y.is_null()
{
return INVALID;
}
let mut n_arr = [0i32; FFT_ND_MAX_RANK as usize];
let mut real_numel: i64 = 1;
let mut complex_numel: i64 = 1;
for i in 0..rank as usize {
let d = unsafe { *dims.add(i) };
if d <= 0 {
return INVALID;
}
n_arr[i] = d;
real_numel = real_numel.saturating_mul(d as i64);
if i + 1 < rank as usize {
complex_numel = complex_numel.saturating_mul(d as i64);
} else {
complex_numel = complex_numel.saturating_mul((d / 2 + 1) as i64);
}
}
let real_dist = real_numel as i32;
let complex_dist = complex_numel as i32;
let mut plan = CufftPlan::new();
let st = unsafe {
cufftPlanMany(
&mut plan.h as *mut _,
rank,
n_arr.as_mut_ptr(),
core::ptr::null_mut(),
1,
real_dist,
core::ptr::null_mut(),
1,
complex_dist,
$cufft_type,
batch,
)
};
if st != 0 {
return INTERNAL;
}
let s = unsafe { bind_stream(plan.h, stream) };
if s != OK {
return s;
}
let st = unsafe { $exec(plan.h, x as *mut $T, y as *mut $cell) };
map_cufft(st)
}
};
}
rfft_nd!(
baracuda_kernels_rfft_nd_f32_run,
baracuda_kernels_rfft_nd_f32_workspace_size,
CUFFT_R2C,
cufftExecR2C,
f32,
cufftComplex
);
rfft_nd!(
baracuda_kernels_rfft_nd_f64_run,
baracuda_kernels_rfft_nd_f64_workspace_size,
CUFFT_D2Z,
cufftExecD2Z,
f64,
cufftDoubleComplex
);
macro_rules! irfft_nd {
($run:ident, $ws:ident, $cufft_type:expr, $exec:ident, $T:ty, $cell:ty, $scale_inplace:ident) => {
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $ws(
_rank: i32,
_dims: *const i32,
_batch: i32,
out_bytes: *mut usize,
) -> i32 {
if out_bytes.is_null() {
return INVALID;
}
unsafe { *out_bytes = 0 };
OK
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn $run(
rank: i32,
dims: *const i32,
batch: i32,
x: *mut c_void,
y: *mut c_void,
_workspace: *mut c_void,
_workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
if !(1..=FFT_ND_MAX_RANK).contains(&rank)
|| batch <= 0
|| dims.is_null()
|| x.is_null()
|| y.is_null()
{
return INVALID;
}
let mut n_arr = [0i32; FFT_ND_MAX_RANK as usize];
let mut real_numel: i64 = 1;
let mut complex_numel: i64 = 1;
for i in 0..rank as usize {
let d = unsafe { *dims.add(i) };
if d <= 0 {
return INVALID;
}
n_arr[i] = d;
real_numel = real_numel.saturating_mul(d as i64);
if i + 1 < rank as usize {
complex_numel = complex_numel.saturating_mul(d as i64);
} else {
complex_numel = complex_numel.saturating_mul((d / 2 + 1) as i64);
}
}
let real_dist = real_numel as i32;
let complex_dist = complex_numel as i32;
let mut plan = CufftPlan::new();
let st = unsafe {
cufftPlanMany(
&mut plan.h as *mut _,
rank,
n_arr.as_mut_ptr(),
core::ptr::null_mut(),
1,
complex_dist,
core::ptr::null_mut(),
1,
real_dist,
$cufft_type,
batch,
)
};
if st != 0 {
return INTERNAL;
}
let s = unsafe { bind_stream(plan.h, stream) };
if s != OK {
return s;
}
let st = unsafe { $exec(plan.h, x as *mut $cell, y as *mut $T) };
if st != 0 {
return INTERNAL;
}
let total_with_batch = real_numel.saturating_mul(batch as i64);
let scale = 1.0 as $T / (real_numel as $T);
let s = unsafe {
$scale_inplace(total_with_batch, scale, y, core::ptr::null_mut(), 0, stream)
};
if s != OK {
return s;
}
OK
}
};
}
irfft_nd!(
baracuda_kernels_irfft_nd_f32_run,
baracuda_kernels_irfft_nd_f32_workspace_size,
CUFFT_C2R,
cufftExecC2R,
f32,
cufftComplex,
baracuda_kernels_scale_inplace_real_f32_run
);
irfft_nd!(
baracuda_kernels_irfft_nd_f64_run,
baracuda_kernels_irfft_nd_f64_workspace_size,
CUFFT_Z2D,
cufftExecZ2D,
f64,
cufftDoubleComplex,
baracuda_kernels_scale_inplace_real_f64_run
);