competitive_programming_rs/math/
fast_fourier_transform.rs1pub struct FastFourierTransform {
2 modulo: i64,
3 sum_e: [i64; 30],
4 sum_ie: [i64; 30],
5}
6
7impl FastFourierTransform {
8 pub fn new(modulo: i64) -> Self {
9 let primitive_root = primitive_root(modulo);
10
11 let mut es = [0; 30];
12 let mut ies = [0; 30];
13 let count2 = (modulo - 1).trailing_zeros();
14 let mut e = mod_pow(primitive_root, (modulo - 1) >> count2, modulo);
15 let mut ie = mod_inv(e, modulo);
16 let count2 = count2 as usize;
17 for i in (2..=count2).rev() {
18 es[i - 2] = e;
19 ies[i - 2] = ie;
20 e = (e * e) % modulo;
21 ie = (ie * ie) % modulo;
22 }
23
24 let mut sum_e = [0; 30];
25 let mut now = 1;
26 for i in 0..=(count2 - 2) {
27 sum_e[i] = (es[i] * now) % modulo;
28 now = (now * ies[i]) % modulo;
29 }
30
31 let mut es = [0; 30];
32 let mut ies = [0; 30];
33 let count2 = (modulo - 1).trailing_zeros();
34 let mut e = mod_pow(primitive_root, (modulo - 1) >> count2, modulo);
35 let mut ie = mod_inv(e, modulo);
36 let count2 = count2 as usize;
37 for i in (2..=count2).rev() {
38 es[i - 2] = e;
39 ies[i - 2] = ie;
40 e = (e * e) % modulo;
41 ie = (ie * ie) % modulo;
42 }
43
44 let mut sum_ie = [0; 30];
45 let mut now = 1;
46 for i in 0..=(count2 - 2) {
47 sum_ie[i] = (ies[i] * now) % modulo;
48 now = (now * es[i]) % modulo;
49 }
50
51 Self {
52 sum_e,
53 modulo,
54 sum_ie,
55 }
56 }
57 fn butterfly(&self, a: &mut [i64]) {
58 let h = a.len().next_power_of_two().trailing_zeros();
59 for ph in 1..=h {
60 let w = 1 << (ph - 1);
61 let p = 1 << (h - ph);
62 let mut now = 1;
63 for s in 0..w {
64 let offset = s << (h - ph + 1);
65 for i in 0..p {
66 let l = a[i + offset];
67 let r = (a[i + offset + p] * now) % self.modulo;
68
69 a[i + offset] = l + r;
70 if a[i + offset] >= self.modulo {
71 a[i + offset] -= self.modulo;
72 }
73
74 a[i + offset + p] = l + self.modulo - r;
75 if a[i + offset + p] >= self.modulo {
76 a[i + offset + p] -= self.modulo;
77 }
78 }
79
80 now = (self.sum_e[(!s).trailing_zeros() as usize] * now) % self.modulo;
81 }
82 }
83 }
84
85 fn butterfly_inv(&self, a: &mut [i64]) {
86 let h = a.len().next_power_of_two().trailing_zeros();
87 for ph in (1..=h).rev() {
88 let w = 1 << (ph - 1);
89 let p = 1 << (h - ph);
90 let mut inv_now = 1;
91 for s in 0..w {
92 let offset = s << (h - ph + 1);
93 for i in 0..p {
94 let l = a[i + offset];
95 let r = a[i + offset + p];
96
97 a[i + offset] = l + r;
98 if a[i + offset] >= self.modulo {
99 a[i + offset] -= self.modulo;
100 }
101
102 a[i + offset + p] = ((l + self.modulo - r) * inv_now) % self.modulo;
103 }
104
105 inv_now = (self.sum_ie[(!s).trailing_zeros() as usize] * inv_now) % self.modulo;
106 }
107 }
108 }
109
110 pub fn convolution(&self, a: &[i64], b: &[i64]) -> Vec<i64> {
111 if a.is_empty() || b.is_empty() {
112 return Vec::new();
113 }
114
115 let n = a.len();
116 let m = b.len();
117
118 let z = (n + m - 1).next_power_of_two();
119 let mut a = a.iter().map(|&v| v % self.modulo).collect::<Vec<_>>();
120 a.resize(z, 0);
121 self.butterfly(&mut a);
122
123 let mut b = b.iter().map(|&v| v % self.modulo).collect::<Vec<_>>();
124 b.resize(z, 0);
125 self.butterfly(&mut b);
126
127 for i in 0..z {
128 a[i] = (a[i] * b[i]) % self.modulo;
129 }
130
131 self.butterfly_inv(&mut a);
132 a.resize(n + m - 1, 0);
133 let iz = mod_inv(z as i64, self.modulo);
134 for i in 0..a.len() {
135 a[i] = (a[i] * iz) % self.modulo;
136 }
137 a
138 }
139}
140
141fn mod_inv(x: i64, m: i64) -> i64 {
142 mod_pow(x, m - 2, m)
143}
144
145fn mod_pow(x: i64, mut e: i64, m: i64) -> i64 {
146 let mut cur = x;
147 let mut result = 1;
148 while e > 0 {
149 if e & 1 == 1 {
150 result = (result * cur) % m;
151 }
152 e >>= 1;
153 cur = (cur * cur) % m;
154 }
155 result
156}
157
158fn primitive_root(m: i64) -> i64 {
159 if m == 2 {
160 return 1;
161 };
162 if m == 167772161 {
163 return 3;
164 };
165 if m == 469762049 {
166 return 3;
167 };
168 if m == 754974721 {
169 return 11;
170 };
171 if m == 998244353 {
172 return 3;
173 };
174 let mut divs = [0; 20];
175 divs[0] = 2;
176 let mut cnt = 1;
177 let mut x = (m - 1) / 2;
178 while x % 2 == 0 {
179 x /= 2
180 }
181
182 let mut i = 3;
183 while i * i <= x {
184 if x % i == 0 {
185 divs[cnt] = i;
186 cnt += 1;
187 while x % i == 0 {
188 x /= i;
189 }
190 }
191 i += 2;
192 }
193 if x > 1 {
194 divs[cnt] = x;
195 cnt += 1;
196 }
197
198 for g in 2.. {
199 let mut ok = true;
200 for i in 0..cnt {
201 if mod_pow(g, (m - 1) / divs[i], m) == 1 {
202 ok = false;
203 break;
204 }
205 }
206 if ok {
207 return g;
208 }
209 }
210 unreachable!()
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216 use rand::{thread_rng, Rng};
217
218 #[test]
219 fn test_fft() {
220 let a = vec![1, 2, 3, 4];
221 let b = vec![5, 6, 7, 8, 9];
222 let m = 998244353;
223 let fft = FastFourierTransform::new(m);
224 let c = fft.convolution(&a, &b);
225 assert_eq!(vec![5, 16, 34, 60, 70, 70, 59, 36], c);
226 }
227
228 #[test]
229 fn test_fft_rand() {
230 let mut rng = thread_rng();
231 let modulo = 998244353;
232 let fft = FastFourierTransform::new(modulo);
233
234 for _ in 0..10 {
235 let n: usize = 5000 + rng.gen_range(0, 5000);
236 let m: usize = 5000 + rng.gen_range(0, 5000);
237 let a = (0..n)
238 .map(|_| rng.gen_range(0, modulo * 2))
239 .collect::<Vec<_>>();
240 let b = (0..m)
241 .map(|_| rng.gen_range(0, modulo * 2))
242 .collect::<Vec<_>>();
243 let c = fft.convolution(&a, &b);
244
245 let mut check = vec![0; n + m - 1];
246 for i in 0..n {
247 for j in 0..m {
248 check[i + j] += a[i] * b[j];
249 check[i + j] %= modulo;
250 }
251 }
252
253 assert_eq!(check, c);
254 }
255 }
256
257 #[test]
258 fn test_primitive_root() {
259 assert!(is_primitive_root(2, primitive_root(2)));
260 assert!(is_primitive_root(3, primitive_root(3)));
261 assert!(is_primitive_root(5, primitive_root(5)));
262 assert!(is_primitive_root(7, primitive_root(7)));
263 assert!(is_primitive_root(11, primitive_root(11)));
264 assert!(is_primitive_root(998244353, primitive_root(998244353)));
265 assert!(is_primitive_root(1000000007, primitive_root(1000000007)));
266 assert!(is_primitive_root(469762049, primitive_root(469762049)));
267 assert!(is_primitive_root(167772161, primitive_root(167772161)));
268 assert!(is_primitive_root(754974721, primitive_root(754974721)));
269 assert!(is_primitive_root(324013369, primitive_root(324013369)));
270 assert!(is_primitive_root(831143041, primitive_root(831143041)));
271 assert!(is_primitive_root(1685283601, primitive_root(1685283601)));
272 }
273
274 fn is_primitive_root(m: i64, g: i64) -> bool {
275 let mut factors = vec![];
276 let mut cur = 2;
277 let mut t = m - 1;
278 while cur * cur <= t {
279 if t % cur == 0 {
280 factors.push(cur);
281 }
282 while t % cur == 0 {
283 t /= cur;
284 }
285 cur += 1;
286 }
287 if t > 1 {
288 factors.push(t);
289 }
290
291 for factor in factors {
292 if mod_pow(g, (m - 1) / factor, m) == 1 {
293 return false;
294 }
295 }
296 true
297 }
298}