use std::ops::{Add, Div, Mul, Neg, Sub};
use crate::{complex::Complex, Abs, Const, Cos, Inv, NumCast, One, Sin, Zero};
pub(super) fn fft<T>(a: &[Complex<T>]) -> Vec<Complex<T>>
where
T: Abs
+ Add<Output = T>
+ Clone
+ Const
+ Cos
+ Div<Output = T>
+ Inv
+ Mul<Output = T>
+ Neg<Output = T>
+ NumCast
+ One
+ PartialOrd
+ Sin
+ Sub<Output = T>
+ Zero,
{
iterative_fft(a, Transform::Direct)
}
pub(super) fn ifft<T>(y: &[Complex<T>]) -> Vec<Complex<T>>
where
T: Abs
+ Add<Output = T>
+ Clone
+ Const
+ Cos
+ Div<Output = T>
+ Inv
+ Mul<Output = T>
+ Neg<Output = T>
+ NumCast
+ One
+ PartialOrd
+ Sin
+ Sub<Output = T>
+ Zero,
{
iterative_fft(y, Transform::Inverse)
}
#[derive(Clone, Copy)]
enum Transform {
Direct,
Inverse,
}
fn log2(n: usize) -> u32 {
n.trailing_zeros()
}
fn bit_reverse_vec<T>(a: Vec<T>) -> Vec<T> {
let mut a = a;
let length = a.len();
let bits = log2(length);
for k in 0..length {
let r = rev(k, bits);
if k < r {
a.swap(k, r);
}
}
a
}
fn rev(k: usize, l: u32) -> usize {
k.reverse_bits().rotate_left(l)
}
fn extend_to_power_of_two<T: Clone + Zero>(a: &[T]) -> Vec<T> {
let n = a.len();
if n.is_power_of_two() {
a.to_vec()
} else {
let pot = n.next_power_of_two();
let mut b = a.to_vec();
b.resize(pot, T::zero());
b
}
}
#[allow(clippy::many_single_char_names, non_snake_case)]
fn iterative_fft<T>(a: &[Complex<T>], dir: Transform) -> Vec<Complex<T>>
where
T: Abs
+ Add<Output = T>
+ Clone
+ Const
+ Cos
+ Div<Output = T>
+ Inv
+ Mul<Output = T>
+ Neg<Output = T>
+ NumCast
+ One
+ PartialOrd
+ Sin
+ Sub<Output = T>
+ Zero,
{
let a = extend_to_power_of_two(a);
let n = a.len();
debug_assert!(n.is_power_of_two());
let mut A = bit_reverse_vec(a);
let sign = match dir {
Transform::Direct => -T::one(),
Transform::Inverse => T::one(),
};
let tau = T::tau();
for s in 1..=log2(n) {
let m = 1 << s;
let m_f = T::from(m).unwrap();
let exp = sign.clone() * tau.clone() / m_f;
let w_n = Complex::from_polar(T::one(), exp);
for k in (0..n).step_by(m) {
let mut w = Complex::one();
for j in 0..m / 2 {
let t = A[k + j + m / 2].clone() * w.clone();
let u = A[k + j].clone();
A[k + j] = u.clone() + t.clone();
A[k + j + m / 2] = u - t;
w = w * w_n.clone();
}
}
}
match dir {
Transform::Direct => A,
Transform::Inverse => {
let n_f = T::from(n).unwrap();
A.iter().map(|x| x / n_f.clone()).collect()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn reverse_bit() {
assert_eq!(0, rev(0, 3));
assert_eq!(4, rev(1, 3));
assert_eq!(2, rev(2, 3));
assert_eq!(6, rev(3, 3));
assert_eq!(1, rev(4, 3));
assert_eq!(5, rev(5, 3));
assert_eq!(3, rev(6, 3));
assert_eq!(7, rev(7, 3));
}
#[test]
fn reverse_copy() {
let a = vec![0, 1, 2, 3, 4, 5, 6, 7];
let b = bit_reverse_vec(a);
let expected = vec![0, 4, 2, 6, 1, 5, 3, 7];
assert_eq!(expected, b);
}
#[test]
fn reverse_copy_long() {
let a = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];
let b = bit_reverse_vec(a);
let expected = vec![0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15];
assert_eq!(expected, b);
}
#[test]
fn fft_iterative() {
let zero = Complex::<f32>::zero();
let one = Complex::<f32>::one();
let two = Complex::<f32>::new(2., 0.);
let a = vec![one, zero, one];
let f = iterative_fft(&a, Transform::Direct);
let expected = vec![two, zero, two, zero];
assert_eq!(expected, f);
}
#[test]
fn fft_ifft() {
let zero = Complex::<f64>::zero();
let one = Complex::<f64>::one();
let a = vec![one, zero, one, zero];
let f = fft(&a);
let a2 = ifft(&f);
assert_eq!(a, a2);
}
#[test]
fn fft_direct() {
let a = [0.0_f32, 2., 3., -1., 4., 5., 7., 9.];
let a: Vec<_> = a.iter().map(|&x| Complex::new(x, 0.)).collect();
let f = iterative_fft(&a, Transform::Direct);
let expected = vec![
Complex::new(29., 0.),
Complex::new(0.949_748_3, 13.192_388),
Complex::new(-6., 0.999_999_94),
Complex::new(-8.949_746, 5.192_387_6),
Complex::new(-1., 0.),
Complex::new(-8.949_748, -5.192_387_6),
Complex::new(-6., -0.999_999_94),
Complex::new(0.949_746_13, -13.192_388),
];
assert_eq!(expected, f);
}
#[test]
fn fft_inverse() {
let a = [
0.0_f32, 2., 3., -1., 4., 5., 7., 9., 0., 0., 0., 0., 0., 0., 0., 0.,
];
let a: Vec<_> = a.iter().map(|&x| Complex::new(x, 0.)).collect();
let f = iterative_fft(&a, Transform::Inverse);
let expected = vec![
Complex::new(1.812_5, 0.0),
Complex::new(-0.724_480_33, 1.186_006_4),
Complex::new(0.059_359_223, -0.824_524_16),
Complex::new(0.355_807_45, 0.731_437_9),
Complex::new(-0.375, -0.062_5),
Complex::new(-0.002_254_262_6, 0.347_554_56),
Complex::new(-0.559_359_1, -0.324_524_28),
Complex::new(0.370_926_68, -0.197_876_87),
Complex::new(-0.062_5, 0.0),
Complex::new(0.370_926_92, 0.197_877_05),
Complex::new(-0.559_359_2, 0.324_524_16),
Complex::new(-0.002_254_068_9, -0.347_554_53),
Complex::new(-0.375, 0.062_5),
Complex::new(0.355_807_66, -0.731_438_04),
Complex::new(0.059_359_044, 0.824_524_3),
Complex::new(-0.724_48, -1.186_006_5),
];
assert_eq!(expected, f);
}
}