competitive_programming_rs/math/
mod_int.rs1pub mod mod_int {
2 type ModInternalNum = i64;
3 thread_local!(
4 static MOD: std::cell::RefCell<ModInternalNum> = std::cell::RefCell::new(0);
5 );
6
7 pub fn set_mod_int<T: ToInternalNum>(v: T) {
8 MOD.with(|x| x.replace(v.to_internal_num()));
9 }
10 fn modulo() -> ModInternalNum {
11 MOD.with(|x| *x.borrow())
12 }
13
14 #[derive(Debug)]
15 pub struct ModInt(ModInternalNum);
16 impl Clone for ModInt {
17 fn clone(&self) -> Self {
18 Self(self.0)
19 }
20 }
21 impl Copy for ModInt {}
22
23 impl ModInt {
24 fn internal_new(mut v: ModInternalNum) -> Self {
25 let m = modulo();
26 if v >= m {
27 v %= m;
28 }
29 Self(v)
30 }
31
32 pub fn internal_pow(&self, mut e: ModInternalNum) -> Self {
33 let mut result = 1;
34 let mut cur = self.0;
35 let modulo = modulo();
36 while e > 0 {
37 if e & 1 == 1 {
38 result *= cur;
39 result %= modulo;
40 }
41 e >>= 1;
42 cur = (cur * cur) % modulo;
43 }
44 Self(result)
45 }
46
47 pub fn pow<T>(&self, e: T) -> Self
48 where
49 T: ToInternalNum,
50 {
51 self.internal_pow(e.to_internal_num())
52 }
53
54 pub fn value(&self) -> ModInternalNum {
55 self.0
56 }
57 }
58
59 pub trait ToInternalNum {
60 fn to_internal_num(&self) -> ModInternalNum;
61 }
62 impl ToInternalNum for ModInt {
63 fn to_internal_num(&self) -> ModInternalNum {
64 self.0
65 }
66 }
67 macro_rules! impl_primitive {
68 ($primitive:ident) => {
69 impl From<$primitive> for ModInt {
70 fn from(v: $primitive) -> Self {
71 let v = v as ModInternalNum;
72 Self::internal_new(v)
73 }
74 }
75 impl ToInternalNum for $primitive {
76 fn to_internal_num(&self) -> ModInternalNum {
77 *self as ModInternalNum
78 }
79 }
80 };
81 }
82 impl_primitive!(u8);
83 impl_primitive!(u16);
84 impl_primitive!(u32);
85 impl_primitive!(u64);
86 impl_primitive!(usize);
87 impl_primitive!(i8);
88 impl_primitive!(i16);
89 impl_primitive!(i32);
90 impl_primitive!(i64);
91 impl_primitive!(isize);
92
93 impl<T: ToInternalNum> std::ops::AddAssign<T> for ModInt {
94 fn add_assign(&mut self, rhs: T) {
95 let mut rhs = rhs.to_internal_num();
96 let m = modulo();
97 if rhs >= m {
98 rhs %= m;
99 }
100
101 self.0 += rhs;
102 if self.0 >= m {
103 self.0 -= m;
104 }
105 }
106 }
107
108 impl<T: ToInternalNum> std::ops::Add<T> for ModInt {
109 type Output = ModInt;
110 fn add(self, rhs: T) -> Self::Output {
111 let mut res = self;
112 res += rhs;
113 res
114 }
115 }
116 impl<T: ToInternalNum> std::ops::SubAssign<T> for ModInt {
117 fn sub_assign(&mut self, rhs: T) {
118 let mut rhs = rhs.to_internal_num();
119 let m = modulo();
120 if rhs >= m {
121 rhs %= m;
122 }
123 if rhs > 0 {
124 self.0 += m - rhs;
125 }
126 if self.0 >= m {
127 self.0 -= m;
128 }
129 }
130 }
131 impl<T: ToInternalNum> std::ops::Sub<T> for ModInt {
132 type Output = Self;
133 fn sub(self, rhs: T) -> Self::Output {
134 let mut res = self;
135 res -= rhs;
136 res
137 }
138 }
139 impl<T: ToInternalNum> std::ops::MulAssign<T> for ModInt {
140 fn mul_assign(&mut self, rhs: T) {
141 let mut rhs = rhs.to_internal_num();
142 let m = modulo();
143 if rhs >= m {
144 rhs %= m;
145 }
146 self.0 *= rhs;
147 self.0 %= m;
148 }
149 }
150 impl<T: ToInternalNum> std::ops::Mul<T> for ModInt {
151 type Output = Self;
152 fn mul(self, rhs: T) -> Self::Output {
153 let mut res = self;
154 res *= rhs;
155 res
156 }
157 }
158
159 impl<T: ToInternalNum> std::ops::DivAssign<T> for ModInt {
160 fn div_assign(&mut self, rhs: T) {
161 let mut rhs = rhs.to_internal_num();
162 let m = modulo();
163 if rhs >= m {
164 rhs %= m;
165 }
166 let inv = Self(rhs).internal_pow(m - 2);
167 self.0 *= inv.value();
168 self.0 %= m;
169 }
170 }
171
172 impl<T: ToInternalNum> std::ops::Div<T> for ModInt {
173 type Output = Self;
174 fn div(self, rhs: T) -> Self::Output {
175 let mut res = self;
176 res /= rhs;
177 res
178 }
179 }
180}
181
182#[cfg(test)]
183mod test {
184 use super::mod_int::*;
185 use rand::distributions::Uniform;
186 use rand::Rng;
187
188 const PRIME_MOD: [i64; 3] = [1_000_000_007, 1_000_000_009, 998244353];
189 const INF: i64 = 1 << 60;
190
191 fn random_add_sub(prime_mod: i64) {
192 let mut rng = rand::thread_rng();
193 set_mod_int(prime_mod);
194 for _ in 0..10000 {
195 let x: i64 = rng.sample(Uniform::from(0..prime_mod));
196 let y: i64 = rng.sample(Uniform::from(0..prime_mod));
197
198 let mx = ModInt::from(x);
199 let my = ModInt::from(y);
200
201 assert_eq!((mx + my).value(), (x + y) % prime_mod);
202 assert_eq!((mx + y).value(), (x + y) % prime_mod);
203 assert_eq!((mx - my).value(), (x + prime_mod - y) % prime_mod);
204 assert_eq!((mx - y).value(), (x + prime_mod - y) % prime_mod);
205
206 let mut x = x;
207 let mut mx = mx;
208 x += y;
209 mx += my;
210 assert_eq!(mx.value(), x % prime_mod);
211
212 mx += y;
213 x += y;
214 assert_eq!(mx.value(), x % prime_mod);
215
216 mx -= my;
217 x = (x + prime_mod - y % prime_mod) % prime_mod;
218 assert_eq!(mx.value(), x);
219
220 mx -= y;
221 x = (x + prime_mod - y % prime_mod) % prime_mod;
222 assert_eq!(mx.value(), x);
223 }
224 }
225
226 #[test]
227 fn test_random_add_sub1() {
228 random_add_sub(PRIME_MOD[0]);
229 }
230
231 #[test]
232 fn test_random_add_sub2() {
233 random_add_sub(PRIME_MOD[1]);
234 }
235
236 #[test]
237 fn test_random_add_sub3() {
238 random_add_sub(PRIME_MOD[2]);
239 }
240
241 fn random_mul(prime_mod: i64) {
242 let mut rng = rand::thread_rng();
243 set_mod_int(prime_mod);
244 for _ in 0..10000 {
245 let x: i64 = rng.sample(Uniform::from(0..prime_mod));
246 let y: i64 = rng.sample(Uniform::from(0..prime_mod));
247
248 let mx = ModInt::from(x);
249 let my = ModInt::from(y);
250
251 assert_eq!((mx * my).value(), (x * y) % prime_mod);
252 assert_eq!((mx * y).value(), (x * y) % prime_mod);
253 }
254 }
255 #[test]
256 fn test_random_mul1() {
257 random_mul(PRIME_MOD[0]);
258 }
259 #[test]
260 fn test_random_mul2() {
261 random_mul(PRIME_MOD[1]);
262 }
263 #[test]
264 fn test_random_mul3() {
265 random_mul(PRIME_MOD[2]);
266 }
267
268 #[test]
269 fn zero_test() {
270 set_mod_int(1_000_000_007i64);
271 let a = ModInt::from(1_000_000_000i64);
272 let b = ModInt::from(7i64);
273 let c = a + b;
274 assert_eq!(c.value(), 0);
275 }
276
277 #[test]
278 fn pow_test() {
279 set_mod_int(1_000_000_007i64);
280 let a = ModInt::from(3i64);
281 let a = a.pow(4i64);
282 assert_eq!(a.value(), 81);
283 }
284
285 #[test]
286 fn div_test() {
287 set_mod_int(1_000_000_007i64);
288 for i in 1..100000i64 {
289 let mut a = ModInt::from(1i64);
290 a /= i;
291 a *= i;
292 assert_eq!(a.value(), 1);
293 }
294 }
295
296 #[test]
297 fn edge_cases() {
298 const MOD: i128 = 1_000_000_007;
299 set_mod_int(1_000_000_007i64);
300
301 let a = ModInt::from(1_000_000_000i64) * INF;
302 assert_eq!(
303 a.value(),
304 ((1_000_000_000i128 * i128::from(INF)) % MOD) as i64
305 );
306
307 let mut a = ModInt::from(1_000_000_000i64);
308 a *= INF;
309 assert_eq!(
310 a.value(),
311 ((1_000_000_000i128 * i128::from(INF)) % MOD) as i64
312 );
313
314 let a = ModInt::from(1_000_000_000i64) + INF;
315 assert_eq!(
316 a.value(),
317 ((1_000_000_000i128 + i128::from(INF)) % MOD) as i64
318 );
319
320 let mut a = ModInt::from(1_000_000_000i64);
321 a += INF;
322 assert_eq!(
323 a.value(),
324 ((1_000_000_000i128 + i128::from(INF)) % MOD) as i64
325 );
326
327 let a = ModInt::from(1_000_000_000i64) - INF;
328 assert_eq!(
329 a.value(),
330 ((1_000_000_000i128 + MOD - (INF as i128) % MOD) % MOD) as i64
331 );
332
333 let mut a = ModInt::from(1_000_000_000i64);
334 a -= INF;
335 assert_eq!(
336 a.value(),
337 ((1_000_000_000i128 + MOD - (INF as i128) % MOD) % MOD) as i64
338 );
339
340 let a = ModInt::from(1_000_000_000i64) / INF;
341 assert_eq!(a.value(), 961239577);
342
343 let mut a = ModInt::from(1_000_000_000i64);
344 a /= INF;
345 assert_eq!(a.value(), 961239577);
346 }
347
348 #[test]
349 fn overflow_guard() {
350 set_mod_int(1_000_000_007i64);
351 let a = ModInt::from(1_000_000_007i64 * 10);
352 assert_eq!(a.value(), 0);
353 }
354
355 #[test]
356 fn initialize_from_various_primitives() {
357 set_mod_int(1_000_000_007);
358 let a = ModInt::from(100usize);
359 let b = ModInt::from(100i64);
360 assert_eq!(a.value(), b.value());
361 }
362}