competitive_hpp/modulo/
mod_int.rs

1use num::integer::*;
2use num::traits::{NumOps, One, PrimInt, ToPrimitive, Zero};
3use std::cmp::Ordering;
4use std::fmt;
5use std::ops::{Add, AddAssign, Div, Mul, Sub, SubAssign};
6
7/// ## Example
8/// ```
9/// use competitive_hpp::prelude::*;
10/// const MOD: usize = 7;
11///
12/// let mi0 = ModInt::new_with(0, MOD);
13/// let mi1 = ModInt::new_with(1, MOD);
14/// let mi2 = ModInt::new_with(2, MOD);
15/// let mi11 = ModInt::new_with(11, MOD);
16/// assert_eq!(4, ModInt::new(4));
17/// assert_eq!(mi0 + mi11, ModInt::new_with(4, 7));
18/// assert_eq!(mi1 + mi2, ModInt::new_with(3, 7));
19/// ```
20#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
21pub struct ModInt<T> {
22    pub value: T,
23    pub modulo: T,
24}
25
26pub trait ModIntTrait<T> {
27    fn new(n: T) -> Self;
28    fn new_with(n: T, modulo: T) -> Self;
29    fn inverse(&self) -> Self;
30    fn pow(self, r: T) -> Self;
31    fn static_inverse_with(n: T, modulo: T) -> T;
32}
33
34impl<T> ModIntTrait<T> for ModInt<T>
35where
36    T: PrimInt,
37{
38    fn new(n: T) -> Self {
39        Self::new_with(n, T::from(1000000007).unwrap())
40    }
41
42    fn new_with(n: T, modulo: T) -> Self {
43        ModInt {
44            value: n % modulo,
45            modulo,
46        }
47    }
48
49    #[inline]
50    fn inverse(&self) -> Self {
51        let value = Self::static_inverse_with(self.value, self.modulo);
52        ModInt {
53            value,
54            modulo: self.modulo,
55        }
56    }
57
58    fn pow(self, mut r: T) -> Self {
59        let mut k = self;
60        let mut ret = ModInt::new_with(T::from(1).unwrap(), self.modulo);
61        let zero = T::from(0).unwrap();
62        let two = T::from(2).unwrap();
63        while r > zero {
64            if r % two != zero {
65                ret = ret * k;
66            }
67            r = r / two;
68            k = k * k;
69        }
70        ret
71    }
72
73    fn static_inverse_with(n: T, modulo: T) -> T {
74        let ExtendedGcd { x, .. } = n.to_i64().unwrap().extended_gcd(&modulo.to_i64().unwrap());
75
76        T::from(if x < 0 {
77            x + modulo.to_i64().unwrap()
78        } else {
79            x
80        })
81        .unwrap()
82    }
83}
84
85impl<T> Zero for ModInt<T>
86where
87    T: PrimInt,
88{
89    fn zero() -> Self {
90        ModInt {
91            value: T::from(0).unwrap(),
92            modulo: T::from(1000000007).unwrap(),
93        }
94    }
95
96    fn is_zero(&self) -> bool {
97        self.value == T::from(0).unwrap()
98    }
99}
100
101impl<T> One for ModInt<T>
102where
103    T: PrimInt,
104{
105    fn one() -> Self {
106        ModInt {
107            value: T::from(1).unwrap(),
108            modulo: T::from(1000000007).unwrap(),
109        }
110    }
111    fn is_one(&self) -> bool
112    where
113        Self: PartialEq,
114    {
115        self.value == T::from(1).unwrap()
116    }
117}
118
119impl<T> fmt::Display for ModInt<T>
120where
121    T: fmt::Display,
122{
123    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124        write!(f, "{}", self.value)
125    }
126}
127
128impl<T> Add for ModInt<T>
129where
130    T: PrimInt,
131{
132    type Output = ModInt<T>;
133
134    #[inline]
135    fn add(self, other: ModInt<T>) -> Self {
136        ModInt {
137            value: if self.value + other.value >= self.modulo {
138                (self.value + other.value) % self.modulo
139            } else {
140                self.value + other.value
141            },
142            modulo: self.modulo,
143        }
144    }
145}
146
147impl<T> Add<T> for ModInt<T>
148where
149    T: NumOps + PartialOrd + Copy,
150{
151    type Output = ModInt<T>;
152
153    #[inline]
154    fn add(self, rhs: T) -> Self {
155        ModInt {
156            value: if self.value + rhs >= self.modulo {
157                (self.value + rhs) % self.modulo
158            } else {
159                self.value + rhs
160            },
161            modulo: self.modulo,
162        }
163    }
164}
165
166macro_rules! impl_modint_add(($($ty:ty),*) => {
167    $(
168        impl<T> Add<ModInt<T>> for $ty
169        where
170            T: PrimInt,
171        {
172            type Output = ModInt<T>;
173
174            #[inline]
175            fn add(self, rhs: ModInt<T>) -> ModInt<T> {
176                ModInt {
177                    value: if T::from(self).unwrap() + rhs.value >= rhs.modulo {
178                        (T::from(self).unwrap() + rhs.value) % rhs.modulo
179                    } else {
180                        T::from(self).unwrap() + rhs.value
181                    },
182                    modulo: rhs.modulo,
183                }
184            }
185        }
186    )*
187});
188
189impl_modint_add!(i8, i16, i32, i64, u8, u16, u32, u64, isize, usize);
190
191impl<T> Sub for ModInt<T>
192where
193    T: PrimInt,
194{
195    type Output = ModInt<T>;
196
197    #[inline]
198    fn sub(self, other: ModInt<T>) -> Self {
199        ModInt {
200            value: if self.value < other.value {
201                self.value + self.modulo - other.value
202            } else {
203                self.value - other.value
204            },
205            modulo: self.modulo,
206        }
207    }
208}
209
210impl<T> Sub<T> for ModInt<T>
211where
212    T: PrimInt,
213{
214    type Output = ModInt<T>;
215
216    #[inline]
217    fn sub(self, rhs: T) -> Self {
218        ModInt {
219            value: if self.value < rhs {
220                self.value + self.modulo - rhs
221            } else {
222                self.value - rhs
223            },
224            modulo: self.modulo,
225        }
226    }
227}
228
229macro_rules! impl_modint_sub(($($ty:ty),*) => {
230    $(
231        impl<T> Sub<ModInt<T>> for $ty
232        where
233            T: PrimInt,
234        {
235            type Output = ModInt<T>;
236
237            #[inline]
238            fn sub(self, rhs: ModInt<T>) -> ModInt<T> {
239                ModInt {
240                    value: if T::from(self).unwrap() < rhs.value {
241                        T::from(self).unwrap() + rhs.modulo - rhs.value
242                    } else {
243                        (T::from(self).unwrap() - rhs.value) % rhs.modulo
244                    },
245                    modulo: rhs.modulo,
246                }
247            }
248        }
249    )*
250});
251
252impl_modint_sub!(i8, i16, i32, i64, u8, u16, u32, u64, isize, usize);
253
254impl<T> Mul for ModInt<T>
255where
256    T: PrimInt,
257{
258    type Output = ModInt<T>;
259
260    #[inline]
261    fn mul(self, other: ModInt<T>) -> Self {
262        ModInt {
263            value: (self.value * other.value) % self.modulo,
264            modulo: self.modulo,
265        }
266    }
267}
268
269impl<T> Mul<T> for ModInt<T>
270where
271    T: PrimInt,
272{
273    type Output = ModInt<T>;
274
275    #[inline]
276    fn mul(self, rhs: T) -> Self {
277        ModInt {
278            value: (self.value * rhs) % self.modulo,
279            modulo: self.modulo,
280        }
281    }
282}
283
284macro_rules! impl_modint_mul(($($ty:ty),*) => {
285    $(
286        impl<T> Mul<ModInt<T>> for $ty
287        where
288            T: PrimInt,
289        {
290            type Output = ModInt<T>;
291
292            #[inline]
293            fn mul(self, rhs: ModInt<T>) -> ModInt<T> {
294                ModInt {
295                    value: (T::from(self).unwrap() * rhs.value) % rhs.modulo,
296                    modulo: rhs.modulo,
297                }
298            }
299        }
300    )*
301});
302
303impl_modint_mul!(i8, i16, i32, i64, u8, u16, u32, u64, isize, usize);
304
305impl<T> Div for ModInt<T>
306where
307    T: PrimInt,
308{
309    type Output = ModInt<T>;
310
311    #[inline]
312    fn div(self, other: ModInt<T>) -> Self {
313        ModInt {
314            value: (self.value * other.inverse().value) % self.modulo,
315            modulo: self.modulo,
316        }
317    }
318}
319
320impl<T> Div<T> for ModInt<T>
321where
322    T: PrimInt,
323{
324    type Output = ModInt<T>;
325
326    #[inline]
327    fn div(self, rhs: T) -> Self {
328        let inv = Self::static_inverse_with(rhs, self.modulo);
329        ModInt {
330            value: (self.value * inv) % self.modulo,
331            modulo: self.modulo,
332        }
333    }
334}
335
336macro_rules! impl_modint_div(($($ty:ty),*) => {
337    $(
338        impl<T> Div<ModInt<T>> for $ty
339        where
340            T: PrimInt,
341        {
342            type Output = ModInt<T>;
343
344            #[inline]
345            fn div(self, rhs: ModInt<T>) -> ModInt<T> {
346                let inv = ModInt::static_inverse_with(rhs.value, rhs.modulo);
347                ModInt {
348                    value: (T::from(self).unwrap() * inv) % rhs.modulo,
349                    modulo: rhs.modulo,
350                }
351            }
352        }
353    )*
354});
355
356impl_modint_div!(i8, i16, i32, i64, u8, u16, u32, u64, isize, usize);
357
358impl<T> AddAssign<T> for ModInt<T>
359where
360    T: PrimInt,
361{
362    fn add_assign(&mut self, rhs: T) {
363        (*self).value = if self.value + rhs >= self.modulo {
364            (self.value + rhs) % self.modulo
365        } else {
366            self.value + rhs
367        }
368    }
369}
370
371impl<T> AddAssign<ModInt<T>> for ModInt<T>
372where
373    T: PrimInt,
374{
375    fn add_assign(&mut self, other: ModInt<T>) {
376        (*self).value = if self.value + other.value >= self.modulo {
377            (self.value + other.value) % self.modulo
378        } else {
379            self.value + other.value
380        }
381    }
382}
383
384impl<T> SubAssign<T> for ModInt<T>
385where
386    T: PrimInt,
387{
388    fn sub_assign(&mut self, rhs: T) {
389        (*self).value = if self.value < rhs {
390            self.value + self.modulo - rhs
391        } else {
392            self.value - rhs
393        }
394    }
395}
396
397impl<T> SubAssign<ModInt<T>> for ModInt<T>
398where
399    T: PrimInt,
400{
401    fn sub_assign(&mut self, other: ModInt<T>) {
402        (*self).value = if self.value < other.value {
403            self.value + self.modulo - other.value
404        } else {
405            self.value - other.value
406        }
407    }
408}
409
410impl<T> PartialEq<T> for ModInt<T>
411where
412    T: PrimInt,
413{
414    fn eq(&self, other: &T) -> bool {
415        self.value == *other
416    }
417}
418
419macro_rules! impl_modint_partial_eq(($($ty:ty),*) => {
420    $(
421        impl<T> PartialEq<ModInt<T>> for $ty
422        where
423            T: PrimInt,
424        {
425            #[inline]
426            fn eq(&self, other: &ModInt<T>) -> bool {
427                T::from(self.clone()).unwrap() == other.value.clone()
428            }
429        }
430    )*
431});
432
433impl_modint_partial_eq!(i8, i16, i32, i64, u8, u16, u32, u64, isize, usize);
434
435impl<T> PartialOrd<T> for ModInt<T>
436where
437    T: PrimInt,
438{
439    fn partial_cmp(&self, other: &T) -> Option<Ordering> {
440        Some(self.value.cmp(other))
441    }
442}
443
444macro_rules! impl_modint_partial_ord(($($ty:ty),*) => {
445    $(
446        impl<T> PartialOrd<ModInt<T>> for $ty
447        where
448            T: PrimInt,
449        {
450            #[inline]
451            fn partial_cmp(&self, other: &ModInt<T>) -> Option<Ordering> {
452                Some(T::from(self.clone()).unwrap().cmp(&other.value))
453            }
454        }
455    )*
456});
457
458impl_modint_partial_ord!(i8, i16, i32, i64, u8, u16, u32, u64, isize, usize);
459
460macro_rules! impl_modint_to_primitive(($(($ty:ty, $method:ident)),*) => {
461    $(
462        #[inline]
463        fn $method(&self) -> Option<$ty> {
464            self.value.$method()
465        }
466    )*
467});
468
469impl<T> ToPrimitive for ModInt<T>
470where
471    T: PrimInt,
472{
473    impl_modint_to_primitive!(
474        (i8, to_i8),
475        (i16, to_i16),
476        (i32, to_i32),
477        (i64, to_i64),
478        (u8, to_u8),
479        (u16, to_u16),
480        (u32, to_u32),
481        (u64, to_u64),
482        (isize, to_isize),
483        (usize, to_usize)
484    );
485}
486
487#[cfg(test)]
488mod test {
489    use super::*;
490
491    #[test]
492    fn test_modint_modint() {
493        const MOD: usize = 7;
494        let mi0 = ModInt::new_with(0, MOD);
495        let mi1 = ModInt::new_with(1, MOD);
496        let mi2 = ModInt::new_with(2, MOD);
497        let mi4 = ModInt::new_with(4, MOD);
498        let mi7 = ModInt::new_with(7, MOD);
499        let mi11 = ModInt::new_with(11, MOD);
500
501        assert_eq!(mi0 + mi7, ModInt::new_with(0, 7));
502        assert_eq!(mi1 + mi2, ModInt::new_with(3, 7));
503        assert_eq!(mi1 + mi11, ModInt::new_with(5, 7));
504        assert_eq!(mi1 - mi4, ModInt::new_with(4, 7));
505    }
506
507    #[test]
508    fn test_modint_other_type() {
509        const MOD: usize = 7;
510        let mi0 = ModInt::new_with(0, MOD);
511
512        assert_eq!(mi0 + 6, ModInt::new_with(6, MOD));
513        assert_eq!(mi0 + 7, ModInt::new_with(0, MOD));
514        assert_eq!(7usize + mi0, ModInt::new_with(0, MOD));
515        assert_eq!(15usize + mi0, ModInt::new_with(1, MOD));
516        assert_eq!(mi0 - 4, ModInt::new_with(3, MOD));
517        assert_eq!(mi0 - ModInt::new_with(0, MOD), ModInt::new_with(0, MOD));
518        assert_eq!(7usize - mi0, ModInt::new_with(0, MOD));
519    }
520
521    #[test]
522    fn test_new() {
523        let mi0 = ModInt::new(0u64);
524        let mi1 = ModInt::new(7u64);
525        let mi2 = ModInt::new(1000000007u64);
526
527        assert!(mi0 == 0);
528        assert_eq!(mi0, ModInt::new(0));
529        assert_eq!(mi1 + mi2, ModInt::new(7));
530        assert_eq!(mi0 - mi1, ModInt::new(1000000007 - 7));
531        assert_eq!(100 * mi1, ModInt::new(700u64));
532        assert_eq!(100u64 * mi1 * 2 / 10 / ModInt::new(5), ModInt::new(28));
533    }
534    #[test]
535    fn test_inverse() {
536        const MOD: u64 = 13;
537
538        assert_eq!(1, ModInt::new_with(1, MOD).inverse());
539        assert_eq!(7, ModInt::new_with(2, MOD).inverse());
540        assert_eq!(9, ModInt::new_with(3, MOD).inverse());
541        assert_eq!(10, ModInt::new_with(4, MOD).inverse());
542        assert_eq!(8, ModInt::new_with(5, MOD).inverse());
543        assert_eq!(11, ModInt::new_with(6, MOD).inverse());
544        assert_eq!(2, ModInt::new_with(7, MOD).inverse());
545        assert_eq!(5, ModInt::new_with(8, MOD).inverse());
546        assert_eq!(3, ModInt::new_with(9, MOD).inverse());
547        assert_eq!(4, ModInt::new_with(10, MOD).inverse());
548        assert_eq!(6, ModInt::new_with(11, MOD).inverse());
549        assert_eq!(12, ModInt::new_with(12, MOD).inverse());
550    }
551
552    #[test]
553    fn test_div() {
554        const MOD: u64 = 13;
555
556        assert_eq!(4, (ModInt::new_with(2, MOD) / ModInt::new_with(7, MOD)));
557        assert_eq!(4, (2u64 / ModInt::new_with(7, MOD)));
558        assert_eq!(4, (ModInt::new_with(2, MOD) / 7));
559    }
560
561    #[test]
562    fn test_mul() {
563        const MOD: u64 = 13;
564
565        assert_eq!(2, ModInt::new_with(3, MOD) * ModInt::new_with(5, MOD));
566        assert_eq!(2, ModInt::new_with(3, MOD) * 5);
567        assert_eq!(2, 3 * ModInt::new_with(5, MOD));
568    }
569
570    #[test]
571    fn test_assign() {
572        const MOD: u64 = 13;
573
574        let mut t = ModInt::new_with(3, MOD) + ModInt::new_with(5, MOD);
575        t += 7;
576        assert_eq!(2, t);
577        t -= 4;
578        assert_eq!(11, t);
579        t += ModInt::new_with(5, MOD);
580        assert_eq!(3, t);
581        t -= ModInt::new_with(20, MOD);
582        assert_eq!(9, t);
583    }
584
585    #[test]
586    fn test_partialord() {
587        const MOD: u64 = 13;
588
589        assert!(ModInt::new_with(3, MOD) < ModInt::new_with(5, MOD));
590        assert!(3 < ModInt::new_with(5, MOD));
591        assert!(ModInt::new_with(3, MOD) < 5);
592        assert!(!(ModInt::new(10) < 7));
593    }
594
595    #[test]
596    fn test_to_primitive() {
597        assert_ne!(2, ModInt::new(13).to_u64().unwrap());
598    }
599}