1macro_rules! modulus {
2 ($($name:ident),*) => {
3 $(
4 #[derive(Copy, Clone, Eq, PartialEq)]
5 enum $name {}
6
7 impl Modulus for $name {
8 const VALUE: u32 = $name as _;
9 const HINT_VALUE_IS_PRIME: bool = true;
10
11 fn butterfly_cache() -> &'static ::std::thread::LocalKey<::std::cell::RefCell<::std::option::Option<$crate::modint::ButterflyCache<Self>>>> {
12 thread_local! {
13 static BUTTERFLY_CACHE: ::std::cell::RefCell<::std::option::Option<$crate::modint::ButterflyCache<$name>>> = ::std::default::Default::default();
14 }
15 &BUTTERFLY_CACHE
16 }
17 }
18 )*
19 };
20}
21
22use crate::{
23 internal_bit, internal_math,
24 modint::{ButterflyCache, Modulus, RemEuclidU32, StaticModInt},
25};
26use std::{
27 cmp,
28 convert::{TryFrom, TryInto as _},
29 fmt,
30};
31
32#[allow(clippy::many_single_char_names)]
33pub fn convolution<M>(a: &[StaticModInt<M>], b: &[StaticModInt<M>]) -> Vec<StaticModInt<M>>
34where
35 M: Modulus,
36{
37 if a.is_empty() || b.is_empty() {
38 return vec![];
39 }
40 let (n, m) = (a.len(), b.len());
41
42 if cmp::min(n, m) <= 60 {
43 let (n, m, a, b) = if n < m { (m, n, b, a) } else { (n, m, a, b) };
44 let mut ans = vec![StaticModInt::new(0); n + m - 1];
45 for i in 0..n {
46 for j in 0..m {
47 ans[i + j] += a[i] * b[j];
48 }
49 }
50 return ans;
51 }
52
53 let (mut a, mut b) = (a.to_owned(), b.to_owned());
54 let z = 1 << internal_bit::ceil_pow2((n + m - 1) as _);
55 a.resize(z, StaticModInt::raw(0));
56 butterfly(&mut a);
57 b.resize(z, StaticModInt::raw(0));
58 butterfly(&mut b);
59 for (a, b) in a.iter_mut().zip(&b) {
60 *a *= b;
61 }
62 butterfly_inv(&mut a);
63 a.resize(n + m - 1, StaticModInt::raw(0));
64 let iz = StaticModInt::new(z).inv();
65 for a in &mut a {
66 *a *= iz;
67 }
68 a
69}
70
71pub fn convolution_raw<T, M>(a: &[T], b: &[T]) -> Vec<T>
72where
73 T: RemEuclidU32 + TryFrom<u32> + Clone,
74 T::Error: fmt::Debug,
75 M: Modulus,
76{
77 let a = a.iter().cloned().map(Into::into).collect::<Vec<_>>();
78 let b = b.iter().cloned().map(Into::into).collect::<Vec<_>>();
79 convolution::<M>(&a, &b)
80 .into_iter()
81 .map(|z| {
82 z.val()
83 .try_into()
84 .expect("the numeric type is smaller than the modulus")
85 })
86 .collect()
87}
88
89#[allow(clippy::many_single_char_names)]
90pub fn convolution_i64(a: &[i64], b: &[i64]) -> Vec<i64> {
91 const M1: u64 = 754_974_721; const M2: u64 = 167_772_161; const M3: u64 = 469_762_049; const M2M3: u64 = M2 * M3;
95 const M1M3: u64 = M1 * M3;
96 const M1M2: u64 = M1 * M2;
97 const M1M2M3: u64 = M1M2.wrapping_mul(M3);
98
99 modulus!(M1, M2, M3);
100
101 if a.is_empty() || b.is_empty() {
102 return vec![];
103 }
104
105 let (_, i1) = internal_math::inv_gcd(M2M3 as _, M1 as _);
106 let (_, i2) = internal_math::inv_gcd(M1M3 as _, M2 as _);
107 let (_, i3) = internal_math::inv_gcd(M1M2 as _, M3 as _);
108
109 let c1 = convolution_raw::<i64, M1>(a, b);
110 let c2 = convolution_raw::<i64, M2>(a, b);
111 let c3 = convolution_raw::<i64, M3>(a, b);
112
113 c1.into_iter()
114 .zip(c2)
115 .zip(c3)
116 .map(|((c1, c2), c3)| {
117 const OFFSET: &[u64] = &[0, 0, M1M2M3, 2 * M1M2M3, 3 * M1M2M3];
118
119 let mut x = [(c1, i1, M1, M2M3), (c2, i2, M2, M1M3), (c3, i3, M3, M1M2)]
120 .iter()
121 .map(|&(c, i, m1, m2)| c.wrapping_mul(i).rem_euclid(m1 as _).wrapping_mul(m2 as _))
122 .fold(0, i64::wrapping_add);
123
124 let mut diff = c1 - internal_math::safe_mod(x, M1 as _);
142 if diff < 0 {
143 diff += M1 as i64;
144 }
145 x = x.wrapping_sub(OFFSET[diff.rem_euclid(5) as usize] as _);
146 x
147 })
148 .collect()
149}
150
151#[allow(clippy::many_single_char_names)]
152fn butterfly<M: Modulus>(a: &mut [StaticModInt<M>]) {
153 let n = a.len();
154 let h = internal_bit::ceil_pow2(n as u32);
155
156 M::butterfly_cache().with(|cache| {
157 let mut cache = cache.borrow_mut();
158 let ButterflyCache { sum_e, .. } = cache.get_or_insert_with(prepare);
159 for ph in 1..=h {
160 let w = 1 << (ph - 1);
161 let p = 1 << (h - ph);
162 let mut now = StaticModInt::<M>::new(1);
163 for s in 0..w {
164 let offset = s << (h - ph + 1);
165 for i in 0..p {
166 let l = a[i + offset];
167 let r = a[i + offset + p] * now;
168 a[i + offset] = l + r;
169 a[i + offset + p] = l - r;
170 }
171 now *= sum_e[(!s).trailing_zeros() as usize];
172 }
173 }
174 });
175}
176
177#[allow(clippy::many_single_char_names)]
178fn butterfly_inv<M: Modulus>(a: &mut [StaticModInt<M>]) {
179 let n = a.len();
180 let h = internal_bit::ceil_pow2(n as u32);
181
182 M::butterfly_cache().with(|cache| {
183 let mut cache = cache.borrow_mut();
184 let ButterflyCache { sum_ie, .. } = cache.get_or_insert_with(prepare);
185 for ph in (1..=h).rev() {
186 let w = 1 << (ph - 1);
187 let p = 1 << (h - ph);
188 let mut inow = StaticModInt::<M>::new(1);
189 for s in 0..w {
190 let offset = s << (h - ph + 1);
191 for i in 0..p {
192 let l = a[i + offset];
193 let r = a[i + offset + p];
194 a[i + offset] = l + r;
195 a[i + offset + p] = StaticModInt::new(M::VALUE + l.val() - r.val()) * inow;
196 }
197 inow *= sum_ie[(!s).trailing_zeros() as usize];
198 }
199 }
200 });
201}
202
203fn prepare<M: Modulus>() -> ButterflyCache<M> {
204 let g = StaticModInt::<M>::raw(internal_math::primitive_root(M::VALUE as i32) as u32);
205 let mut es = [StaticModInt::<M>::raw(0); 30]; let mut ies = [StaticModInt::<M>::raw(0); 30];
207 let cnt2 = (M::VALUE - 1).trailing_zeros() as usize;
208 let mut e = g.pow(((M::VALUE - 1) >> cnt2).into());
209 let mut ie = e.inv();
210 for i in (2..=cnt2).rev() {
211 es[i - 2] = e;
212 ies[i - 2] = ie;
213 e *= e;
214 ie *= ie;
215 }
216 let sum_e = es
217 .iter()
218 .scan(StaticModInt::new(1), |acc, e| {
219 *acc *= e;
220 Some(*acc)
221 })
222 .collect();
223 let sum_ie = ies
224 .iter()
225 .scan(StaticModInt::new(1), |acc, ie| {
226 *acc *= ie;
227 Some(*acc)
228 })
229 .collect();
230 ButterflyCache { sum_e, sum_ie }
231}
232
233#[cfg(test)]
234mod tests {
235 use crate::{
236 modint::{Mod998244353, Modulus, StaticModInt},
237 RemEuclidU32,
238 };
239 use rand::{rngs::ThreadRng, Rng as _};
240 use std::{
241 convert::{TryFrom, TryInto as _},
242 fmt,
243 };
244
245 #[test]
247 fn empty() {
248 assert!(super::convolution_raw::<i32, Mod998244353>(&[], &[]).is_empty());
249 assert!(super::convolution_raw::<i32, Mod998244353>(&[], &[1, 2]).is_empty());
250 assert!(super::convolution_raw::<i32, Mod998244353>(&[1, 2], &[]).is_empty());
251 assert!(super::convolution_raw::<i32, Mod998244353>(&[1], &[]).is_empty());
252 assert!(super::convolution_raw::<i64, Mod998244353>(&[], &[]).is_empty());
253 assert!(super::convolution_raw::<i64, Mod998244353>(&[], &[1, 2]).is_empty());
254 assert!(super::convolution::<Mod998244353>(&[], &[]).is_empty());
255 assert!(super::convolution::<Mod998244353>(&[], &[1.into(), 2.into()]).is_empty());
256 }
257
258 #[test]
260 fn mid() {
261 const N: usize = 1234;
262 const M: usize = 2345;
263
264 let mut rng = rand::thread_rng();
265 let mut gen_values = |n| gen_values::<Mod998244353>(&mut rng, n);
266 let (a, b) = (gen_values(N), gen_values(M));
267 assert_eq!(conv_naive(&a, &b), super::convolution(&a, &b));
268 }
269
270 #[test]
272 fn simple_s_mod() {
273 const M1: u32 = 998_244_353;
274 const M2: u32 = 924_844_033;
275
276 modulus!(M1, M2);
277
278 fn test<M: Modulus>(rng: &mut ThreadRng) {
279 let mut gen_values = |n| gen_values::<Mod998244353>(rng, n);
280 for (n, m) in (1..20).flat_map(|i| (1..20).map(move |j| (i, j))) {
281 let (a, b) = (gen_values(n), gen_values(m));
282 assert_eq!(conv_naive(&a, &b), super::convolution(&a, &b));
283 }
284 }
285
286 let mut rng = rand::thread_rng();
287 test::<M1>(&mut rng);
288 test::<M2>(&mut rng);
289 }
290
291 #[test]
293 fn simple_int() {
294 simple_raw::<i32>();
295 }
296
297 #[test]
299 fn simple_uint() {
300 simple_raw::<u32>();
301 }
302
303 #[test]
305 fn simple_ll() {
306 simple_raw::<i64>();
307 }
308
309 #[test]
311 fn simple_ull() {
312 simple_raw::<u64>();
313 }
314
315 #[test]
317 fn simple_int128() {
318 simple_raw::<i128>();
319 }
320
321 #[test]
323 fn simple_uint128() {
324 simple_raw::<u128>();
325 }
326
327 fn simple_raw<T>()
328 where
329 T: TryFrom<u32> + Copy + RemEuclidU32,
330 T::Error: fmt::Debug,
331 {
332 const M1: u32 = 998_244_353;
333 const M2: u32 = 924_844_033;
334
335 modulus!(M1, M2);
336
337 fn test<T, M>(rng: &mut ThreadRng)
338 where
339 T: TryFrom<u32> + Copy + RemEuclidU32,
340 T::Error: fmt::Debug,
341 M: Modulus,
342 {
343 let mut gen_raw_values = |n| gen_raw_values::<u32, Mod998244353>(rng, n);
344 for (n, m) in (1..20).flat_map(|i| (1..20).map(move |j| (i, j))) {
345 let (a, b) = (gen_raw_values(n), gen_raw_values(m));
346 assert_eq!(
347 conv_raw_naive::<_, M>(&a, &b),
348 super::convolution_raw::<_, M>(&a, &b),
349 );
350 }
351 }
352
353 let mut rng = rand::thread_rng();
354 test::<T, M1>(&mut rng);
355 test::<T, M2>(&mut rng);
356 }
357
358 #[test]
360 fn conv_ll() {
361 let mut rng = rand::thread_rng();
362 for (n, m) in (1..20).flat_map(|i| (1..20).map(move |j| (i, j))) {
363 let mut gen =
364 |n: usize| -> Vec<_> { (0..n).map(|_| rng.gen_range(-500_000, 500_000)).collect() };
365 let (a, b) = (gen(n), gen(m));
366 assert_eq!(conv_i64_naive(&a, &b), super::convolution_i64(&a, &b));
367 }
368 }
369
370 #[test]
372 fn conv_ll_bound() {
373 const M1: u64 = 754_974_721; const M2: u64 = 167_772_161; const M3: u64 = 469_762_049; const M2M3: u64 = M2 * M3;
377 const M1M3: u64 = M1 * M3;
378 const M1M2: u64 = M1 * M2;
379
380 modulus!(M1, M2, M3);
381
382 for i in -1000..=1000 {
383 let a = vec![0u64.wrapping_sub(M1M2 + M1M3 + M2M3) as i64 + i];
384 let b = vec![1];
385 assert_eq!(a, super::convolution_i64(&a, &b));
386 }
387
388 for i in 0..1000 {
389 let a = vec![i64::min_value() + i];
390 let b = vec![1];
391 assert_eq!(a, super::convolution_i64(&a, &b));
392 }
393
394 for i in 0..1000 {
395 let a = vec![i64::max_value() - i];
396 let b = vec![1];
397 assert_eq!(a, super::convolution_i64(&a, &b));
398 }
399 }
400
401 #[test]
403 fn conv_641() {
404 const M: u32 = 641;
405 modulus!(M);
406
407 let mut rng = rand::thread_rng();
408 let mut gen_values = |n| gen_values::<M>(&mut rng, n);
409 let (a, b) = (gen_values(64), gen_values(65));
410 assert_eq!(conv_naive(&a, &b), super::convolution(&a, &b));
411 }
412
413 #[test]
415 fn conv_18433() {
416 const M: u32 = 18433;
417 modulus!(M);
418
419 let mut rng = rand::thread_rng();
420 let mut gen_values = |n| gen_values::<M>(&mut rng, n);
421 let (a, b) = (gen_values(1024), gen_values(1025));
422 assert_eq!(conv_naive(&a, &b), super::convolution(&a, &b));
423 }
424
425 #[allow(clippy::many_single_char_names)]
426 fn conv_naive<M: Modulus>(
427 a: &[StaticModInt<M>],
428 b: &[StaticModInt<M>],
429 ) -> Vec<StaticModInt<M>> {
430 let (n, m) = (a.len(), b.len());
431 let mut c = vec![StaticModInt::raw(0); n + m - 1];
432 for (i, j) in (0..n).flat_map(|i| (0..m).map(move |j| (i, j))) {
433 c[i + j] += a[i] * b[j];
434 }
435 c
436 }
437
438 fn conv_raw_naive<T, M>(a: &[T], b: &[T]) -> Vec<T>
439 where
440 T: TryFrom<u32> + Copy + RemEuclidU32,
441 T::Error: fmt::Debug,
442 M: Modulus,
443 {
444 conv_naive::<M>(
445 &a.iter().copied().map(Into::into).collect::<Vec<_>>(),
446 &b.iter().copied().map(Into::into).collect::<Vec<_>>(),
447 )
448 .into_iter()
449 .map(|x| x.val().try_into().unwrap())
450 .collect()
451 }
452
453 #[allow(clippy::many_single_char_names)]
454 fn conv_i64_naive(a: &[i64], b: &[i64]) -> Vec<i64> {
455 let (n, m) = (a.len(), b.len());
456 let mut c = vec![0; n + m - 1];
457 for (i, j) in (0..n).flat_map(|i| (0..m).map(move |j| (i, j))) {
458 c[i + j] += a[i] * b[j];
459 }
460 c
461 }
462
463 fn gen_values<M: Modulus>(rng: &mut ThreadRng, n: usize) -> Vec<StaticModInt<M>> {
464 (0..n).map(|_| rng.gen_range(0, M::VALUE).into()).collect()
465 }
466
467 fn gen_raw_values<T, M>(rng: &mut ThreadRng, n: usize) -> Vec<T>
468 where
469 T: TryFrom<u32>,
470 T::Error: fmt::Debug,
471 M: Modulus,
472 {
473 (0..n)
474 .map(|_| rng.gen_range(0, M::VALUE).try_into().unwrap())
475 .collect()
476 }
477}