use num_complex::Complex;
use num_traits::{Float, FloatConst, NumCast, One, Zero};
fn log2(n: usize) -> u32 {
n.trailing_zeros()
}
#[allow(non_snake_case)]
fn bit_reverse_vec<T>(a: Vec<T>) -> Vec<T> {
let mut a = a;
let length = a.len();
let half_length = length / 2;
let bits = log2(length);
for k in 0..half_length {
let r = rev(k, bits);
a.swap(k, r);
}
a
}
fn rev(k: usize, l: u32) -> usize {
k.reverse_bits().rotate_left(l)
}
pub(super) fn fft<T>(a: Vec<Complex<T>>) -> Vec<Complex<T>>
where
T: Float + FloatConst + NumCast,
{
iterative_fft(a, Transform::Direct)
}
pub(super) fn ifft<T>(y: Vec<Complex<T>>) -> Vec<Complex<T>>
where
T: Float + FloatConst + NumCast,
{
iterative_fft(y, Transform::Inverse)
}
fn extend_to_power_of_two<T: Clone + Zero>(mut a: Vec<T>) -> Vec<T> {
let n = a.len();
if n.is_power_of_two() {
a
} else {
let pot = n.next_power_of_two();
a.resize(pot, T::zero());
a
}
}
#[derive(Clone, Copy)]
enum Transform {
Direct,
Inverse,
}
#[allow(clippy::many_single_char_names, non_snake_case)]
fn iterative_fft<T>(a: Vec<Complex<T>>, dir: Transform) -> Vec<Complex<T>>
where
T: Float + FloatConst + NumCast,
{
let a = extend_to_power_of_two(a);
let n = a.len();
debug_assert!(n.is_power_of_two());
let mut A = bit_reverse_vec(a);
let sign = match dir {
Transform::Direct => T::one(),
Transform::Inverse => -T::one(),
};
let tau = T::TAU();
for s in 1..=log2(n) {
let m = 1 << s;
let m_f = T::from(m).unwrap();
let exp = sign * tau / m_f;
let w_n = Complex::from_polar(T::one(), exp);
for k in (0..n).step_by(m) {
let mut w = Complex::one();
for j in 0..m / 2 {
let t = A[k + j + m / 2] * w;
let u = A[k + j];
A[k + j] = u + t;
A[k + j + m / 2] = u - t;
w = w * w_n;
}
}
}
match dir {
Transform::Direct => A,
Transform::Inverse => {
let n_f = T::from(n).unwrap();
A.iter().map(|x| x / n_f).collect()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn reverse_bit() {
assert_eq!(0, rev(0, 3));
assert_eq!(4, rev(1, 3));
assert_eq!(2, rev(2, 3));
assert_eq!(6, rev(3, 3));
assert_eq!(1, rev(4, 3));
assert_eq!(5, rev(5, 3));
assert_eq!(3, rev(6, 3));
assert_eq!(7, rev(7, 3));
}
#[test]
fn reverse_copy() {
let a = vec![0, 1, 2, 3, 4, 5, 6, 7];
let b = bit_reverse_vec(a);
let expected = vec![0, 4, 2, 6, 1, 5, 3, 7];
assert_eq!(expected, b);
}
#[test]
fn fft_iterative() {
let one = Complex::one();
let a = vec![one * 1., one * 0., one * 1.];
let f = iterative_fft(a, Transform::Direct);
let expected = vec![one * 2., one * 0., one * 2., one * 0.];
assert_eq!(expected, f);
}
#[test]
fn fft_ifft() {
let one = Complex::one();
let a = vec![one * 1., one * 0., one * 1., one * 0.];
let f = fft(a.clone());
let a2 = ifft(f);
assert_eq!(a, a2);
}
}