#[cfg(not(feature = "std"))]
extern crate alloc;
mod radix2;
mod twiddle;
pub use radix2::{fft_fixed, fft_fixed_inplace, ifft_fixed, ifft_fixed_inplace};
pub use twiddle::{const_cos, const_sin, twiddle_factor};
use crate::kernel::Complex;
pub trait ConstFft<const N: usize> {
fn fft(input: &[Complex<f64>; N]) -> [Complex<f64>; N];
fn ifft(input: &[Complex<f64>; N]) -> [Complex<f64>; N];
fn fft_inplace(data: &mut [Complex<f64>; N]);
fn ifft_inplace(data: &mut [Complex<f64>; N]);
}
pub struct ConstFftImpl;
macro_rules! impl_const_fft {
($n:expr) => {
impl ConstFft<$n> for ConstFftImpl {
#[inline]
fn fft(input: &[Complex<f64>; $n]) -> [Complex<f64>; $n] {
radix2::fft_fixed(input)
}
#[inline]
fn ifft(input: &[Complex<f64>; $n]) -> [Complex<f64>; $n] {
radix2::ifft_fixed(input)
}
#[inline]
fn fft_inplace(data: &mut [Complex<f64>; $n]) {
radix2::fft_fixed_inplace(data);
}
#[inline]
fn ifft_inplace(data: &mut [Complex<f64>; $n]) {
radix2::ifft_fixed_inplace(data);
}
}
};
}
impl_const_fft!(2);
impl_const_fft!(4);
impl_const_fft!(8);
impl_const_fft!(16);
impl_const_fft!(32);
impl_const_fft!(64);
impl_const_fft!(128);
impl_const_fft!(256);
impl_const_fft!(512);
impl_const_fft!(1024);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_const_fft_impulse_8() {
let mut input = [Complex::<f64>::zero(); 8];
input[0] = Complex::new(1.0, 0.0);
let output = fft_fixed(&input);
for c in &output {
assert!((c.re - 1.0).abs() < 1e-10);
assert!(c.im.abs() < 1e-10);
}
}
#[test]
fn test_const_fft_roundtrip_8() {
let input: [Complex<f64>; 8] = [
Complex::new(1.0, 0.0),
Complex::new(2.0, 0.0),
Complex::new(3.0, 0.0),
Complex::new(4.0, 0.0),
Complex::new(5.0, 0.0),
Complex::new(6.0, 0.0),
Complex::new(7.0, 0.0),
Complex::new(8.0, 0.0),
];
let spectrum = fft_fixed(&input);
let recovered = ifft_fixed(&spectrum);
for i in 0..8 {
assert!(
(recovered[i].re - input[i].re).abs() < 1e-10,
"Real part mismatch at {}: {} vs {}",
i,
recovered[i].re,
input[i].re
);
assert!(
(recovered[i].im - input[i].im).abs() < 1e-10,
"Imag part mismatch at {i}"
);
}
}
#[test]
fn test_const_fft_dc_4() {
let input = [Complex::new(3.0, 0.0); 4];
let output = fft_fixed(&input);
assert!((output[0].re - 12.0).abs() < 1e-10);
assert!(output[0].im.abs() < 1e-10);
for i in 1..4 {
assert!(output[i].re.abs() < 1e-10);
assert!(output[i].im.abs() < 1e-10);
}
}
#[test]
fn test_const_fft_inplace_16() {
let original: [Complex<f64>; 16] =
core::array::from_fn(|i| Complex::new((i as f64 / 5.0).sin(), 0.0));
let mut data = original;
fft_fixed_inplace(&mut data);
ifft_fixed_inplace(&mut data);
for i in 0..16 {
assert!(
(data[i].re - original[i].re).abs() < 1e-10,
"Mismatch at {}: {} vs {}",
i,
data[i].re,
original[i].re
);
}
}
#[test]
fn test_const_fft_trait_64() {
let input: [Complex<f64>; 64] =
core::array::from_fn(|i| Complex::new(if i == 0 { 1.0 } else { 0.0 }, 0.0));
let output = <ConstFftImpl as ConstFft<64>>::fft(&input);
for c in &output {
assert!((c.re - 1.0).abs() < 1e-10);
assert!(c.im.abs() < 1e-10);
}
}
}