use core::cell::Cell;
use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::{DeviceSlice, DeviceSliceMut, Stream};
use baracuda_kernels_sys::{
baracuda_kernels_scale_inplace_c32_run, baracuda_kernels_scale_inplace_c64_run, cufftComplex,
cufftDestroy, cufftDoubleComplex, cufftExecC2C, cufftExecZ2Z, cufftHandle, cufftPlanMany,
cufftSetStream, CUFFT_C2C, CUFFT_FORWARD, CUFFT_INVERSE, CUFFT_Z2Z,
};
use baracuda_kernels_types::{
ArchSku, BackendKind, Complex32, Complex64, Element, ElementKind, FftKind, KernelSku,
MathPrecision, OpCategory, PlanPreference, PrecisionGuarantee, Workspace,
};
use super::fft::{cufft_to_status, map_status};
const HANDLE_UNINIT: cufftHandle = -1;
const MAX_RANK: usize = 4;
#[derive(Copy, Clone, Debug)]
pub struct FftNdDescriptor {
pub dims: [i32; MAX_RANK],
pub rank: u8,
pub batch: i32,
pub inverse: bool,
pub element: ElementKind,
}
impl FftNdDescriptor {
#[inline]
pub fn transform_numel(&self) -> i64 {
let mut n: i64 = 1;
let rank = self.rank as usize;
let mut i = 0;
while i < rank {
n = n.saturating_mul(self.dims[i] as i64);
i += 1;
}
n
}
}
pub struct FftNdArgs<'a, T: Element> {
pub x: DeviceSlice<'a, T>,
pub y: DeviceSliceMut<'a, T>,
}
pub struct FftNdPlan<T: Element> {
desc: FftNdDescriptor,
sku: KernelSku,
handle: Cell<cufftHandle>,
_marker: PhantomData<T>,
}
impl<T: Element> FftNdPlan<T> {
pub fn select(
_stream: &Stream,
desc: &FftNdDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if desc.element != T::KIND {
return Err(Error::Unsupported(
"baracuda-kernels::FftNdPlan: descriptor.element != T::KIND",
));
}
if !matches!(T::KIND, ElementKind::Complex32 | ElementKind::Complex64) {
return Err(Error::Unsupported(
"baracuda-kernels::FftNdPlan: C2C ND FFT supports Complex32 + Complex64 only",
));
}
if !(1..=3).contains(&desc.rank) {
return Err(Error::Unsupported(
"baracuda-kernels::FftNdPlan: rank must be in 1..=3 (trailblazer)",
));
}
if desc.batch <= 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::FftNdPlan: batch must be > 0",
));
}
for i in 0..desc.rank as usize {
if desc.dims[i] <= 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::FftNdPlan: every transformed-axis dim must be > 0",
));
}
}
let math_precision = match T::KIND {
ElementKind::Complex64 => MathPrecision::F64,
_ => MathPrecision::F32,
};
let precision_guarantee = PrecisionGuarantee {
math_precision,
accumulator: T::KIND,
bit_stable_on_same_hardware: false,
deterministic: true,
};
let op = if desc.inverse {
FftKind::Ifft
} else {
FftKind::Fft
};
let sku = KernelSku {
category: OpCategory::Fft,
op: op as u16,
element: T::KIND,
aux_element: None,
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Cufft,
precision_guarantee,
};
Ok(Self {
desc: *desc,
sku,
handle: Cell::new(HANDLE_UNINIT),
_marker: PhantomData,
})
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
fn ensure_handle(&self) -> Result<cufftHandle> {
let h = self.handle.get();
if h != HANDLE_UNINIT {
return Ok(h);
}
let fft_type = match T::KIND {
ElementKind::Complex32 => CUFFT_C2C,
ElementKind::Complex64 => CUFFT_Z2Z,
_ => unreachable!("select() gates on Complex32 / Complex64"),
};
let rank = self.desc.rank as i32;
let mut n: [i32; MAX_RANK] = self.desc.dims;
let dist = self.desc.transform_numel() as i32;
let mut handle: cufftHandle = HANDLE_UNINIT;
let status = unsafe {
cufftPlanMany(
&mut handle as *mut _,
rank,
n.as_mut_ptr(),
core::ptr::null_mut(),
1,
dist,
core::ptr::null_mut(),
1,
dist,
fft_type,
self.desc.batch,
)
};
if status != 0 {
return Err(Error::CutlassInternal(cufft_to_status(status)));
}
self.handle.set(handle);
Ok(handle)
}
fn bind_stream(&self, handle: cufftHandle, stream: &Stream) -> Result<()> {
let stream_ptr = stream.as_raw() as *mut c_void;
let status = unsafe { cufftSetStream(handle, stream_ptr) };
if status != 0 {
return Err(Error::CutlassInternal(cufft_to_status(status)));
}
Ok(())
}
fn check_args(&self, x: &DeviceSlice<'_, T>, y: &DeviceSliceMut<'_, T>) -> Result<i64> {
let per = self.desc.transform_numel();
let total = per.saturating_mul(self.desc.batch as i64);
if (x.len() as i64) < total {
return Err(Error::BufferTooSmall {
needed: total as usize,
got: x.len(),
});
}
if (y.len() as i64) < total {
return Err(Error::BufferTooSmall {
needed: total as usize,
got: y.len(),
});
}
Ok(total)
}
}
impl FftNdPlan<Complex32> {
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: FftNdArgs<'_, Complex32>,
) -> Result<()> {
let total = self.check_args(&args.x, &args.y)?;
if total == 0 {
return Ok(());
}
let handle = self.ensure_handle()?;
self.bind_stream(handle, stream)?;
let direction = if self.desc.inverse {
CUFFT_INVERSE
} else {
CUFFT_FORWARD
};
let idata = args.x.as_raw().0 as *mut cufftComplex;
let odata = args.y.as_raw().0 as *mut cufftComplex;
let status = unsafe { cufftExecC2C(handle, idata, odata, direction) };
if status != 0 {
return Err(Error::CutlassInternal(cufft_to_status(status)));
}
if self.desc.inverse {
let per = self.desc.transform_numel() as f32;
let scale = 1.0_f32 / per;
let stream_ptr = stream.as_raw() as *mut c_void;
let s = unsafe {
baracuda_kernels_scale_inplace_c32_run(
total,
scale,
odata as *mut c_void,
core::ptr::null_mut(),
0,
stream_ptr,
)
};
map_status(s)?;
}
Ok(())
}
}
impl FftNdPlan<Complex64> {
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: FftNdArgs<'_, Complex64>,
) -> Result<()> {
let total = self.check_args(&args.x, &args.y)?;
if total == 0 {
return Ok(());
}
let handle = self.ensure_handle()?;
self.bind_stream(handle, stream)?;
let direction = if self.desc.inverse {
CUFFT_INVERSE
} else {
CUFFT_FORWARD
};
let idata = args.x.as_raw().0 as *mut cufftDoubleComplex;
let odata = args.y.as_raw().0 as *mut cufftDoubleComplex;
let status = unsafe { cufftExecZ2Z(handle, idata, odata, direction) };
if status != 0 {
return Err(Error::CutlassInternal(cufft_to_status(status)));
}
if self.desc.inverse {
let per = self.desc.transform_numel() as f64;
let scale = 1.0_f64 / per;
let stream_ptr = stream.as_raw() as *mut c_void;
let s = unsafe {
baracuda_kernels_scale_inplace_c64_run(
total,
scale,
odata as *mut c_void,
core::ptr::null_mut(),
0,
stream_ptr,
)
};
map_status(s)?;
}
Ok(())
}
}
impl<T: Element> Drop for FftNdPlan<T> {
fn drop(&mut self) {
let h = self.handle.get();
if h != HANDLE_UNINIT {
unsafe {
let _ = cufftDestroy(h);
}
self.handle.set(HANDLE_UNINIT);
}
}
}