use super::num::{CommonField, Complex, PI};
use std::ops::{Add, Div, Mul, Neg, Sub};
struct BitRevIterator {
a: usize,
n: usize,
}
impl BitRevIterator {
fn new(n: usize) -> Self {
assert!(n.is_power_of_two());
Self { a: 2 * n - 1, n }
}
}
impl Iterator for BitRevIterator {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
if self.a == 2 * self.n - 2 {
return None;
}
let mut mask = self.n;
while self.a & mask > 0 {
self.a ^= mask;
mask /= 2;
}
self.a |= mask;
Some(self.a / 2)
}
}
#[allow(clippy::upper_case_acronyms)]
pub trait FFT: Sized + Copy {
type F: Sized
+ Copy
+ From<Self>
+ Neg
+ Add<Output = Self::F>
+ Div<Output = Self::F>
+ Mul<Output = Self::F>
+ Sub<Output = Self::F>;
const ZERO: Self;
fn get_roots(n: usize, inverse: bool) -> Vec<Self::F>;
fn get_factor(n: usize, inverse: bool) -> Self::F;
fn extract(f: Self::F) -> Self;
}
impl FFT for f64 {
type F = Complex;
const ZERO: f64 = 0.0;
fn get_roots(n: usize, inverse: bool) -> Vec<Self::F> {
let step = if inverse { -2.0 } else { 2.0 } * PI / n as f64;
(0..n / 2)
.map(|i| Complex::from_polar(1.0, step * i as f64))
.collect()
}
fn get_factor(n: usize, inverse: bool) -> Self::F {
Self::F::from(if inverse { (n as f64).recip() } else { 1.0 })
}
fn extract(f: Self::F) -> f64 {
f.real
}
}
impl FFT for i64 {
type F = CommonField;
const ZERO: Self = 0;
fn get_roots(n: usize, inverse: bool) -> Vec<Self::F> {
assert!(n <= 1 << 23);
let mut prim_root = Self::F::from(15_311_432);
if inverse {
prim_root = prim_root.recip();
}
for _ in (0..).take_while(|&i| n < 1 << (23 - i)) {
prim_root = prim_root * prim_root;
}
let mut roots = Vec::with_capacity(n / 2);
let mut root = Self::F::from(1);
for _ in 0..roots.capacity() {
roots.push(root);
root = root * prim_root;
}
roots
}
fn get_factor(n: usize, inverse: bool) -> Self::F {
Self::F::from(if inverse { n as Self } else { 1 }).recip()
}
fn extract(f: Self::F) -> Self {
f.val
}
}
pub fn fft<T: FFT>(v: &[T::F], inverse: bool) -> Vec<T::F> {
let n = v.len();
assert!(n.is_power_of_two());
let factor = T::get_factor(n, inverse);
let roots_of_unity = T::get_roots(n, inverse);
let mut dft = BitRevIterator::new(n)
.map(|i| v[i] * factor)
.collect::<Vec<_>>();
for m in (0..).map(|s| 1 << s).take_while(|&m| m < n) {
for k in (0..n).step_by(2 * m) {
for j in 0..m {
let u = dft[k + j];
let t = dft[k + j + m] * roots_of_unity[n / 2 / m * j];
dft[k + j] = u + t;
dft[k + j + m] = u - t;
}
}
}
dft
}
pub fn dft_from_reals<T: FFT>(v: &[T], desired_len: usize) -> Vec<T::F> {
assert!(v.len() <= desired_len);
let complex_v = v
.iter()
.cloned()
.chain(std::iter::repeat(T::ZERO))
.take(desired_len.next_power_of_two())
.map(T::F::from)
.collect::<Vec<_>>();
fft::<T>(&complex_v, false)
}
pub fn idft_to_reals<T: FFT>(dft_v: &[T::F], desired_len: usize) -> Vec<T> {
assert!(dft_v.len() >= desired_len);
let complex_v = fft::<T>(dft_v, true);
complex_v
.into_iter()
.take(desired_len)
.map(T::extract)
.collect()
}
pub fn convolution<T: FFT>(a: &[T], b: &[T]) -> Vec<T> {
let len_c = a.len() + b.len() - 1;
let dft_a = dft_from_reals(a, len_c).into_iter();
let dft_b = dft_from_reals(b, len_c).into_iter();
let dft_c = dft_a.zip(dft_b).map(|(a, b)| a * b).collect::<Vec<_>>();
idft_to_reals(&dft_c, len_c)
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_complex_dft() {
let v = vec![7.0, 1.0, 1.0];
let dft_v = dft_from_reals(&v, v.len());
let new_v: Vec<f64> = idft_to_reals(&dft_v, v.len());
let six = Complex::from(6.0);
let seven = Complex::from(7.0);
let nine = Complex::from(9.0);
let i = Complex::new(0.0, 1.0);
assert_eq!(dft_v, vec![nine, six + i, seven, six - i]);
assert_eq!(new_v, v);
}
#[test]
fn test_modular_dft() {
let v = vec![7, 1, 1];
let dft_v = dft_from_reals(&v, v.len());
let new_v: Vec<i64> = idft_to_reals(&dft_v, v.len());
let seven = CommonField::from(7);
let one = CommonField::from(1);
let prim = CommonField::from(15_311_432).pow(1 << 21);
let prim2 = prim * prim;
let eval0 = seven + one + one;
let eval1 = seven + prim + prim2;
let eval2 = seven + prim2 + one;
let eval3 = seven + prim.recip() + prim2;
assert_eq!(dft_v, vec![eval0, eval1, eval2, eval3]);
assert_eq!(new_v, v);
}
#[test]
fn test_complex_convolution() {
let x = vec![7.0, 1.0, 1.0];
let y = vec![2.0, 4.0];
let z = convolution(&x, &y);
let m = convolution(&vec![999.0], &vec![1e6]);
assert_eq!(z, vec![14.0, 30.0, 6.0, 4.0]);
assert_eq!(m, vec![999e6]);
}
#[test]
fn test_modular_convolution() {
let x = vec![7, 1, 1];
let y = vec![2, 4];
let z = convolution(&x, &y);
let m = convolution(&vec![999], &vec![1_000_000]);
assert_eq!(z, vec![14, 30, 6, 4]);
assert_eq!(m, vec![999_000_000 - super::super::num::COMMON_PRIME]);
}
}