use core::cell::Cell;
use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_sys::{
baracuda_kernels_scale_inplace_real_f32_run, baracuda_kernels_scale_inplace_real_f64_run,
cufftComplex, cufftDestroy, cufftDoubleComplex, cufftExecC2R, cufftExecD2Z, cufftExecR2C,
cufftExecZ2D, cufftHandle, cufftPlan1d, cufftSetStream, CUFFT_C2R, CUFFT_D2Z, CUFFT_R2C,
CUFFT_Z2D,
};
use baracuda_kernels_types::{
ArchSku, BackendKind, Complex32, Complex64, Element, ElementKind, FftKind, KernelSku,
MathPrecision, OpCategory, PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
};
use super::fft::{cufft_to_status, map_status};
const HANDLE_UNINIT: cufftHandle = -1;
#[derive(Copy, Clone, Debug)]
pub struct RfftDescriptor {
pub n: i32,
pub batch: i32,
pub element: ElementKind,
}
pub struct RfftArgs<'a, T: Element, C: Element> {
pub x: TensorRef<'a, T, 2>,
pub y: TensorMut<'a, C, 2>,
}
pub struct RfftPlan<T: Element> {
desc: RfftDescriptor,
sku: KernelSku,
handle: Cell<cufftHandle>,
_marker: PhantomData<T>,
}
impl<T: Element> RfftPlan<T> {
pub fn select(_stream: &Stream, desc: &RfftDescriptor, _pref: PlanPreference) -> Result<Self> {
if desc.element != T::KIND {
return Err(Error::Unsupported(
"baracuda-kernels::RfftPlan: descriptor.element != T::KIND",
));
}
if !matches!(T::KIND, ElementKind::F32 | ElementKind::F64) {
return Err(Error::Unsupported(
"baracuda-kernels::RfftPlan: R2C FFT supports f32 + f64 only",
));
}
if desc.n <= 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::RfftPlan: n must be > 0",
));
}
if desc.batch <= 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::RfftPlan: batch must be > 0",
));
}
let math_precision = match T::KIND {
ElementKind::F64 => MathPrecision::F64,
_ => MathPrecision::F32,
};
let aux = match T::KIND {
ElementKind::F32 => Some(ElementKind::Complex32),
ElementKind::F64 => Some(ElementKind::Complex64),
_ => None,
};
let precision_guarantee = PrecisionGuarantee {
math_precision,
accumulator: T::KIND,
bit_stable_on_same_hardware: false,
deterministic: true,
};
let sku = KernelSku {
category: OpCategory::Fft,
op: FftKind::Rfft as u16,
element: T::KIND,
aux_element: aux,
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Cufft,
precision_guarantee,
};
Ok(Self {
desc: *desc,
sku,
handle: Cell::new(HANDLE_UNINIT),
_marker: PhantomData,
})
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
fn ensure_handle(&self) -> Result<cufftHandle> {
let h = self.handle.get();
if h != HANDLE_UNINIT {
return Ok(h);
}
let fft_type = match T::KIND {
ElementKind::F32 => CUFFT_R2C,
ElementKind::F64 => CUFFT_D2Z,
_ => unreachable!("select() gates on F32 / F64"),
};
let mut handle: cufftHandle = HANDLE_UNINIT;
let status = unsafe {
cufftPlan1d(
&mut handle as *mut _,
self.desc.n,
fft_type,
self.desc.batch,
)
};
if status != 0 {
return Err(Error::CutlassInternal(cufft_to_status(status)));
}
self.handle.set(handle);
Ok(handle)
}
fn bind_stream(&self, handle: cufftHandle, stream: &Stream) -> Result<()> {
let stream_ptr = stream.as_raw() as *mut c_void;
let status = unsafe { cufftSetStream(handle, stream_ptr) };
if status != 0 {
return Err(Error::CutlassInternal(cufft_to_status(status)));
}
Ok(())
}
}
impl RfftPlan<f32> {
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: RfftArgs<'_, f32, Complex32>,
) -> Result<()> {
let n = self.desc.n;
let batch = self.desc.batch;
let in_shape = [batch, n];
let out_shape = [batch, n / 2 + 1];
if args.x.shape != in_shape {
return Err(Error::InvalidProblem(
"baracuda-kernels::RfftPlan<f32>: x shape != [batch, n]",
));
}
if args.y.shape != out_shape {
return Err(Error::InvalidProblem(
"baracuda-kernels::RfftPlan<f32>: y shape != [batch, n/2 + 1]",
));
}
let in_numel = (batch as i64) * (n as i64);
let out_numel = (batch as i64) * ((n / 2 + 1) as i64);
if (args.x.data.len() as i64) < in_numel {
return Err(Error::BufferTooSmall {
needed: in_numel as usize,
got: args.x.data.len(),
});
}
if (args.y.data.len() as i64) < out_numel {
return Err(Error::BufferTooSmall {
needed: out_numel as usize,
got: args.y.data.len(),
});
}
if in_numel == 0 {
return Ok(());
}
let handle = self.ensure_handle()?;
self.bind_stream(handle, stream)?;
let idata = args.x.data.as_raw().0 as *mut f32;
let odata = args.y.data.as_raw().0 as *mut cufftComplex;
let status = unsafe { cufftExecR2C(handle, idata, odata) };
if status != 0 {
return Err(Error::CutlassInternal(cufft_to_status(status)));
}
Ok(())
}
}
impl RfftPlan<f64> {
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: RfftArgs<'_, f64, Complex64>,
) -> Result<()> {
let n = self.desc.n;
let batch = self.desc.batch;
let in_shape = [batch, n];
let out_shape = [batch, n / 2 + 1];
if args.x.shape != in_shape {
return Err(Error::InvalidProblem(
"baracuda-kernels::RfftPlan<f64>: x shape != [batch, n]",
));
}
if args.y.shape != out_shape {
return Err(Error::InvalidProblem(
"baracuda-kernels::RfftPlan<f64>: y shape != [batch, n/2 + 1]",
));
}
let in_numel = (batch as i64) * (n as i64);
let out_numel = (batch as i64) * ((n / 2 + 1) as i64);
if (args.x.data.len() as i64) < in_numel {
return Err(Error::BufferTooSmall {
needed: in_numel as usize,
got: args.x.data.len(),
});
}
if (args.y.data.len() as i64) < out_numel {
return Err(Error::BufferTooSmall {
needed: out_numel as usize,
got: args.y.data.len(),
});
}
if in_numel == 0 {
return Ok(());
}
let handle = self.ensure_handle()?;
self.bind_stream(handle, stream)?;
let idata = args.x.data.as_raw().0 as *mut f64;
let odata = args.y.data.as_raw().0 as *mut cufftDoubleComplex;
let status = unsafe { cufftExecD2Z(handle, idata, odata) };
if status != 0 {
return Err(Error::CutlassInternal(cufft_to_status(status)));
}
Ok(())
}
}
impl<T: Element> Drop for RfftPlan<T> {
fn drop(&mut self) {
let h = self.handle.get();
if h != HANDLE_UNINIT {
unsafe {
let _ = cufftDestroy(h);
}
self.handle.set(HANDLE_UNINIT);
}
}
}
#[derive(Copy, Clone, Debug)]
pub struct IrfftDescriptor {
pub n: i32,
pub batch: i32,
pub element: ElementKind,
}
pub struct IrfftArgs<'a, T: Element, C: Element> {
pub x: TensorRef<'a, C, 2>,
pub y: TensorMut<'a, T, 2>,
}
pub struct IrfftPlan<T: Element> {
desc: IrfftDescriptor,
sku: KernelSku,
handle: Cell<cufftHandle>,
_marker: PhantomData<T>,
}
impl<T: Element> IrfftPlan<T> {
pub fn select(
_stream: &Stream,
desc: &IrfftDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if desc.element != T::KIND {
return Err(Error::Unsupported(
"baracuda-kernels::IrfftPlan: descriptor.element != T::KIND",
));
}
if !matches!(T::KIND, ElementKind::F32 | ElementKind::F64) {
return Err(Error::Unsupported(
"baracuda-kernels::IrfftPlan: C2R FFT supports f32 + f64 only",
));
}
if desc.n <= 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::IrfftPlan: n must be > 0",
));
}
if desc.batch <= 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::IrfftPlan: batch must be > 0",
));
}
let math_precision = match T::KIND {
ElementKind::F64 => MathPrecision::F64,
_ => MathPrecision::F32,
};
let aux = match T::KIND {
ElementKind::F32 => Some(ElementKind::Complex32),
ElementKind::F64 => Some(ElementKind::Complex64),
_ => None,
};
let precision_guarantee = PrecisionGuarantee {
math_precision,
accumulator: T::KIND,
bit_stable_on_same_hardware: false,
deterministic: true,
};
let sku = KernelSku {
category: OpCategory::Fft,
op: FftKind::Irfft as u16,
element: T::KIND,
aux_element: aux,
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Cufft,
precision_guarantee,
};
Ok(Self {
desc: *desc,
sku,
handle: Cell::new(HANDLE_UNINIT),
_marker: PhantomData,
})
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
fn ensure_handle(&self) -> Result<cufftHandle> {
let h = self.handle.get();
if h != HANDLE_UNINIT {
return Ok(h);
}
let fft_type = match T::KIND {
ElementKind::F32 => CUFFT_C2R,
ElementKind::F64 => CUFFT_Z2D,
_ => unreachable!("select() gates on F32 / F64"),
};
let mut handle: cufftHandle = HANDLE_UNINIT;
let status = unsafe {
cufftPlan1d(
&mut handle as *mut _,
self.desc.n,
fft_type,
self.desc.batch,
)
};
if status != 0 {
return Err(Error::CutlassInternal(cufft_to_status(status)));
}
self.handle.set(handle);
Ok(handle)
}
fn bind_stream(&self, handle: cufftHandle, stream: &Stream) -> Result<()> {
let stream_ptr = stream.as_raw() as *mut c_void;
let status = unsafe { cufftSetStream(handle, stream_ptr) };
if status != 0 {
return Err(Error::CutlassInternal(cufft_to_status(status)));
}
Ok(())
}
}
impl IrfftPlan<f32> {
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: IrfftArgs<'_, f32, Complex32>,
) -> Result<()> {
let n = self.desc.n;
let batch = self.desc.batch;
let in_shape = [batch, n / 2 + 1];
let out_shape = [batch, n];
if args.x.shape != in_shape {
return Err(Error::InvalidProblem(
"baracuda-kernels::IrfftPlan<f32>: x shape != [batch, n/2 + 1]",
));
}
if args.y.shape != out_shape {
return Err(Error::InvalidProblem(
"baracuda-kernels::IrfftPlan<f32>: y shape != [batch, n]",
));
}
let in_numel = (batch as i64) * ((n / 2 + 1) as i64);
let out_numel = (batch as i64) * (n as i64);
if (args.x.data.len() as i64) < in_numel {
return Err(Error::BufferTooSmall {
needed: in_numel as usize,
got: args.x.data.len(),
});
}
if (args.y.data.len() as i64) < out_numel {
return Err(Error::BufferTooSmall {
needed: out_numel as usize,
got: args.y.data.len(),
});
}
if out_numel == 0 {
return Ok(());
}
let handle = self.ensure_handle()?;
self.bind_stream(handle, stream)?;
let idata = args.x.data.as_raw().0 as *mut cufftComplex;
let odata = args.y.data.as_raw().0 as *mut f32;
let status = unsafe { cufftExecC2R(handle, idata, odata) };
if status != 0 {
return Err(Error::CutlassInternal(cufft_to_status(status)));
}
let scale = 1.0_f32 / (n as f32);
let stream_ptr = stream.as_raw() as *mut c_void;
let s = unsafe {
baracuda_kernels_scale_inplace_real_f32_run(
out_numel,
scale,
odata as *mut c_void,
core::ptr::null_mut(),
0,
stream_ptr,
)
};
map_status(s)
}
}
impl IrfftPlan<f64> {
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: IrfftArgs<'_, f64, Complex64>,
) -> Result<()> {
let n = self.desc.n;
let batch = self.desc.batch;
let in_shape = [batch, n / 2 + 1];
let out_shape = [batch, n];
if args.x.shape != in_shape {
return Err(Error::InvalidProblem(
"baracuda-kernels::IrfftPlan<f64>: x shape != [batch, n/2 + 1]",
));
}
if args.y.shape != out_shape {
return Err(Error::InvalidProblem(
"baracuda-kernels::IrfftPlan<f64>: y shape != [batch, n]",
));
}
let in_numel = (batch as i64) * ((n / 2 + 1) as i64);
let out_numel = (batch as i64) * (n as i64);
if (args.x.data.len() as i64) < in_numel {
return Err(Error::BufferTooSmall {
needed: in_numel as usize,
got: args.x.data.len(),
});
}
if (args.y.data.len() as i64) < out_numel {
return Err(Error::BufferTooSmall {
needed: out_numel as usize,
got: args.y.data.len(),
});
}
if out_numel == 0 {
return Ok(());
}
let handle = self.ensure_handle()?;
self.bind_stream(handle, stream)?;
let idata = args.x.data.as_raw().0 as *mut cufftDoubleComplex;
let odata = args.y.data.as_raw().0 as *mut f64;
let status = unsafe { cufftExecZ2D(handle, idata, odata) };
if status != 0 {
return Err(Error::CutlassInternal(cufft_to_status(status)));
}
let scale = 1.0_f64 / (n as f64);
let stream_ptr = stream.as_raw() as *mut c_void;
let s = unsafe {
baracuda_kernels_scale_inplace_real_f64_run(
out_numel,
scale,
odata as *mut c_void,
core::ptr::null_mut(),
0,
stream_ptr,
)
};
map_status(s)
}
}
impl<T: Element> Drop for IrfftPlan<T> {
fn drop(&mut self) {
let h = self.handle.get();
if h != HANDLE_UNINIT {
unsafe {
let _ = cufftDestroy(h);
}
self.handle.set(HANDLE_UNINIT);
}
}
}