use cubecl::prelude::*;
use crate::butterfly::{
bit_reverse,
butterfly_inner, butterfly_inner_batch,
butterfly_stage, butterfly_stage_batch,
butterfly_stage_radix4, butterfly_stage_radix4_batch,
};
use crate::{TILE_BITS, TILE_SIZE, WORKGROUP_SIZE};
#[must_use]
pub fn ifft<R: Runtime>(
device: &R::Device,
input_real: &[f32],
input_imag: &[f32],
) -> Vec<f32> {
assert_eq!(
input_real.len(),
input_imag.len(),
"ifft: real and imag slices must have the same length"
);
let n = input_real.len();
assert!(
n.is_power_of_two(),
"ifft: input length {n} is not a power of two (pass the direct output of fft)"
);
if n <= 1 {
let mut out = input_real.to_vec();
out.extend_from_slice(input_imag);
return out;
}
let m = n.ilog2() as usize;
let mut real = vec![0.0f32; n];
let mut imag = vec![0.0f32; n];
for i in 0..n {
let j = bit_reverse(i, m as u32);
real[j] = input_real[i];
imag[j] = input_imag[i];
}
let client = R::client(device);
let real_handle = client.create_from_slice(f32::as_bytes(&real));
let imag_handle = client.create_from_slice(f32::as_bytes(&imag));
let inner_stages = m.min(TILE_BITS);
let tile = TILE_SIZE.min(n);
let num_tiles = (n / TILE_SIZE).max(1) as u32;
let wg_threads = (n / 2).min(TILE_SIZE / 2) as u32;
unsafe {
butterfly_inner::launch::<f32, R>(
&client,
CubeCount::Static(num_tiles, 1, 1),
CubeDim::new_1d(wg_threads),
ArrayArg::from_raw_parts::<f32>(&real_handle, n, 1),
ArrayArg::from_raw_parts::<f32>(&imag_handle, n, 1),
tile, inner_stages, false, )
.expect("IFFT inner (shared-memory) launch failed")
};
let outer_wg_r4 = ((n / 4) as u32 + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
let outer_wg_r2 = ((n / 2) as u32 + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
let mut s = inner_stages;
while s + 1 < m {
let q = 1_usize << s;
unsafe {
butterfly_stage_radix4::launch::<f32, R>(
&client,
CubeCount::Static(outer_wg_r4, 1, 1),
CubeDim::new_1d(WORKGROUP_SIZE),
ArrayArg::from_raw_parts::<f32>(&real_handle, n, 1),
ArrayArg::from_raw_parts::<f32>(&imag_handle, n, 1),
n, q, false, )
.expect("IFFT outer radix-4 butterfly launch failed")
};
s += 2;
}
if s < m {
let hs = 1_usize << s;
unsafe {
butterfly_stage::launch::<f32, R>(
&client,
CubeCount::Static(outer_wg_r2, 1, 1),
CubeDim::new_1d(WORKGROUP_SIZE),
ArrayArg::from_raw_parts::<f32>(&real_handle, n, 1),
ArrayArg::from_raw_parts::<f32>(&imag_handle, n, 1),
n, hs, false, )
.expect("IFFT outer radix-2 trailing butterfly launch failed")
};
}
let mut real_out = f32::from_bytes(&client.read_one(real_handle)).to_vec();
let mut imag_out = f32::from_bytes(&client.read_one(imag_handle)).to_vec();
let scale = (n as f32).recip();
for v in &mut real_out {
*v *= scale;
}
for v in &mut imag_out {
*v *= scale;
}
real_out.extend_from_slice(&imag_out);
real_out
}
#[must_use]
pub fn ifft_batch<R: Runtime>(
device: &R::Device,
signals: &[(Vec<f32>, Vec<f32>)],
) -> Vec<Vec<f32>> {
if signals.is_empty() {
return Vec::new();
}
let batch_size = signals.len();
let n = signals[0].0.len();
for (b, (re, im)) in signals.iter().enumerate() {
assert_eq!(
re.len(), im.len(),
"ifft_batch: signal {b}: real and imag slices have different lengths"
);
assert_eq!(
re.len(), n,
"ifft_batch: all signals must have the same length (expected {n}, got {})", re.len()
);
}
assert!(
n.is_power_of_two(),
"ifft_batch: signal length {n} is not a power of two"
);
if n <= 1 {
return signals
.iter()
.map(|(re, im)| {
let mut out = re.clone();
out.extend_from_slice(im);
out
})
.collect();
}
let m = n.ilog2() as usize;
let mut real_flat = vec![0.0f32; batch_size * n];
let mut imag_flat = vec![0.0f32; batch_size * n];
for (b, (input_real, input_imag)) in signals.iter().enumerate() {
let base = b * n;
for i in 0..n {
let j = bit_reverse(i, m as u32);
real_flat[base + j] = input_real[i];
imag_flat[base + j] = input_imag[i];
}
}
let client = R::client(device);
let total = batch_size * n;
let real_handle = client.create_from_slice(f32::as_bytes(&real_flat));
let imag_handle = client.create_from_slice(f32::as_bytes(&imag_flat));
let inner_stages = m.min(TILE_BITS);
let tile = TILE_SIZE.min(n);
let tiles_per_signal = (n / tile).max(1);
let wg_count = (tiles_per_signal * batch_size) as u32;
let wg_threads = (tile / 2) as u32;
unsafe {
butterfly_inner_batch::launch::<f32, R>(
&client,
CubeCount::Static(wg_count, 1, 1),
CubeDim::new_1d(wg_threads),
ArrayArg::from_raw_parts::<f32>(&real_handle, total, 1),
ArrayArg::from_raw_parts::<f32>(&imag_handle, total, 1),
n, tile, inner_stages, false, )
.expect("IFFT batch inner (shared-memory) launch failed")
};
let total_groups_r4 = batch_size * (n / 4);
let total_pairs_r2 = batch_size * (n / 2);
let outer_wg_r4 = ((total_groups_r4 as u32) + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
let outer_wg_r2 = ((total_pairs_r2 as u32) + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;
let mut s = inner_stages;
while s + 1 < m {
let q = 1_usize << s;
unsafe {
butterfly_stage_radix4_batch::launch::<f32, R>(
&client,
CubeCount::Static(outer_wg_r4, 1, 1),
CubeDim::new_1d(WORKGROUP_SIZE),
ArrayArg::from_raw_parts::<f32>(&real_handle, total, 1),
ArrayArg::from_raw_parts::<f32>(&imag_handle, total, 1),
n, q, batch_size, false, )
.expect("IFFT batch outer radix-4 butterfly launch failed")
};
s += 2;
}
if s < m {
let hs = 1_usize << s;
unsafe {
butterfly_stage_batch::launch::<f32, R>(
&client,
CubeCount::Static(outer_wg_r2, 1, 1),
CubeDim::new_1d(WORKGROUP_SIZE),
ArrayArg::from_raw_parts::<f32>(&real_handle, total, 1),
ArrayArg::from_raw_parts::<f32>(&imag_handle, total, 1),
n, hs, batch_size, false, )
.expect("IFFT batch outer radix-2 trailing butterfly launch failed")
};
}
let mut real_out = f32::from_bytes(&client.read_one(real_handle)).to_vec();
let mut imag_out = f32::from_bytes(&client.read_one(imag_handle)).to_vec();
let scale = (n as f32).recip();
for v in real_out.iter_mut() { *v *= scale; }
for v in imag_out.iter_mut() { *v *= scale; }
(0..batch_size)
.map(|b| {
let start = b * n;
let end = start + n;
let mut out = real_out[start..end].to_vec();
out.extend_from_slice(&imag_out[start..end]);
out
})
.collect()
}