competitive_programming_rs/math/
fast_fourier_transform.rs

1pub struct FastFourierTransform {
2    modulo: i64,
3    sum_e: [i64; 30],
4    sum_ie: [i64; 30],
5}
6
7impl FastFourierTransform {
8    pub fn new(modulo: i64) -> Self {
9        let primitive_root = primitive_root(modulo);
10
11        let mut es = [0; 30];
12        let mut ies = [0; 30];
13        let count2 = (modulo - 1).trailing_zeros();
14        let mut e = mod_pow(primitive_root, (modulo - 1) >> count2, modulo);
15        let mut ie = mod_inv(e, modulo);
16        let count2 = count2 as usize;
17        for i in (2..=count2).rev() {
18            es[i - 2] = e;
19            ies[i - 2] = ie;
20            e = (e * e) % modulo;
21            ie = (ie * ie) % modulo;
22        }
23
24        let mut sum_e = [0; 30];
25        let mut now = 1;
26        for i in 0..=(count2 - 2) {
27            sum_e[i] = (es[i] * now) % modulo;
28            now = (now * ies[i]) % modulo;
29        }
30
31        let mut es = [0; 30];
32        let mut ies = [0; 30];
33        let count2 = (modulo - 1).trailing_zeros();
34        let mut e = mod_pow(primitive_root, (modulo - 1) >> count2, modulo);
35        let mut ie = mod_inv(e, modulo);
36        let count2 = count2 as usize;
37        for i in (2..=count2).rev() {
38            es[i - 2] = e;
39            ies[i - 2] = ie;
40            e = (e * e) % modulo;
41            ie = (ie * ie) % modulo;
42        }
43
44        let mut sum_ie = [0; 30];
45        let mut now = 1;
46        for i in 0..=(count2 - 2) {
47            sum_ie[i] = (ies[i] * now) % modulo;
48            now = (now * es[i]) % modulo;
49        }
50
51        Self {
52            sum_e,
53            modulo,
54            sum_ie,
55        }
56    }
57    fn butterfly(&self, a: &mut [i64]) {
58        let h = a.len().next_power_of_two().trailing_zeros();
59        for ph in 1..=h {
60            let w = 1 << (ph - 1);
61            let p = 1 << (h - ph);
62            let mut now = 1;
63            for s in 0..w {
64                let offset = s << (h - ph + 1);
65                for i in 0..p {
66                    let l = a[i + offset];
67                    let r = (a[i + offset + p] * now) % self.modulo;
68
69                    a[i + offset] = l + r;
70                    if a[i + offset] >= self.modulo {
71                        a[i + offset] -= self.modulo;
72                    }
73
74                    a[i + offset + p] = l + self.modulo - r;
75                    if a[i + offset + p] >= self.modulo {
76                        a[i + offset + p] -= self.modulo;
77                    }
78                }
79
80                now = (self.sum_e[(!s).trailing_zeros() as usize] * now) % self.modulo;
81            }
82        }
83    }
84
85    fn butterfly_inv(&self, a: &mut [i64]) {
86        let h = a.len().next_power_of_two().trailing_zeros();
87        for ph in (1..=h).rev() {
88            let w = 1 << (ph - 1);
89            let p = 1 << (h - ph);
90            let mut inv_now = 1;
91            for s in 0..w {
92                let offset = s << (h - ph + 1);
93                for i in 0..p {
94                    let l = a[i + offset];
95                    let r = a[i + offset + p];
96
97                    a[i + offset] = l + r;
98                    if a[i + offset] >= self.modulo {
99                        a[i + offset] -= self.modulo;
100                    }
101
102                    a[i + offset + p] = ((l + self.modulo - r) * inv_now) % self.modulo;
103                }
104
105                inv_now = (self.sum_ie[(!s).trailing_zeros() as usize] * inv_now) % self.modulo;
106            }
107        }
108    }
109
110    pub fn convolution(&self, a: &[i64], b: &[i64]) -> Vec<i64> {
111        if a.is_empty() || b.is_empty() {
112            return Vec::new();
113        }
114
115        let n = a.len();
116        let m = b.len();
117
118        let z = (n + m - 1).next_power_of_two();
119        let mut a = a.iter().map(|&v| v % self.modulo).collect::<Vec<_>>();
120        a.resize(z, 0);
121        self.butterfly(&mut a);
122
123        let mut b = b.iter().map(|&v| v % self.modulo).collect::<Vec<_>>();
124        b.resize(z, 0);
125        self.butterfly(&mut b);
126
127        for i in 0..z {
128            a[i] = (a[i] * b[i]) % self.modulo;
129        }
130
131        self.butterfly_inv(&mut a);
132        a.resize(n + m - 1, 0);
133        let iz = mod_inv(z as i64, self.modulo);
134        for i in 0..a.len() {
135            a[i] = (a[i] * iz) % self.modulo;
136        }
137        a
138    }
139}
140
141fn mod_inv(x: i64, m: i64) -> i64 {
142    mod_pow(x, m - 2, m)
143}
144
145fn mod_pow(x: i64, mut e: i64, m: i64) -> i64 {
146    let mut cur = x;
147    let mut result = 1;
148    while e > 0 {
149        if e & 1 == 1 {
150            result = (result * cur) % m;
151        }
152        e >>= 1;
153        cur = (cur * cur) % m;
154    }
155    result
156}
157
158fn primitive_root(m: i64) -> i64 {
159    if m == 2 {
160        return 1;
161    };
162    if m == 167772161 {
163        return 3;
164    };
165    if m == 469762049 {
166        return 3;
167    };
168    if m == 754974721 {
169        return 11;
170    };
171    if m == 998244353 {
172        return 3;
173    };
174    let mut divs = [0; 20];
175    divs[0] = 2;
176    let mut cnt = 1;
177    let mut x = (m - 1) / 2;
178    while x % 2 == 0 {
179        x /= 2
180    }
181
182    let mut i = 3;
183    while i * i <= x {
184        if x % i == 0 {
185            divs[cnt] = i;
186            cnt += 1;
187            while x % i == 0 {
188                x /= i;
189            }
190        }
191        i += 2;
192    }
193    if x > 1 {
194        divs[cnt] = x;
195        cnt += 1;
196    }
197
198    for g in 2.. {
199        let mut ok = true;
200        for i in 0..cnt {
201            if mod_pow(g, (m - 1) / divs[i], m) == 1 {
202                ok = false;
203                break;
204            }
205        }
206        if ok {
207            return g;
208        }
209    }
210    unreachable!()
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216    use rand::{thread_rng, Rng};
217
218    #[test]
219    fn test_fft() {
220        let a = vec![1, 2, 3, 4];
221        let b = vec![5, 6, 7, 8, 9];
222        let m = 998244353;
223        let fft = FastFourierTransform::new(m);
224        let c = fft.convolution(&a, &b);
225        assert_eq!(vec![5, 16, 34, 60, 70, 70, 59, 36], c);
226    }
227
228    #[test]
229    fn test_fft_rand() {
230        let mut rng = thread_rng();
231        let modulo = 998244353;
232        let fft = FastFourierTransform::new(modulo);
233
234        for _ in 0..10 {
235            let n: usize = 5000 + rng.gen_range(0, 5000);
236            let m: usize = 5000 + rng.gen_range(0, 5000);
237            let a = (0..n)
238                .map(|_| rng.gen_range(0, modulo * 2))
239                .collect::<Vec<_>>();
240            let b = (0..m)
241                .map(|_| rng.gen_range(0, modulo * 2))
242                .collect::<Vec<_>>();
243            let c = fft.convolution(&a, &b);
244
245            let mut check = vec![0; n + m - 1];
246            for i in 0..n {
247                for j in 0..m {
248                    check[i + j] += a[i] * b[j];
249                    check[i + j] %= modulo;
250                }
251            }
252
253            assert_eq!(check, c);
254        }
255    }
256
257    #[test]
258    fn test_primitive_root() {
259        assert!(is_primitive_root(2, primitive_root(2)));
260        assert!(is_primitive_root(3, primitive_root(3)));
261        assert!(is_primitive_root(5, primitive_root(5)));
262        assert!(is_primitive_root(7, primitive_root(7)));
263        assert!(is_primitive_root(11, primitive_root(11)));
264        assert!(is_primitive_root(998244353, primitive_root(998244353)));
265        assert!(is_primitive_root(1000000007, primitive_root(1000000007)));
266        assert!(is_primitive_root(469762049, primitive_root(469762049)));
267        assert!(is_primitive_root(167772161, primitive_root(167772161)));
268        assert!(is_primitive_root(754974721, primitive_root(754974721)));
269        assert!(is_primitive_root(324013369, primitive_root(324013369)));
270        assert!(is_primitive_root(831143041, primitive_root(831143041)));
271        assert!(is_primitive_root(1685283601, primitive_root(1685283601)));
272    }
273
274    fn is_primitive_root(m: i64, g: i64) -> bool {
275        let mut factors = vec![];
276        let mut cur = 2;
277        let mut t = m - 1;
278        while cur * cur <= t {
279            if t % cur == 0 {
280                factors.push(cur);
281            }
282            while t % cur == 0 {
283                t /= cur;
284            }
285            cur += 1;
286        }
287        if t > 1 {
288            factors.push(t);
289        }
290
291        for factor in factors {
292            if mod_pow(g, (m - 1) / factor, m) == 1 {
293                return false;
294            }
295        }
296        true
297    }
298}