const_modulo_ring/
lib.rs

1use core::{
2    ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign},
3    str::FromStr,
4};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
7pub struct ConstModulo32<const M: u32>(u32);
8
9impl<const M: u32> From<u32> for ConstModulo32<M> {
10    fn from(x: u32) -> Self {
11        Self(x % M)
12    }
13}
14#[auto_impl_ops::auto_ops]
15impl<const M: u32> AddAssign<&ConstModulo32<M>> for ConstModulo32<M> {
16    fn add_assign(&mut self, other: &Self) {
17        let t = self.0 + other.0;
18        self.0 = if t >= M { t - M } else { t };
19    }
20}
21#[auto_impl_ops::auto_ops]
22impl<const M: u32> SubAssign<&ConstModulo32<M>> for ConstModulo32<M> {
23    fn sub_assign(&mut self, other: &Self) {
24        let t = self.0 + (M - other.0);
25        self.0 = if t >= M { t - M } else { t };
26    }
27}
28impl<const M: u32> Neg for ConstModulo32<M> {
29    type Output = Self;
30    fn neg(self) -> Self::Output {
31        Self(M - self.0)
32    }
33}
34impl<const M: u32> Neg for &ConstModulo32<M> {
35    type Output = ConstModulo32<M>;
36    fn neg(self) -> Self::Output {
37        ConstModulo32(M - self.0)
38    }
39}
40#[auto_impl_ops::auto_ops]
41impl<const M: u32> MulAssign<&ConstModulo32<M>> for ConstModulo32<M> {
42    fn mul_assign(&mut self, other: &Self) {
43        self.0 = ((self.0 as u64) * (other.0 as u64) % (M as u64)) as u32;
44    }
45}
46#[auto_impl_ops::auto_ops]
47impl<const M: u32> DivAssign<&ConstModulo32<M>> for ConstModulo32<M> {
48    fn div_assign(&mut self, other: &Self) {
49        let t = ring_algorithm::modulo_division(self.0 as i64, other.0 as i64, M as i64)
50            .expect("Can't divide");
51        let t = if t < 0 { t + M as i64 } else { t };
52        self.0 = t as u32;
53    }
54}
55impl<const M: u32> FromStr for ConstModulo32<M> {
56    type Err = std::num::ParseIntError;
57    fn from_str(s: &str) -> Result<Self, Self::Err> {
58        let mut v = s.parse::<i128>()?;
59        v %= M as i128;
60        if v < 0 {
61            v += M as i128;
62        }
63        Ok(Self(v as u32))
64    }
65}
66impl<const M: u32> std::fmt::Display for ConstModulo32<M> {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
68        write!(f, "{}", self.0)
69    }
70}
71
72#[cfg(feature = "num-traits")]
73mod impl_num_traits {
74    use crate::ConstModulo32;
75    use num_traits::{ConstOne, ConstZero, One, Zero};
76    impl<const M: u32> Zero for ConstModulo32<M> {
77        fn zero() -> Self {
78            Self::ZERO
79        }
80        fn is_zero(&self) -> bool {
81            *self == Self::ZERO
82        }
83    }
84    impl<const M: u32> ConstZero for ConstModulo32<M> {
85        const ZERO: Self = Self(0);
86    }
87    impl<const M: u32> One for ConstModulo32<M> {
88        fn one() -> Self {
89            Self::ONE
90        }
91    }
92    impl<const M: u32> ConstOne for ConstModulo32<M> {
93        const ONE: Self = Self(1);
94    }
95}
96
97#[test]
98fn add() {
99    let a = ConstModulo32::<6>(2);
100    let b = ConstModulo32::<6>(3);
101    let c = ConstModulo32::<6>(4);
102    let d = ConstModulo32::<6>(5);
103    let e = ConstModulo32::<6>(1);
104    let f = ConstModulo32::<6>(0);
105    assert_eq!(a + b, d);
106    assert_eq!(b + c, e);
107    assert_eq!(b + b, f);
108}
109#[test]
110fn sub() {
111    let a = ConstModulo32::<6>(3);
112    let b = ConstModulo32::<6>(2);
113    let c = ConstModulo32::<6>(4);
114    let d = ConstModulo32::<6>(1);
115    let e = ConstModulo32::<6>(4);
116    let f = ConstModulo32::<6>(0);
117    assert_eq!(a - b, d);
118    assert_eq!(b - c, e);
119    assert_eq!(a - a, f);
120}
121#[test]
122fn mul() {
123    let a = ConstModulo32::<6>(2);
124    let b = ConstModulo32::<6>(5);
125    let c = ConstModulo32::<6>(3);
126    let d = ConstModulo32::<6>(4);
127    let e = ConstModulo32::<6>(3);
128    let f = ConstModulo32::<6>(0);
129    assert_eq!(a * b, d);
130    assert_eq!(b * c, e);
131    assert_eq!(a * c, f);
132}
133#[test]
134fn div() {
135    let a = ConstModulo32::<5>(2);
136    let b = ConstModulo32::<5>(3);
137    let c = ConstModulo32::<5>(4);
138    let d = ConstModulo32::<5>(4);
139    let e = ConstModulo32::<5>(2);
140    let f = ConstModulo32::<5>(3);
141    assert_eq!(a / b, d);
142    assert_eq!(b / c, e);
143    assert_eq!(a / c, f);
144}