ac_library/
math.rs

1//! Number-theoretic algorithms.
2
3use crate::internal_math;
4
5use std::mem::swap;
6
7/// Returns $x^n \bmod m$.
8///
9/// # Constraints
10///
11/// - $0 \leq n$
12/// - $1 \leq m$
13///
14/// # Panics
15///
16/// Panics if the above constraints are not satisfied.
17///
18/// # Complexity
19///
20/// - $O(\log n)$
21///
22/// # Example
23///
24/// ```
25/// use ac_library::math;
26///
27/// assert_eq!(math::pow_mod(2, 10000, 7), 2);
28/// ```
29#[allow(clippy::many_single_char_names)]
30pub fn pow_mod(x: i64, mut n: i64, m: u32) -> u32 {
31    assert!(0 <= n && 1 <= m && m <= 2u32.pow(31));
32    if m == 1 {
33        return 0;
34    }
35    let bt = internal_math::Barrett::new(m);
36    let mut r = 1;
37    let mut y = internal_math::safe_mod(x, m as i64) as u32;
38    while n != 0 {
39        if n & 1 != 0 {
40            r = bt.mul(r, y);
41        }
42        y = bt.mul(y, y);
43        n >>= 1;
44    }
45    r
46}
47
48/// Returns an integer $y \in [0, m)$ such that $xy \equiv 1 \pmod m$.
49///
50/// # Constraints
51///
52/// - $\gcd(x, m) = 1$
53/// - $1 \leq m$
54///
55/// # Panics
56///
57/// Panics if the above constraints are not satisfied.
58///
59/// # Complexity
60///
61/// - $O(\log m)$
62///
63/// # Example
64///
65/// ```
66/// use ac_library::math;
67///
68/// assert_eq!(math::inv_mod(3, 7), 5);
69/// ```
70pub fn inv_mod(x: i64, m: i64) -> i64 {
71    assert!(1 <= m);
72    let z = internal_math::inv_gcd(x, m);
73    assert!(z.0 == 1);
74    z.1
75}
76
77/// Performs CRT (Chinese Remainder Theorem).
78///
79/// Given two sequences $r, m$ of length $n$, this function solves the modular equation system
80///
81/// \\[
82///   x \equiv r_i \pmod{m_i}, \forall i \in \\{0, 1, \cdots, n - 1\\}
83/// \\]
84///
85/// If there is no solution, it returns $(0, 0)$.
86///
87/// Otherwise, all of the solutions can be written as the form $x \equiv y \pmod z$, using integer $y, z\\ (0 \leq y < z = \text{lcm}(m))$.
88/// It returns this $(y, z)$.
89///
90/// If $n = 0$, it returns $(0, 1)$.
91///
92/// # Constraints
93///
94/// - $|r| = |m|$
95/// - $1 \leq m_{\forall i}$
96/// - $\text{lcm}(m)$ is in `i64`
97///
98/// # Panics
99///
100/// Panics if the above constraints are not satisfied.
101///
102/// # Complexity
103///
104/// - $O(n \log \text{lcm}(m))$
105///
106/// # Example
107///
108/// ```
109/// use ac_library::math;
110///
111/// let r = [2, 3, 2];
112/// let m = [3, 5, 7];
113/// assert_eq!(math::crt(&r, &m), (23, 105));
114/// ```
115pub fn crt(r: &[i64], m: &[i64]) -> (i64, i64) {
116    assert_eq!(r.len(), m.len());
117    // Contracts: 0 <= r0 < m0
118    let (mut r0, mut m0) = (0, 1);
119    for (&(mut ri), &(mut mi)) in r.iter().zip(m.iter()) {
120        assert!(1 <= mi);
121        ri = internal_math::safe_mod(ri, mi);
122        if m0 < mi {
123            swap(&mut r0, &mut ri);
124            swap(&mut m0, &mut mi);
125        }
126        if m0 % mi == 0 {
127            if r0 % mi != ri {
128                return (0, 0);
129            }
130            continue;
131        }
132        // assume: m0 > mi, lcm(m0, mi) >= 2 * max(m0, mi)
133
134        // (r0, m0), (ri, mi) -> (r2, m2 = lcm(m0, m1));
135        // r2 % m0 = r0
136        // r2 % mi = ri
137        // -> (r0 + x*m0) % mi = ri
138        // -> x*u0*g = ri-r0 (mod u1*g) (u0*g = m0, u1*g = mi)
139        // -> x = (ri - r0) / g * inv(u0) (mod u1)
140
141        // im = inv(u0) (mod u1) (0 <= im < u1)
142        let (g, im) = internal_math::inv_gcd(m0, mi);
143        let u1 = mi / g;
144        // |ri - r0| < (m0 + mi) <= lcm(m0, mi)
145        if (ri - r0) % g != 0 {
146            return (0, 0);
147        }
148        // u1 * u1 <= mi * mi / g / g <= m0 * mi / g = lcm(m0, mi)
149        let x = (ri - r0) / g % u1 * im % u1;
150
151        // |r0| + |m0 * x|
152        // < m0 + m0 * (u1 - 1)
153        // = m0 + m0 * mi / g - m0
154        // = lcm(m0, mi)
155        r0 += x * m0;
156        m0 *= u1; // -> lcm(m0, mi)
157        if r0 < 0 {
158            r0 += m0
159        };
160    }
161
162    (r0, m0)
163}
164
165/// Returns $\sum_{i = 0}^{n - 1} \lfloor \frac{a \times i + b}{m} \rfloor$.
166///
167/// # Constraints
168///
169/// - $0 \leq n \leq 10^9$
170/// - $1 \leq m \leq 10^9$
171/// - $0 \leq a, b \leq m$
172///
173/// # Panics
174///
175/// Panics if the above constraints are not satisfied and overflow or division by zero occurred.
176///
177/// # Complexity
178///
179/// - $O(\log(n + m + a + b))$
180///
181/// # Example
182///
183/// ```
184/// use ac_library::math;
185///
186/// assert_eq!(math::floor_sum(6, 5, 4, 3), 13);
187/// ```
188pub fn floor_sum(n: i64, m: i64, mut a: i64, mut b: i64) -> i64 {
189    let mut ans = 0;
190    if a >= m {
191        ans += (n - 1) * n * (a / m) / 2;
192        a %= m;
193    }
194    if b >= m {
195        ans += n * (b / m);
196        b %= m;
197    }
198
199    let y_max = (a * n + b) / m;
200    let x_max = y_max * m - b;
201    if y_max == 0 {
202        return ans;
203    }
204    ans += (n - (x_max + a - 1) / a) * y_max;
205    ans += floor_sum(y_max, a, m, (a - x_max % a) % a);
206    ans
207}
208
209#[cfg(test)]
210mod tests {
211    #![allow(clippy::unreadable_literal)]
212    #![allow(clippy::cognitive_complexity)]
213    use super::*;
214    #[test]
215    fn test_pow_mod() {
216        assert_eq!(pow_mod(0, 0, 1), 0);
217        assert_eq!(pow_mod(0, 0, 3), 1);
218        assert_eq!(pow_mod(0, 0, 723), 1);
219        assert_eq!(pow_mod(0, 0, 998244353), 1);
220        assert_eq!(pow_mod(0, 0, 2u32.pow(31)), 1);
221
222        assert_eq!(pow_mod(0, 1, 1), 0);
223        assert_eq!(pow_mod(0, 1, 3), 0);
224        assert_eq!(pow_mod(0, 1, 723), 0);
225        assert_eq!(pow_mod(0, 1, 998244353), 0);
226        assert_eq!(pow_mod(0, 1, 2u32.pow(31)), 0);
227
228        assert_eq!(pow_mod(0, i64::max_value(), 1), 0);
229        assert_eq!(pow_mod(0, i64::max_value(), 3), 0);
230        assert_eq!(pow_mod(0, i64::max_value(), 723), 0);
231        assert_eq!(pow_mod(0, i64::max_value(), 998244353), 0);
232        assert_eq!(pow_mod(0, i64::max_value(), 2u32.pow(31)), 0);
233
234        assert_eq!(pow_mod(1, 0, 1), 0);
235        assert_eq!(pow_mod(1, 0, 3), 1);
236        assert_eq!(pow_mod(1, 0, 723), 1);
237        assert_eq!(pow_mod(1, 0, 998244353), 1);
238        assert_eq!(pow_mod(1, 0, 2u32.pow(31)), 1);
239
240        assert_eq!(pow_mod(1, 1, 1), 0);
241        assert_eq!(pow_mod(1, 1, 3), 1);
242        assert_eq!(pow_mod(1, 1, 723), 1);
243        assert_eq!(pow_mod(1, 1, 998244353), 1);
244        assert_eq!(pow_mod(1, 1, 2u32.pow(31)), 1);
245
246        assert_eq!(pow_mod(1, i64::max_value(), 1), 0);
247        assert_eq!(pow_mod(1, i64::max_value(), 3), 1);
248        assert_eq!(pow_mod(1, i64::max_value(), 723), 1);
249        assert_eq!(pow_mod(1, i64::max_value(), 998244353), 1);
250        assert_eq!(pow_mod(1, i64::max_value(), 2u32.pow(31)), 1);
251
252        assert_eq!(pow_mod(i64::max_value(), 0, 1), 0);
253        assert_eq!(pow_mod(i64::max_value(), 0, 3), 1);
254        assert_eq!(pow_mod(i64::max_value(), 0, 723), 1);
255        assert_eq!(pow_mod(i64::max_value(), 0, 998244353), 1);
256        assert_eq!(pow_mod(i64::max_value(), 0, 2u32.pow(31)), 1);
257
258        assert_eq!(pow_mod(i64::max_value(), i64::max_value(), 1), 0);
259        assert_eq!(pow_mod(i64::max_value(), i64::max_value(), 3), 1);
260        assert_eq!(pow_mod(i64::max_value(), i64::max_value(), 723), 640);
261        assert_eq!(
262            pow_mod(i64::max_value(), i64::max_value(), 998244353),
263            683296792
264        );
265        assert_eq!(
266            pow_mod(i64::max_value(), i64::max_value(), 2u32.pow(31)),
267            2147483647
268        );
269
270        assert_eq!(pow_mod(2, 3, 1_000_000_007), 8);
271        assert_eq!(pow_mod(5, 7, 1_000_000_007), 78125);
272        assert_eq!(pow_mod(123, 456, 1_000_000_007), 565291922);
273    }
274
275    #[test]
276    #[should_panic]
277    fn test_inv_mod_1() {
278        inv_mod(271828, 0);
279    }
280
281    #[test]
282    #[should_panic]
283    fn test_inv_mod_2() {
284        inv_mod(3141592, 1000000008);
285    }
286
287    #[test]
288    fn test_crt() {
289        let a = [44, 23, 13];
290        let b = [13, 50, 22];
291        assert_eq!(crt(&a, &b), (1773, 7150));
292        let a = [12345, 67890, 99999];
293        let b = [13, 444321, 95318];
294        assert_eq!(crt(&a, &b), (103333581255, 550573258014));
295        let a = [0, 3, 4];
296        let b = [1, 9, 5];
297        assert_eq!(crt(&a, &b), (39, 45));
298    }
299
300    #[test]
301    fn test_floor_sum() {
302        assert_eq!(floor_sum(0, 1, 0, 0), 0);
303        assert_eq!(floor_sum(1_000_000_000, 1, 1, 1), 500_000_000_500_000_000);
304        assert_eq!(
305            floor_sum(1_000_000_000, 1_000_000_000, 999_999_999, 999_999_999),
306            499_999_999_500_000_000
307        );
308        assert_eq!(floor_sum(332955, 5590132, 2231, 999423), 22014575);
309    }
310}