use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_sys::{
baracuda_kernels_fftshift_nd_16_run, baracuda_kernels_fftshift_nd_4_run,
baracuda_kernels_fftshift_nd_8_run,
};
use baracuda_kernels_types::{
contiguous_stride, ArchSku, BackendKind, Element, ElementKind, FftKind, KernelSku,
MathPrecision, OpCategory, PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
};
use super::fft::map_status;
pub const FFTSHIFT_ND_MAX_RANK: usize = 8;
pub const FFTSHIFT_ND_MAX_SHIFT_AXES: usize = 4;
#[derive(Copy, Clone, Debug)]
pub struct FftShiftNdDescriptor {
pub shape: [i32; FFTSHIFT_ND_MAX_RANK],
pub ndim: u8,
pub shift_axes: [u8; FFTSHIFT_ND_MAX_SHIFT_AXES],
pub num_shift_axes: u8,
pub inverse: bool,
pub element: ElementKind,
}
pub struct FftShiftNdArgs<'a, T: Element, const N: usize> {
pub input: TensorRef<'a, T, N>,
pub output: TensorMut<'a, T, N>,
}
pub struct FftShiftNdPlan<T: Element, const N: usize> {
desc: FftShiftNdDescriptor,
sku: KernelSku,
_marker: PhantomData<T>,
}
impl<T: Element, const N: usize> FftShiftNdPlan<T, N> {
pub fn select(
_stream: &Stream,
desc: &FftShiftNdDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if desc.element != T::KIND {
return Err(Error::Unsupported(
"baracuda-kernels::FftShiftNdPlan: descriptor.element != T::KIND",
));
}
let size = core::mem::size_of::<T>();
if !matches!(size, 4 | 8 | 16) {
return Err(Error::Unsupported(
"baracuda-kernels::FftShiftNdPlan: only 4/8/16-byte element types supported",
));
}
let ndim = desc.ndim as usize;
if ndim == 0 || ndim > FFTSHIFT_ND_MAX_RANK {
return Err(Error::InvalidProblem(
"baracuda-kernels::FftShiftNdPlan: ndim must be in 1..=8",
));
}
if ndim != N {
return Err(Error::InvalidProblem(
"baracuda-kernels::FftShiftNdPlan: descriptor.ndim != const N",
));
}
for &d in &desc.shape[..ndim] {
if d < 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::FftShiftNdPlan: shape dims must be non-negative",
));
}
}
let num_shift_axes = desc.num_shift_axes as usize;
if num_shift_axes > FFTSHIFT_ND_MAX_SHIFT_AXES || num_shift_axes > ndim {
return Err(Error::InvalidProblem(
"baracuda-kernels::FftShiftNdPlan: num_shift_axes out of range",
));
}
let mut seen = [false; FFTSHIFT_ND_MAX_RANK];
for &axis in &desc.shift_axes[..num_shift_axes] {
let a = axis as usize;
if a >= ndim {
return Err(Error::InvalidProblem(
"baracuda-kernels::FftShiftNdPlan: shift_axes entry out of range",
));
}
if seen[a] {
return Err(Error::InvalidProblem(
"baracuda-kernels::FftShiftNdPlan: duplicate entry in shift_axes",
));
}
seen[a] = true;
}
let math_precision = match T::KIND {
ElementKind::F64 | ElementKind::Complex64 => MathPrecision::F64,
_ => MathPrecision::F32,
};
let precision_guarantee = PrecisionGuarantee {
math_precision,
accumulator: T::KIND,
bit_stable_on_same_hardware: true,
deterministic: true,
};
let op = if desc.inverse {
FftKind::IfftShift
} else {
FftKind::FftShift
};
let sku = KernelSku {
category: OpCategory::Fft,
op: op as u16,
element: T::KIND,
aux_element: None,
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Bespoke,
precision_guarantee,
};
Ok(Self {
desc: *desc,
sku,
_marker: PhantomData,
})
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
pub fn can_implement(&self, args: &FftShiftNdArgs<'_, T, N>) -> Result<()> {
let ndim = self.desc.ndim as usize;
let mut expected_shape = [0i32; N];
for i in 0..ndim {
expected_shape[i] = self.desc.shape[i];
}
if args.input.shape != expected_shape {
return Err(Error::InvalidProblem(
"baracuda-kernels::FftShiftNdPlan: input shape != descriptor.shape",
));
}
if args.output.shape != expected_shape {
return Err(Error::InvalidProblem(
"baracuda-kernels::FftShiftNdPlan: output shape != descriptor.shape",
));
}
let contig = contiguous_stride(expected_shape);
if args.input.stride != contig {
return Err(Error::Unsupported(
"baracuda-kernels::FftShiftNdPlan: input must be contiguous",
));
}
if args.output.stride != contig {
return Err(Error::Unsupported(
"baracuda-kernels::FftShiftNdPlan: output must be contiguous",
));
}
let numel = args.output.numel();
let x_len = args.input.data.len() as i64;
let y_len = args.output.data.len() as i64;
if x_len < numel || y_len < numel {
return Err(Error::BufferTooSmall {
needed: numel as usize,
got: x_len.min(y_len) as usize,
});
}
Ok(())
}
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: FftShiftNdArgs<'_, T, N>,
) -> Result<()> {
self.can_implement(&args)?;
let numel = args.output.numel();
if numel == 0 {
return Ok(());
}
let ndim = self.desc.ndim as usize;
let mut shape_arr = [0i32; FFTSHIFT_ND_MAX_RANK];
let mut shift_amt_arr = [0i32; FFTSHIFT_ND_MAX_RANK];
let mut stride_arr = [0i64; FFTSHIFT_ND_MAX_RANK];
for i in 0..ndim {
shape_arr[i] = self.desc.shape[i];
stride_arr[i] = args.output.stride[i];
}
let num_shift_axes = self.desc.num_shift_axes as usize;
for &axis in &self.desc.shift_axes[..num_shift_axes] {
let a = axis as usize;
let n = shape_arr[a];
let half = n / 2;
shift_amt_arr[a] = if self.desc.inverse { n - half } else { half };
}
let x_ptr = args.input.data.as_raw().0 as *const c_void;
let y_ptr = args.output.data.as_raw().0 as *mut c_void;
let stream_ptr = stream.as_raw() as *mut c_void;
let rank = ndim as i32;
let size = core::mem::size_of::<T>();
let status = unsafe {
match size {
4 => baracuda_kernels_fftshift_nd_4_run(
numel,
rank,
shape_arr.as_ptr(),
shift_amt_arr.as_ptr(),
stride_arr.as_ptr(),
x_ptr,
y_ptr,
core::ptr::null_mut(),
0,
stream_ptr,
),
8 => baracuda_kernels_fftshift_nd_8_run(
numel,
rank,
shape_arr.as_ptr(),
shift_amt_arr.as_ptr(),
stride_arr.as_ptr(),
x_ptr,
y_ptr,
core::ptr::null_mut(),
0,
stream_ptr,
),
16 => baracuda_kernels_fftshift_nd_16_run(
numel,
rank,
shape_arr.as_ptr(),
shift_amt_arr.as_ptr(),
stride_arr.as_ptr(),
x_ptr,
y_ptr,
core::ptr::null_mut(),
0,
stream_ptr,
),
_ => unreachable!("select() gates on size_of::<T>() in 4 / 8 / 16"),
}
};
map_status(status)
}
}