competitive_programming_rs/math/
mod_int.rs

1pub mod mod_int {
2    type ModInternalNum = i64;
3    thread_local!(
4        static MOD: std::cell::RefCell<ModInternalNum> = std::cell::RefCell::new(0);
5    );
6
7    pub fn set_mod_int<T: ToInternalNum>(v: T) {
8        MOD.with(|x| x.replace(v.to_internal_num()));
9    }
10    fn modulo() -> ModInternalNum {
11        MOD.with(|x| *x.borrow())
12    }
13
14    #[derive(Debug)]
15    pub struct ModInt(ModInternalNum);
16    impl Clone for ModInt {
17        fn clone(&self) -> Self {
18            Self(self.0)
19        }
20    }
21    impl Copy for ModInt {}
22
23    impl ModInt {
24        fn internal_new(mut v: ModInternalNum) -> Self {
25            let m = modulo();
26            if v >= m {
27                v %= m;
28            }
29            Self(v)
30        }
31
32        pub fn internal_pow(&self, mut e: ModInternalNum) -> Self {
33            let mut result = 1;
34            let mut cur = self.0;
35            let modulo = modulo();
36            while e > 0 {
37                if e & 1 == 1 {
38                    result *= cur;
39                    result %= modulo;
40                }
41                e >>= 1;
42                cur = (cur * cur) % modulo;
43            }
44            Self(result)
45        }
46
47        pub fn pow<T>(&self, e: T) -> Self
48        where
49            T: ToInternalNum,
50        {
51            self.internal_pow(e.to_internal_num())
52        }
53
54        pub fn value(&self) -> ModInternalNum {
55            self.0
56        }
57    }
58
59    pub trait ToInternalNum {
60        fn to_internal_num(&self) -> ModInternalNum;
61    }
62    impl ToInternalNum for ModInt {
63        fn to_internal_num(&self) -> ModInternalNum {
64            self.0
65        }
66    }
67    macro_rules! impl_primitive {
68        ($primitive:ident) => {
69            impl From<$primitive> for ModInt {
70                fn from(v: $primitive) -> Self {
71                    let v = v as ModInternalNum;
72                    Self::internal_new(v)
73                }
74            }
75            impl ToInternalNum for $primitive {
76                fn to_internal_num(&self) -> ModInternalNum {
77                    *self as ModInternalNum
78                }
79            }
80        };
81    }
82    impl_primitive!(u8);
83    impl_primitive!(u16);
84    impl_primitive!(u32);
85    impl_primitive!(u64);
86    impl_primitive!(usize);
87    impl_primitive!(i8);
88    impl_primitive!(i16);
89    impl_primitive!(i32);
90    impl_primitive!(i64);
91    impl_primitive!(isize);
92
93    impl<T: ToInternalNum> std::ops::AddAssign<T> for ModInt {
94        fn add_assign(&mut self, rhs: T) {
95            let mut rhs = rhs.to_internal_num();
96            let m = modulo();
97            if rhs >= m {
98                rhs %= m;
99            }
100
101            self.0 += rhs;
102            if self.0 >= m {
103                self.0 -= m;
104            }
105        }
106    }
107
108    impl<T: ToInternalNum> std::ops::Add<T> for ModInt {
109        type Output = ModInt;
110        fn add(self, rhs: T) -> Self::Output {
111            let mut res = self;
112            res += rhs;
113            res
114        }
115    }
116    impl<T: ToInternalNum> std::ops::SubAssign<T> for ModInt {
117        fn sub_assign(&mut self, rhs: T) {
118            let mut rhs = rhs.to_internal_num();
119            let m = modulo();
120            if rhs >= m {
121                rhs %= m;
122            }
123            if rhs > 0 {
124                self.0 += m - rhs;
125            }
126            if self.0 >= m {
127                self.0 -= m;
128            }
129        }
130    }
131    impl<T: ToInternalNum> std::ops::Sub<T> for ModInt {
132        type Output = Self;
133        fn sub(self, rhs: T) -> Self::Output {
134            let mut res = self;
135            res -= rhs;
136            res
137        }
138    }
139    impl<T: ToInternalNum> std::ops::MulAssign<T> for ModInt {
140        fn mul_assign(&mut self, rhs: T) {
141            let mut rhs = rhs.to_internal_num();
142            let m = modulo();
143            if rhs >= m {
144                rhs %= m;
145            }
146            self.0 *= rhs;
147            self.0 %= m;
148        }
149    }
150    impl<T: ToInternalNum> std::ops::Mul<T> for ModInt {
151        type Output = Self;
152        fn mul(self, rhs: T) -> Self::Output {
153            let mut res = self;
154            res *= rhs;
155            res
156        }
157    }
158
159    impl<T: ToInternalNum> std::ops::DivAssign<T> for ModInt {
160        fn div_assign(&mut self, rhs: T) {
161            let mut rhs = rhs.to_internal_num();
162            let m = modulo();
163            if rhs >= m {
164                rhs %= m;
165            }
166            let inv = Self(rhs).internal_pow(m - 2);
167            self.0 *= inv.value();
168            self.0 %= m;
169        }
170    }
171
172    impl<T: ToInternalNum> std::ops::Div<T> for ModInt {
173        type Output = Self;
174        fn div(self, rhs: T) -> Self::Output {
175            let mut res = self;
176            res /= rhs;
177            res
178        }
179    }
180}
181
182#[cfg(test)]
183mod test {
184    use super::mod_int::*;
185    use rand::distributions::Uniform;
186    use rand::Rng;
187
188    const PRIME_MOD: [i64; 3] = [1_000_000_007, 1_000_000_009, 998244353];
189    const INF: i64 = 1 << 60;
190
191    fn random_add_sub(prime_mod: i64) {
192        let mut rng = rand::thread_rng();
193        set_mod_int(prime_mod);
194        for _ in 0..10000 {
195            let x: i64 = rng.sample(Uniform::from(0..prime_mod));
196            let y: i64 = rng.sample(Uniform::from(0..prime_mod));
197
198            let mx = ModInt::from(x);
199            let my = ModInt::from(y);
200
201            assert_eq!((mx + my).value(), (x + y) % prime_mod);
202            assert_eq!((mx + y).value(), (x + y) % prime_mod);
203            assert_eq!((mx - my).value(), (x + prime_mod - y) % prime_mod);
204            assert_eq!((mx - y).value(), (x + prime_mod - y) % prime_mod);
205
206            let mut x = x;
207            let mut mx = mx;
208            x += y;
209            mx += my;
210            assert_eq!(mx.value(), x % prime_mod);
211
212            mx += y;
213            x += y;
214            assert_eq!(mx.value(), x % prime_mod);
215
216            mx -= my;
217            x = (x + prime_mod - y % prime_mod) % prime_mod;
218            assert_eq!(mx.value(), x);
219
220            mx -= y;
221            x = (x + prime_mod - y % prime_mod) % prime_mod;
222            assert_eq!(mx.value(), x);
223        }
224    }
225
226    #[test]
227    fn test_random_add_sub1() {
228        random_add_sub(PRIME_MOD[0]);
229    }
230
231    #[test]
232    fn test_random_add_sub2() {
233        random_add_sub(PRIME_MOD[1]);
234    }
235
236    #[test]
237    fn test_random_add_sub3() {
238        random_add_sub(PRIME_MOD[2]);
239    }
240
241    fn random_mul(prime_mod: i64) {
242        let mut rng = rand::thread_rng();
243        set_mod_int(prime_mod);
244        for _ in 0..10000 {
245            let x: i64 = rng.sample(Uniform::from(0..prime_mod));
246            let y: i64 = rng.sample(Uniform::from(0..prime_mod));
247
248            let mx = ModInt::from(x);
249            let my = ModInt::from(y);
250
251            assert_eq!((mx * my).value(), (x * y) % prime_mod);
252            assert_eq!((mx * y).value(), (x * y) % prime_mod);
253        }
254    }
255    #[test]
256    fn test_random_mul1() {
257        random_mul(PRIME_MOD[0]);
258    }
259    #[test]
260    fn test_random_mul2() {
261        random_mul(PRIME_MOD[1]);
262    }
263    #[test]
264    fn test_random_mul3() {
265        random_mul(PRIME_MOD[2]);
266    }
267
268    #[test]
269    fn zero_test() {
270        set_mod_int(1_000_000_007i64);
271        let a = ModInt::from(1_000_000_000i64);
272        let b = ModInt::from(7i64);
273        let c = a + b;
274        assert_eq!(c.value(), 0);
275    }
276
277    #[test]
278    fn pow_test() {
279        set_mod_int(1_000_000_007i64);
280        let a = ModInt::from(3i64);
281        let a = a.pow(4i64);
282        assert_eq!(a.value(), 81);
283    }
284
285    #[test]
286    fn div_test() {
287        set_mod_int(1_000_000_007i64);
288        for i in 1..100000i64 {
289            let mut a = ModInt::from(1i64);
290            a /= i;
291            a *= i;
292            assert_eq!(a.value(), 1);
293        }
294    }
295
296    #[test]
297    fn edge_cases() {
298        const MOD: i128 = 1_000_000_007;
299        set_mod_int(1_000_000_007i64);
300
301        let a = ModInt::from(1_000_000_000i64) * INF;
302        assert_eq!(
303            a.value(),
304            ((1_000_000_000i128 * i128::from(INF)) % MOD) as i64
305        );
306
307        let mut a = ModInt::from(1_000_000_000i64);
308        a *= INF;
309        assert_eq!(
310            a.value(),
311            ((1_000_000_000i128 * i128::from(INF)) % MOD) as i64
312        );
313
314        let a = ModInt::from(1_000_000_000i64) + INF;
315        assert_eq!(
316            a.value(),
317            ((1_000_000_000i128 + i128::from(INF)) % MOD) as i64
318        );
319
320        let mut a = ModInt::from(1_000_000_000i64);
321        a += INF;
322        assert_eq!(
323            a.value(),
324            ((1_000_000_000i128 + i128::from(INF)) % MOD) as i64
325        );
326
327        let a = ModInt::from(1_000_000_000i64) - INF;
328        assert_eq!(
329            a.value(),
330            ((1_000_000_000i128 + MOD - (INF as i128) % MOD) % MOD) as i64
331        );
332
333        let mut a = ModInt::from(1_000_000_000i64);
334        a -= INF;
335        assert_eq!(
336            a.value(),
337            ((1_000_000_000i128 + MOD - (INF as i128) % MOD) % MOD) as i64
338        );
339
340        let a = ModInt::from(1_000_000_000i64) / INF;
341        assert_eq!(a.value(), 961239577);
342
343        let mut a = ModInt::from(1_000_000_000i64);
344        a /= INF;
345        assert_eq!(a.value(), 961239577);
346    }
347
348    #[test]
349    fn overflow_guard() {
350        set_mod_int(1_000_000_007i64);
351        let a = ModInt::from(1_000_000_007i64 * 10);
352        assert_eq!(a.value(), 0);
353    }
354
355    #[test]
356    fn initialize_from_various_primitives() {
357        set_mod_int(1_000_000_007);
358        let a = ModInt::from(100usize);
359        let b = ModInt::from(100i64);
360        assert_eq!(a.value(), b.value());
361    }
362}