ac_library/
convolution.rs

1//! Functions that calculate $(+, \times)$ convolution.
2//!
3//! Given two non-empty sequences $a_0, a_1, \ldots, a_{N - 1}$ and $b_0, b_1, \ldots, b_{M - 1}$, they calculate the sequence $c$ of length $N + M - 1$ defined by
4//!
5//! \\[
6//!   c_i = \sum_ {j = 0}^i a_j b_{i - j}
7//! \\]
8//!
9//! # Major changes from the original ACL
10//!
11//! - Separated the overloaded `convolution` into `convolution<_>` and `convolution_raw<_, _>`.
12//! - Renamed `convolution_ll` to `convolution_i64`.
13
14macro_rules! modulus {
15    ($($name:ident),*) => {
16        $(
17            #[derive(Copy, Clone, Eq, PartialEq)]
18            enum $name {}
19
20            impl Modulus for $name {
21                const VALUE: u32 = $name as _;
22                const HINT_VALUE_IS_PRIME: bool = true;
23
24                fn butterfly_cache() -> &'static ::std::thread::LocalKey<::std::cell::RefCell<::std::option::Option<$crate::modint::ButterflyCache<Self>>>> {
25                    thread_local! {
26                        static BUTTERFLY_CACHE: ::std::cell::RefCell<::std::option::Option<$crate::modint::ButterflyCache<$name>>> = ::std::default::Default::default();
27                    }
28                    &BUTTERFLY_CACHE
29                }
30            }
31        )*
32    };
33}
34
35use crate::{
36    internal_bit, internal_math,
37    modint::{ButterflyCache, Modulus, RemEuclidU32, StaticModInt},
38};
39use std::{
40    cmp,
41    convert::{TryFrom, TryInto as _},
42    fmt,
43};
44
45/// Calculates the $(+, \times)$ convolution in $\mathbb{Z}/p\mathbb{Z}$.
46///
47/// See the [module-level documentation] for more details.
48///
49/// Returns a empty `Vec` if `a` or `b` is empty.
50///
51/// # Constraints
52///
53/// - $2 \leq m \leq 2 \times 10^9$
54/// - $m$ is a prime number.
55/// - $\exists c \text{ s.t. } 2^c \mid (m - 1), |a| + |b| - 1 \leq 2^c$
56///
57/// where $m$ is `M::VALUE`.
58///
59/// # Complexity
60///
61/// - $O(n \log n + \log m)$ where $n = |a| + |b|$.
62///
63/// # Example
64///
65/// ```
66/// use ac_library::ModInt1000000007 as Mint;
67/// use proconio::{input, source::once::OnceSource};
68///
69/// input! {
70///     from OnceSource::from(
71///         "3\n\
72///          1 2 3\n\
73///          3\n\
74///          -1 -2 -3\n",
75///     ),
76///     a: [Mint],
77///     b: [Mint],
78/// }
79///
80/// assert_eq!(
81///     ac_library::convolution(&a, &b),
82///     [
83///         Mint::new(-1),
84///         Mint::new(-4),
85///         Mint::new(-10),
86///         Mint::new(-12),
87///         Mint::new(-9),
88///     ],
89/// );
90/// ```
91///
92/// [module-level documentation]: ./index.html
93#[allow(clippy::many_single_char_names)]
94pub fn convolution<M>(a: &[StaticModInt<M>], b: &[StaticModInt<M>]) -> Vec<StaticModInt<M>>
95where
96    M: Modulus,
97{
98    if a.is_empty() || b.is_empty() {
99        return vec![];
100    }
101    let (n, m) = (a.len(), b.len());
102
103    if cmp::min(n, m) <= 60 {
104        let (n, m, a, b) = if n < m { (m, n, b, a) } else { (n, m, a, b) };
105        let mut ans = vec![StaticModInt::new(0); n + m - 1];
106        for i in 0..n {
107            for j in 0..m {
108                ans[i + j] += a[i] * b[j];
109            }
110        }
111        return ans;
112    }
113
114    let (mut a, mut b) = (a.to_owned(), b.to_owned());
115    let z = 1 << internal_bit::ceil_pow2((n + m - 1) as _);
116    a.resize(z, StaticModInt::raw(0));
117    butterfly(&mut a);
118    b.resize(z, StaticModInt::raw(0));
119    butterfly(&mut b);
120    for (a, b) in a.iter_mut().zip(&b) {
121        *a *= b;
122    }
123    butterfly_inv(&mut a);
124    a.resize(n + m - 1, StaticModInt::raw(0));
125    let iz = StaticModInt::new(z).inv();
126    for a in &mut a {
127        *a *= iz;
128    }
129    a
130}
131
132/// Calculates the $(+, \times)$ convolution in $\mathbb{Z}/p\mathbb{Z}$.
133///
134/// See the [module-level documentation] for more details.
135///
136/// Returns a empty `Vec` if `a` or `b` is empty.
137///
138/// # Constraints
139///
140/// - $2 \leq m \leq 2 \times 10^9$
141/// - $m$ is a prime number.
142/// - $\exists c \text{ s.t. } 2^c \mid (m - 1), |a| + |b| - 1 \leq 2^c$
143/// - $(0, m] \subseteq$ `T`
144///
145/// where $m$ is `M::VALUE`.
146///
147/// # Complexity
148///
149/// - $O(n \log n + \log m)$ where $n = |a| + |b|$.
150///
151/// # Panics
152///
153/// Panics if any element of the result ($\in [0,$ `M::VALUE`$)$) is outside of the range of `T`.
154///
155/// # Example
156///
157/// ```
158/// use ac_library::{Mod1000000007 as M, Modulus as _};
159/// use proconio::{input, source::once::OnceSource};
160///
161/// const M: i32 = M::VALUE as _;
162///
163/// input! {
164///     from OnceSource::from(
165///         "3\n\
166///          1 2 3\n\
167///          3\n\
168///          -1 -2 -3\n",
169///     ),
170///     a: [i32],
171///     b: [i32],
172/// }
173///
174/// assert_eq!(
175///     ac_library::convolution::convolution_raw::<_, M>(&a, &b),
176///     [
177///         (-1i32).rem_euclid(M),
178///         (-4i32).rem_euclid(M),
179///         (-10i32).rem_euclid(M),
180///         (-12i32).rem_euclid(M),
181///         (-9i32).rem_euclid(M),
182///     ],
183/// );
184/// ```
185///
186/// [module-level documentation]: ./index.html
187pub fn convolution_raw<T, M>(a: &[T], b: &[T]) -> Vec<T>
188where
189    T: RemEuclidU32 + TryFrom<u32> + Clone,
190    T::Error: fmt::Debug,
191    M: Modulus,
192{
193    let a = a.iter().cloned().map(Into::into).collect::<Vec<_>>();
194    let b = b.iter().cloned().map(Into::into).collect::<Vec<_>>();
195    convolution::<M>(&a, &b)
196        .into_iter()
197        .map(|z| {
198            z.val()
199                .try_into()
200                .expect("the numeric type is smaller than the modulus")
201        })
202        .collect()
203}
204
205/// Calculates the $(+, \times)$ convolution in `i64`.
206///
207/// See the [module-level documentation] for more details.
208///
209/// Returns a empty `Vec` if `a` or `b` is empty.
210///
211/// # Constraints
212///
213/// - $|a| + |b| - 1 \leq 2^{24}$
214/// - All elements of the result are inside of the range of `i64`
215///
216/// # Complexity
217///
218/// - $O(n \log n)$ where $n = |a| + |b|$.
219///
220/// # Example
221///
222/// ```
223/// use proconio::{input, source::once::OnceSource};
224///
225/// input! {
226///     from OnceSource::from(
227///         "3\n\
228///          1 2 3\n\
229///          3\n\
230///          -1 -2 -3\n",
231///     ),
232///     a: [i64],
233///     b: [i64],
234/// }
235///
236/// assert_eq!(
237///     ac_library::convolution_i64(&a, &b),
238///     [-1, -4, -10, -12, -9],
239/// );
240/// ```
241///
242/// [module-level documentation]: ./index.html
243#[allow(clippy::many_single_char_names)]
244pub fn convolution_i64(a: &[i64], b: &[i64]) -> Vec<i64> {
245    const M1: u64 = 754_974_721; // 2^24
246    const M2: u64 = 167_772_161; // 2^25
247    const M3: u64 = 469_762_049; // 2^26
248    const M2M3: u64 = M2 * M3;
249    const M1M3: u64 = M1 * M3;
250    const M1M2: u64 = M1 * M2;
251    const M1M2M3: u64 = M1M2.wrapping_mul(M3);
252
253    modulus!(M1, M2, M3);
254
255    if a.is_empty() || b.is_empty() {
256        return vec![];
257    }
258
259    let (_, i1) = internal_math::inv_gcd(M2M3 as _, M1 as _);
260    let (_, i2) = internal_math::inv_gcd(M1M3 as _, M2 as _);
261    let (_, i3) = internal_math::inv_gcd(M1M2 as _, M3 as _);
262
263    let c1 = convolution_raw::<i64, M1>(a, b);
264    let c2 = convolution_raw::<i64, M2>(a, b);
265    let c3 = convolution_raw::<i64, M3>(a, b);
266
267    c1.into_iter()
268        .zip(c2)
269        .zip(c3)
270        .map(|((c1, c2), c3)| {
271            const OFFSET: &[u64] = &[0, 0, M1M2M3, 2 * M1M2M3, 3 * M1M2M3];
272
273            let mut x = [(c1, i1, M1, M2M3), (c2, i2, M2, M1M3), (c3, i3, M3, M1M2)]
274                .iter()
275                .map(|&(c, i, m1, m2)| c.wrapping_mul(i).rem_euclid(m1 as _).wrapping_mul(m2 as _))
276                .fold(0, i64::wrapping_add);
277
278            // B = 2^63, -B <= x, r(real value) < B
279            // (x, x - M, x - 2M, or x - 3M) = r (mod 2B)
280            // r = c1[i] (mod MOD1)
281            // focus on MOD1
282            // r = x, x - M', x - 2M', x - 3M' (M' = M % 2^64) (mod 2B)
283            // r = x,
284            //     x - M' + (0 or 2B),
285            //     x - 2M' + (0, 2B or 4B),
286            //     x - 3M' + (0, 2B, 4B or 6B) (without mod!)
287            // (r - x) = 0, (0)
288            //           - M' + (0 or 2B), (1)
289            //           -2M' + (0 or 2B or 4B), (2)
290            //           -3M' + (0 or 2B or 4B or 6B) (3) (mod MOD1)
291            // we checked that
292            //   ((1) mod MOD1) mod 5 = 2
293            //   ((2) mod MOD1) mod 5 = 3
294            //   ((3) mod MOD1) mod 5 = 4
295            let mut diff = c1 - internal_math::safe_mod(x, M1 as _);
296            if diff < 0 {
297                diff += M1 as i64;
298            }
299            x = x.wrapping_sub(OFFSET[diff.rem_euclid(5) as usize] as _);
300            x
301        })
302        .collect()
303}
304
305#[allow(clippy::many_single_char_names)]
306fn butterfly<M: Modulus>(a: &mut [StaticModInt<M>]) {
307    let n = a.len();
308    let h = internal_bit::ceil_pow2(n as u32);
309
310    M::butterfly_cache().with(|cache| {
311        let mut cache = cache.borrow_mut();
312        let ButterflyCache { sum_e, .. } = cache.get_or_insert_with(prepare);
313        for ph in 1..=h {
314            let w = 1 << (ph - 1);
315            let p = 1 << (h - ph);
316            let mut now = StaticModInt::<M>::new(1);
317            for s in 0..w {
318                let offset = s << (h - ph + 1);
319                for i in 0..p {
320                    let l = a[i + offset];
321                    let r = a[i + offset + p] * now;
322                    a[i + offset] = l + r;
323                    a[i + offset + p] = l - r;
324                }
325                now *= sum_e[(!s).trailing_zeros() as usize];
326            }
327        }
328    });
329}
330
331#[allow(clippy::many_single_char_names)]
332fn butterfly_inv<M: Modulus>(a: &mut [StaticModInt<M>]) {
333    let n = a.len();
334    let h = internal_bit::ceil_pow2(n as u32);
335
336    M::butterfly_cache().with(|cache| {
337        let mut cache = cache.borrow_mut();
338        let ButterflyCache { sum_ie, .. } = cache.get_or_insert_with(prepare);
339        for ph in (1..=h).rev() {
340            let w = 1 << (ph - 1);
341            let p = 1 << (h - ph);
342            let mut inow = StaticModInt::<M>::new(1);
343            for s in 0..w {
344                let offset = s << (h - ph + 1);
345                for i in 0..p {
346                    let l = a[i + offset];
347                    let r = a[i + offset + p];
348                    a[i + offset] = l + r;
349                    a[i + offset + p] = StaticModInt::new(M::VALUE + l.val() - r.val()) * inow;
350                }
351                inow *= sum_ie[(!s).trailing_zeros() as usize];
352            }
353        }
354    });
355}
356
357fn prepare<M: Modulus>() -> ButterflyCache<M> {
358    let g = StaticModInt::<M>::raw(internal_math::primitive_root(M::VALUE as i32) as u32);
359    let mut es = [StaticModInt::<M>::raw(0); 30]; // es[i]^(2^(2+i)) == 1
360    let mut ies = [StaticModInt::<M>::raw(0); 30];
361    let cnt2 = (M::VALUE - 1).trailing_zeros() as usize;
362    let mut e = g.pow(((M::VALUE - 1) >> cnt2).into());
363    let mut ie = e.inv();
364    for i in (2..=cnt2).rev() {
365        es[i - 2] = e;
366        ies[i - 2] = ie;
367        e *= e;
368        ie *= ie;
369    }
370    let sum_e = es
371        .iter()
372        .scan(StaticModInt::new(1), |acc, e| {
373            *acc *= e;
374            Some(*acc)
375        })
376        .collect();
377    let sum_ie = ies
378        .iter()
379        .scan(StaticModInt::new(1), |acc, ie| {
380            *acc *= ie;
381            Some(*acc)
382        })
383        .collect();
384    ButterflyCache { sum_e, sum_ie }
385}
386
387#[cfg(test)]
388mod tests {
389    use crate::{
390        modint::{Mod998244353, Modulus, StaticModInt},
391        RemEuclidU32,
392    };
393    use rand::{rngs::ThreadRng, Rng as _};
394    use std::{
395        convert::{TryFrom, TryInto as _},
396        fmt,
397    };
398
399    //https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L51-L71
400    #[test]
401    fn empty() {
402        assert!(super::convolution_raw::<i32, Mod998244353>(&[], &[]).is_empty());
403        assert!(super::convolution_raw::<i32, Mod998244353>(&[], &[1, 2]).is_empty());
404        assert!(super::convolution_raw::<i32, Mod998244353>(&[1, 2], &[]).is_empty());
405        assert!(super::convolution_raw::<i32, Mod998244353>(&[1], &[]).is_empty());
406        assert!(super::convolution_raw::<i64, Mod998244353>(&[], &[]).is_empty());
407        assert!(super::convolution_raw::<i64, Mod998244353>(&[], &[1, 2]).is_empty());
408        assert!(super::convolution::<Mod998244353>(&[], &[]).is_empty());
409        assert!(super::convolution::<Mod998244353>(&[], &[1.into(), 2.into()]).is_empty());
410    }
411
412    // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L73-L85
413    #[test]
414    fn mid() {
415        const N: usize = 1234;
416        const M: usize = 2345;
417
418        let mut rng = rand::thread_rng();
419        let mut gen_values = |n| gen_values::<Mod998244353>(&mut rng, n);
420        let (a, b) = (gen_values(N), gen_values(M));
421        assert_eq!(conv_naive(&a, &b), super::convolution(&a, &b));
422    }
423
424    // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L87-L118
425    #[test]
426    fn simple_s_mod() {
427        const M1: u32 = 998_244_353;
428        const M2: u32 = 924_844_033;
429
430        modulus!(M1, M2);
431
432        fn test<M: Modulus>(rng: &mut ThreadRng) {
433            let mut gen_values = |n| gen_values::<M>(rng, n);
434            for (n, m) in (1..20).flat_map(|i| (1..20).map(move |j| (i, j))) {
435                let (a, b) = (gen_values(n), gen_values(m));
436                assert_eq!(conv_naive(&a, &b), super::convolution(&a, &b));
437            }
438        }
439
440        let mut rng = rand::thread_rng();
441        test::<M1>(&mut rng);
442        test::<M2>(&mut rng);
443    }
444
445    // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L120-L150
446    #[test]
447    fn simple_int() {
448        simple_raw::<i32>();
449    }
450
451    // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L152-L182
452    #[test]
453    fn simple_uint() {
454        simple_raw::<u32>();
455    }
456
457    // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L184-L214
458    #[test]
459    fn simple_ll() {
460        simple_raw::<i64>();
461    }
462
463    // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L216-L246
464    #[test]
465    fn simple_ull() {
466        simple_raw::<u64>();
467    }
468
469    // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L249-L279
470    #[test]
471    fn simple_int128() {
472        simple_raw::<i128>();
473    }
474
475    // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L281-L311
476    #[test]
477    fn simple_uint128() {
478        simple_raw::<u128>();
479    }
480
481    fn simple_raw<T>()
482    where
483        T: TryFrom<u32> + Copy + RemEuclidU32 + Eq,
484        T::Error: fmt::Debug,
485    {
486        const M1: u32 = 998_244_353;
487        const M2: u32 = 924_844_033;
488
489        modulus!(M1, M2);
490
491        fn test<T, M>(rng: &mut ThreadRng)
492        where
493            T: TryFrom<u32> + Copy + RemEuclidU32 + Eq,
494            T::Error: fmt::Debug,
495            M: Modulus,
496        {
497            let mut gen_raw_values = |n| {
498                gen_raw_values::<u32, M>(rng, n)
499                    .into_iter()
500                    .map(|x| x.try_into().unwrap())
501                    .collect::<Vec<T>>()
502            };
503            for (n, m) in (1..20).flat_map(|i| (1..20).map(move |j| (i, j))) {
504                let (a, b) = (gen_raw_values(n), gen_raw_values(m));
505                assert!(
506                    conv_raw_naive::<T, M>(&a, &b) == super::convolution_raw::<T, M>(&a, &b),
507                    "values don't match",
508                );
509            }
510        }
511
512        let mut rng = rand::thread_rng();
513        test::<T, M1>(&mut rng);
514        test::<T, M2>(&mut rng);
515    }
516
517    // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L315-L329
518    #[test]
519    fn conv_ll() {
520        let mut rng = rand::thread_rng();
521        for (n, m) in (1..20).flat_map(|i| (1..20).map(move |j| (i, j))) {
522            let mut gen =
523                |n: usize| -> Vec<_> { (0..n).map(|_| rng.gen_range(-500_000, 500_000)).collect() };
524            let (a, b) = (gen(n), gen(m));
525            assert_eq!(conv_i64_naive(&a, &b), super::convolution_i64(&a, &b));
526        }
527    }
528
529    // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L331-L356
530    #[test]
531    fn conv_ll_bound() {
532        const M1: u64 = 754_974_721; // 2^24
533        const M2: u64 = 167_772_161; // 2^25
534        const M3: u64 = 469_762_049; // 2^26
535        const M2M3: u64 = M2 * M3;
536        const M1M3: u64 = M1 * M3;
537        const M1M2: u64 = M1 * M2;
538
539        for i in -1000..=1000 {
540            let a = vec![0u64.wrapping_sub(M1M2 + M1M3 + M2M3) as i64 + i];
541            let b = vec![1];
542            assert_eq!(a, super::convolution_i64(&a, &b));
543        }
544
545        for i in 0..1000 {
546            let a = vec![i64::MIN + i];
547            let b = vec![1];
548            assert_eq!(a, super::convolution_i64(&a, &b));
549        }
550
551        for i in 0..1000 {
552            let a = vec![i64::MAX - i];
553            let b = vec![1];
554            assert_eq!(a, super::convolution_i64(&a, &b));
555        }
556    }
557
558    // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L358-L371
559    #[test]
560    fn conv_641() {
561        const M: u32 = 641;
562        modulus!(M);
563
564        let mut rng = rand::thread_rng();
565        let mut gen_values = |n| gen_values::<M>(&mut rng, n);
566        let (a, b) = (gen_values(64), gen_values(65));
567        assert_eq!(conv_naive(&a, &b), super::convolution(&a, &b));
568    }
569
570    // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L373-L386
571    #[test]
572    fn conv_18433() {
573        const M: u32 = 18433;
574        modulus!(M);
575
576        let mut rng = rand::thread_rng();
577        let mut gen_values = |n| gen_values::<M>(&mut rng, n);
578        let (a, b) = (gen_values(1024), gen_values(1025));
579        assert_eq!(conv_naive(&a, &b), super::convolution(&a, &b));
580    }
581
582    #[allow(clippy::many_single_char_names)]
583    fn conv_naive<M: Modulus>(
584        a: &[StaticModInt<M>],
585        b: &[StaticModInt<M>],
586    ) -> Vec<StaticModInt<M>> {
587        let (n, m) = (a.len(), b.len());
588        let mut c = vec![StaticModInt::raw(0); n + m - 1];
589        for (i, j) in (0..n).flat_map(|i| (0..m).map(move |j| (i, j))) {
590            c[i + j] += a[i] * b[j];
591        }
592        c
593    }
594
595    fn conv_raw_naive<T, M>(a: &[T], b: &[T]) -> Vec<T>
596    where
597        T: TryFrom<u32> + Copy + RemEuclidU32,
598        T::Error: fmt::Debug,
599        M: Modulus,
600    {
601        conv_naive::<M>(
602            &a.iter().copied().map(Into::into).collect::<Vec<_>>(),
603            &b.iter().copied().map(Into::into).collect::<Vec<_>>(),
604        )
605        .into_iter()
606        .map(|x| x.val().try_into().unwrap())
607        .collect()
608    }
609
610    #[allow(clippy::many_single_char_names)]
611    fn conv_i64_naive(a: &[i64], b: &[i64]) -> Vec<i64> {
612        let (n, m) = (a.len(), b.len());
613        let mut c = vec![0; n + m - 1];
614        for (i, j) in (0..n).flat_map(|i| (0..m).map(move |j| (i, j))) {
615            c[i + j] += a[i] * b[j];
616        }
617        c
618    }
619
620    fn gen_values<M: Modulus>(rng: &mut ThreadRng, n: usize) -> Vec<StaticModInt<M>> {
621        (0..n).map(|_| rng.gen_range(0, M::VALUE).into()).collect()
622    }
623
624    fn gen_raw_values<T, M>(rng: &mut ThreadRng, n: usize) -> Vec<T>
625    where
626        T: TryFrom<u32>,
627        T::Error: fmt::Debug,
628        M: Modulus,
629    {
630        (0..n)
631            .map(|_| rng.gen_range(0, M::VALUE).try_into().unwrap())
632            .collect()
633    }
634}