use scirs2_core::numeric::Complex64;
use std::f64::consts::PI;
use crate::error::{FFTError, FFTResult};
pub fn generate_twiddle_table(n: usize) -> FFTResult<Vec<Complex64>> {
if n == 0 {
return Err(FFTError::ValueError(
"generate_twiddle_table: n must be > 0".into(),
));
}
if n == 1 {
return Ok(vec![Complex64::new(1.0, 0.0)]);
}
let inv_n = -2.0 * PI / n as f64;
Ok((0..n)
.map(|k| {
let angle = inv_n * k as f64;
Complex64::new(angle.cos(), angle.sin())
})
.collect())
}
pub fn generate_inverse_twiddle_table(n: usize) -> FFTResult<Vec<Complex64>> {
if n == 0 {
return Err(FFTError::ValueError(
"generate_inverse_twiddle_table: n must be > 0".into(),
));
}
if n == 1 {
return Ok(vec![Complex64::new(1.0, 0.0)]);
}
let inv_n = 2.0 * PI / n as f64;
Ok((0..n)
.map(|k| {
let angle = inv_n * k as f64;
Complex64::new(angle.cos(), angle.sin())
})
.collect())
}
#[inline(always)]
pub fn butterfly2(a: &mut Complex64, b: &mut Complex64, twiddle: Complex64) {
let t = twiddle * *b;
let new_a = *a + t;
let new_b = *a - t;
*a = new_a;
*b = new_b;
}
#[inline]
pub fn butterfly4(a: &mut [Complex64; 4], twiddles: &[Complex64; 3]) {
let x0 = a[0];
let x1 = a[1];
let x2 = a[2];
let x3 = a[3];
a[0] = x0 + x1 + x2 + x3;
a[1] = x0 + twiddles[0] * x1 + twiddles[1] * x2 + twiddles[2] * x3;
let w2 = twiddles[1]; let w4 = w2 * w2; let w6 = w4 * w2; a[2] = x0 + w2 * x1 + w4 * x2 + w6 * x3;
let w3 = twiddles[2]; let w9 = w3 * w3 * w3; a[3] = x0 + w3 * x1 + w6 * x2 + w9 * x3;
}
#[inline]
pub fn butterfly8(a: &mut [Complex64; 8], twiddles: &[Complex64; 7]) {
let w = [
Complex64::new(1.0, 0.0), twiddles[0], twiddles[1], twiddles[2], twiddles[3], twiddles[4], twiddles[5], twiddles[6], ];
let input = *a;
for k in 0..8 {
let mut sum = Complex64::new(0.0, 0.0);
for n in 0..8 {
let idx = (n * k) % 8;
sum += input[n] * w[idx];
}
a[k] = sum;
}
}
pub fn split_radix_butterfly(data: &mut [Complex64]) -> FFTResult<()> {
let n = data.len();
if n < 4 {
return Err(FFTError::ValueError(
"split_radix_butterfly: length must be >= 4".into(),
));
}
if !n.is_power_of_two() {
return Err(FFTError::ValueError(
"split_radix_butterfly: length must be a power of two".into(),
));
}
let bits = n.trailing_zeros();
for i in 0..n {
let j = reverse_bits(i, bits);
if i < j {
data.swap(i, j);
}
}
let mut size = 2;
while size <= n {
let half = size / 2;
let angle_step = -2.0 * PI / size as f64;
let mut group_start = 0;
while group_start < n {
for k in 0..half {
let angle = angle_step * k as f64;
let twiddle = Complex64::new(angle.cos(), angle.sin());
let i = group_start + k;
let j = i + half;
let t = twiddle * data[j];
data[j] = data[i] - t;
data[i] = data[i] + t;
}
group_start += size;
}
size *= 2;
}
Ok(())
}
fn reverse_bits(x: usize, bits: u32) -> usize {
let mut result = 0usize;
let mut val = x;
for _ in 0..bits {
result = (result << 1) | (val & 1);
val >>= 1;
}
result
}
pub fn direct_dft(data: &[Complex64]) -> FFTResult<Vec<Complex64>> {
let n = data.len();
if n == 0 {
return Err(FFTError::ValueError("direct_dft: empty input".into()));
}
if n == 1 {
return Ok(data.to_vec());
}
let angle_base = -2.0 * PI / n as f64;
let mut result = vec![Complex64::new(0.0, 0.0); n];
for k in 0..n {
let mut sum = Complex64::new(0.0, 0.0);
for j in 0..n {
let angle = angle_base * (k * j) as f64;
let w = Complex64::new(angle.cos(), angle.sin());
sum += data[j] * w;
}
result[k] = sum;
}
Ok(result)
}
pub fn direct_idft(data: &[Complex64]) -> FFTResult<Vec<Complex64>> {
let n = data.len();
if n == 0 {
return Err(FFTError::ValueError("direct_idft: empty input".into()));
}
if n == 1 {
return Ok(data.to_vec());
}
let angle_base = 2.0 * PI / n as f64;
let inv_n = 1.0 / n as f64;
let mut result = vec![Complex64::new(0.0, 0.0); n];
for k in 0..n {
let mut sum = Complex64::new(0.0, 0.0);
for j in 0..n {
let angle = angle_base * (k * j) as f64;
let w = Complex64::new(angle.cos(), angle.sin());
sum += data[j] * w;
}
result[k] = sum * inv_n;
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
fn max_abs_err(a: &[Complex64], b: &[Complex64]) -> f64 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).norm())
.fold(0.0_f64, f64::max)
}
#[test]
fn test_twiddle_table_size_1() {
let tw = generate_twiddle_table(1).expect("should succeed");
assert_eq!(tw.len(), 1);
assert_relative_eq!(tw[0].re, 1.0, epsilon = 1e-15);
assert_relative_eq!(tw[0].im, 0.0, epsilon = 1e-15);
}
#[test]
fn test_twiddle_table_values() {
let n = 8;
let tw = generate_twiddle_table(n).expect("should succeed");
assert_eq!(tw.len(), n);
assert_relative_eq!(tw[0].re, 1.0, epsilon = 1e-14);
assert_relative_eq!(tw[0].im, 0.0, epsilon = 1e-14);
assert_relative_eq!(tw[n / 4].re, 0.0, epsilon = 1e-14);
assert_relative_eq!(tw[n / 4].im, -1.0, epsilon = 1e-14);
assert_relative_eq!(tw[n / 2].re, -1.0, epsilon = 1e-14);
assert_relative_eq!(tw[n / 2].im, 0.0, epsilon = 1e-14);
for w in &tw {
assert_relative_eq!(w.norm(), 1.0, epsilon = 1e-14);
}
}
#[test]
fn test_twiddle_table_error_on_zero() {
assert!(generate_twiddle_table(0).is_err());
}
#[test]
fn test_butterfly2_trivial_twiddle() {
let mut a = Complex64::new(3.0, 0.0);
let mut b = Complex64::new(1.0, 0.0);
butterfly2(&mut a, &mut b, Complex64::new(1.0, 0.0));
assert_relative_eq!(a.re, 4.0, epsilon = 1e-14);
assert_relative_eq!(b.re, 2.0, epsilon = 1e-14);
}
#[test]
fn test_butterfly2_with_twiddle() {
let mut a = Complex64::new(5.0, 0.0);
let mut b = Complex64::new(3.0, 0.0);
butterfly2(&mut a, &mut b, Complex64::new(-1.0, 0.0));
assert_relative_eq!(a.re, 2.0, epsilon = 1e-14);
assert_relative_eq!(b.re, 8.0, epsilon = 1e-14);
}
#[test]
fn test_butterfly4_matches_direct_dft() {
let input = [
Complex64::new(1.0, 0.0),
Complex64::new(2.0, 0.0),
Complex64::new(3.0, 0.0),
Complex64::new(4.0, 0.0),
];
let expected = direct_dft(&input).expect("direct_dft failed");
let twiddles = [
Complex64::new(0.0, -1.0), Complex64::new(-1.0, 0.0), Complex64::new(0.0, 1.0), ];
let mut data = input;
butterfly4(&mut data, &twiddles);
let err = max_abs_err(&data, &expected);
assert!(err < 1e-12, "butterfly4 error = {err}");
}
#[test]
fn test_butterfly8_matches_direct_dft() {
let input: [Complex64; 8] = [
Complex64::new(1.0, 0.0),
Complex64::new(2.0, -1.0),
Complex64::new(0.5, 0.5),
Complex64::new(3.0, 0.0),
Complex64::new(-1.0, 1.0),
Complex64::new(0.0, 2.0),
Complex64::new(1.5, -0.5),
Complex64::new(-0.5, 0.0),
];
let expected = direct_dft(&input).expect("direct_dft failed");
let twiddles: [Complex64; 7] = std::array::from_fn(|k| {
let angle = -2.0 * PI * (k + 1) as f64 / 8.0;
Complex64::new(angle.cos(), angle.sin())
});
let mut data = input;
butterfly8(&mut data, &twiddles);
let err = max_abs_err(&data, &expected);
assert!(err < 1e-10, "butterfly8 error = {err}");
}
#[test]
fn test_direct_dft_known_result() {
let input = vec![Complex64::new(1.0, 0.0); 4];
let result = direct_dft(&input).expect("direct_dft failed");
assert_relative_eq!(result[0].re, 4.0, epsilon = 1e-12);
for k in 1..4 {
assert!(result[k].norm() < 1e-12, "non-zero at k={k}");
}
}
#[test]
fn test_direct_dft_idft_roundtrip() {
let input = vec![
Complex64::new(1.0, 2.0),
Complex64::new(3.0, -1.0),
Complex64::new(0.5, 0.5),
Complex64::new(-2.0, 1.5),
];
let spectrum = direct_dft(&input).expect("dft failed");
let recovered = direct_idft(&spectrum).expect("idft failed");
let err = max_abs_err(&input, &recovered);
assert!(err < 1e-12, "roundtrip error = {err}");
}
#[test]
fn test_direct_dft_empty() {
assert!(direct_dft(&[]).is_err());
}
#[test]
fn test_split_radix_butterfly_size_4() {
let input = vec![
Complex64::new(1.0, 0.0),
Complex64::new(0.0, 1.0),
Complex64::new(-1.0, 0.0),
Complex64::new(0.0, -1.0),
];
let expected = direct_dft(&input).expect("dft failed");
let mut data = input;
split_radix_butterfly(&mut data).expect("split_radix failed");
let err = max_abs_err(&data, &expected);
assert!(err < 1e-10, "split_radix error (n=4) = {err}");
}
#[test]
fn test_split_radix_butterfly_size_8() {
let input: Vec<Complex64> = (0..8)
.map(|k| Complex64::new(k as f64, -(k as f64) * 0.5))
.collect();
let expected = direct_dft(&input).expect("dft failed");
let mut data = input;
split_radix_butterfly(&mut data).expect("split_radix failed");
let err = max_abs_err(&data, &expected);
assert!(err < 1e-10, "split_radix error (n=8) = {err}");
}
#[test]
fn test_split_radix_butterfly_size_16() {
let input: Vec<Complex64> = (0..16)
.map(|k| Complex64::new((k as f64 * 0.5).sin(), (k as f64 * 0.3).cos()))
.collect();
let expected = direct_dft(&input).expect("dft failed");
let mut data = input;
split_radix_butterfly(&mut data).expect("split_radix failed");
let err = max_abs_err(&data, &expected);
assert!(err < 1e-10, "split_radix error (n=16) = {err}");
}
#[test]
fn test_split_radix_butterfly_not_power_of_two() {
let mut data = vec![Complex64::new(1.0, 0.0); 6];
assert!(split_radix_butterfly(&mut data).is_err());
}
#[test]
fn test_split_radix_butterfly_too_small() {
let mut data = vec![Complex64::new(1.0, 0.0); 2];
assert!(split_radix_butterfly(&mut data).is_err());
}
#[test]
fn test_inverse_twiddle_table() {
let n = 8;
let fw = generate_twiddle_table(n).expect("forward failed");
let inv = generate_inverse_twiddle_table(n).expect("inverse failed");
for k in 0..n {
let product = fw[k] * inv[k];
assert_relative_eq!(product.re, 1.0, epsilon = 1e-14);
assert_relative_eq!(product.im, 0.0, epsilon = 1e-14);
}
}
}