contest_algorithms/math/
fft.rs1use super::num::{CommonField, Complex, PI};
3use std::ops::{Add, Div, Mul, Neg, Sub};
4
5struct BitRevIterator {
7 a: usize,
8 n: usize,
9}
10impl BitRevIterator {
11 fn new(n: usize) -> Self {
12 assert!(n.is_power_of_two());
13 Self { a: 2 * n - 1, n }
14 }
15}
16impl Iterator for BitRevIterator {
17 type Item = usize;
18
19 fn next(&mut self) -> Option<Self::Item> {
20 if self.a == 2 * self.n - 2 {
21 return None;
22 }
23 let mut mask = self.n;
24 while self.a & mask > 0 {
25 self.a ^= mask;
26 mask /= 2;
27 }
28 self.a |= mask;
29 Some(self.a / 2)
30 }
31}
32
33#[allow(clippy::upper_case_acronyms)]
34pub trait FFT: Sized + Copy {
35 type F: Sized
36 + Copy
37 + From<Self>
38 + Neg
39 + Add<Output = Self::F>
40 + Div<Output = Self::F>
41 + Mul<Output = Self::F>
42 + Sub<Output = Self::F>;
43
44 const ZERO: Self;
45
46 fn get_roots(n: usize, inverse: bool) -> Vec<Self::F>;
48 fn get_factor(n: usize, inverse: bool) -> Self::F;
50 fn extract(f: Self::F) -> Self;
52}
53
54impl FFT for f64 {
55 type F = Complex;
56
57 const ZERO: f64 = 0.0;
58
59 fn get_roots(n: usize, inverse: bool) -> Vec<Self::F> {
60 let step = if inverse { -2.0 } else { 2.0 } * PI / n as f64;
61 (0..n / 2)
62 .map(|i| Complex::from_polar(1.0, step * i as f64))
63 .collect()
64 }
65
66 fn get_factor(n: usize, inverse: bool) -> Self::F {
67 Self::F::from(if inverse { (n as f64).recip() } else { 1.0 })
68 }
69
70 fn extract(f: Self::F) -> f64 {
71 f.real
72 }
73}
74
75impl FFT for i64 {
81 type F = CommonField;
82
83 const ZERO: Self = 0;
84
85 fn get_roots(n: usize, inverse: bool) -> Vec<Self::F> {
86 assert!(n <= 1 << 23);
87 let mut prim_root = Self::F::from(15_311_432);
88 if inverse {
89 prim_root = prim_root.recip();
90 }
91 for _ in (0..).take_while(|&i| n < 1 << (23 - i)) {
92 prim_root = prim_root * prim_root;
93 }
94
95 let mut roots = Vec::with_capacity(n / 2);
96 let mut root = Self::F::from(1);
97 for _ in 0..roots.capacity() {
98 roots.push(root);
99 root = root * prim_root;
100 }
101 roots
102 }
103
104 fn get_factor(n: usize, inverse: bool) -> Self::F {
105 Self::F::from(if inverse { n as Self } else { 1 }).recip()
106 }
107
108 fn extract(f: Self::F) -> Self {
109 f.val
110 }
111}
112
113pub fn fft<T: FFT>(v: &[T::F], inverse: bool) -> Vec<T::F> {
117 let n = v.len();
118 assert!(n.is_power_of_two());
119
120 let factor = T::get_factor(n, inverse);
121 let roots_of_unity = T::get_roots(n, inverse);
122 let mut dft = BitRevIterator::new(n)
123 .map(|i| v[i] * factor)
124 .collect::<Vec<_>>();
125
126 for m in (0..).map(|s| 1 << s).take_while(|&m| m < n) {
127 for k in (0..n).step_by(2 * m) {
128 for j in 0..m {
129 let u = dft[k + j];
130 let t = dft[k + j + m] * roots_of_unity[n / 2 / m * j];
131 dft[k + j] = u + t;
132 dft[k + j + m] = u - t;
133 }
134 }
135 }
136 dft
137}
138
139pub fn dft_from_reals<T: FFT>(v: &[T], desired_len: usize) -> Vec<T::F> {
141 assert!(v.len() <= desired_len);
142
143 let complex_v = v
144 .iter()
145 .cloned()
146 .chain(std::iter::repeat(T::ZERO))
147 .take(desired_len.next_power_of_two())
148 .map(T::F::from)
149 .collect::<Vec<_>>();
150 fft::<T>(&complex_v, false)
151}
152
153pub fn idft_to_reals<T: FFT>(dft_v: &[T::F], desired_len: usize) -> Vec<T> {
155 assert!(dft_v.len() >= desired_len);
156
157 let complex_v = fft::<T>(dft_v, true);
158 complex_v
159 .into_iter()
160 .take(desired_len)
161 .map(T::extract)
162 .collect()
163}
164
165pub fn convolution<T: FFT>(a: &[T], b: &[T]) -> Vec<T> {
169 let len_c = a.len() + b.len() - 1;
170 let dft_a = dft_from_reals(a, len_c).into_iter();
171 let dft_b = dft_from_reals(b, len_c).into_iter();
172 let dft_c = dft_a.zip(dft_b).map(|(a, b)| a * b).collect::<Vec<_>>();
173 idft_to_reals(&dft_c, len_c)
174}
175
176#[cfg(test)]
177mod test {
178 use super::*;
179
180 #[test]
181 fn test_complex_dft() {
182 let v = vec![7.0, 1.0, 1.0];
183 let dft_v = dft_from_reals(&v, v.len());
184 let new_v: Vec<f64> = idft_to_reals(&dft_v, v.len());
185
186 let six = Complex::from(6.0);
187 let seven = Complex::from(7.0);
188 let nine = Complex::from(9.0);
189 let i = Complex::new(0.0, 1.0);
190
191 assert_eq!(dft_v, vec![nine, six + i, seven, six - i]);
192 assert_eq!(new_v, v);
193 }
194
195 #[test]
196 fn test_modular_dft() {
197 let v = vec![7, 1, 1];
198 let dft_v = dft_from_reals(&v, v.len());
199 let new_v: Vec<i64> = idft_to_reals(&dft_v, v.len());
200
201 let seven = CommonField::from(7);
202 let one = CommonField::from(1);
203 let prim = CommonField::from(15_311_432).pow(1 << 21);
204 let prim2 = prim * prim;
205
206 let eval0 = seven + one + one;
207 let eval1 = seven + prim + prim2;
208 let eval2 = seven + prim2 + one;
209 let eval3 = seven + prim.recip() + prim2;
210
211 assert_eq!(dft_v, vec![eval0, eval1, eval2, eval3]);
212 assert_eq!(new_v, v);
213 }
214
215 #[test]
216 fn test_complex_convolution() {
217 let x = vec![7.0, 1.0, 1.0];
218 let y = vec![2.0, 4.0];
219 let z = convolution(&x, &y);
220 let m = convolution(&vec![999.0], &vec![1e6]);
221
222 assert_eq!(z, vec![14.0, 30.0, 6.0, 4.0]);
223 assert_eq!(m, vec![999e6]);
224 }
225
226 #[test]
227 fn test_modular_convolution() {
228 let x = vec![7, 1, 1];
229 let y = vec![2, 4];
230 let z = convolution(&x, &y);
231 let m = convolution(&vec![999], &vec![1_000_000]);
232
233 assert_eq!(z, vec![14, 30, 6, 4]);
234 assert_eq!(m, vec![999_000_000 - super::super::num::COMMON_PRIME]);
235 }
236}