use num_integer::Integer;
use num_traits::identities::{One, Zero};
use num_traits::{Num, Pow};
use std::cmp::Ordering;
use std::convert::TryInto;
use std::num::NonZeroU32;
use std::num::ParseIntError;
use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Rem, RemAssign, Sub, SubAssign};
use cargo_snippet::snippet;
#[snippet("modint")]
fn compensated_rem(n: i64, m: usize) -> i64 {
match n % m as i64 {
r if r >= 0 => r,
r => r + m as i64,
}
}
#[snippet("modint")]
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum Modulo {
Static(NonZeroU32),
Dynamic,
}
#[snippet("modint")]
impl Modulo {
pub fn get(&self) -> Option<u32> {
match self {
Modulo::Static(nz) => Some(nz.get()),
Modulo::Dynamic => None,
}
}
}
#[snippet("modint")]
#[derive(Debug, Clone, Copy)]
pub struct ModInt {
num: i64,
_modulo: Modulo,
}
#[snippet("modint")]
impl Into<usize> for ModInt {
fn into(self) -> usize {
self.get() as usize
}
}
#[snippet("modint")]
pub trait IntoModInt: Copy {
fn to_mint<M: TryInto<u32> + Copy>(self, modulo: M) -> ModInt;
}
#[snippet("modint")]
macro_rules! impl_into_mint {
($($t:ty),*) => {
$(
impl IntoModInt for $t {
fn to_mint<M: TryInto<u32> + Copy>(self, modulo: M) -> ModInt {
ModInt::new(self, modulo)
}
}
)*
};
}
#[snippet("modint")]
impl_into_mint!(usize, u8, u16, u32, u64, isize, i8, i16, i32, i64);
#[snippet("modint")]
impl PartialEq for ModInt {
fn eq(&self, other: &Self) -> bool {
if !check_mod_eq(self, other).1 {
panic!("cannot compare these values because they have different modulo number")
}
self.get() == other.num
}
}
#[snippet("modint")]
impl PartialOrd for ModInt {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
if !check_mod_eq(self, other).1 {
None
} else {
Some(self.get().cmp(&other.num))
}
}
}
#[snippet("modint")]
fn check_mod_eq(a: &ModInt, b: &ModInt) -> (NonZeroU32, bool) {
match (a._modulo, b._modulo) {
(Modulo::Static(a), Modulo::Static(b)) => {
if a == b {
(a, true)
} else {
(unsafe { NonZeroU32::new_unchecked(1) }, false)
}
}
(Modulo::Static(m), Modulo::Dynamic) | (Modulo::Dynamic, Modulo::Static(m)) => (m, true),
(Modulo::Dynamic, Modulo::Dynamic) => (unsafe { NonZeroU32::new_unchecked(1) }, false),
}
}
#[snippet("modint")]
impl ModInt {
pub fn new<N: TryInto<i64>, M: TryInto<u32> + Copy>(n: N, m: M) -> Self {
let m = NonZeroU32::new(m.try_into().ok().expect("modulo number may be wrong")).unwrap();
let r = n
.try_into()
.ok()
.expect("modulo number maybe over i64 range");
let num = compensated_rem(r, m.get() as usize);
Self {
num,
_modulo: Modulo::Static(m),
}
}
pub fn get(&self) -> i64 {
self.num
}
pub fn get_mod(&self) -> usize {
self._modulo.get().unwrap() as usize
}
pub fn pow_mod(&self, mut exp: usize) -> Self {
let mut res = 1;
let mut base = self.get() as usize;
let m = self.get_mod();
while exp > 0 {
if exp & 1 != 0 {
res *= base;
res %= m;
}
base *= base;
base %= m;
exp >>= 1;
}
Self::new(res, self.get_mod())
}
pub fn inv(&self) -> i64 {
let x = self.get().extended_gcd(&(self.get_mod() as i64)).x;
compensated_rem(x, self.get_mod())
}
}
#[test]
fn mint_new() {
let m = ModInt::new(10, 3);
assert_eq!(m.get(), 1);
let m = ModInt::new(-10, 3);
assert_eq!(m.get(), 2);
let x = 4.to_mint(10); let y = ModInt::new(4, 10);
assert_eq!(x, y);
}
#[test]
fn inv_test() {
let a = ModInt::new(6, 13);
assert_eq!(a.inv(), 11);
}
#[snippet("modint")]
impl Add<Self> for ModInt {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
let c = check_mod_eq(&self, &rhs);
if !c.1 {
panic!("modulo between two instance is different!")
}
let r = self.get() + rhs.num;
Self {
num: if r >= self.get_mod() as i64 {
r - c.0.get() as i64
} else {
r
},
_modulo: Modulo::Static(c.0),
}
}
}
#[snippet("modint")]
impl AddAssign<Self> for ModInt {
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs;
}
}
#[test]
fn mint_add() {
let a = ModInt::new(13, 8); let b = ModInt::new(10, 8); assert_eq!((a + b).get(), 7);
let c = ModInt::new(7, 8); assert_eq!((a + c).get(), 4); }
#[snippet("modint")]
impl Sub<Self> for ModInt {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
let c = check_mod_eq(&self, &rhs);
if !c.1 {
panic!("modulo between two instance is different!")
}
let num = compensated_rem(self.get() - rhs.get(), c.0.get() as usize);
Self {
num,
_modulo: Modulo::Static(c.0),
}
}
}
#[snippet("modint")]
impl SubAssign<Self> for ModInt {
fn sub_assign(&mut self, rhs: Self) {
*self = *self - rhs;
}
}
#[test]
fn mint_sub() {
let a = ModInt::new(2, 10);
let b = ModInt::new(3, 10);
assert_eq!((b - a).get(), 1);
assert_eq!((a - b).get(), 9);
}
#[snippet("modint")]
impl Mul<Self> for ModInt {
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
let c = check_mod_eq(&self, &rhs);
if !c.1 {
panic!("modulo between two instance is different!")
}
let num = compensated_rem(self.get() * rhs.get(), c.0.get() as usize);
Self {
num,
_modulo: Modulo::Static(c.0),
}
}
}
#[snippet("modint")]
impl MulAssign<Self> for ModInt {
fn mul_assign(&mut self, rhs: Self) {
*self = *self * rhs
}
}
#[snippet("modint")]
impl Div<Self> for ModInt {
type Output = Self;
fn div(self, rhs: Self) -> Self::Output {
let c = check_mod_eq(&self, &rhs);
if !c.1 {
panic!("modulo between two instance is different!")
}
Self {
num: self.get() * rhs.inv() % c.0.get() as i64,
_modulo: Modulo::Static(c.0),
}
}
}
#[snippet("modint")]
impl DivAssign<Self> for ModInt {
fn div_assign(&mut self, rhs: Self) {
*self = *self / rhs;
}
}
#[test]
fn div_test() {
let a = ModInt::new(2, 5);
let b = ModInt::new(3, 5);
assert_eq!(a / b, ModInt::new(4, 5));
let x = ModInt::new(1, 13);
assert_eq!((x / 4i64).get(), 10);
let x = ModInt::new(2, 13);
assert_eq!((x / 4i64).get(), 7);
let x = ModInt::new(3, 13);
assert_eq!((x / 4i64).get(), 4);
let x = ModInt::new(4, 13);
assert_eq!((x / 4i64).get(), 1);
let x = ModInt::new(5, 13);
assert_eq!((x / 4i64).get(), 11);
let x = ModInt::new(6, 13);
assert_eq!((x / 4i64).get(), 8);
let x = ModInt::new(7, 13);
assert_eq!((x / 4i64).get(), 5);
let x = ModInt::new(8, 13);
assert_eq!((x / 4i64).get(), 2);
let x = ModInt::new(9, 13);
assert_eq!((x / 4i64).get(), 12);
let x = ModInt::new(10, 13);
assert_eq!((x / 4i64).get(), 9);
let x = ModInt::new(11, 13);
assert_eq!((x / 4i64).get(), 6);
let x = ModInt::new(12, 13);
assert_eq!((x / 4i64).get(), 3);
}
#[snippet("modint")]
impl Rem for ModInt {
type Output = Self;
fn rem(self, rhs: Self) -> Self::Output {
let c = check_mod_eq(&self, &rhs);
if !c.1 {
panic!("modulo between two instance is different!")
}
Self {
num: self.num % rhs.num,
_modulo: Modulo::Static(c.0),
}
}
}
#[snippet("modint")]
impl RemAssign for ModInt {
fn rem_assign(&mut self, rhs: Self) {
*self = *self % rhs
}
}
#[snippet("modint")]
impl Zero for ModInt {
fn zero() -> Self {
ModInt {
num: 0,
_modulo: Modulo::Dynamic,
}
}
fn is_zero(&self) -> bool {
self.num == 0
}
}
#[snippet("modint")]
impl One for ModInt {
fn one() -> Self {
ModInt {
num: 1,
_modulo: Modulo::Dynamic,
}
}
fn is_one(&self) -> bool {
self.num == 1
}
}
#[snippet("modint")]
impl Num for ModInt {
type FromStrRadixErr = ParseIntError;
fn from_str_radix(str: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
let num = str
.chars()
.rev()
.enumerate()
.map(|(i, b)| radix.pow(i as u32) as i64 * b.to_digit(radix).unwrap() as i64)
.sum::<i64>();
Ok(ModInt {
num,
_modulo: Modulo::Dynamic,
})
}
}
#[snippet("modint")]
impl Pow<usize> for ModInt {
type Output = Self;
fn pow(mut self, mut exp: usize) -> Self::Output {
if exp == 0 {
return Self::one();
}
while exp & 1 == 0 {
self = self * self;
exp >>= 1;
}
if exp == 1 {
return self;
}
let mut acc = self;
while exp > 1 {
exp >>= 1;
self = self * self;
if exp & 1 == 1 {
acc = acc * self;
}
}
acc
}
}
#[test]
fn pow_test() {
let a = ModInt::new(3, 10);
assert_eq!(a.pow(3).get(), 7);
let b = ModInt::new(100, 9999);
assert_eq!(b.pow(2).get(), 1);
}
#[snippet("modint")]
macro_rules! impl_ops_between_mint_and_primitive {
($($t:ty),*) => {
$(
impl Add<$t> for ModInt {
type Output = Self;
fn add(self, rhs: $t) -> Self::Output {
self + Self::new(rhs as i64, self.get_mod())
}
}
impl AddAssign<$t> for ModInt {
fn add_assign(&mut self, rhs: $t) {
*self = *self + rhs;
}
}
impl Sub<$t> for ModInt {
type Output = Self;
fn sub(self, rhs: $t) -> Self::Output {
self - Self::new(rhs as i64, self.get_mod())
}
}
impl SubAssign<$t> for ModInt {
fn sub_assign(&mut self, rhs: $t) {
*self = *self - rhs;
}
}
impl Mul<$t> for ModInt {
type Output = Self;
fn mul(self, rhs: $t) -> Self::Output {
self * Self::new(rhs as i64, self.get_mod())
}
}
impl MulAssign<$t> for ModInt {
fn mul_assign(&mut self, rhs: $t) {
*self = *self * rhs;
}
}
impl Div<$t> for ModInt {
type Output = Self;
fn div(self, rhs: $t) -> Self::Output {
self / Self::new(rhs as i64, self.get_mod())
}
}
impl DivAssign<$t> for ModInt {
fn div_assign(&mut self, rhs: $t) {
*self = *self / rhs;
}
}
)*
};
}
#[snippet("modint")]
impl_ops_between_mint_and_primitive!(usize, u8, u16, u32, u64, isize, i8, i16, i32, i64);
#[test]
fn op_between_different_type() {
let mut mint = ModInt::new(1, 10);
mint += 1;
assert_eq!(mint.get(), 2);
mint *= 2;
assert_eq!(mint.get(), 4);
mint += 10001;
assert_eq!(mint.get(), 5);
}