contest_algorithms/math/
fft.rs

1//! The Fast Fourier Transform (FFT) and Number Theoretic Transform (NTT)
2use super::num::{CommonField, Complex, PI};
3use std::ops::{Add, Div, Mul, Neg, Sub};
4
5// We can delete this struct once f64::reverse_bits() stabilizes.
6struct BitRevIterator {
7    a: usize,
8    n: usize,
9}
10impl BitRevIterator {
11    fn new(n: usize) -> Self {
12        assert!(n.is_power_of_two());
13        Self { a: 2 * n - 1, n }
14    }
15}
16impl Iterator for BitRevIterator {
17    type Item = usize;
18
19    fn next(&mut self) -> Option<Self::Item> {
20        if self.a == 2 * self.n - 2 {
21            return None;
22        }
23        let mut mask = self.n;
24        while self.a & mask > 0 {
25            self.a ^= mask;
26            mask /= 2;
27        }
28        self.a |= mask;
29        Some(self.a / 2)
30    }
31}
32
33#[allow(clippy::upper_case_acronyms)]
34pub trait FFT: Sized + Copy {
35    type F: Sized
36        + Copy
37        + From<Self>
38        + Neg
39        + Add<Output = Self::F>
40        + Div<Output = Self::F>
41        + Mul<Output = Self::F>
42        + Sub<Output = Self::F>;
43
44    const ZERO: Self;
45
46    /// A primitive nth root of one raised to the powers 0, 1, 2, ..., n/2 - 1
47    fn get_roots(n: usize, inverse: bool) -> Vec<Self::F>;
48    /// 1 for forward transform, 1/n for inverse transform
49    fn get_factor(n: usize, inverse: bool) -> Self::F;
50    /// The inverse of Self::F::from()
51    fn extract(f: Self::F) -> Self;
52}
53
54impl FFT for f64 {
55    type F = Complex;
56
57    const ZERO: f64 = 0.0;
58
59    fn get_roots(n: usize, inverse: bool) -> Vec<Self::F> {
60        let step = if inverse { -2.0 } else { 2.0 } * PI / n as f64;
61        (0..n / 2)
62            .map(|i| Complex::from_polar(1.0, step * i as f64))
63            .collect()
64    }
65
66    fn get_factor(n: usize, inverse: bool) -> Self::F {
67        Self::F::from(if inverse { (n as f64).recip() } else { 1.0 })
68    }
69
70    fn extract(f: Self::F) -> f64 {
71        f.real
72    }
73}
74
75// NTT notes: see problem 30-6 in CLRS for details, keeping in mind that
76//      2187 and  410692747 are inverses and 2^26th roots of 1 mod (7<<26)+1
77//  15311432 and  469870224 are inverses and 2^23rd roots of 1 mod (119<<23)+1
78// 440564289 and 1713844692 are inverses and 2^27th roots of 1 mod (15<<27)+1
79//       125 and 2267742733 are inverses and 2^30th roots of 1 mod (3<<30)+1
80impl FFT for i64 {
81    type F = CommonField;
82
83    const ZERO: Self = 0;
84
85    fn get_roots(n: usize, inverse: bool) -> Vec<Self::F> {
86        assert!(n <= 1 << 23);
87        let mut prim_root = Self::F::from(15_311_432);
88        if inverse {
89            prim_root = prim_root.recip();
90        }
91        for _ in (0..).take_while(|&i| n < 1 << (23 - i)) {
92            prim_root = prim_root * prim_root;
93        }
94
95        let mut roots = Vec::with_capacity(n / 2);
96        let mut root = Self::F::from(1);
97        for _ in 0..roots.capacity() {
98            roots.push(root);
99            root = root * prim_root;
100        }
101        roots
102    }
103
104    fn get_factor(n: usize, inverse: bool) -> Self::F {
105        Self::F::from(if inverse { n as Self } else { 1 }).recip()
106    }
107
108    fn extract(f: Self::F) -> Self {
109        f.val
110    }
111}
112
113/// Computes the discrete fourier transform of v, whose length is a power of 2.
114/// Forward transform: polynomial coefficients -> evaluate at roots of unity
115/// Inverse transform: values at roots of unity -> interpolated coefficients
116pub fn fft<T: FFT>(v: &[T::F], inverse: bool) -> Vec<T::F> {
117    let n = v.len();
118    assert!(n.is_power_of_two());
119
120    let factor = T::get_factor(n, inverse);
121    let roots_of_unity = T::get_roots(n, inverse);
122    let mut dft = BitRevIterator::new(n)
123        .map(|i| v[i] * factor)
124        .collect::<Vec<_>>();
125
126    for m in (0..).map(|s| 1 << s).take_while(|&m| m < n) {
127        for k in (0..n).step_by(2 * m) {
128            for j in 0..m {
129                let u = dft[k + j];
130                let t = dft[k + j + m] * roots_of_unity[n / 2 / m * j];
131                dft[k + j] = u + t;
132                dft[k + j + m] = u - t;
133            }
134        }
135    }
136    dft
137}
138
139/// From a slice of reals (f64 or i64), computes DFT of size at least desired_len
140pub fn dft_from_reals<T: FFT>(v: &[T], desired_len: usize) -> Vec<T::F> {
141    assert!(v.len() <= desired_len);
142
143    let complex_v = v
144        .iter()
145        .cloned()
146        .chain(std::iter::repeat(T::ZERO))
147        .take(desired_len.next_power_of_two())
148        .map(T::F::from)
149        .collect::<Vec<_>>();
150    fft::<T>(&complex_v, false)
151}
152
153/// The inverse of dft_from_reals()
154pub fn idft_to_reals<T: FFT>(dft_v: &[T::F], desired_len: usize) -> Vec<T> {
155    assert!(dft_v.len() >= desired_len);
156
157    let complex_v = fft::<T>(dft_v, true);
158    complex_v
159        .into_iter()
160        .take(desired_len)
161        .map(T::extract)
162        .collect()
163}
164
165/// Given two polynomials (vectors) sum_i a[i] x^i and sum_i b[i] x^i,
166/// computes their product (convolution) c[k] = sum_(i+j=k) a[i]*b[j].
167/// Uses complex FFT if inputs are f64, or modular NTT if inputs are i64.
168pub fn convolution<T: FFT>(a: &[T], b: &[T]) -> Vec<T> {
169    let len_c = a.len() + b.len() - 1;
170    let dft_a = dft_from_reals(a, len_c).into_iter();
171    let dft_b = dft_from_reals(b, len_c).into_iter();
172    let dft_c = dft_a.zip(dft_b).map(|(a, b)| a * b).collect::<Vec<_>>();
173    idft_to_reals(&dft_c, len_c)
174}
175
176#[cfg(test)]
177mod test {
178    use super::*;
179
180    #[test]
181    fn test_complex_dft() {
182        let v = vec![7.0, 1.0, 1.0];
183        let dft_v = dft_from_reals(&v, v.len());
184        let new_v: Vec<f64> = idft_to_reals(&dft_v, v.len());
185
186        let six = Complex::from(6.0);
187        let seven = Complex::from(7.0);
188        let nine = Complex::from(9.0);
189        let i = Complex::new(0.0, 1.0);
190
191        assert_eq!(dft_v, vec![nine, six + i, seven, six - i]);
192        assert_eq!(new_v, v);
193    }
194
195    #[test]
196    fn test_modular_dft() {
197        let v = vec![7, 1, 1];
198        let dft_v = dft_from_reals(&v, v.len());
199        let new_v: Vec<i64> = idft_to_reals(&dft_v, v.len());
200
201        let seven = CommonField::from(7);
202        let one = CommonField::from(1);
203        let prim = CommonField::from(15_311_432).pow(1 << 21);
204        let prim2 = prim * prim;
205
206        let eval0 = seven + one + one;
207        let eval1 = seven + prim + prim2;
208        let eval2 = seven + prim2 + one;
209        let eval3 = seven + prim.recip() + prim2;
210
211        assert_eq!(dft_v, vec![eval0, eval1, eval2, eval3]);
212        assert_eq!(new_v, v);
213    }
214
215    #[test]
216    fn test_complex_convolution() {
217        let x = vec![7.0, 1.0, 1.0];
218        let y = vec![2.0, 4.0];
219        let z = convolution(&x, &y);
220        let m = convolution(&vec![999.0], &vec![1e6]);
221
222        assert_eq!(z, vec![14.0, 30.0, 6.0, 4.0]);
223        assert_eq!(m, vec![999e6]);
224    }
225
226    #[test]
227    fn test_modular_convolution() {
228        let x = vec![7, 1, 1];
229        let y = vec![2, 4];
230        let z = convolution(&x, &y);
231        let m = convolution(&vec![999], &vec![1_000_000]);
232
233        assert_eq!(z, vec![14, 30, 6, 4]);
234        assert_eq!(m, vec![999_000_000 - super::super::num::COMMON_PRIME]);
235    }
236}