use cubecl::prelude::*;
use crate::butterfly::{
bit_reverse,
butterfly_inner, butterfly_inner_batch,
butterfly_stage, butterfly_stage_batch,
butterfly_stage_radix4, butterfly_stage_radix4_batch,
};
use crate::{TILE_BITS, TILE_SIZE, WORKGROUP_SIZE};
#[must_use]
pub fn fft<R: Runtime>(device: &R::Device, input: &[f32]) -> (Vec<f32>, Vec<f32>) {
let n_orig = input.len();
let n = n_orig.next_power_of_two();
if n <= 1 {
let mut real = vec![0.0f32; n];
if n == 1 && n_orig == 1 {
real[0] = input[0];
}
return (real, vec![0.0f32; n]);
}
let m = n.ilog2() as usize;
let mut real = vec![0.0f32; n];
for (i, &v) in input.iter().enumerate() {
real[bit_reverse(i, m as u32)] = v;
}
let imag = vec![0.0f32; n];
let client = R::client(device);
let real_handle = client.create_from_slice(f32::as_bytes(&real));
let imag_handle = client.create_from_slice(f32::as_bytes(&imag));
let inner_stages = m.min(TILE_BITS);
let tile = TILE_SIZE.min(n); let num_tiles = (n / TILE_SIZE).max(1) as u32;
let wg_threads = (n / 2).min(TILE_SIZE / 2) as u32;
unsafe {
butterfly_inner::launch::<f32, R>(
&client,
CubeCount::Static(num_tiles, 1, 1),
CubeDim::new_1d(wg_threads),
ArrayArg::from_raw_parts::<f32>(&real_handle, n, 1),
ArrayArg::from_raw_parts::<f32>(&imag_handle, n, 1),
tile, inner_stages, true, )
.expect("FFT inner (shared-memory) launch failed")
};
let outer_wg_r4 = ((n / 4) as u32 + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
let outer_wg_r2 = ((n / 2) as u32 + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
let mut s = inner_stages;
while s + 1 < m {
let q = 1_usize << s; unsafe {
butterfly_stage_radix4::launch::<f32, R>(
&client,
CubeCount::Static(outer_wg_r4, 1, 1),
CubeDim::new_1d(WORKGROUP_SIZE),
ArrayArg::from_raw_parts::<f32>(&real_handle, n, 1),
ArrayArg::from_raw_parts::<f32>(&imag_handle, n, 1),
n, q, true, )
.expect("FFT outer radix-4 butterfly launch failed")
};
s += 2;
}
if s < m {
let hs = 1_usize << s;
unsafe {
butterfly_stage::launch::<f32, R>(
&client,
CubeCount::Static(outer_wg_r2, 1, 1),
CubeDim::new_1d(WORKGROUP_SIZE),
ArrayArg::from_raw_parts::<f32>(&real_handle, n, 1),
ArrayArg::from_raw_parts::<f32>(&imag_handle, n, 1),
n, hs, true, )
.expect("FFT outer radix-2 trailing butterfly launch failed")
};
}
let real_out = f32::from_bytes(&client.read_one(real_handle)).to_vec();
let imag_out = f32::from_bytes(&client.read_one(imag_handle)).to_vec();
(real_out, imag_out)
}
#[must_use]
pub fn fft_batch<R: Runtime>(device: &R::Device, signals: &[Vec<f32>]) -> Vec<(Vec<f32>, Vec<f32>)> {
if signals.is_empty() {
return Vec::new();
}
let batch_size = signals.len();
let max_len = signals.iter().map(|s| s.len()).max().unwrap_or(0);
let n_raw = max_len.next_power_of_two().max(1);
if n_raw <= 1 {
return signals
.iter()
.map(|s| {
let mut real = vec![0.0f32; n_raw];
if n_raw == 1 && !s.is_empty() {
real[0] = s[0];
}
(real, vec![0.0f32; n_raw])
})
.collect();
}
let n = n_raw;
let m = n.ilog2() as usize;
let mut real_flat = vec![0.0f32; batch_size * n];
let imag_flat = vec![0.0f32; batch_size * n];
for (b, signal) in signals.iter().enumerate() {
let base = b * n;
for (i, &v) in signal.iter().enumerate() {
real_flat[base + bit_reverse(i, m as u32)] = v;
}
}
let client = R::client(device);
let total = batch_size * n;
let real_handle = client.create_from_slice(f32::as_bytes(&real_flat));
let imag_handle = client.create_from_slice(f32::as_bytes(&imag_flat));
let inner_stages = m.min(TILE_BITS);
let tile = TILE_SIZE.min(n);
let tiles_per_signal = (n / tile).max(1);
let wg_count = (tiles_per_signal * batch_size) as u32;
let wg_threads = (tile / 2) as u32;
unsafe {
butterfly_inner_batch::launch::<f32, R>(
&client,
CubeCount::Static(wg_count, 1, 1),
CubeDim::new_1d(wg_threads),
ArrayArg::from_raw_parts::<f32>(&real_handle, total, 1),
ArrayArg::from_raw_parts::<f32>(&imag_handle, total, 1),
n, tile, inner_stages, true, )
.expect("FFT batch inner (shared-memory) launch failed")
};
let total_groups_r4 = batch_size * (n / 4);
let total_pairs_r2 = batch_size * (n / 2);
let outer_wg_r4 = ((total_groups_r4 as u32) + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
let outer_wg_r2 = ((total_pairs_r2 as u32) + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
let mut s = inner_stages;
while s + 1 < m {
let q = 1_usize << s;
unsafe {
butterfly_stage_radix4_batch::launch::<f32, R>(
&client,
CubeCount::Static(outer_wg_r4, 1, 1),
CubeDim::new_1d(WORKGROUP_SIZE),
ArrayArg::from_raw_parts::<f32>(&real_handle, total, 1),
ArrayArg::from_raw_parts::<f32>(&imag_handle, total, 1),
n, q, batch_size, true, )
.expect("FFT batch outer radix-4 butterfly launch failed")
};
s += 2;
}
if s < m {
let hs = 1_usize << s;
unsafe {
butterfly_stage_batch::launch::<f32, R>(
&client,
CubeCount::Static(outer_wg_r2, 1, 1),
CubeDim::new_1d(WORKGROUP_SIZE),
ArrayArg::from_raw_parts::<f32>(&real_handle, total, 1),
ArrayArg::from_raw_parts::<f32>(&imag_handle, total, 1),
n, hs, batch_size, true, )
.expect("FFT batch outer radix-2 trailing butterfly launch failed")
};
}
let real_out = f32::from_bytes(&client.read_one(real_handle)).to_vec();
let imag_out = f32::from_bytes(&client.read_one(imag_handle)).to_vec();
(0..batch_size)
.map(|b| {
let start = b * n;
let end = start + n;
(real_out[start..end].to_vec(), imag_out[start..end].to_vec())
})
.collect()
}