use core::cell::Cell;
use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_sys::{
baracuda_kernels_scale_inplace_c32_run, baracuda_kernels_scale_inplace_c64_run, cufftComplex,
cufftDestroy, cufftDoubleComplex, cufftExecC2C, cufftExecZ2Z, cufftHandle, cufftPlan1d,
cufftSetStream, CUFFT_C2C, CUFFT_FORWARD, CUFFT_INVERSE, CUFFT_Z2Z,
};
use baracuda_kernels_types::{
ArchSku, BackendKind, Complex32, Complex64, Element, ElementKind, FftKind, KernelSku,
MathPrecision, OpCategory, PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
};
const HANDLE_UNINIT: cufftHandle = -1;
#[derive(Copy, Clone, Debug)]
pub struct FftDescriptor {
pub n: i32,
pub batch: i32,
pub inverse: bool,
pub element: ElementKind,
}
pub struct FftArgs<'a, T: Element> {
pub x: TensorRef<'a, T, 2>,
pub y: TensorMut<'a, T, 2>,
}
pub struct FftPlan<T: Element> {
desc: FftDescriptor,
sku: KernelSku,
handle: Cell<cufftHandle>,
_marker: PhantomData<T>,
}
impl<T: Element> FftPlan<T> {
pub fn select(_stream: &Stream, desc: &FftDescriptor, _pref: PlanPreference) -> Result<Self> {
if desc.element != T::KIND {
return Err(Error::Unsupported(
"baracuda-kernels::FftPlan: descriptor.element != T::KIND",
));
}
if !matches!(T::KIND, ElementKind::Complex32 | ElementKind::Complex64) {
return Err(Error::Unsupported(
"baracuda-kernels::FftPlan: C2C FFT supports Complex32 + Complex64 only",
));
}
if desc.n <= 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::FftPlan: n must be > 0",
));
}
if desc.batch <= 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::FftPlan: batch must be > 0",
));
}
let math_precision = match T::KIND {
ElementKind::Complex64 => MathPrecision::F64,
_ => MathPrecision::F32,
};
let precision_guarantee = PrecisionGuarantee {
math_precision,
accumulator: T::KIND,
bit_stable_on_same_hardware: false,
deterministic: true,
};
let op = if desc.inverse {
FftKind::Ifft
} else {
FftKind::Fft
};
let sku = KernelSku {
category: OpCategory::Fft,
op: op as u16,
element: T::KIND,
aux_element: None,
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Cufft,
precision_guarantee,
};
Ok(Self {
desc: *desc,
sku,
handle: Cell::new(HANDLE_UNINIT),
_marker: PhantomData,
})
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
fn ensure_handle(&self) -> Result<cufftHandle> {
let h = self.handle.get();
if h != HANDLE_UNINIT {
return Ok(h);
}
let fft_type = match T::KIND {
ElementKind::Complex32 => CUFFT_C2C,
ElementKind::Complex64 => CUFFT_Z2Z,
_ => unreachable!("select() gates on Complex32 / Complex64"),
};
let mut handle: cufftHandle = HANDLE_UNINIT;
let status = unsafe {
cufftPlan1d(
&mut handle as *mut _,
self.desc.n,
fft_type,
self.desc.batch,
)
};
if status != 0 {
return Err(Error::CutlassInternal(cufft_to_status(status)));
}
self.handle.set(handle);
Ok(handle)
}
fn bind_stream(&self, handle: cufftHandle, stream: &Stream) -> Result<()> {
let stream_ptr = stream.as_raw() as *mut c_void;
let status = unsafe { cufftSetStream(handle, stream_ptr) };
if status != 0 {
return Err(Error::CutlassInternal(cufft_to_status(status)));
}
Ok(())
}
fn check_args(&self, x: &TensorRef<'_, T, 2>, y: &TensorMut<'_, T, 2>) -> Result<i64> {
let expected = [self.desc.batch, self.desc.n];
if x.shape != expected {
return Err(Error::InvalidProblem(
"baracuda-kernels::FftPlan: x shape != [batch, n]",
));
}
if y.shape != expected {
return Err(Error::InvalidProblem(
"baracuda-kernels::FftPlan: y shape != [batch, n]",
));
}
let numel = (self.desc.batch as i64) * (self.desc.n as i64);
if (x.data.len() as i64) < numel {
return Err(Error::BufferTooSmall {
needed: numel as usize,
got: x.data.len(),
});
}
if (y.data.len() as i64) < numel {
return Err(Error::BufferTooSmall {
needed: numel as usize,
got: y.data.len(),
});
}
Ok(numel)
}
}
impl FftPlan<Complex32> {
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: FftArgs<'_, Complex32>,
) -> Result<()> {
let numel = self.check_args(&args.x, &args.y)?;
if numel == 0 {
return Ok(());
}
let handle = self.ensure_handle()?;
self.bind_stream(handle, stream)?;
let direction = if self.desc.inverse {
CUFFT_INVERSE
} else {
CUFFT_FORWARD
};
let idata = args.x.data.as_raw().0 as *mut cufftComplex;
let odata = args.y.data.as_raw().0 as *mut cufftComplex;
let status = unsafe { cufftExecC2C(handle, idata, odata, direction) };
if status != 0 {
return Err(Error::CutlassInternal(cufft_to_status(status)));
}
if self.desc.inverse {
let scale = 1.0_f32 / (self.desc.n as f32);
let stream_ptr = stream.as_raw() as *mut c_void;
let s = unsafe {
baracuda_kernels_scale_inplace_c32_run(
numel,
scale,
odata as *mut c_void,
core::ptr::null_mut(),
0,
stream_ptr,
)
};
map_status(s)?;
}
Ok(())
}
}
impl FftPlan<Complex64> {
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: FftArgs<'_, Complex64>,
) -> Result<()> {
let numel = self.check_args(&args.x, &args.y)?;
if numel == 0 {
return Ok(());
}
let handle = self.ensure_handle()?;
self.bind_stream(handle, stream)?;
let direction = if self.desc.inverse {
CUFFT_INVERSE
} else {
CUFFT_FORWARD
};
let idata = args.x.data.as_raw().0 as *mut cufftDoubleComplex;
let odata = args.y.data.as_raw().0 as *mut cufftDoubleComplex;
let status = unsafe { cufftExecZ2Z(handle, idata, odata, direction) };
if status != 0 {
return Err(Error::CutlassInternal(cufft_to_status(status)));
}
if self.desc.inverse {
let scale = 1.0_f64 / (self.desc.n as f64);
let stream_ptr = stream.as_raw() as *mut c_void;
let s = unsafe {
baracuda_kernels_scale_inplace_c64_run(
numel,
scale,
odata as *mut c_void,
core::ptr::null_mut(),
0,
stream_ptr,
)
};
map_status(s)?;
}
Ok(())
}
}
impl<T: Element> Drop for FftPlan<T> {
fn drop(&mut self) {
let h = self.handle.get();
if h != HANDLE_UNINIT {
unsafe {
let _ = cufftDestroy(h);
}
self.handle.set(HANDLE_UNINIT);
}
}
}
pub(crate) fn cufft_to_status(cufft_code: i32) -> i32 {
if cufft_code == 0 {
0
} else {
-cufft_code
}
}
pub(crate) fn map_status(code: i32) -> Result<()> {
match code {
0 => Ok(()),
1 => Err(Error::MisalignedOperand),
2 => Err(Error::InvalidProblem(
"baracuda-kernels-sys reported invalid problem",
)),
3 => Err(Error::Unsupported(
"baracuda-kernels-sys reported unsupported configuration",
)),
4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
n => Err(Error::CutlassInternal(n)),
}
}