polynomen 1.0.0

Polynomial library
Documentation
use std::ops::{Add, Div, Mul, Neg, Sub};

use crate::{complex::Complex, Abs, Const, Cos, Inv, NumCast, One, Sin, Zero};

/// Direct Fast Fourier Transform.
///
/// # Arguments
///
/// * `a` - vector
pub(super) fn fft<T>(a: &[Complex<T>]) -> Vec<Complex<T>>
where
    T: Abs
        + Add<Output = T>
        + Clone
        + Const
        + Cos
        + Div<Output = T>
        + Inv
        + Mul<Output = T>
        + Neg<Output = T>
        + NumCast
        + One
        + PartialOrd
        + Sin
        + Sub<Output = T>
        + Zero,
{
    iterative_fft(a, Transform::Direct)
}

/// Inverse Fast Fourier Transform.
///
/// # Arguments
///
/// * `y` - vector
pub(super) fn ifft<T>(y: &[Complex<T>]) -> Vec<Complex<T>>
where
    T: Abs
        + Add<Output = T>
        + Clone
        + Const
        + Cos
        + Div<Output = T>
        + Inv
        + Mul<Output = T>
        + Neg<Output = T>
        + NumCast
        + One
        + PartialOrd
        + Sin
        + Sub<Output = T>
        + Zero,
{
    iterative_fft(y, Transform::Inverse)
}

/// Type of Fourier transform.
#[derive(Clone, Copy)]
enum Transform {
    /// Direct fast Fourier transform.
    Direct,
    /// Inverse fast Fourier transform.
    Inverse,
}

/// Integer logarithm of a power of two.
///
/// # Arguments
///
/// * `n` - power of two
fn log2(n: usize) -> u32 {
    // core::mem::size_of::<usize>() * 8 - 1 - n.leading_zeros() as usize
    n.trailing_zeros()
}

/// Reorder the elements of the vector using a bit inversion permutation.
///
/// # Arguments
///
/// * `a` - vector with a power of two length
fn bit_reverse_vec<T>(a: Vec<T>) -> Vec<T> {
    let mut a = a;
    // The number of elements is a power of two.
    let length = a.len();
    let bits = log2(length);

    for k in 0..length {
        let r = rev(k, bits);
        if k < r {
            a.swap(k, r);
        }
    }
    a
}

/// Reverse the last `l` bits of `k`.
///
/// # Arguments
///
/// * `k` - number on which the permutation acts.
/// * `l` - number of lower bits to reverse.
fn rev(k: usize, l: u32) -> usize {
    k.reverse_bits().rotate_left(l)
}

/// Extend the vector to a length that is the next power of two.
///
/// # Arguments
///
/// * `a` - vector
fn extend_to_power_of_two<T: Clone + Zero>(a: &[T]) -> Vec<T> {
    let n = a.len();
    if n.is_power_of_two() {
        a.to_vec()
    } else {
        let pot = n.next_power_of_two();
        let mut b = a.to_vec();
        b.resize(pot, T::zero());
        b
    }
}

