use std::f64::consts::PI;
use scirs2_core::numeric::Complex64;
use super::types::{FftDirection, GpuFftError, GpuFftResult, NormalizationMode};
pub fn compute_twiddles_gpu(n: usize) -> GpuFftResult<Vec<Complex64>> {
if n < 2 {
return Err(GpuFftError::SizeTooSmall(n));
}
let half = n / 2;
let mut twiddles = Vec::with_capacity(half);
let angle_step = -2.0 * PI / n as f64;
let w1 = Complex64::new(angle_step.cos(), angle_step.sin());
let mut wk = Complex64::new(1.0, 0.0);
for _ in 0..half {
twiddles.push(wk);
wk = wk * w1;
}
Ok(twiddles)
}
pub fn compute_inverse_twiddles_gpu(n: usize) -> GpuFftResult<Vec<Complex64>> {
if n < 2 {
return Err(GpuFftError::SizeTooSmall(n));
}
let half = n / 2;
let mut twiddles = Vec::with_capacity(half);
let angle_step = 2.0 * PI / n as f64;
let w1 = Complex64::new(angle_step.cos(), angle_step.sin());
let mut wk = Complex64::new(1.0, 0.0);
for _ in 0..half {
twiddles.push(wk);
wk = wk * w1;
}
Ok(twiddles)
}
pub fn bit_reverse_permute_gpu(data: &mut [Complex64]) {
let n = data.len();
debug_assert!(
n.is_power_of_two(),
"bit_reverse_permute_gpu: n must be a power of two"
);
if n <= 1 {
return;
}
let log2n = n.trailing_zeros() as usize;
for i in 0..n {
let rev = bit_reverse(i, log2n);
if i < rev {
data.swap(i, rev);
}
}
}
#[inline]
fn bit_reverse(mut x: usize, bits: usize) -> usize {
let mut result = 0usize;
for _ in 0..bits {
result = (result << 1) | (x & 1);
x >>= 1;
}
result
}
pub fn butterfly_pass_gpu(data: &mut [Complex64], stride: usize, twiddles: &[Complex64]) {
let n = data.len();
let step = 2 * stride;
let twiddle_step = if !twiddles.is_empty() {
twiddles.len() / stride
} else {
1
};
let mut pos = 0;
while pos < n {
for k in 0..stride {
let twiddle_idx = k * twiddle_step;
let w = if twiddle_idx < twiddles.len() {
twiddles[twiddle_idx]
} else {
Complex64::new(1.0, 0.0)
};
let u = data[pos + k];
let v = w * data[pos + k + stride];
data[pos + k] = u + v;
data[pos + k + stride] = u - v;
}
pos += step;
}
}
pub fn cooley_tukey_gpu(
data: &mut [Complex64],
direction: FftDirection,
twiddles: &[Complex64],
) -> GpuFftResult<()> {
let n = data.len();
if n < 2 {
return Err(GpuFftError::SizeTooSmall(n));
}
if !n.is_power_of_two() {
return Err(GpuFftError::NonPowerOfTwo(n));
}
let effective_twiddles: Vec<Complex64> = match direction {
FftDirection::Forward => twiddles.to_vec(),
FftDirection::Inverse => twiddles.iter().map(|w| w.conj()).collect(),
};
bit_reverse_permute_gpu(data);
let mut stride = 1usize;
while stride < n {
butterfly_pass_gpu(data, stride, &effective_twiddles);
stride <<= 1;
}
if direction == FftDirection::Inverse {
let scale = 1.0 / n as f64;
for x in data.iter_mut() {
*x = *x * scale;
}
}
Ok(())
}
pub fn bluestein_gpu(data: &mut [Complex64], direction: FftDirection) -> GpuFftResult<()> {
let n = data.len();
if n < 2 {
return Err(GpuFftError::SizeTooSmall(n));
}
if n.is_power_of_two() {
let twiddles = compute_twiddles_gpu(n)?;
return cooley_tukey_gpu(data, direction, &twiddles);
}
let sign: f64 = match direction {
FftDirection::Forward => -1.0,
FftDirection::Inverse => 1.0,
};
let chirp: Vec<Complex64> = (0..n)
.map(|k| {
let angle = sign * PI * (k * k) as f64 / n as f64;
Complex64::new(angle.cos(), angle.sin())
})
.collect();
let mut a: Vec<Complex64> = data
.iter()
.zip(chirp.iter())
.map(|(&x, &c)| x * c)
.collect();
let m = next_pow2(2 * n - 1);
a.resize(m, Complex64::new(0.0, 0.0));
let mut b = vec![Complex64::new(0.0, 0.0); m];
for k in 0..n {
b[k] = chirp[k].conj();
}
for k in 1..n {
b[m - k] = chirp[k].conj();
}
let tw = compute_twiddles_gpu(m)?;
cooley_tukey_gpu(&mut a, FftDirection::Forward, &tw)
.map_err(|e| GpuFftError::KernelLaunchFailed(format!("bluestein sub-fft a: {e}")))?;
cooley_tukey_gpu(&mut b, FftDirection::Forward, &tw)
.map_err(|e| GpuFftError::KernelLaunchFailed(format!("bluestein sub-fft b: {e}")))?;
for (ai, bi) in a.iter_mut().zip(b.iter()) {
*ai = *ai * *bi;
}
cooley_tukey_gpu(&mut a, FftDirection::Inverse, &tw)
.map_err(|e| GpuFftError::KernelLaunchFailed(format!("bluestein ifft: {e}")))?;
let inv_scale = if direction == FftDirection::Inverse {
1.0 / n as f64
} else {
1.0
};
for (k, out) in data.iter_mut().enumerate() {
*out = a[k] * chirp[k] * inv_scale;
}
Ok(())
}
fn next_pow2(n: usize) -> usize {
if n.is_power_of_two() {
return n;
}
1usize << (usize::BITS - n.leading_zeros()) as usize
}
pub fn tiled_fft_1d(
data: &mut [Complex64],
tile_size: usize,
twiddles: &[Complex64],
direction: FftDirection,
) -> GpuFftResult<()> {
let n = data.len();
if n < 2 {
return Err(GpuFftError::SizeTooSmall(n));
}
if n <= tile_size && n.is_power_of_two() {
return cooley_tukey_gpu(data, direction, twiddles);
}
let effective_tile = tile_size.max(2);
let mut offset = 0;
while offset < n {
let end = (offset + effective_tile).min(n);
let chunk = &mut data[offset..end];
let chunk_n = chunk.len();
if chunk_n < 2 {
offset += effective_tile;
continue;
}
if chunk_n.is_power_of_two() && chunk_n <= twiddles.len() * 2 {
cooley_tukey_gpu(chunk, direction, twiddles)
.map_err(|e| GpuFftError::KernelLaunchFailed(format!("tiled chunk: {e}")))?;
} else {
bluestein_gpu(chunk, direction)
.map_err(|e| GpuFftError::KernelLaunchFailed(format!("tiled bluestein: {e}")))?;
}
offset += effective_tile;
}
Ok(())
}
pub fn apply_normalization(data: &mut [Complex64], mode: NormalizationMode) {
let n = data.len();
if n == 0 {
return;
}
match mode {
NormalizationMode::None => {}
NormalizationMode::Forward | NormalizationMode::Backward => {
let scale = 1.0 / n as f64;
for x in data.iter_mut() {
*x = *x * scale;
}
}
NormalizationMode::Ortho => {
let scale = 1.0 / (n as f64).sqrt();
for x in data.iter_mut() {
*x = *x * scale;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPS: f64 = 1e-9;
fn nearly_equal(a: Complex64, b: Complex64) -> bool {
(a.re - b.re).abs() < EPS && (a.im - b.im).abs() < EPS
}
#[test]
fn test_twiddle_size_zero_fails() {
assert!(compute_twiddles_gpu(0).is_err());
assert!(compute_twiddles_gpu(1).is_err());
}
#[test]
fn test_twiddle_size_two() {
let tw = compute_twiddles_gpu(2).expect("twiddles for n=2");
assert_eq!(tw.len(), 1);
assert!(nearly_equal(tw[0], Complex64::new(1.0, 0.0)));
}
#[test]
fn test_bit_reverse_size8() {
let mut data: Vec<Complex64> = (0..8_u64).map(|i| Complex64::new(i as f64, 0.0)).collect();
bit_reverse_permute_gpu(&mut data);
let expected = [0.0, 4.0, 2.0, 6.0, 1.0, 5.0, 3.0, 7.0];
for (i, &e) in expected.iter().enumerate() {
assert!(
(data[i].re - e).abs() < EPS,
"index {i}: got {}",
data[i].re
);
}
}
#[test]
fn test_butterfly_pass_size2() {
let mut data = vec![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
let twiddles = vec![Complex64::new(1.0, 0.0)];
butterfly_pass_gpu(&mut data, 1, &twiddles);
assert!(nearly_equal(data[0], Complex64::new(3.0, 0.0)));
assert!(nearly_equal(data[1], Complex64::new(-1.0, 0.0)));
}
#[test]
fn test_cooley_tukey_identity() {
let original: Vec<Complex64> = (0..8).map(|i| Complex64::new(i as f64, 0.0)).collect();
let mut data = original.clone();
let tw = compute_twiddles_gpu(8).expect("twiddles");
cooley_tukey_gpu(&mut data, FftDirection::Forward, &tw).expect("fft");
cooley_tukey_gpu(&mut data, FftDirection::Inverse, &tw).expect("ifft");
for (i, (got, exp)) in data.iter().zip(original.iter()).enumerate() {
assert!(
(got.re - exp.re).abs() < 1e-10,
"index {i}: re mismatch got {} exp {}",
got.re,
exp.re
);
}
}
}