use core::f32::consts::PI;
use cubecl::prelude::*;
use cubecl::std::tensor::{
AsView as _, AsViewExpand, AsViewMut as _, AsViewMutExpand, TensorHandle,
};
use crate::{
fft::{
FftMode,
fft_parallel::{bit_reverse, fft_butterfly_parallel},
},
layout::BatchSignalLayout,
};
pub(crate) const MAX_SHARED_N_FFT: usize = 4096;
const MAX_UNITS_PER_CUBE: usize = 256;
pub(crate) struct CfftBindings<R: Runtime> {
pub(crate) input_re: TensorBinding<R>,
pub(crate) input_im: TensorBinding<R>,
pub(crate) output_re: TensorBinding<R>,
pub(crate) output_im: TensorBinding<R>,
}
#[derive(Clone, Copy)]
struct CfftPlan {
dim: usize,
count: usize,
n_fft: usize,
fft_mode: FftMode,
}
pub(crate) fn factor_four_step(n_fft: usize) -> (usize, usize) {
assert!(
n_fft.is_power_of_two(),
"four-step needs power-of-two n_fft"
);
let log2_n = n_fft.trailing_zeros() as usize;
let max_log2 = MAX_SHARED_N_FFT.trailing_zeros() as usize;
let log2_n1 = log2_n / 2;
let log2_n2 = log2_n - log2_n1;
let (log2_n1, log2_n2) = if log2_n2 > max_log2 {
(log2_n - max_log2, max_log2)
} else {
(log2_n1, log2_n2)
};
assert!(
log2_n1 <= max_log2 && log2_n2 <= max_log2,
"four-step cannot handle n_fft = {n_fft} with MAX_SHARED_N_FFT = {MAX_SHARED_N_FFT}",
);
(1 << log2_n1, 1 << log2_n2)
}
pub(crate) fn cfft_launch_any_size<R: Runtime>(
client: &ComputeClient<R>,
bindings: CfftBindings<R>,
dim: usize,
dtype: StorageType,
fft_mode: FftMode,
) -> Result<(), LaunchError> {
let n_fft = bindings.input_re.shape[dim];
assert!(n_fft.is_power_of_two(), "cfft needs power-of-two n_fft");
assert!(n_fft >= 2);
let count: usize = bindings
.input_re
.shape
.iter()
.enumerate()
.filter(|(i, _)| *i != dim)
.map(|(_, e)| *e)
.product();
if count == 0 {
return Ok(());
}
let plan = CfftPlan {
dim,
count,
n_fft,
fft_mode,
};
if n_fft <= MAX_SHARED_N_FFT {
cfft_shared_launch::<R>(client, bindings, plan)
} else {
cfft_four_step_launch::<R>(client, bindings, dtype, plan)
}
}
fn cfft_shared_launch<R: Runtime>(
client: &ComputeClient<R>,
bindings: CfftBindings<R>,
plan: CfftPlan,
) -> Result<(), LaunchError> {
let log2_n = plan.n_fft.trailing_zeros() as usize;
let threads_per_cube = (plan.n_fft / 2).clamp(1, MAX_UNITS_PER_CUBE);
let cube_dim = CubeDim::new_1d(threads_per_cube as u32);
let cube_count =
cubecl::calculate_cube_count_elemwise(client, plan.count, CubeDim::new_single());
cfft_shared_kernel::launch::<f32, R>(
client,
cube_count,
cube_dim,
bindings.input_re.into_tensor_arg(),
bindings.input_im.into_tensor_arg(),
bindings.output_re.into_tensor_arg(),
bindings.output_im.into_tensor_arg(),
plan.count as u32,
plan.n_fft,
log2_n,
threads_per_cube,
plan.dim,
plan.fft_mode,
);
Ok(())
}
#[cube(launch)]
fn cfft_shared_kernel<F: Float>(
input_re: &Tensor<F>,
input_im: &Tensor<F>,
output_re: &mut Tensor<F>,
output_im: &mut Tensor<F>,
num_windows: u32,
#[comptime] n_fft: usize,
#[comptime] log2_n: usize,
#[comptime] threads_per_cube: usize,
#[comptime] dim: usize,
#[comptime] fft_mode: FftMode,
) {
let window_index = CUBE_POS;
if (window_index as u32) >= num_windows {
terminate!();
}
let input_re_view = input_re.view(BatchSignalLayout::new(input_re, window_index, dim));
let input_im_view = input_im.view(BatchSignalLayout::new(input_im, window_index, dim));
let mut output_re_view =
output_re.view_mut(BatchSignalLayout::new(output_re, window_index, dim));
let mut output_im_view =
output_im.view_mut(BatchSignalLayout::new(output_im, window_index, dim));
let mut shared_re = SharedMemory::<F>::new(n_fft);
let mut shared_im = SharedMemory::<F>::new(n_fft);
let mut i = UNIT_POS as usize;
while i < n_fft {
let j = bit_reverse(i, log2_n);
shared_re[j] = input_re_view[i];
shared_im[j] = input_im_view[i];
i += threads_per_cube;
}
sync_cube();
fft_butterfly_parallel::<F>(
&mut shared_re,
&mut shared_im,
n_fft,
log2_n,
threads_per_cube,
fft_mode,
);
let mut k = UNIT_POS as usize;
while k < n_fft {
output_re_view[k] = shared_re[k];
output_im_view[k] = shared_im[k];
k += threads_per_cube;
}
}
fn cfft_four_step_launch<R: Runtime>(
client: &ComputeClient<R>,
bindings: CfftBindings<R>,
dtype: StorageType,
plan: CfftPlan,
) -> Result<(), LaunchError> {
let (n1, n2) = factor_four_step(plan.n_fft);
let scratch_shape: Vec<usize> = bindings.input_re.shape.to_vec();
let elems: usize = scratch_shape.iter().product();
let scratch_re = TensorHandle::<R>::new_contiguous(
scratch_shape.clone(),
client.empty(elems * dtype.size()),
dtype,
);
let scratch_im = TensorHandle::<R>::new_contiguous(
scratch_shape.clone(),
client.empty(elems * dtype.size()),
dtype,
);
{
let threads_per_cube = (n1 / 2).clamp(1, MAX_UNITS_PER_CUBE);
let log2_n1 = n1.trailing_zeros() as usize;
let cube_dim = CubeDim::new_1d(threads_per_cube as u32);
let cube_count =
cubecl::calculate_cube_count_elemwise(client, plan.count * n2, CubeDim::new_single());
cfft_four_step_radix1_kernel::launch::<f32, R>(
client,
cube_count,
cube_dim,
bindings.input_re.into_tensor_arg(),
bindings.input_im.into_tensor_arg(),
scratch_re.clone().binding().into_tensor_arg(),
scratch_im.clone().binding().into_tensor_arg(),
(plan.count * n2) as u32,
n1,
n2,
log2_n1,
threads_per_cube,
plan.dim,
plan.fft_mode,
);
}
{
let threads_per_cube = (n2 / 2).clamp(1, MAX_UNITS_PER_CUBE);
let log2_n2 = n2.trailing_zeros() as usize;
let cube_dim = CubeDim::new_1d(threads_per_cube as u32);
let cube_count =
cubecl::calculate_cube_count_elemwise(client, plan.count * n1, CubeDim::new_single());
cfft_four_step_radix2_kernel::launch::<f32, R>(
client,
cube_count,
cube_dim,
scratch_re.clone().binding().into_tensor_arg(),
scratch_im.clone().binding().into_tensor_arg(),
(plan.count * n1) as u32,
n1,
n2,
log2_n2,
threads_per_cube,
plan.dim,
plan.fft_mode,
);
}
{
let total = plan.count * plan.n_fft;
let cube_dim = CubeDim::new_1d(256);
let cube_count = cubecl::calculate_cube_count_elemwise(client, total, cube_dim);
cfft_four_step_transpose_kernel::launch::<f32, R>(
client,
cube_count,
cube_dim,
scratch_re.binding().into_tensor_arg(),
scratch_im.binding().into_tensor_arg(),
bindings.output_re.into_tensor_arg(),
bindings.output_im.into_tensor_arg(),
total as u32,
n1,
n2,
plan.dim,
);
}
Ok(())
}
#[cube(launch)]
fn cfft_four_step_radix1_kernel<F: Float>(
input_re: &Tensor<F>,
input_im: &Tensor<F>,
scratch_re: &mut Tensor<F>,
scratch_im: &mut Tensor<F>,
num_cubes: u32,
#[comptime] n1: usize,
#[comptime] n2: usize,
#[comptime] log2_n1: usize,
#[comptime] threads_per_cube: usize,
#[comptime] dim: usize,
#[comptime] fft_mode: FftMode,
) {
let cube_pos = CUBE_POS;
if cube_pos >= num_cubes as usize {
terminate!();
}
let window = cube_pos / n2;
let n2_idx = cube_pos - window * n2;
let input_re_view = input_re.view(BatchSignalLayout::new(input_re, window, dim));
let input_im_view = input_im.view(BatchSignalLayout::new(input_im, window, dim));
let mut scratch_re_view = scratch_re.view_mut(BatchSignalLayout::new(scratch_re, window, dim));
let mut scratch_im_view = scratch_im.view_mut(BatchSignalLayout::new(scratch_im, window, dim));
let mut shared_re = SharedMemory::<F>::new(n1);
let mut shared_im = SharedMemory::<F>::new(n1);
let mut i = UNIT_POS as usize;
while i < n1 {
let j = bit_reverse(i, log2_n1);
let flat = i * n2 + n2_idx;
shared_re[j] = input_re_view[flat];
shared_im[j] = input_im_view[flat];
i += threads_per_cube;
}
sync_cube();
fft_butterfly_parallel::<F>(
&mut shared_re,
&mut shared_im,
n1,
log2_n1,
threads_per_cube,
fft_mode,
);
let sign = F::new(fft_mode.sign());
let n_total = comptime![n1 * n2];
let two_pi = F::new(2.0 * PI);
let mut k1 = UNIT_POS as usize;
while k1 < n1 {
let theta = sign * two_pi * F::cast_from(k1 * n2_idx) / F::cast_from(n_total);
let w_re = theta.cos();
let w_im = theta.sin();
let ar = shared_re[k1];
let ai = shared_im[k1];
let flat = k1 * n2 + n2_idx;
scratch_re_view[flat] = w_re * ar - w_im * ai;
scratch_im_view[flat] = w_re * ai + w_im * ar;
k1 += threads_per_cube;
}
}
#[cube(launch)]
fn cfft_four_step_radix2_kernel<F: Float>(
scratch_re: &mut Tensor<F>,
scratch_im: &mut Tensor<F>,
num_cubes: u32,
#[comptime] n1: usize,
#[comptime] n2: usize,
#[comptime] log2_n2: usize,
#[comptime] threads_per_cube: usize,
#[comptime] dim: usize,
#[comptime] fft_mode: FftMode,
) {
let cube_pos = CUBE_POS;
if cube_pos >= num_cubes as usize {
terminate!();
}
let window = cube_pos / n1;
let k1 = cube_pos - window * n1;
let row_base = k1 * n2;
let mut scratch_re_view = scratch_re.view_mut(BatchSignalLayout::new(scratch_re, window, dim));
let mut scratch_im_view = scratch_im.view_mut(BatchSignalLayout::new(scratch_im, window, dim));
let mut shared_re = SharedMemory::<F>::new(n2);
let mut shared_im = SharedMemory::<F>::new(n2);
let mut i = UNIT_POS as usize;
while i < n2 {
let j = bit_reverse(i, log2_n2);
shared_re[j] = scratch_re_view[row_base + i];
shared_im[j] = scratch_im_view[row_base + i];
i += threads_per_cube;
}
sync_cube();
fft_butterfly_parallel::<F>(
&mut shared_re,
&mut shared_im,
n2,
log2_n2,
threads_per_cube,
fft_mode,
);
let mut k2 = UNIT_POS as usize;
while k2 < n2 {
scratch_re_view[row_base + k2] = shared_re[k2];
scratch_im_view[row_base + k2] = shared_im[k2];
k2 += threads_per_cube;
}
}
#[cube(launch)]
fn cfft_four_step_transpose_kernel<F: Float>(
scratch_re: &Tensor<F>,
scratch_im: &Tensor<F>,
output_re: &mut Tensor<F>,
output_im: &mut Tensor<F>,
total: u32,
#[comptime] n1: usize,
#[comptime] n2: usize,
#[comptime] dim: usize,
) {
let pos = ABSOLUTE_POS;
if pos >= total as usize {
terminate!();
}
let m = comptime![n1 * n2];
let pos_u = pos;
let inner = pos_u % m;
let window = pos_u / m;
let scratch_re_view = scratch_re.view(BatchSignalLayout::new(scratch_re, window, dim));
let scratch_im_view = scratch_im.view(BatchSignalLayout::new(scratch_im, window, dim));
let mut output_re_view = output_re.view_mut(BatchSignalLayout::new(output_re, window, dim));
let mut output_im_view = output_im.view_mut(BatchSignalLayout::new(output_im, window, dim));
let k2 = inner / n1;
let k1 = inner - k2 * n1;
let src = k1 * n2 + k2;
output_re_view[inner] = scratch_re_view[src];
output_im_view[inner] = scratch_im_view[src];
}