1macro_rules! modulus {
15 ($($name:ident),*) => {
16 $(
17 #[derive(Copy, Clone, Eq, PartialEq)]
18 enum $name {}
19
20 impl Modulus for $name {
21 const VALUE: u32 = $name as _;
22 const HINT_VALUE_IS_PRIME: bool = true;
23
24 fn butterfly_cache() -> &'static ::std::thread::LocalKey<::std::cell::RefCell<::std::option::Option<$crate::modint::ButterflyCache<Self>>>> {
25 thread_local! {
26 static BUTTERFLY_CACHE: ::std::cell::RefCell<::std::option::Option<$crate::modint::ButterflyCache<$name>>> = ::std::default::Default::default();
27 }
28 &BUTTERFLY_CACHE
29 }
30 }
31 )*
32 };
33}
34
35use crate::{
36 internal_bit, internal_math,
37 modint::{ButterflyCache, Modulus, RemEuclidU32, StaticModInt},
38};
39use std::{
40 cmp,
41 convert::{TryFrom, TryInto as _},
42 fmt,
43};
44
45#[allow(clippy::many_single_char_names)]
94pub fn convolution<M>(a: &[StaticModInt<M>], b: &[StaticModInt<M>]) -> Vec<StaticModInt<M>>
95where
96 M: Modulus,
97{
98 if a.is_empty() || b.is_empty() {
99 return vec![];
100 }
101 let (n, m) = (a.len(), b.len());
102
103 if cmp::min(n, m) <= 60 {
104 let (n, m, a, b) = if n < m { (m, n, b, a) } else { (n, m, a, b) };
105 let mut ans = vec![StaticModInt::new(0); n + m - 1];
106 for i in 0..n {
107 for j in 0..m {
108 ans[i + j] += a[i] * b[j];
109 }
110 }
111 return ans;
112 }
113
114 let (mut a, mut b) = (a.to_owned(), b.to_owned());
115 let z = 1 << internal_bit::ceil_pow2((n + m - 1) as _);
116 a.resize(z, StaticModInt::raw(0));
117 butterfly(&mut a);
118 b.resize(z, StaticModInt::raw(0));
119 butterfly(&mut b);
120 for (a, b) in a.iter_mut().zip(&b) {
121 *a *= b;
122 }
123 butterfly_inv(&mut a);
124 a.resize(n + m - 1, StaticModInt::raw(0));
125 let iz = StaticModInt::new(z).inv();
126 for a in &mut a {
127 *a *= iz;
128 }
129 a
130}
131
132pub fn convolution_raw<T, M>(a: &[T], b: &[T]) -> Vec<T>
188where
189 T: RemEuclidU32 + TryFrom<u32> + Clone,
190 T::Error: fmt::Debug,
191 M: Modulus,
192{
193 let a = a.iter().cloned().map(Into::into).collect::<Vec<_>>();
194 let b = b.iter().cloned().map(Into::into).collect::<Vec<_>>();
195 convolution::<M>(&a, &b)
196 .into_iter()
197 .map(|z| {
198 z.val()
199 .try_into()
200 .expect("the numeric type is smaller than the modulus")
201 })
202 .collect()
203}
204
205#[allow(clippy::many_single_char_names)]
244pub fn convolution_i64(a: &[i64], b: &[i64]) -> Vec<i64> {
245 const M1: u64 = 754_974_721; const M2: u64 = 167_772_161; const M3: u64 = 469_762_049; const M2M3: u64 = M2 * M3;
249 const M1M3: u64 = M1 * M3;
250 const M1M2: u64 = M1 * M2;
251 const M1M2M3: u64 = M1M2.wrapping_mul(M3);
252
253 modulus!(M1, M2, M3);
254
255 if a.is_empty() || b.is_empty() {
256 return vec![];
257 }
258
259 let (_, i1) = internal_math::inv_gcd(M2M3 as _, M1 as _);
260 let (_, i2) = internal_math::inv_gcd(M1M3 as _, M2 as _);
261 let (_, i3) = internal_math::inv_gcd(M1M2 as _, M3 as _);
262
263 let c1 = convolution_raw::<i64, M1>(a, b);
264 let c2 = convolution_raw::<i64, M2>(a, b);
265 let c3 = convolution_raw::<i64, M3>(a, b);
266
267 c1.into_iter()
268 .zip(c2)
269 .zip(c3)
270 .map(|((c1, c2), c3)| {
271 const OFFSET: &[u64] = &[0, 0, M1M2M3, 2 * M1M2M3, 3 * M1M2M3];
272
273 let mut x = [(c1, i1, M1, M2M3), (c2, i2, M2, M1M3), (c3, i3, M3, M1M2)]
274 .iter()
275 .map(|&(c, i, m1, m2)| c.wrapping_mul(i).rem_euclid(m1 as _).wrapping_mul(m2 as _))
276 .fold(0, i64::wrapping_add);
277
278 let mut diff = c1 - internal_math::safe_mod(x, M1 as _);
296 if diff < 0 {
297 diff += M1 as i64;
298 }
299 x = x.wrapping_sub(OFFSET[diff.rem_euclid(5) as usize] as _);
300 x
301 })
302 .collect()
303}
304
305#[allow(clippy::many_single_char_names)]
306fn butterfly<M: Modulus>(a: &mut [StaticModInt<M>]) {
307 let n = a.len();
308 let h = internal_bit::ceil_pow2(n as u32);
309
310 M::butterfly_cache().with(|cache| {
311 let mut cache = cache.borrow_mut();
312 let ButterflyCache { sum_e, .. } = cache.get_or_insert_with(prepare);
313 for ph in 1..=h {
314 let w = 1 << (ph - 1);
315 let p = 1 << (h - ph);
316 let mut now = StaticModInt::<M>::new(1);
317 for s in 0..w {
318 let offset = s << (h - ph + 1);
319 for i in 0..p {
320 let l = a[i + offset];
321 let r = a[i + offset + p] * now;
322 a[i + offset] = l + r;
323 a[i + offset + p] = l - r;
324 }
325 now *= sum_e[(!s).trailing_zeros() as usize];
326 }
327 }
328 });
329}
330
331#[allow(clippy::many_single_char_names)]
332fn butterfly_inv<M: Modulus>(a: &mut [StaticModInt<M>]) {
333 let n = a.len();
334 let h = internal_bit::ceil_pow2(n as u32);
335
336 M::butterfly_cache().with(|cache| {
337 let mut cache = cache.borrow_mut();
338 let ButterflyCache { sum_ie, .. } = cache.get_or_insert_with(prepare);
339 for ph in (1..=h).rev() {
340 let w = 1 << (ph - 1);
341 let p = 1 << (h - ph);
342 let mut inow = StaticModInt::<M>::new(1);
343 for s in 0..w {
344 let offset = s << (h - ph + 1);
345 for i in 0..p {
346 let l = a[i + offset];
347 let r = a[i + offset + p];
348 a[i + offset] = l + r;
349 a[i + offset + p] = StaticModInt::new(M::VALUE + l.val() - r.val()) * inow;
350 }
351 inow *= sum_ie[(!s).trailing_zeros() as usize];
352 }
353 }
354 });
355}
356
357fn prepare<M: Modulus>() -> ButterflyCache<M> {
358 let g = StaticModInt::<M>::raw(internal_math::primitive_root(M::VALUE as i32) as u32);
359 let mut es = [StaticModInt::<M>::raw(0); 30]; let mut ies = [StaticModInt::<M>::raw(0); 30];
361 let cnt2 = (M::VALUE - 1).trailing_zeros() as usize;
362 let mut e = g.pow(((M::VALUE - 1) >> cnt2).into());
363 let mut ie = e.inv();
364 for i in (2..=cnt2).rev() {
365 es[i - 2] = e;
366 ies[i - 2] = ie;
367 e *= e;
368 ie *= ie;
369 }
370 let sum_e = es
371 .iter()
372 .scan(StaticModInt::new(1), |acc, e| {
373 *acc *= e;
374 Some(*acc)
375 })
376 .collect();
377 let sum_ie = ies
378 .iter()
379 .scan(StaticModInt::new(1), |acc, ie| {
380 *acc *= ie;
381 Some(*acc)
382 })
383 .collect();
384 ButterflyCache { sum_e, sum_ie }
385}
386
387#[cfg(test)]
388mod tests {
389 use crate::{
390 modint::{Mod998244353, Modulus, StaticModInt},
391 RemEuclidU32,
392 };
393 use rand::{rngs::ThreadRng, Rng as _};
394 use std::{
395 convert::{TryFrom, TryInto as _},
396 fmt,
397 };
398
399 #[test]
401 fn empty() {
402 assert!(super::convolution_raw::<i32, Mod998244353>(&[], &[]).is_empty());
403 assert!(super::convolution_raw::<i32, Mod998244353>(&[], &[1, 2]).is_empty());
404 assert!(super::convolution_raw::<i32, Mod998244353>(&[1, 2], &[]).is_empty());
405 assert!(super::convolution_raw::<i32, Mod998244353>(&[1], &[]).is_empty());
406 assert!(super::convolution_raw::<i64, Mod998244353>(&[], &[]).is_empty());
407 assert!(super::convolution_raw::<i64, Mod998244353>(&[], &[1, 2]).is_empty());
408 assert!(super::convolution::<Mod998244353>(&[], &[]).is_empty());
409 assert!(super::convolution::<Mod998244353>(&[], &[1.into(), 2.into()]).is_empty());
410 }
411
412 #[test]
414 fn mid() {
415 const N: usize = 1234;
416 const M: usize = 2345;
417
418 let mut rng = rand::thread_rng();
419 let mut gen_values = |n| gen_values::<Mod998244353>(&mut rng, n);
420 let (a, b) = (gen_values(N), gen_values(M));
421 assert_eq!(conv_naive(&a, &b), super::convolution(&a, &b));
422 }
423
424 #[test]
426 fn simple_s_mod() {
427 const M1: u32 = 998_244_353;
428 const M2: u32 = 924_844_033;
429
430 modulus!(M1, M2);
431
432 fn test<M: Modulus>(rng: &mut ThreadRng) {
433 let mut gen_values = |n| gen_values::<M>(rng, n);
434 for (n, m) in (1..20).flat_map(|i| (1..20).map(move |j| (i, j))) {
435 let (a, b) = (gen_values(n), gen_values(m));
436 assert_eq!(conv_naive(&a, &b), super::convolution(&a, &b));
437 }
438 }
439
440 let mut rng = rand::thread_rng();
441 test::<M1>(&mut rng);
442 test::<M2>(&mut rng);
443 }
444
445 #[test]
447 fn simple_int() {
448 simple_raw::<i32>();
449 }
450
451 #[test]
453 fn simple_uint() {
454 simple_raw::<u32>();
455 }
456
457 #[test]
459 fn simple_ll() {
460 simple_raw::<i64>();
461 }
462
463 #[test]
465 fn simple_ull() {
466 simple_raw::<u64>();
467 }
468
469 #[test]
471 fn simple_int128() {
472 simple_raw::<i128>();
473 }
474
475 #[test]
477 fn simple_uint128() {
478 simple_raw::<u128>();
479 }
480
481 fn simple_raw<T>()
482 where
483 T: TryFrom<u32> + Copy + RemEuclidU32 + Eq,
484 T::Error: fmt::Debug,
485 {
486 const M1: u32 = 998_244_353;
487 const M2: u32 = 924_844_033;
488
489 modulus!(M1, M2);
490
491 fn test<T, M>(rng: &mut ThreadRng)
492 where
493 T: TryFrom<u32> + Copy + RemEuclidU32 + Eq,
494 T::Error: fmt::Debug,
495 M: Modulus,
496 {
497 let mut gen_raw_values = |n| {
498 gen_raw_values::<u32, M>(rng, n)
499 .into_iter()
500 .map(|x| x.try_into().unwrap())
501 .collect::<Vec<T>>()
502 };
503 for (n, m) in (1..20).flat_map(|i| (1..20).map(move |j| (i, j))) {
504 let (a, b) = (gen_raw_values(n), gen_raw_values(m));
505 assert!(
506 conv_raw_naive::<T, M>(&a, &b) == super::convolution_raw::<T, M>(&a, &b),
507 "values don't match",
508 );
509 }
510 }
511
512 let mut rng = rand::thread_rng();
513 test::<T, M1>(&mut rng);
514 test::<T, M2>(&mut rng);
515 }
516
517 #[test]
519 fn conv_ll() {
520 let mut rng = rand::thread_rng();
521 for (n, m) in (1..20).flat_map(|i| (1..20).map(move |j| (i, j))) {
522 let mut gen =
523 |n: usize| -> Vec<_> { (0..n).map(|_| rng.gen_range(-500_000, 500_000)).collect() };
524 let (a, b) = (gen(n), gen(m));
525 assert_eq!(conv_i64_naive(&a, &b), super::convolution_i64(&a, &b));
526 }
527 }
528
529 #[test]
531 fn conv_ll_bound() {
532 const M1: u64 = 754_974_721; const M2: u64 = 167_772_161; const M3: u64 = 469_762_049; const M2M3: u64 = M2 * M3;
536 const M1M3: u64 = M1 * M3;
537 const M1M2: u64 = M1 * M2;
538
539 for i in -1000..=1000 {
540 let a = vec![0u64.wrapping_sub(M1M2 + M1M3 + M2M3) as i64 + i];
541 let b = vec![1];
542 assert_eq!(a, super::convolution_i64(&a, &b));
543 }
544
545 for i in 0..1000 {
546 let a = vec![i64::MIN + i];
547 let b = vec![1];
548 assert_eq!(a, super::convolution_i64(&a, &b));
549 }
550
551 for i in 0..1000 {
552 let a = vec![i64::MAX - i];
553 let b = vec![1];
554 assert_eq!(a, super::convolution_i64(&a, &b));
555 }
556 }
557
558 #[test]
560 fn conv_641() {
561 const M: u32 = 641;
562 modulus!(M);
563
564 let mut rng = rand::thread_rng();
565 let mut gen_values = |n| gen_values::<M>(&mut rng, n);
566 let (a, b) = (gen_values(64), gen_values(65));
567 assert_eq!(conv_naive(&a, &b), super::convolution(&a, &b));
568 }
569
570 #[test]
572 fn conv_18433() {
573 const M: u32 = 18433;
574 modulus!(M);
575
576 let mut rng = rand::thread_rng();
577 let mut gen_values = |n| gen_values::<M>(&mut rng, n);
578 let (a, b) = (gen_values(1024), gen_values(1025));
579 assert_eq!(conv_naive(&a, &b), super::convolution(&a, &b));
580 }
581
582 #[allow(clippy::many_single_char_names)]
583 fn conv_naive<M: Modulus>(
584 a: &[StaticModInt<M>],
585 b: &[StaticModInt<M>],
586 ) -> Vec<StaticModInt<M>> {
587 let (n, m) = (a.len(), b.len());
588 let mut c = vec![StaticModInt::raw(0); n + m - 1];
589 for (i, j) in (0..n).flat_map(|i| (0..m).map(move |j| (i, j))) {
590 c[i + j] += a[i] * b[j];
591 }
592 c
593 }
594
595 fn conv_raw_naive<T, M>(a: &[T], b: &[T]) -> Vec<T>
596 where
597 T: TryFrom<u32> + Copy + RemEuclidU32,
598 T::Error: fmt::Debug,
599 M: Modulus,
600 {
601 conv_naive::<M>(
602 &a.iter().copied().map(Into::into).collect::<Vec<_>>(),
603 &b.iter().copied().map(Into::into).collect::<Vec<_>>(),
604 )
605 .into_iter()
606 .map(|x| x.val().try_into().unwrap())
607 .collect()
608 }
609
610 #[allow(clippy::many_single_char_names)]
611 fn conv_i64_naive(a: &[i64], b: &[i64]) -> Vec<i64> {
612 let (n, m) = (a.len(), b.len());
613 let mut c = vec![0; n + m - 1];
614 for (i, j) in (0..n).flat_map(|i| (0..m).map(move |j| (i, j))) {
615 c[i + j] += a[i] * b[j];
616 }
617 c
618 }
619
620 fn gen_values<M: Modulus>(rng: &mut ThreadRng, n: usize) -> Vec<StaticModInt<M>> {
621 (0..n).map(|_| rng.gen_range(0, M::VALUE).into()).collect()
622 }
623
624 fn gen_raw_values<T, M>(rng: &mut ThreadRng, n: usize) -> Vec<T>
625 where
626 T: TryFrom<u32>,
627 T::Error: fmt::Debug,
628 M: Modulus,
629 {
630 (0..n)
631 .map(|_| rng.gen_range(0, M::VALUE).try_into().unwrap())
632 .collect()
633 }
634}