use cubecl::prelude::*;
use std::f32::consts::PI;
#[cube(launch)]
pub fn butterfly_stage<F: Float>(
real: &mut Array<F>,
imag: &mut Array<F>,
#[comptime] n: usize,
#[comptime] half_stride: usize,
#[comptime] forward: bool,
) {
let tid = ABSOLUTE_POS;
if tid < n / 2 {
let k = tid % half_stride;
let i = (tid / half_stride) * (half_stride * 2) + k;
let j = i + half_stride;
let sign = if forward { F::new(-1.0) } else { F::new(1.0) };
let angle = sign * F::new(PI) * F::cast_from(k) / F::cast_from(half_stride);
let cos_a = F::cos(angle);
let sin_a = F::sin(angle);
let ur = real[i];
let ui = imag[i];
let vr = cos_a * real[j] - sin_a * imag[j];
let vi = sin_a * real[j] + cos_a * imag[j];
real[i] = ur + vr;
imag[i] = ui + vi;
real[j] = ur - vr;
imag[j] = ui - vi;
}
}
#[cube(launch)]
pub fn butterfly_inner<F: Float>(
real: &mut Array<F>,
imag: &mut Array<F>,
#[comptime] tile: usize, #[comptime] stages: usize, #[comptime] forward: bool,
) {
let mut s_real = SharedMemory::<F>::new(tile);
let mut s_imag = SharedMemory::<F>::new(tile);
let half_tile = tile / 2;
let tid = ABSOLUTE_POS;
let local = tid % half_tile;
let base = (tid / half_tile) * tile;
s_real[local] = real[base + local];
s_real[local + half_tile] = real[base + local + half_tile];
s_imag[local] = imag[base + local];
s_imag[local + half_tile] = imag[base + local + half_tile];
sync_cube();
for s in 0..stages {
let hs = 1_usize << s;
let k = local % hs;
let i = (local / hs) * (hs * 2) + k; let j = i + hs;
let sign = if forward { F::new(-1.0) } else { F::new(1.0) };
let angle = sign * F::new(PI) * F::cast_from(k) / F::cast_from(hs);
let cos_a = F::cos(angle);
let sin_a = F::sin(angle);
let ur = s_real[i];
let ui = s_imag[i];
let vr = cos_a * s_real[j] - sin_a * s_imag[j];
let vi = sin_a * s_real[j] + cos_a * s_imag[j];
s_real[i] = ur + vr;
s_imag[i] = ui + vi;
s_real[j] = ur - vr;
s_imag[j] = ui - vi;
sync_cube(); }
real[base + local] = s_real[local];
real[base + local + half_tile] = s_real[local + half_tile];
imag[base + local] = s_imag[local];
imag[base + local + half_tile] = s_imag[local + half_tile];
}
#[cube(launch)]
pub fn butterfly_stage_batch<F: Float>(
real: &mut Array<F>,
imag: &mut Array<F>,
#[comptime] n: usize,
#[comptime] half_stride: usize,
#[comptime] batch_size: usize,
#[comptime] forward: bool,
) {
let tid = ABSOLUTE_POS;
if tid < batch_size * (n / 2) {
let signal = tid / (n / 2);
let pos = tid % (n / 2);
let offset = signal * n;
let k = pos % half_stride;
let i = (pos / half_stride) * (half_stride * 2) + k;
let j = i + half_stride;
let sign = if forward { F::new(-1.0) } else { F::new(1.0) };
let angle = sign * F::new(PI) * F::cast_from(k) / F::cast_from(half_stride);
let cos_a = F::cos(angle);
let sin_a = F::sin(angle);
let ur = real[offset + i];
let ui = imag[offset + i];
let vr = cos_a * real[offset + j] - sin_a * imag[offset + j];
let vi = sin_a * real[offset + j] + cos_a * imag[offset + j];
real[offset + i] = ur + vr;
imag[offset + i] = ui + vi;
real[offset + j] = ur - vr;
imag[offset + j] = ui - vi;
}
}
#[cube(launch)]
pub fn butterfly_inner_batch<F: Float>(
real: &mut Array<F>,
imag: &mut Array<F>,
#[comptime] n: usize, #[comptime] tile: usize, #[comptime] stages: usize, #[comptime] forward: bool,
) {
let mut s_real = SharedMemory::<F>::new(tile);
let mut s_imag = SharedMemory::<F>::new(tile);
let half_tile = tile / 2; let tiles_per_signal = (n / tile).max(1);
let tid = ABSOLUTE_POS;
let local = tid % half_tile; let tile_global = tid / half_tile; let signal = tile_global / tiles_per_signal;
let tile_in_sig = tile_global % tiles_per_signal;
let base = signal * n + tile_in_sig * tile;
s_real[local] = real[base + local];
s_real[local + half_tile] = real[base + local + half_tile];
s_imag[local] = imag[base + local];
s_imag[local + half_tile] = imag[base + local + half_tile];
sync_cube();
for s in 0..stages {
let hs = 1_usize << s;
let k = local % hs;
let i = (local / hs) * (hs * 2) + k;
let j = i + hs;
let sign = if forward { F::new(-1.0) } else { F::new(1.0) };
let angle = sign * F::new(PI) * F::cast_from(k) / F::cast_from(hs);
let cos_a = F::cos(angle);
let sin_a = F::sin(angle);
let ur = s_real[i];
let ui = s_imag[i];
let vr = cos_a * s_real[j] - sin_a * s_imag[j];
let vi = sin_a * s_real[j] + cos_a * s_imag[j];
s_real[i] = ur + vr;
s_imag[i] = ui + vi;
s_real[j] = ur - vr;
s_imag[j] = ui - vi;
sync_cube();
}
real[base + local] = s_real[local];
real[base + local + half_tile] = s_real[local + half_tile];
imag[base + local] = s_imag[local];
imag[base + local + half_tile] = s_imag[local + half_tile];
}
#[cube(launch)]
pub fn butterfly_stage_radix4<F: Float>(
real: &mut Array<F>,
imag: &mut Array<F>,
#[comptime] n: usize,
#[comptime] q: usize, #[comptime] forward: bool,
) {
let tid = ABSOLUTE_POS;
if tid < n / 4 {
let k = tid % q;
let group = tid / q;
let p = group * (q * 4) + k;
let ar = real[p];
let ai = imag[p];
let br = real[p + q];
let bi = imag[p + q];
let cr = real[p + q * 2];
let ci = imag[p + q * 2];
let dr = real[p + q * 3];
let di = imag[p + q * 3];
let sign = if forward { F::new(-1.0) } else { F::new(1.0) };
let angle1 = sign * F::new(PI) * F::cast_from(k) / F::cast_from(q);
let cos1 = F::cos(angle1);
let sin1 = F::sin(angle1);
let w1b_r = cos1 * br - sin1 * bi;
let w1b_i = sin1 * br + cos1 * bi;
let w1d_r = cos1 * dr - sin1 * di;
let w1d_i = sin1 * dr + cos1 * di;
let u0r = ar + w1b_r; let u0i = ai + w1b_i;
let u1r = ar - w1b_r; let u1i = ai - w1b_i;
let u2r = cr + w1d_r; let u2i = ci + w1d_i;
let u3r = cr - w1d_r; let u3i = ci - w1d_i;
let angle2a = sign * F::new(PI) * F::cast_from(k) / F::cast_from(q * 2);
let cos2a = F::cos(angle2a);
let sin2a = F::sin(angle2a);
let neg_sign = if forward { F::new(1.0) } else { F::new(-1.0) };
let cos2b = neg_sign * sin2a;
let sin2b = sign * cos2a;
let w2a_u2r = cos2a * u2r - sin2a * u2i;
let w2a_u2i = sin2a * u2r + cos2a * u2i;
let w2b_u3r = cos2b * u3r - sin2b * u3i;
let w2b_u3i = sin2b * u3r + cos2b * u3i;
real[p] = u0r + w2a_u2r;
imag[p] = u0i + w2a_u2i;
real[p + q * 2] = u0r - w2a_u2r;
imag[p + q * 2] = u0i - w2a_u2i;
real[p + q] = u1r + w2b_u3r;
imag[p + q] = u1i + w2b_u3i;
real[p + q * 3] = u1r - w2b_u3r;
imag[p + q * 3] = u1i - w2b_u3i;
}
}
#[cube(launch)]
pub fn butterfly_stage_radix4_batch<F: Float>(
real: &mut Array<F>,
imag: &mut Array<F>,
#[comptime] n: usize,
#[comptime] q: usize,
#[comptime] batch_size: usize,
#[comptime] forward: bool,
) {
let tid = ABSOLUTE_POS;
if tid < batch_size * (n / 4) {
let signal = tid / (n / 4);
let pos = tid % (n / 4);
let offset = signal * n;
let k = pos % q;
let group = pos / q;
let p = group * (q * 4) + k;
let ar = real[offset + p];
let ai = imag[offset + p];
let br = real[offset + p + q];
let bi = imag[offset + p + q];
let cr = real[offset + p + q * 2];
let ci = imag[offset + p + q * 2];
let dr = real[offset + p + q * 3];
let di = imag[offset + p + q * 3];
let sign = if forward { F::new(-1.0) } else { F::new(1.0) };
let angle1 = sign * F::new(PI) * F::cast_from(k) / F::cast_from(q);
let cos1 = F::cos(angle1);
let sin1 = F::sin(angle1);
let w1b_r = cos1 * br - sin1 * bi;
let w1b_i = sin1 * br + cos1 * bi;
let w1d_r = cos1 * dr - sin1 * di;
let w1d_i = sin1 * dr + cos1 * di;
let u0r = ar + w1b_r; let u0i = ai + w1b_i;
let u1r = ar - w1b_r; let u1i = ai - w1b_i;
let u2r = cr + w1d_r; let u2i = ci + w1d_i;
let u3r = cr - w1d_r; let u3i = ci - w1d_i;
let angle2a = sign * F::new(PI) * F::cast_from(k) / F::cast_from(q * 2);
let cos2a = F::cos(angle2a);
let sin2a = F::sin(angle2a);
let neg_sign = if forward { F::new(1.0) } else { F::new(-1.0) };
let cos2b = neg_sign * sin2a;
let sin2b = sign * cos2a;
let w2a_u2r = cos2a * u2r - sin2a * u2i;
let w2a_u2i = sin2a * u2r + cos2a * u2i;
let w2b_u3r = cos2b * u3r - sin2b * u3i;
let w2b_u3i = sin2b * u3r + cos2b * u3i;
real[offset + p] = u0r + w2a_u2r;
imag[offset + p] = u0i + w2a_u2i;
real[offset + p + q * 2] = u0r - w2a_u2r;
imag[offset + p + q * 2] = u0i - w2a_u2i;
real[offset + p + q] = u1r + w2b_u3r;
imag[offset + p + q] = u1i + w2b_u3i;
real[offset + p + q * 3] = u1r - w2b_u3r;
imag[offset + p + q * 3] = u1i - w2b_u3i;
}
}
#[inline]
pub fn bit_reverse(mut x: usize, bits: u32) -> usize {
let mut r = 0usize;
for _ in 0..bits {
r = (r << 1) | (x & 1);
x >>= 1;
}
r
}