ac_library/
convolution.rs

1macro_rules! modulus {
2    ($($name:ident),*) => {
3        $(
4            #[derive(Copy, Clone, Eq, PartialEq)]
5            enum $name {}
6
7            impl Modulus for $name {
8                const VALUE: u32 = $name as _;
9                const HINT_VALUE_IS_PRIME: bool = true;
10
11                fn butterfly_cache() -> &'static ::std::thread::LocalKey<::std::cell::RefCell<::std::option::Option<$crate::modint::ButterflyCache<Self>>>> {
12                    thread_local! {
13                        static BUTTERFLY_CACHE: ::std::cell::RefCell<::std::option::Option<$crate::modint::ButterflyCache<$name>>> = ::std::default::Default::default();
14                    }
15                    &BUTTERFLY_CACHE
16                }
17            }
18        )*
19    };
20}
21
22use crate::{
23    internal_bit, internal_math,
24    modint::{ButterflyCache, Modulus, RemEuclidU32, StaticModInt},
25};
26use std::{
27    cmp,
28    convert::{TryFrom, TryInto as _},
29    fmt,
30};
31
32#[allow(clippy::many_single_char_names)]
33pub fn convolution<M>(a: &[StaticModInt<M>], b: &[StaticModInt<M>]) -> Vec<StaticModInt<M>>
34where
35    M: Modulus,
36{
37    if a.is_empty() || b.is_empty() {
38        return vec![];
39    }
40    let (n, m) = (a.len(), b.len());
41
42    if cmp::min(n, m) <= 60 {
43        let (n, m, a, b) = if n < m { (m, n, b, a) } else { (n, m, a, b) };
44        let mut ans = vec![StaticModInt::new(0); n + m - 1];
45        for i in 0..n {
46            for j in 0..m {
47                ans[i + j] += a[i] * b[j];
48            }
49        }
50        return ans;
51    }
52
53    let (mut a, mut b) = (a.to_owned(), b.to_owned());
54    let z = 1 << internal_bit::ceil_pow2((n + m - 1) as _);
55    a.resize(z, StaticModInt::raw(0));
56    butterfly(&mut a);
57    b.resize(z, StaticModInt::raw(0));
58    butterfly(&mut b);
59    for (a, b) in a.iter_mut().zip(&b) {
60        *a *= b;
61    }
62    butterfly_inv(&mut a);
63    a.resize(n + m - 1, StaticModInt::raw(0));
64    let iz = StaticModInt::new(z).inv();
65    for a in &mut a {
66        *a *= iz;
67    }
68    a
69}
70
71pub fn convolution_raw<T, M>(a: &[T], b: &[T]) -> Vec<T>
72where
73    T: RemEuclidU32 + TryFrom<u32> + Clone,
74    T::Error: fmt::Debug,
75    M: Modulus,
76{
77    let a = a.iter().cloned().map(Into::into).collect::<Vec<_>>();
78    let b = b.iter().cloned().map(Into::into).collect::<Vec<_>>();
79    convolution::<M>(&a, &b)
80        .into_iter()
81        .map(|z| {
82            z.val()
83                .try_into()
84                .expect("the numeric type is smaller than the modulus")
85        })
86        .collect()
87}
88
89#[allow(clippy::many_single_char_names)]
90pub fn convolution_i64(a: &[i64], b: &[i64]) -> Vec<i64> {
91    const M1: u64 = 754_974_721; // 2^24
92    const M2: u64 = 167_772_161; // 2^25
93    const M3: u64 = 469_762_049; // 2^26
94    const M2M3: u64 = M2 * M3;
95    const M1M3: u64 = M1 * M3;
96    const M1M2: u64 = M1 * M2;
97    const M1M2M3: u64 = M1M2.wrapping_mul(M3);
98
99    modulus!(M1, M2, M3);
100
101    if a.is_empty() || b.is_empty() {
102        return vec![];
103    }
104
105    let (_, i1) = internal_math::inv_gcd(M2M3 as _, M1 as _);
106    let (_, i2) = internal_math::inv_gcd(M1M3 as _, M2 as _);
107    let (_, i3) = internal_math::inv_gcd(M1M2 as _, M3 as _);
108
109    let c1 = convolution_raw::<i64, M1>(a, b);
110    let c2 = convolution_raw::<i64, M2>(a, b);
111    let c3 = convolution_raw::<i64, M3>(a, b);
112
113    c1.into_iter()
114        .zip(c2)
115        .zip(c3)
116        .map(|((c1, c2), c3)| {
117            const OFFSET: &[u64] = &[0, 0, M1M2M3, 2 * M1M2M3, 3 * M1M2M3];
118
119            let mut x = [(c1, i1, M1, M2M3), (c2, i2, M2, M1M3), (c3, i3, M3, M1M2)]
120                .iter()
121                .map(|&(c, i, m1, m2)| c.wrapping_mul(i).rem_euclid(m1 as _).wrapping_mul(m2 as _))
122                .fold(0, i64::wrapping_add);
123
124            // B = 2^63, -B <= x, r(real value) < B
125            // (x, x - M, x - 2M, or x - 3M) = r (mod 2B)
126            // r = c1[i] (mod MOD1)
127            // focus on MOD1
128            // r = x, x - M', x - 2M', x - 3M' (M' = M % 2^64) (mod 2B)
129            // r = x,
130            //     x - M' + (0 or 2B),
131            //     x - 2M' + (0, 2B or 4B),
132            //     x - 3M' + (0, 2B, 4B or 6B) (without mod!)
133            // (r - x) = 0, (0)
134            //           - M' + (0 or 2B), (1)
135            //           -2M' + (0 or 2B or 4B), (2)
136            //           -3M' + (0 or 2B or 4B or 6B) (3) (mod MOD1)
137            // we checked that
138            //   ((1) mod MOD1) mod 5 = 2
139            //   ((2) mod MOD1) mod 5 = 3
140            //   ((3) mod MOD1) mod 5 = 4
141            let mut diff = c1 - internal_math::safe_mod(x, M1 as _);
142            if diff < 0 {
143                diff += M1 as i64;
144            }
145            x = x.wrapping_sub(OFFSET[diff.rem_euclid(5) as usize] as _);
146            x
147        })
148        .collect()
149}
150
151#[allow(clippy::many_single_char_names)]
152fn butterfly<M: Modulus>(a: &mut [StaticModInt<M>]) {
153    let n = a.len();
154    let h = internal_bit::ceil_pow2(n as u32);
155
156    M::butterfly_cache().with(|cache| {
157        let mut cache = cache.borrow_mut();
158        let ButterflyCache { sum_e, .. } = cache.get_or_insert_with(prepare);
159        for ph in 1..=h {
160            let w = 1 << (ph - 1);
161            let p = 1 << (h - ph);
162            let mut now = StaticModInt::<M>::new(1);
163            for s in 0..w {
164                let offset = s << (h - ph + 1);
165                for i in 0..p {
166                    let l = a[i + offset];
167                    let r = a[i + offset + p] * now;
168                    a[i + offset] = l + r;
169                    a[i + offset + p] = l - r;
170                }
171                now *= sum_e[(!s).trailing_zeros() as usize];
172            }
173        }
174    });
175}
176
177#[allow(clippy::many_single_char_names)]
178fn butterfly_inv<M: Modulus>(a: &mut [StaticModInt<M>]) {
179    let n = a.len();
180    let h = internal_bit::ceil_pow2(n as u32);
181
182    M::butterfly_cache().with(|cache| {
183        let mut cache = cache.borrow_mut();
184        let ButterflyCache { sum_ie, .. } = cache.get_or_insert_with(prepare);
185        for ph in (1..=h).rev() {
186            let w = 1 << (ph - 1);
187            let p = 1 << (h - ph);
188            let mut inow = StaticModInt::<M>::new(1);
189            for s in 0..w {
190                let offset = s << (h - ph + 1);
191                for i in 0..p {
192                    let l = a[i + offset];
193                    let r = a[i + offset + p];
194                    a[i + offset] = l + r;
195                    a[i + offset + p] = StaticModInt::new(M::VALUE + l.val() - r.val()) * inow;
196                }
197                inow *= sum_ie[(!s).trailing_zeros() as usize];
198            }
199        }
200    });
201}
202
203fn prepare<M: Modulus>() -> ButterflyCache<M> {
204    let g = StaticModInt::<M>::raw(internal_math::primitive_root(M::VALUE as i32) as u32);
205    let mut es = [StaticModInt::<M>::raw(0); 30]; // es[i]^(2^(2+i)) == 1
206    let mut ies = [StaticModInt::<M>::raw(0); 30];
207    let cnt2 = (M::VALUE - 1).trailing_zeros() as usize;
208    let mut e = g.pow(((M::VALUE - 1) >> cnt2).into());
209    let mut ie = e.inv();
210    for i in (2..=cnt2).rev() {
211        es[i - 2] = e;
212        ies[i - 2] = ie;
213        e *= e;
214        ie *= ie;
215    }
216    let sum_e = es
217        .iter()
218        .scan(StaticModInt::new(1), |acc, e| {
219            *acc *= e;
220            Some(*acc)
221        })
222        .collect();
223    let sum_ie = ies
224        .iter()
225        .scan(StaticModInt::new(1), |acc, ie| {
226            *acc *= ie;
227            Some(*acc)
228        })
229        .collect();
230    ButterflyCache { sum_e, sum_ie }
231}
232
233#[cfg(test)]
234mod tests {
235    use crate::{
236        modint::{Mod998244353, Modulus, StaticModInt},
237        RemEuclidU32,
238    };
239    use rand::{rngs::ThreadRng, Rng as _};
240    use std::{
241        convert::{TryFrom, TryInto as _},
242        fmt,
243    };
244
245    //https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L51-L71
246    #[test]
247    fn empty() {
248        assert!(super::convolution_raw::<i32, Mod998244353>(&[], &[]).is_empty());
249        assert!(super::convolution_raw::<i32, Mod998244353>(&[], &[1, 2]).is_empty());
250        assert!(super::convolution_raw::<i32, Mod998244353>(&[1, 2], &[]).is_empty());
251        assert!(super::convolution_raw::<i32, Mod998244353>(&[1], &[]).is_empty());
252        assert!(super::convolution_raw::<i64, Mod998244353>(&[], &[]).is_empty());
253        assert!(super::convolution_raw::<i64, Mod998244353>(&[], &[1, 2]).is_empty());
254        assert!(super::convolution::<Mod998244353>(&[], &[]).is_empty());
255        assert!(super::convolution::<Mod998244353>(&[], &[1.into(), 2.into()]).is_empty());
256    }
257
258    // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L73-L85
259    #[test]
260    fn mid() {
261        const N: usize = 1234;
262        const M: usize = 2345;
263
264        let mut rng = rand::thread_rng();
265        let mut gen_values = |n| gen_values::<Mod998244353>(&mut rng, n);
266        let (a, b) = (gen_values(N), gen_values(M));
267        assert_eq!(conv_naive(&a, &b), super::convolution(&a, &b));
268    }
269
270    // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L87-L118
271    #[test]
272    fn simple_s_mod() {
273        const M1: u32 = 998_244_353;
274        const M2: u32 = 924_844_033;
275
276        modulus!(M1, M2);
277
278        fn test<M: Modulus>(rng: &mut ThreadRng) {
279            let mut gen_values = |n| gen_values::<Mod998244353>(rng, n);
280            for (n, m) in (1..20).flat_map(|i| (1..20).map(move |j| (i, j))) {
281                let (a, b) = (gen_values(n), gen_values(m));
282                assert_eq!(conv_naive(&a, &b), super::convolution(&a, &b));
283            }
284        }
285
286        let mut rng = rand::thread_rng();
287        test::<M1>(&mut rng);
288        test::<M2>(&mut rng);
289    }
290
291    // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L120-L150
292    #[test]
293    fn simple_int() {
294        simple_raw::<i32>();
295    }
296
297    // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L152-L182
298    #[test]
299    fn simple_uint() {
300        simple_raw::<u32>();
301    }
302
303    // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L184-L214
304    #[test]
305    fn simple_ll() {
306        simple_raw::<i64>();
307    }
308
309    // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L216-L246
310    #[test]
311    fn simple_ull() {
312        simple_raw::<u64>();
313    }
314
315    // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L249-L279
316    #[test]
317    fn simple_int128() {
318        simple_raw::<i128>();
319    }
320
321    // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L281-L311
322    #[test]
323    fn simple_uint128() {
324        simple_raw::<u128>();
325    }
326
327    fn simple_raw<T>()
328    where
329        T: TryFrom<u32> + Copy + RemEuclidU32,
330        T::Error: fmt::Debug,
331    {
332        const M1: u32 = 998_244_353;
333        const M2: u32 = 924_844_033;
334
335        modulus!(M1, M2);
336
337        fn test<T, M>(rng: &mut ThreadRng)
338        where
339            T: TryFrom<u32> + Copy + RemEuclidU32,
340            T::Error: fmt::Debug,
341            M: Modulus,
342        {
343            let mut gen_raw_values = |n| gen_raw_values::<u32, Mod998244353>(rng, n);
344            for (n, m) in (1..20).flat_map(|i| (1..20).map(move |j| (i, j))) {
345                let (a, b) = (gen_raw_values(n), gen_raw_values(m));
346                assert_eq!(
347                    conv_raw_naive::<_, M>(&a, &b),
348                    super::convolution_raw::<_, M>(&a, &b),
349                );
350            }
351        }
352
353        let mut rng = rand::thread_rng();
354        test::<T, M1>(&mut rng);
355        test::<T, M2>(&mut rng);
356    }
357
358    // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L315-L329
359    #[test]
360    fn conv_ll() {
361        let mut rng = rand::thread_rng();
362        for (n, m) in (1..20).flat_map(|i| (1..20).map(move |j| (i, j))) {
363            let mut gen =
364                |n: usize| -> Vec<_> { (0..n).map(|_| rng.gen_range(-500_000, 500_000)).collect() };
365            let (a, b) = (gen(n), gen(m));
366            assert_eq!(conv_i64_naive(&a, &b), super::convolution_i64(&a, &b));
367        }
368    }
369
370    // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L331-L356
371    #[test]
372    fn conv_ll_bound() {
373        const M1: u64 = 754_974_721; // 2^24
374        const M2: u64 = 167_772_161; // 2^25
375        const M3: u64 = 469_762_049; // 2^26
376        const M2M3: u64 = M2 * M3;
377        const M1M3: u64 = M1 * M3;
378        const M1M2: u64 = M1 * M2;
379
380        modulus!(M1, M2, M3);
381
382        for i in -1000..=1000 {
383            let a = vec![0u64.wrapping_sub(M1M2 + M1M3 + M2M3) as i64 + i];
384            let b = vec![1];
385            assert_eq!(a, super::convolution_i64(&a, &b));
386        }
387
388        for i in 0..1000 {
389            let a = vec![i64::min_value() + i];
390            let b = vec![1];
391            assert_eq!(a, super::convolution_i64(&a, &b));
392        }
393
394        for i in 0..1000 {
395            let a = vec![i64::max_value() - i];
396            let b = vec![1];
397            assert_eq!(a, super::convolution_i64(&a, &b));
398        }
399    }
400
401    // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L358-L371
402    #[test]
403    fn conv_641() {
404        const M: u32 = 641;
405        modulus!(M);
406
407        let mut rng = rand::thread_rng();
408        let mut gen_values = |n| gen_values::<M>(&mut rng, n);
409        let (a, b) = (gen_values(64), gen_values(65));
410        assert_eq!(conv_naive(&a, &b), super::convolution(&a, &b));
411    }
412
413    // https://github.com/atcoder/ac-library/blob/8250de484ae0ab597391db58040a602e0dc1a419/test/unittest/convolution_test.cpp#L373-L386
414    #[test]
415    fn conv_18433() {
416        const M: u32 = 18433;
417        modulus!(M);
418
419        let mut rng = rand::thread_rng();
420        let mut gen_values = |n| gen_values::<M>(&mut rng, n);
421        let (a, b) = (gen_values(1024), gen_values(1025));
422        assert_eq!(conv_naive(&a, &b), super::convolution(&a, &b));
423    }
424
425    #[allow(clippy::many_single_char_names)]
426    fn conv_naive<M: Modulus>(
427        a: &[StaticModInt<M>],
428        b: &[StaticModInt<M>],
429    ) -> Vec<StaticModInt<M>> {
430        let (n, m) = (a.len(), b.len());
431        let mut c = vec![StaticModInt::raw(0); n + m - 1];
432        for (i, j) in (0..n).flat_map(|i| (0..m).map(move |j| (i, j))) {
433            c[i + j] += a[i] * b[j];
434        }
435        c
436    }
437
438    fn conv_raw_naive<T, M>(a: &[T], b: &[T]) -> Vec<T>
439    where
440        T: TryFrom<u32> + Copy + RemEuclidU32,
441        T::Error: fmt::Debug,
442        M: Modulus,
443    {
444        conv_naive::<M>(
445            &a.iter().copied().map(Into::into).collect::<Vec<_>>(),
446            &b.iter().copied().map(Into::into).collect::<Vec<_>>(),
447        )
448        .into_iter()
449        .map(|x| x.val().try_into().unwrap())
450        .collect()
451    }
452
453    #[allow(clippy::many_single_char_names)]
454    fn conv_i64_naive(a: &[i64], b: &[i64]) -> Vec<i64> {
455        let (n, m) = (a.len(), b.len());
456        let mut c = vec![0; n + m - 1];
457        for (i, j) in (0..n).flat_map(|i| (0..m).map(move |j| (i, j))) {
458            c[i + j] += a[i] * b[j];
459        }
460        c
461    }
462
463    fn gen_values<M: Modulus>(rng: &mut ThreadRng, n: usize) -> Vec<StaticModInt<M>> {
464        (0..n).map(|_| rng.gen_range(0, M::VALUE).into()).collect()
465    }
466
467    fn gen_raw_values<T, M>(rng: &mut ThreadRng, n: usize) -> Vec<T>
468    where
469        T: TryFrom<u32>,
470        T::Error: fmt::Debug,
471        M: Modulus,
472    {
473        (0..n)
474            .map(|_| rng.gen_range(0, M::VALUE).try_into().unwrap())
475            .collect()
476    }
477}