competitive-programming-rs 41.0.0

Competitive Programming Library in Rust
Documentation
pub mod mod_int {
    type ModInternalNum = i64;
    thread_local!(
        static MOD: std::cell::RefCell<ModInternalNum> = std::cell::RefCell::new(0);
    );

    pub fn set_mod_int<T: ToInternalNum>(v: T) {
        MOD.with(|x| x.replace(v.to_internal_num()));
    }
    fn modulo() -> ModInternalNum {
        MOD.with(|x| *x.borrow())
    }

    #[derive(Debug)]
    pub struct ModInt(ModInternalNum);
    impl Clone for ModInt {
        fn clone(&self) -> Self {
            Self(self.0)
        }
    }
    impl Copy for ModInt {}

    impl ModInt {
        fn internal_new(mut v: ModInternalNum) -> Self {
            let m = modulo();
            if v >= m {
                v %= m;
            }
            Self(v)
        }

        pub fn internal_pow(&self, mut e: ModInternalNum) -> Self {
            let mut result = 1;
            let mut cur = self.0;
            let modulo = modulo();
            while e > 0 {
                if e & 1 == 1 {
                    result *= cur;
                    result %= modulo;
                }
                e >>= 1;
                cur = (cur * cur) % modulo;
            }
            Self(result)
        }

        pub fn pow<T>(&self, e: T) -> Self
        where
            T: ToInternalNum,
        {
            self.internal_pow(e.to_internal_num())
        }

        pub fn value(&self) -> ModInternalNum {
            self.0
        }
    }

    pub trait ToInternalNum {
        fn to_internal_num(&self) -> ModInternalNum;
    }
    impl ToInternalNum for ModInt {
        fn to_internal_num(&self) -> ModInternalNum {
            self.0
        }
    }
    macro_rules! impl_primitive {
        ($primitive:ident) => {
            impl From<$primitive> for ModInt {
                fn from(v: $primitive) -> Self {
                    let v = v as ModInternalNum;
                    Self::internal_new(v)
                }
            }
            impl ToInternalNum for $primitive {
                fn to_internal_num(&self) -> ModInternalNum {
                    *self as ModInternalNum
                }
            }
        };
    }
    impl_primitive!(u8);
    impl_primitive!(u16);
    impl_primitive!(u32);
    impl_primitive!(u64);
    impl_primitive!(usize);
    impl_primitive!(i8);
    impl_primitive!(i16);
    impl_primitive!(i32);
    impl_primitive!(i64);
    impl_primitive!(isize);

    impl<T: ToInternalNum> std::ops::AddAssign<T> for ModInt {
        fn add_assign(&mut self, rhs: T) {
            let mut rhs = rhs.to_internal_num();
            let m = modulo();
            if rhs >= m {
                rhs %= m;
            }

            self.0 += rhs;
            if self.0 >= m {
                self.0 -= m;
            }
        }
    }

    impl<T: ToInternalNum> std::ops::Add<T> for ModInt {
        type Output = ModInt;
        fn add(self, rhs: T) -> Self::Output {
            let mut res = self;
            res += rhs;
            res
        }
    }
    impl<T: ToInternalNum> std::ops::SubAssign<T> for ModInt {
        fn sub_assign(&mut self, rhs: T) {
            let mut rhs = rhs.to_internal_num();
            let m = modulo();
            if rhs >= m {
                rhs %= m;
            }
            if rhs > 0 {
                self.0 += m - rhs;
            }
            if self.0 >= m {
                self.0 -= m;
            }
        }
    }
    impl<T: ToInternalNum> std::ops::Sub<T> for ModInt {
        type Output = Self;
        fn sub(self, rhs: T) -> Self::Output {
            let mut res = self;
            res -= rhs;
            res
        }
    }
    impl<T: ToInternalNum> std::ops::MulAssign<T> for ModInt {
        fn mul_assign(&mut self, rhs: T) {
            let mut rhs = rhs.to_internal_num();
            let m = modulo();
            if rhs >= m {
                rhs %= m;
            }
            self.0 *= rhs;
            self.0 %= m;
        }
    }
    impl<T: ToInternalNum> std::ops::Mul<T> for ModInt {
        type Output = Self;
        fn mul(self, rhs: T) -> Self::Output {
            let mut res = self;
            res *= rhs;
            res
        }
    }

    impl<T: ToInternalNum> std::ops::DivAssign<T> for ModInt {
        fn div_assign(&mut self, rhs: T) {
            let mut rhs = rhs.to_internal_num();
            let m = modulo();
            if rhs >= m {
                rhs %= m;
            }
            let inv = Self(rhs).internal_pow(m - 2);
            self.0 *= inv.value();
            self.0 %= m;
        }
    }

    impl<T: ToInternalNum> std::ops::Div<T> for ModInt {
        type Output = Self;
        fn div(self, rhs: T) -> Self::Output {
            let mut res = self;
            res /= rhs;
            res
        }
    }
}

#[cfg(test)]
mod test {
    use super::mod_int::*;
    use rand::distributions::Uniform;
    use rand::Rng;

    const PRIME_MOD: [i64; 3] = [1_000_000_007, 1_000_000_009, 998244353];
    const INF: i64 = 1 << 60;