/// Iterative fast Fourier transform algorithm.
/// T. H. Cormen, C. E. Leiserson, R. L. Rivest, C. Stein, Introduction to Algorithms, 3rd edition, 2009
///
/// # Arguments
///
/// * `a` - input vector for the transform
/// * `dir` - transform "direction" (direct or inverse)
#[allow(clippy::many_single_char_names, non_snake_case)]
fn iterative_fft<T>(a: &[Complex<T>], dir: Transform) -> Vec<Complex<T>>
where
    T: Abs
        + Add<Output = T>
        + Clone
        + Const
        + Cos
        + Div<Output = T>
        + Inv
        + Mul<Output = T>
        + Neg<Output = T>
        + NumCast
        + One
        + PartialOrd
        + Sin
        + Sub<Output = T>
        + Zero,
{
    let a = extend_to_power_of_two(a);
    let n = a.len();
    debug_assert!(n.is_power_of_two());
    let mut A = bit_reverse_vec(a);

    let sign = match dir {
        Transform::Direct => -T::one(),
        Transform::Inverse => T::one(),
    };

    let tau = T::tau();
    for s in 1..=log2(n) {
        let m = 1 << s;
        let m_f = T::from(m).unwrap();
        let exp = sign.clone() * tau.clone() / m_f;
        let w_n = Complex::from_polar(T::one(), exp);
        for k in (0..n).step_by(m) {
            let mut w = Complex::one();
            for j in 0..m / 2 {
                let t = A[k + j + m / 2].clone() * w.clone();
                let u = A[k + j].clone();
                A[k + j] = u.clone() + t.clone();
                A[k + j + m / 2] = u - t;
                w = w * w_n.clone();
            }
        }
    }

    match dir {
        Transform::Direct => A,
        Transform::Inverse => {
            let n_f = T::from(n).unwrap();
            A.iter().map(|x| x / n_f.clone()).collect()
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn reverse_bit() {
        assert_eq!(0, rev(0, 3));
        assert_eq!(4, rev(1, 3));
        assert_eq!(2, rev(2, 3));
        assert_eq!(6, rev(3, 3));
        assert_eq!(1, rev(4, 3));
        assert_eq!(5, rev(5, 3));
        assert_eq!(3, rev(6, 3));
        assert_eq!(7, rev(7, 3));
    }

    #[test]
    fn reverse_copy() {
        let a = vec![0, 1, 2, 3, 4, 5, 6, 7];
        let b = bit_reverse_vec(a);
        let expected = vec![0, 4, 2, 6, 1, 5, 3, 7];
        assert_eq!(expected, b);
    }

    #[test]
    fn reverse_copy_long() {
        let a = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];
        let b = bit_reverse_vec(a);
        let expected = vec![0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15];
        assert_eq!(expected, b);
    }

    #[test]
    fn fft_iterative() {
        let zero = Complex::<f32>::zero();
        let one = Complex::<f32>::one();
        let two = Complex::<f32>::new(2., 0.);
        let a = vec![one, zero, one];
        // `a` is extended to four elements
        let f = iterative_fft(&a, Transform::Direct);
        let expected = vec![two, zero, two, zero];
        assert_eq!(expected, f);
    }

    #[test]
    fn fft_ifft() {
        let zero = Complex::<f64>::zero();
        let one = Complex::<f64>::one();
        let a = vec![one, zero, one, zero];
        let f = fft(&a);
        let a2 = ifft(&f);
        assert_eq!(a, a2);
    }

    #[test]
    fn fft_direct() {
        let a = [0.0_f32, 2., 3., -1., 4., 5., 7., 9.];
        let a: Vec<_> = a.iter().map(|&x| Complex::new(x, 0.)).collect();
        let f = iterative_fft(&a, Transform::Direct);
        let expected = vec![
            Complex::new(29., 0.),
            Complex::new(0.949_748_3, 13.192_388),
            Complex::new(-6., 0.999_999_94),
            Complex::new(-8.949_746, 5.192_387_6),
            Complex::new(-1., 0.),
            Complex::new(-8.949_748, -5.192_387_6),
            Complex::new(-6., -0.999_999_94),
            Complex::new(0.949_746_13, -13.192_388),
        ];
        assert_eq!(expected, f);
    }

    #[test]
    fn fft_inverse() {
        let a = [
            0.0_f32, 2., 3., -1., 4., 5., 7., 9., 0., 0., 0., 0., 0., 0., 0., 0.,
        ];
        let a: Vec<_> = a.iter().map(|&x| Complex::new(x, 0.)).collect();
        let f = iterative_fft(&a, Transform::Inverse);
        let expected = vec![
            Complex::new(1.812_5, 0.0),
            Complex::new(-0.724_480_33, 1.186_006_4),
            Complex::new(0.059_359_223, -0.824_524_16),
            Complex::new(0.355_807_45, 0.731_437_9),
            Complex::new(-0.375, -0.062_5),
            Complex::new(-0.002_254_262_6, 0.347_554_56),
            Complex::new(-0.559_359_1, -0.324_524_28),
            Complex::new(0.370_926_68, -0.197_876_87),
            Complex::new(-0.062_5, 0.0),
            Complex::new(0.370_926_92, 0.197_877_05),
            Complex::new(-0.559_359_2, 0.324_524_16),
            Complex::new(-0.002_254_068_9, -0.347_554_53),
            Complex::new(-0.375, 0.062_5),
            Complex::new(0.355_807_66, -0.731_438_04),
            Complex::new(0.059_359_044, 0.824_524_3),
            Complex::new(-0.724_48, -1.186_006_5),
        ];
        assert_eq!(expected, f);
    }
}