ac_library/
math.rs

1//! Number-theoretic algorithms.
2
3use crate::internal_math;
4
5use std::mem::swap;
6
7/// Returns $x^n \bmod m$.
8///
9/// # Constraints
10///
11/// - $0 \leq n$
12/// - $1 \leq m$
13///
14/// # Panics
15///
16/// Panics if the above constraints are not satisfied.
17///
18/// # Complexity
19///
20/// - $O(\log n)$
21///
22/// # Example
23///
24/// ```
25/// use ac_library::math;
26///
27/// assert_eq!(math::pow_mod(2, 10000, 7), 2);
28/// ```
29#[allow(clippy::many_single_char_names)]
30pub fn pow_mod(x: i64, mut n: i64, m: u32) -> u32 {
31    assert!(0 <= n && 1 <= m && m <= 2u32.pow(31));
32    if m == 1 {
33        return 0;
34    }
35    let bt = internal_math::Barrett::new(m);
36    let mut r = 1;
37    let mut y = internal_math::safe_mod(x, m as i64) as u32;
38    while n != 0 {
39        if n & 1 != 0 {
40            r = bt.mul(r, y);
41        }
42        y = bt.mul(y, y);
43        n >>= 1;
44    }
45    r
46}
47
48/// Returns an integer $y \in [0, m)$ such that $xy \equiv 1 \pmod m$.
49///
50/// # Constraints
51///
52/// - $\gcd(x, m) = 1$
53/// - $1 \leq m$
54///
55/// # Panics
56///
57/// Panics if the above constraints are not satisfied.
58///
59/// # Complexity
60///
61/// - $O(\log m)$
62///
63/// # Example
64///
65/// ```
66/// use ac_library::math;
67///
68/// assert_eq!(math::inv_mod(3, 7), 5);
69/// ```
70pub fn inv_mod(x: i64, m: i64) -> i64 {
71    assert!(1 <= m);
72    let z = internal_math::inv_gcd(x, m);
73    assert!(z.0 == 1);
74    z.1
75}
76
77/// Performs CRT (Chinese Remainder Theorem).
78///
79/// Given two sequences $r, m$ of length $n$, this function solves the modular equation system
80///
81/// \\[
82///   x \equiv r_i \pmod{m_i}, \forall i \in \\{0, 1, \cdots, n - 1\\}
83/// \\]
84///
85/// If there is no solution, it returns $(0, 0)$.
86///
87/// Otherwise, all of the solutions can be written as the form $x \equiv y \pmod z$, using integer $y, z\\ (0 \leq y < z = \text{lcm}(m))$.
88/// It returns this $(y, z)$.
89///
90/// If $n = 0$, it returns $(0, 1)$.
91///
92/// # Constraints
93///
94/// - $|r| = |m|$
95/// - $1 \leq m_{\forall i}$
96/// - $\text{lcm}(m)$ is in `i64`
97///
98/// # Panics
99///
100/// Panics if the above constraints are not satisfied.
101///
102/// # Complexity
103///
104/// - $O(n \log \text{lcm}(m))$
105///
106/// # Example
107///
108/// ```
109/// use ac_library::math;
110///
111/// let r = [2, 3, 2];
112/// let m = [3, 5, 7];
113/// assert_eq!(math::crt(&r, &m), (23, 105));
114/// ```
115pub fn crt(r: &[i64], m: &[i64]) -> (i64, i64) {
116    assert_eq!(r.len(), m.len());
117    // Contracts: 0 <= r0 < m0
118    let (mut r0, mut m0) = (0, 1);
119    for (&(mut ri), &(mut mi)) in r.iter().zip(m.iter()) {
120        assert!(1 <= mi);
121        ri = internal_math::safe_mod(ri, mi);
122        if m0 < mi {
123            swap(&mut r0, &mut ri);
124            swap(&mut m0, &mut mi);
125        }
126        if m0 % mi == 0 {
127            if r0 % mi != ri {
128                return (0, 0);
129            }
130            continue;
131        }
132        // assume: m0 > mi, lcm(m0, mi) >= 2 * max(m0, mi)
133
134        // (r0, m0), (ri, mi) -> (r2, m2 = lcm(m0, m1));
135        // r2 % m0 = r0
136        // r2 % mi = ri
137        // -> (r0 + x*m0) % mi = ri
138        // -> x*u0*g = ri-r0 (mod u1*g) (u0*g = m0, u1*g = mi)
139        // -> x = (ri - r0) / g * inv(u0) (mod u1)
140
141        // im = inv(u0) (mod u1) (0 <= im < u1)
142        let (g, im) = internal_math::inv_gcd(m0, mi);
143        let u1 = mi / g;
144        // |ri - r0| < (m0 + mi) <= lcm(m0, mi)
145        if (ri - r0) % g != 0 {
146            return (0, 0);
147        }
148        // u1 * u1 <= mi * mi / g / g <= m0 * mi / g = lcm(m0, mi)
149        let x = (ri - r0) / g % u1 * im % u1;
150
151        // |r0| + |m0 * x|
152        // < m0 + m0 * (u1 - 1)
153        // = m0 + m0 * mi / g - m0
154        // = lcm(m0, mi)
155        r0 += x * m0;
156        m0 *= u1; // -> lcm(m0, mi)
157        if r0 < 0 {
158            r0 += m0
159        };
160    }
161
162    (r0, m0)
163}
164
165/// Returns
166///
167/// $$\sum_{i = 0}^{n - 1} \left\lfloor \frac{a \times i + b}{m} \right\rfloor.$$
168///
169/// It returns the answer in $\bmod 2^{\mathrm{64}}$, if overflowed.
170///
171/// # Constraints
172///
173/// - $0 \leq n \lt 2^{32}$
174/// - $1 \leq m \lt 2^{32}$
175///
176/// # Panics
177///
178/// Panics if the above constraints are not satisfied and overflow or division by zero occurred.
179///
180/// # Complexity
181///
182/// - $O(\log{(m+a)})$
183///
184/// # Example
185///
186/// ```
187/// use ac_library::math;
188///
189/// assert_eq!(math::floor_sum(6, 5, 4, 3), 13);
190/// ```
191#[allow(clippy::many_single_char_names)]
192pub fn floor_sum(n: i64, m: i64, a: i64, b: i64) -> i64 {
193    use std::num::Wrapping as W;
194    assert!((0..1i64 << 32).contains(&n));
195    assert!((1..1i64 << 32).contains(&m));
196    let mut ans = W(0_u64);
197    let (wn, wm, mut wa, mut wb) = (W(n as u64), W(m as u64), W(a as u64), W(b as u64));
198    if a < 0 {
199        let a2 = W(internal_math::safe_mod(a, m) as u64);
200        ans -= wn * (wn - W(1)) / W(2) * ((a2 - wa) / wm);
201        wa = a2;
202    }
203    if b < 0 {
204        let b2 = W(internal_math::safe_mod(b, m) as u64);
205        ans -= wn * ((b2 - wb) / wm);
206        wb = b2;
207    }
208    let ret = ans + internal_math::floor_sum_unsigned(wn, wm, wa, wb);
209    ret.0 as i64
210}
211
212#[cfg(test)]
213mod tests {
214    #![allow(clippy::unreadable_literal)]
215    #![allow(clippy::cognitive_complexity)]
216    use super::*;
217    #[test]
218    fn test_pow_mod() {
219        assert_eq!(pow_mod(0, 0, 1), 0);
220        assert_eq!(pow_mod(0, 0, 3), 1);
221        assert_eq!(pow_mod(0, 0, 723), 1);
222        assert_eq!(pow_mod(0, 0, 998244353), 1);
223        assert_eq!(pow_mod(0, 0, 2u32.pow(31)), 1);
224
225        assert_eq!(pow_mod(0, 1, 1), 0);
226        assert_eq!(pow_mod(0, 1, 3), 0);
227        assert_eq!(pow_mod(0, 1, 723), 0);
228        assert_eq!(pow_mod(0, 1, 998244353), 0);
229        assert_eq!(pow_mod(0, 1, 2u32.pow(31)), 0);
230
231        assert_eq!(pow_mod(0, i64::MAX, 1), 0);
232        assert_eq!(pow_mod(0, i64::MAX, 3), 0);
233        assert_eq!(pow_mod(0, i64::MAX, 723), 0);
234        assert_eq!(pow_mod(0, i64::MAX, 998244353), 0);
235        assert_eq!(pow_mod(0, i64::MAX, 2u32.pow(31)), 0);
236
237        assert_eq!(pow_mod(1, 0, 1), 0);
238        assert_eq!(pow_mod(1, 0, 3), 1);
239        assert_eq!(pow_mod(1, 0, 723), 1);
240        assert_eq!(pow_mod(1, 0, 998244353), 1);
241        assert_eq!(pow_mod(1, 0, 2u32.pow(31)), 1);
242
243        assert_eq!(pow_mod(1, 1, 1), 0);
244        assert_eq!(pow_mod(1, 1, 3), 1);
245        assert_eq!(pow_mod(1, 1, 723), 1);
246        assert_eq!(pow_mod(1, 1, 998244353), 1);
247        assert_eq!(pow_mod(1, 1, 2u32.pow(31)), 1);
248
249        assert_eq!(pow_mod(1, i64::MAX, 1), 0);
250        assert_eq!(pow_mod(1, i64::MAX, 3), 1);
251        assert_eq!(pow_mod(1, i64::MAX, 723), 1);
252        assert_eq!(pow_mod(1, i64::MAX, 998244353), 1);
253        assert_eq!(pow_mod(1, i64::MAX, 2u32.pow(31)), 1);
254
255        assert_eq!(pow_mod(i64::MAX, 0, 1), 0);
256        assert_eq!(pow_mod(i64::MAX, 0, 3), 1);
257        assert_eq!(pow_mod(i64::MAX, 0, 723), 1);
258        assert_eq!(pow_mod(i64::MAX, 0, 998244353), 1);
259        assert_eq!(pow_mod(i64::MAX, 0, 2u32.pow(31)), 1);
260
261        assert_eq!(pow_mod(i64::MAX, i64::MAX, 1), 0);
262        assert_eq!(pow_mod(i64::MAX, i64::MAX, 3), 1);
263        assert_eq!(pow_mod(i64::MAX, i64::MAX, 723), 640);
264        assert_eq!(pow_mod(i64::MAX, i64::MAX, 998244353), 683296792);
265        assert_eq!(pow_mod(i64::MAX, i64::MAX, 2u32.pow(31)), 2147483647);
266
267        assert_eq!(pow_mod(2, 3, 1_000_000_007), 8);
268        assert_eq!(pow_mod(5, 7, 1_000_000_007), 78125);
269        assert_eq!(pow_mod(123, 456, 1_000_000_007), 565291922);
270    }
271
272    #[test]
273    #[should_panic]
274    fn test_inv_mod_1() {
275        inv_mod(271828, 0);
276    }
277
278    #[test]
279    #[should_panic]
280    fn test_inv_mod_2() {
281        inv_mod(3141592, 1000000008);
282    }
283
284    #[test]
285    fn test_crt() {
286        let a = [44, 23, 13];
287        let b = [13, 50, 22];
288        assert_eq!(crt(&a, &b), (1773, 7150));
289        let a = [12345, 67890, 99999];
290        let b = [13, 444321, 95318];
291        assert_eq!(crt(&a, &b), (103333581255, 550573258014));
292        let a = [0, 3, 4];
293        let b = [1, 9, 5];
294        assert_eq!(crt(&a, &b), (39, 45));
295    }
296
297    #[test]
298    fn test_floor_sum() {
299        assert_eq!(floor_sum(0, 1, 0, 0), 0);
300        assert_eq!(floor_sum(1_000_000_000, 1, 1, 1), 500_000_000_500_000_000);
301        assert_eq!(
302            floor_sum(1_000_000_000, 1_000_000_000, 999_999_999, 999_999_999),
303            499_999_999_500_000_000
304        );
305        assert_eq!(floor_sum(332955, 5590132, 2231, 999423), 22014575);
306        for n in 0..20 {
307            for m in 1..20 {
308                for a in -20..20 {
309                    for b in -20..20 {
310                        assert_eq!(floor_sum(n, m, a, b), floor_sum_naive(n, m, a, b));
311                    }
312                }
313            }
314        }
315    }
316
317    #[allow(clippy::many_single_char_names)]
318    fn floor_sum_naive(n: i64, m: i64, a: i64, b: i64) -> i64 {
319        let mut ans = 0;
320        for i in 0..n {
321            let z = a * i + b;
322            ans += (z - internal_math::safe_mod(z, m)) / m;
323        }
324        ans
325    }
326}