    fn random_add_sub(prime_mod: i64) {
        let mut rng = rand::thread_rng();
        set_mod_int(prime_mod);
        for _ in 0..10000 {
            let x: i64 = rng.sample(Uniform::from(0..prime_mod));
            let y: i64 = rng.sample(Uniform::from(0..prime_mod));

            let mx = ModInt::from(x);
            let my = ModInt::from(y);

            assert_eq!((mx + my).value(), (x + y) % prime_mod);
            assert_eq!((mx + y).value(), (x + y) % prime_mod);
            assert_eq!((mx - my).value(), (x + prime_mod - y) % prime_mod);
            assert_eq!((mx - y).value(), (x + prime_mod - y) % prime_mod);

            let mut x = x;
            let mut mx = mx;
            x += y;
            mx += my;
            assert_eq!(mx.value(), x % prime_mod);

            mx += y;
            x += y;
            assert_eq!(mx.value(), x % prime_mod);

            mx -= my;
            x = (x + prime_mod - y % prime_mod) % prime_mod;
            assert_eq!(mx.value(), x);

            mx -= y;
            x = (x + prime_mod - y % prime_mod) % prime_mod;
            assert_eq!(mx.value(), x);
        }
    }

    #[test]
    fn test_random_add_sub1() {
        random_add_sub(PRIME_MOD[0]);
    }

    #[test]
    fn test_random_add_sub2() {
        random_add_sub(PRIME_MOD[1]);
    }

    #[test]
    fn test_random_add_sub3() {
        random_add_sub(PRIME_MOD[2]);
    }

    fn random_mul(prime_mod: i64) {
        let mut rng = rand::thread_rng();
        set_mod_int(prime_mod);
        for _ in 0..10000 {
            let x: i64 = rng.sample(Uniform::from(0..prime_mod));
            let y: i64 = rng.sample(Uniform::from(0..prime_mod));

            let mx = ModInt::from(x);
            let my = ModInt::from(y);

            assert_eq!((mx * my).value(), (x * y) % prime_mod);
            assert_eq!((mx * y).value(), (x * y) % prime_mod);
        }
    }
    #[test]
    fn test_random_mul1() {
        random_mul(PRIME_MOD[0]);
    }
    #[test]
    fn test_random_mul2() {
        random_mul(PRIME_MOD[1]);
    }
    #[test]
    fn test_random_mul3() {
        random_mul(PRIME_MOD[2]);
    }

    #[test]
    fn zero_test() {
        set_mod_int(1_000_000_007i64);
        let a = ModInt::from(1_000_000_000i64);
        let b = ModInt::from(7i64);
        let c = a + b;
        assert_eq!(c.value(), 0);
    }

    #[test]
    fn pow_test() {
        set_mod_int(1_000_000_007i64);
        let a = ModInt::from(3i64);
        let a = a.pow(4i64);
        assert_eq!(a.value(), 81);
    }

    #[test]
    fn div_test() {
        set_mod_int(1_000_000_007i64);
        for i in 1..100000i64 {
            let mut a = ModInt::from(1i64);
            a /= i;
            a *= i;
            assert_eq!(a.value(), 1);
        }
    }

    #[test]
    fn edge_cases() {
        const MOD: i128 = 1_000_000_007;
        set_mod_int(1_000_000_007i64);

        let a = ModInt::from(1_000_000_000i64) * INF;
        assert_eq!(
            a.value(),
            ((1_000_000_000i128 * i128::from(INF)) % MOD) as i64
        );

        let mut a = ModInt::from(1_000_000_000i64);
        a *= INF;
        assert_eq!(
            a.value(),
            ((1_000_000_000i128 * i128::from(INF)) % MOD) as i64
        );

        let a = ModInt::from(1_000_000_000i64) + INF;
        assert_eq!(
            a.value(),
            ((1_000_000_000i128 + i128::from(INF)) % MOD) as i64
        );

        let mut a = ModInt::from(1_000_000_000i64);
        a += INF;
        assert_eq!(
            a.value(),
            ((1_000_000_000i128 + i128::from(INF)) % MOD) as i64
        );

        let a = ModInt::from(1_000_000_000i64) - INF;
        assert_eq!(
            a.value(),
            ((1_000_000_000i128 + MOD - (INF as i128) % MOD) % MOD) as i64
        );

        let mut a = ModInt::from(1_000_000_000i64);
        a -= INF;
        assert_eq!(
            a.value(),
            ((1_000_000_000i128 + MOD - (INF as i128) % MOD) % MOD) as i64
        );

        let a = ModInt::from(1_000_000_000i64) / INF;
        assert_eq!(a.value(), 961239577);

        let mut a = ModInt::from(1_000_000_000i64);
        a /= INF;
        assert_eq!(a.value(), 961239577);
    }

    #[test]
    fn overflow_guard() {
        set_mod_int(1_000_000_007i64);
        let a = ModInt::from(1_000_000_007i64 * 10);
        assert_eq!(a.value(), 0);
    }

    #[test]
    fn initialize_from_various_primitives() {
        set_mod_int(1_000_000_007);
        let a = ModInt::from(100usize);
        let b = ModInt::from(100i64);
        assert_eq!(a.value(), b.value());
    }
}