dsalgo/
default_static_modular_arithmetic.rs

1use crate::{
2    modular_inverse_euclidean_u64::modinv,
3    static_modular_arithmetic_trait::ModularArithmetic,
4    static_modulus_trait::Get,
5};
6
7/// why `default`?
8/// because there exists other modular arithmetic implementations.
9/// e.g. Montgomery Multiplication, or Burrett Reduction.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
11
12pub struct DefaultStatic<T, M: Get<T = T>>(std::marker::PhantomData<(T, M)>);
13
14macro_rules! impl_default_static {
15    ($uint:ty, $mul_cast_uint:ty) => {
16        impl<M: Get<T = $uint>> ModularArithmetic for DefaultStatic<$uint, M> {
17            type T = $uint;
18
19            fn modulus() -> Self::T {
20                M::get()
21            }
22
23            fn add(
24                lhs: Self::T,
25                rhs: Self::T,
26            ) -> Self::T {
27                assert!(lhs < M::get() && rhs < M::get());
28
29                let mut x = lhs;
30
31                x += rhs;
32
33                if x >= M::get() {
34                    x -= M::get();
35                }
36
37                x
38            }
39
40            fn neg(x: Self::T) -> Self::T {
41                assert!(x < M::get());
42
43                if x == 0 {
44                    0
45                } else {
46                    M::get() - x
47                }
48            }
49
50            fn mul(
51                lhs: Self::T,
52                rhs: Self::T,
53            ) -> Self::T {
54                let mut x = lhs as $mul_cast_uint;
55
56                x *= rhs as $mul_cast_uint;
57
58                x %= M::get() as $mul_cast_uint;
59
60                x as Self::T
61            }
62
63            fn inv(x: $uint) -> Self::T {
64                assert!(x > 0);
65
66                modinv(M::get() as u64, x as u64).unwrap() as Self::T
67            }
68        }
69    };
70}
71
72impl_default_static!(u32, u64);
73
74impl_default_static!(u64, u128);
75
76// TODO: change later. still not compile on AtCoder.
77// use crate::modular::modulus::ConstMod32;
78// #[allow(dead_code)]
79// pub type Modular1_000_000_007 =
80//     DefaultStatic<u32, ConstMod32<1_000_000_007>>;
81// #[allow(dead_code)]
82// pub type Modular998_244_353 =
83//     DefaultStatic<u32, ConstMod32<998_244_353>>;
84use crate::define_const_modulus_macro::{
85    Mod1_000_000_007,
86    Mod998_244_353,
87};
88
89#[allow(dead_code)]
90
91pub type Modular1_000_000_007 = DefaultStatic<u32, Mod1_000_000_007>;
92
93#[allow(dead_code)]
94
95pub type Modular998_244_353 = DefaultStatic<u32, Mod998_244_353>;
96
97#[cfg(test)]
98
99mod tests {
100
101    use super::*;
102
103    #[test]
104
105    fn test() {
106        use crate::modular_int_with_arithmetic::Modint;
107
108        type Mint = Modint<u32, Modular1_000_000_007>;
109
110        let a = Mint::from(1_000_000_008);
111
112        assert_eq!(a.value(), 1);
113    }
114}