1use num::integer::*;
2use num::traits::{NumOps, One, PrimInt, ToPrimitive, Zero};
3use std::cmp::Ordering;
4use std::fmt;
5use std::ops::{Add, AddAssign, Div, Mul, Sub, SubAssign};
6
7#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
21pub struct ModInt<T> {
22 pub value: T,
23 pub modulo: T,
24}
25
26pub trait ModIntTrait<T> {
27 fn new(n: T) -> Self;
28 fn new_with(n: T, modulo: T) -> Self;
29 fn inverse(&self) -> Self;
30 fn pow(self, r: T) -> Self;
31 fn static_inverse_with(n: T, modulo: T) -> T;
32}
33
34impl<T> ModIntTrait<T> for ModInt<T>
35where
36 T: PrimInt,
37{
38 fn new(n: T) -> Self {
39 Self::new_with(n, T::from(1000000007).unwrap())
40 }
41
42 fn new_with(n: T, modulo: T) -> Self {
43 ModInt {
44 value: n % modulo,
45 modulo,
46 }
47 }
48
49 #[inline]
50 fn inverse(&self) -> Self {
51 let value = Self::static_inverse_with(self.value, self.modulo);
52 ModInt {
53 value,
54 modulo: self.modulo,
55 }
56 }
57
58 fn pow(self, mut r: T) -> Self {
59 let mut k = self;
60 let mut ret = ModInt::new_with(T::from(1).unwrap(), self.modulo);
61 let zero = T::from(0).unwrap();
62 let two = T::from(2).unwrap();
63 while r > zero {
64 if r % two != zero {
65 ret = ret * k;
66 }
67 r = r / two;
68 k = k * k;
69 }
70 ret
71 }
72
73 fn static_inverse_with(n: T, modulo: T) -> T {
74 let ExtendedGcd { x, .. } = n.to_i64().unwrap().extended_gcd(&modulo.to_i64().unwrap());
75
76 T::from(if x < 0 {
77 x + modulo.to_i64().unwrap()
78 } else {
79 x
80 })
81 .unwrap()
82 }
83}
84
85impl<T> Zero for ModInt<T>
86where
87 T: PrimInt,
88{
89 fn zero() -> Self {
90 ModInt {
91 value: T::from(0).unwrap(),
92 modulo: T::from(1000000007).unwrap(),
93 }
94 }
95
96 fn is_zero(&self) -> bool {
97 self.value == T::from(0).unwrap()
98 }
99}
100
101impl<T> One for ModInt<T>
102where
103 T: PrimInt,
104{
105 fn one() -> Self {
106 ModInt {
107 value: T::from(1).unwrap(),
108 modulo: T::from(1000000007).unwrap(),
109 }
110 }
111 fn is_one(&self) -> bool
112 where
113 Self: PartialEq,
114 {
115 self.value == T::from(1).unwrap()
116 }
117}
118
119impl<T> fmt::Display for ModInt<T>
120where
121 T: fmt::Display,
122{
123 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124 write!(f, "{}", self.value)
125 }
126}
127
128impl<T> Add for ModInt<T>
129where
130 T: PrimInt,
131{
132 type Output = ModInt<T>;
133
134 #[inline]
135 fn add(self, other: ModInt<T>) -> Self {
136 ModInt {
137 value: if self.value + other.value >= self.modulo {
138 (self.value + other.value) % self.modulo
139 } else {
140 self.value + other.value
141 },
142 modulo: self.modulo,
143 }
144 }
145}
146
147impl<T> Add<T> for ModInt<T>
148where
149 T: NumOps + PartialOrd + Copy,
150{
151 type Output = ModInt<T>;
152
153 #[inline]
154 fn add(self, rhs: T) -> Self {
155 ModInt {
156 value: if self.value + rhs >= self.modulo {
157 (self.value + rhs) % self.modulo
158 } else {
159 self.value + rhs
160 },
161 modulo: self.modulo,
162 }
163 }
164}
165
166macro_rules! impl_modint_add(($($ty:ty),*) => {
167 $(
168 impl<T> Add<ModInt<T>> for $ty
169 where
170 T: PrimInt,
171 {
172 type Output = ModInt<T>;
173
174 #[inline]
175 fn add(self, rhs: ModInt<T>) -> ModInt<T> {
176 ModInt {
177 value: if T::from(self).unwrap() + rhs.value >= rhs.modulo {
178 (T::from(self).unwrap() + rhs.value) % rhs.modulo
179 } else {
180 T::from(self).unwrap() + rhs.value
181 },
182 modulo: rhs.modulo,
183 }
184 }
185 }
186 )*
187});
188
189impl_modint_add!(i8, i16, i32, i64, u8, u16, u32, u64, isize, usize);
190
191impl<T> Sub for ModInt<T>
192where
193 T: PrimInt,
194{
195 type Output = ModInt<T>;
196
197 #[inline]
198 fn sub(self, other: ModInt<T>) -> Self {
199 ModInt {
200 value: if self.value < other.value {
201 self.value + self.modulo - other.value
202 } else {
203 self.value - other.value
204 },
205 modulo: self.modulo,
206 }
207 }
208}
209
210impl<T> Sub<T> for ModInt<T>
211where
212 T: PrimInt,
213{
214 type Output = ModInt<T>;
215
216 #[inline]
217 fn sub(self, rhs: T) -> Self {
218 ModInt {
219 value: if self.value < rhs {
220 self.value + self.modulo - rhs
221 } else {
222 self.value - rhs
223 },
224 modulo: self.modulo,
225 }
226 }
227}
228
229macro_rules! impl_modint_sub(($($ty:ty),*) => {
230 $(
231 impl<T> Sub<ModInt<T>> for $ty
232 where
233 T: PrimInt,
234 {
235 type Output = ModInt<T>;
236
237 #[inline]
238 fn sub(self, rhs: ModInt<T>) -> ModInt<T> {
239 ModInt {
240 value: if T::from(self).unwrap() < rhs.value {
241 T::from(self).unwrap() + rhs.modulo - rhs.value
242 } else {
243 (T::from(self).unwrap() - rhs.value) % rhs.modulo
244 },
245 modulo: rhs.modulo,
246 }
247 }
248 }
249 )*
250});
251
252impl_modint_sub!(i8, i16, i32, i64, u8, u16, u32, u64, isize, usize);
253
254impl<T> Mul for ModInt<T>
255where
256 T: PrimInt,
257{
258 type Output = ModInt<T>;
259
260 #[inline]
261 fn mul(self, other: ModInt<T>) -> Self {
262 ModInt {
263 value: (self.value * other.value) % self.modulo,
264 modulo: self.modulo,
265 }
266 }
267}
268
269impl<T> Mul<T> for ModInt<T>
270where
271 T: PrimInt,
272{
273 type Output = ModInt<T>;
274
275 #[inline]
276 fn mul(self, rhs: T) -> Self {
277 ModInt {
278 value: (self.value * rhs) % self.modulo,
279 modulo: self.modulo,
280 }
281 }
282}
283
284macro_rules! impl_modint_mul(($($ty:ty),*) => {
285 $(
286 impl<T> Mul<ModInt<T>> for $ty
287 where
288 T: PrimInt,
289 {
290 type Output = ModInt<T>;
291
292 #[inline]
293 fn mul(self, rhs: ModInt<T>) -> ModInt<T> {
294 ModInt {
295 value: (T::from(self).unwrap() * rhs.value) % rhs.modulo,
296 modulo: rhs.modulo,
297 }
298 }
299 }
300 )*
301});
302
303impl_modint_mul!(i8, i16, i32, i64, u8, u16, u32, u64, isize, usize);
304
305impl<T> Div for ModInt<T>
306where
307 T: PrimInt,
308{
309 type Output = ModInt<T>;
310
311 #[inline]
312 fn div(self, other: ModInt<T>) -> Self {
313 ModInt {
314 value: (self.value * other.inverse().value) % self.modulo,
315 modulo: self.modulo,
316 }
317 }
318}
319
320impl<T> Div<T> for ModInt<T>
321where
322 T: PrimInt,
323{
324 type Output = ModInt<T>;
325
326 #[inline]
327 fn div(self, rhs: T) -> Self {
328 let inv = Self::static_inverse_with(rhs, self.modulo);
329 ModInt {
330 value: (self.value * inv) % self.modulo,
331 modulo: self.modulo,
332 }
333 }
334}
335
336macro_rules! impl_modint_div(($($ty:ty),*) => {
337 $(
338 impl<T> Div<ModInt<T>> for $ty
339 where
340 T: PrimInt,
341 {
342 type Output = ModInt<T>;
343
344 #[inline]
345 fn div(self, rhs: ModInt<T>) -> ModInt<T> {
346 let inv = ModInt::static_inverse_with(rhs.value, rhs.modulo);
347 ModInt {
348 value: (T::from(self).unwrap() * inv) % rhs.modulo,
349 modulo: rhs.modulo,
350 }
351 }
352 }
353 )*
354});
355
356impl_modint_div!(i8, i16, i32, i64, u8, u16, u32, u64, isize, usize);
357
358impl<T> AddAssign<T> for ModInt<T>
359where
360 T: PrimInt,
361{
362 fn add_assign(&mut self, rhs: T) {
363 (*self).value = if self.value + rhs >= self.modulo {
364 (self.value + rhs) % self.modulo
365 } else {
366 self.value + rhs
367 }
368 }
369}
370
371impl<T> AddAssign<ModInt<T>> for ModInt<T>
372where
373 T: PrimInt,
374{
375 fn add_assign(&mut self, other: ModInt<T>) {
376 (*self).value = if self.value + other.value >= self.modulo {
377 (self.value + other.value) % self.modulo
378 } else {
379 self.value + other.value
380 }
381 }
382}
383
384impl<T> SubAssign<T> for ModInt<T>
385where
386 T: PrimInt,
387{
388 fn sub_assign(&mut self, rhs: T) {
389 (*self).value = if self.value < rhs {
390 self.value + self.modulo - rhs
391 } else {
392 self.value - rhs
393 }
394 }
395}
396
397impl<T> SubAssign<ModInt<T>> for ModInt<T>
398where
399 T: PrimInt,
400{
401 fn sub_assign(&mut self, other: ModInt<T>) {
402 (*self).value = if self.value < other.value {
403 self.value + self.modulo - other.value
404 } else {
405 self.value - other.value
406 }
407 }
408}
409
410impl<T> PartialEq<T> for ModInt<T>
411where
412 T: PrimInt,
413{
414 fn eq(&self, other: &T) -> bool {
415 self.value == *other
416 }
417}
418
419macro_rules! impl_modint_partial_eq(($($ty:ty),*) => {
420 $(
421 impl<T> PartialEq<ModInt<T>> for $ty
422 where
423 T: PrimInt,
424 {
425 #[inline]
426 fn eq(&self, other: &ModInt<T>) -> bool {
427 T::from(self.clone()).unwrap() == other.value.clone()
428 }
429 }
430 )*
431});
432
433impl_modint_partial_eq!(i8, i16, i32, i64, u8, u16, u32, u64, isize, usize);
434
435impl<T> PartialOrd<T> for ModInt<T>
436where
437 T: PrimInt,
438{
439 fn partial_cmp(&self, other: &T) -> Option<Ordering> {
440 Some(self.value.cmp(other))
441 }
442}
443
444macro_rules! impl_modint_partial_ord(($($ty:ty),*) => {
445 $(
446 impl<T> PartialOrd<ModInt<T>> for $ty
447 where
448 T: PrimInt,
449 {
450 #[inline]
451 fn partial_cmp(&self, other: &ModInt<T>) -> Option<Ordering> {
452 Some(T::from(self.clone()).unwrap().cmp(&other.value))
453 }
454 }
455 )*
456});
457
458impl_modint_partial_ord!(i8, i16, i32, i64, u8, u16, u32, u64, isize, usize);
459
460macro_rules! impl_modint_to_primitive(($(($ty:ty, $method:ident)),*) => {
461 $(
462 #[inline]
463 fn $method(&self) -> Option<$ty> {
464 self.value.$method()
465 }
466 )*
467});
468
469impl<T> ToPrimitive for ModInt<T>
470where
471 T: PrimInt,
472{
473 impl_modint_to_primitive!(
474 (i8, to_i8),
475 (i16, to_i16),
476 (i32, to_i32),
477 (i64, to_i64),
478 (u8, to_u8),
479 (u16, to_u16),
480 (u32, to_u32),
481 (u64, to_u64),
482 (isize, to_isize),
483 (usize, to_usize)
484 );
485}
486
487#[cfg(test)]
488mod test {
489 use super::*;
490
491 #[test]
492 fn test_modint_modint() {
493 const MOD: usize = 7;
494 let mi0 = ModInt::new_with(0, MOD);
495 let mi1 = ModInt::new_with(1, MOD);
496 let mi2 = ModInt::new_with(2, MOD);
497 let mi4 = ModInt::new_with(4, MOD);
498 let mi7 = ModInt::new_with(7, MOD);
499 let mi11 = ModInt::new_with(11, MOD);
500
501 assert_eq!(mi0 + mi7, ModInt::new_with(0, 7));
502 assert_eq!(mi1 + mi2, ModInt::new_with(3, 7));
503 assert_eq!(mi1 + mi11, ModInt::new_with(5, 7));
504 assert_eq!(mi1 - mi4, ModInt::new_with(4, 7));
505 }
506
507 #[test]
508 fn test_modint_other_type() {
509 const MOD: usize = 7;
510 let mi0 = ModInt::new_with(0, MOD);
511
512 assert_eq!(mi0 + 6, ModInt::new_with(6, MOD));
513 assert_eq!(mi0 + 7, ModInt::new_with(0, MOD));
514 assert_eq!(7usize + mi0, ModInt::new_with(0, MOD));
515 assert_eq!(15usize + mi0, ModInt::new_with(1, MOD));
516 assert_eq!(mi0 - 4, ModInt::new_with(3, MOD));
517 assert_eq!(mi0 - ModInt::new_with(0, MOD), ModInt::new_with(0, MOD));
518 assert_eq!(7usize - mi0, ModInt::new_with(0, MOD));
519 }
520
521 #[test]
522 fn test_new() {
523 let mi0 = ModInt::new(0u64);
524 let mi1 = ModInt::new(7u64);
525 let mi2 = ModInt::new(1000000007u64);
526
527 assert!(mi0 == 0);
528 assert_eq!(mi0, ModInt::new(0));
529 assert_eq!(mi1 + mi2, ModInt::new(7));
530 assert_eq!(mi0 - mi1, ModInt::new(1000000007 - 7));
531 assert_eq!(100 * mi1, ModInt::new(700u64));
532 assert_eq!(100u64 * mi1 * 2 / 10 / ModInt::new(5), ModInt::new(28));
533 }
534 #[test]
535 fn test_inverse() {
536 const MOD: u64 = 13;
537
538 assert_eq!(1, ModInt::new_with(1, MOD).inverse());
539 assert_eq!(7, ModInt::new_with(2, MOD).inverse());
540 assert_eq!(9, ModInt::new_with(3, MOD).inverse());
541 assert_eq!(10, ModInt::new_with(4, MOD).inverse());
542 assert_eq!(8, ModInt::new_with(5, MOD).inverse());
543 assert_eq!(11, ModInt::new_with(6, MOD).inverse());
544 assert_eq!(2, ModInt::new_with(7, MOD).inverse());
545 assert_eq!(5, ModInt::new_with(8, MOD).inverse());
546 assert_eq!(3, ModInt::new_with(9, MOD).inverse());
547 assert_eq!(4, ModInt::new_with(10, MOD).inverse());
548 assert_eq!(6, ModInt::new_with(11, MOD).inverse());
549 assert_eq!(12, ModInt::new_with(12, MOD).inverse());
550 }
551
552 #[test]
553 fn test_div() {
554 const MOD: u64 = 13;
555
556 assert_eq!(4, (ModInt::new_with(2, MOD) / ModInt::new_with(7, MOD)));
557 assert_eq!(4, (2u64 / ModInt::new_with(7, MOD)));
558 assert_eq!(4, (ModInt::new_with(2, MOD) / 7));
559 }
560
561 #[test]
562 fn test_mul() {
563 const MOD: u64 = 13;
564
565 assert_eq!(2, ModInt::new_with(3, MOD) * ModInt::new_with(5, MOD));
566 assert_eq!(2, ModInt::new_with(3, MOD) * 5);
567 assert_eq!(2, 3 * ModInt::new_with(5, MOD));
568 }
569
570 #[test]
571 fn test_assign() {
572 const MOD: u64 = 13;
573
574 let mut t = ModInt::new_with(3, MOD) + ModInt::new_with(5, MOD);
575 t += 7;
576 assert_eq!(2, t);
577 t -= 4;
578 assert_eq!(11, t);
579 t += ModInt::new_with(5, MOD);
580 assert_eq!(3, t);
581 t -= ModInt::new_with(20, MOD);
582 assert_eq!(9, t);
583 }
584
585 #[test]
586 fn test_partialord() {
587 const MOD: u64 = 13;
588
589 assert!(ModInt::new_with(3, MOD) < ModInt::new_with(5, MOD));
590 assert!(3 < ModInt::new_with(5, MOD));
591 assert!(ModInt::new_with(3, MOD) < 5);
592 assert!(!(ModInt::new(10) < 7));
593 }
594
595 #[test]
596 fn test_to_primitive() {
597 assert_ne!(2, ModInt::new(13).to_u64().unwrap());
598 }
599}