use core::f32::consts::PI;
use cubecl::prelude::*;
use crate::fft::FftMode;
#[cube]
pub(crate) fn bit_reverse(i: usize, #[comptime] log2_n: usize) -> usize {
let mut j = 0usize;
let mut x = i;
#[unroll]
for _ in 0..log2_n {
j = (j << 1usize) | (x & 1usize);
x >>= 1usize;
}
j
}
#[cube]
pub(crate) fn fft_butterfly_parallel<F: Float>(
shared_re: &mut SharedMemory<F>,
shared_im: &mut SharedMemory<F>,
#[comptime] n_fft: usize,
#[comptime] log2_n: usize,
#[comptime] threads_per_cube: usize,
#[comptime] fft_mode: FftMode,
) {
let num_butterflies = comptime![n_fft / 2];
let sign = F::new(fft_mode.sign());
let two_pi = F::new(2.0 * PI);
let mut s = 0usize;
while s < log2_n {
let half_m = 1usize << s;
let m = half_m << 1usize;
let mut b = UNIT_POS as usize;
while b < num_butterflies {
let group = b / half_m;
let j = b - group * half_m;
let i0 = group * m + j;
let i1 = i0 + half_m;
let theta: F = sign * two_pi * F::cast_from(j) / F::cast_from(m);
let w_re = theta.cos();
let w_im = theta.sin();
let ar = shared_re[i0];
let ai = shared_im[i0];
let br = shared_re[i1];
let bi = shared_im[i1];
let tr = w_re * br - w_im * bi;
let ti = w_re * bi + w_im * br;
shared_re[i0] = ar + tr;
shared_im[i0] = ai + ti;
shared_re[i1] = ar - tr;
shared_im[i1] = ai - ti;
b += threads_per_cube;
}
sync_cube();
s += 1usize;
}
}