const_modulo_ring/
lib.rs

1use core::{
2    ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign},
3    str::FromStr,
4};
5macro_rules! generate {
6    ($name: ident, $base:ty, $wide:ty, $signed_wide: ty, $num_traits: ident, $test_name: ident) => {
7        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
8        pub struct $name<const M: $base>($base);
9
10        impl<const M: $base> From<$base> for $name<M> {
11            fn from(x: $base) -> Self {
12                Self(x % M)
13            }
14        }
15        #[auto_impl_ops::auto_ops]
16        impl<const M: $base> AddAssign<&$name<M>> for $name<M> {
17            fn add_assign(&mut self, other: &Self) {
18                let t = self.0 + other.0;
19                self.0 = if t >= M { t - M } else { t };
20            }
21        }
22        #[auto_impl_ops::auto_ops]
23        impl<const M: $base> SubAssign<&$name<M>> for $name<M> {
24            fn sub_assign(&mut self, other: &Self) {
25                let t = self.0 + (M - other.0);
26                self.0 = if t >= M { t - M } else { t };
27            }
28        }
29        impl<const M: $base> Neg for $name<M> {
30            type Output = Self;
31            fn neg(self) -> Self::Output {
32                Self(M - self.0)
33            }
34        }
35        impl<const M: $base> Neg for &$name<M> {
36            type Output = $name<M>;
37            fn neg(self) -> Self::Output {
38                $name(M - self.0)
39            }
40        }
41        #[auto_impl_ops::auto_ops]
42        impl<const M: $base> MulAssign<&$name<M>> for $name<M> {
43            fn mul_assign(&mut self, other: &Self) {
44                self.0 = ((self.0 as $wide) * (other.0 as $wide) % (M as $wide)) as $base;
45            }
46        }
47        #[auto_impl_ops::auto_ops]
48        impl<const M: $base> DivAssign<&$name<M>> for $name<M> {
49            fn div_assign(&mut self, other: &Self) {
50                let t = ring_algorithm::modulo_division(
51                    self.0 as $signed_wide,
52                    other.0 as $signed_wide,
53                    M as $signed_wide,
54                )
55                .expect("Can't divide");
56                let t = if t < 0 { t + M as $signed_wide } else { t };
57                self.0 = t as $base;
58            }
59        }
60        impl<const M: $base> FromStr for $name<M> {
61            type Err = std::num::ParseIntError;
62            fn from_str(s: &str) -> Result<Self, Self::Err> {
63                let mut v = s.parse::<i128>()?;
64                v %= M as i128;
65                if v < 0 {
66                    v += M as i128;
67                }
68                Ok(Self(v as $base))
69            }
70        }
71        impl<const M: $base> std::fmt::Display for $name<M> {
72            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
73                write!(f, "{}", self.0)
74            }
75        }
76
77        #[cfg(feature = "num-traits")]
78        mod $num_traits {
79            use crate::$name;
80            use num_traits::{ConstOne, ConstZero, One, Zero};
81            impl<const M: $base> Zero for $name<M> {
82                fn zero() -> Self {
83                    Self::ZERO
84                }
85                fn is_zero(&self) -> bool {
86                    *self == Self::ZERO
87                }
88            }
89            impl<const M: $base> ConstZero for $name<M> {
90                const ZERO: Self = Self(0);
91            }
92            impl<const M: $base> One for $name<M> {
93                fn one() -> Self {
94                    Self::ONE
95                }
96            }
97            impl<const M: $base> ConstOne for $name<M> {
98                const ONE: Self = Self(1);
99            }
100        }
101        #[cfg(test)]
102        mod $test_name {
103            use super::*;
104            #[test]
105            fn add() {
106                let a = $name::<6>(2);
107                let b = $name::<6>(3);
108                let c = $name::<6>(4);
109                let d = $name::<6>(5);
110                let e = $name::<6>(1);
111                let f = $name::<6>(0);
112                assert_eq!(a + b, d);
113                assert_eq!(b + c, e);
114                assert_eq!(b + b, f);
115            }
116            #[test]
117            fn sub() {
118                let a = $name::<6>(3);
119                let b = $name::<6>(2);
120                let c = $name::<6>(4);
121                let d = $name::<6>(1);
122                let e = $name::<6>(4);
123                let f = $name::<6>(0);
124                assert_eq!(a - b, d);
125                assert_eq!(b - c, e);
126                assert_eq!(a - a, f);
127            }
128            #[test]
129            fn mul() {
130                let a = $name::<6>(2);
131                let b = $name::<6>(5);
132                let c = $name::<6>(3);
133                let d = $name::<6>(4);
134                let e = $name::<6>(3);
135                let f = $name::<6>(0);
136                assert_eq!(a * b, d);
137                assert_eq!(b * c, e);
138                assert_eq!(a * c, f);
139            }
140            #[test]
141            fn div() {
142                let a = $name::<5>(2);
143                let b = $name::<5>(3);
144                let c = $name::<5>(4);
145                let d = $name::<5>(4);
146                let e = $name::<5>(2);
147                let f = $name::<5>(3);
148                assert_eq!(a / b, d);
149                assert_eq!(b / c, e);
150                assert_eq!(a / c, f);
151            }
152        }
153    };
154}
155generate!(ConstModulo8, u8, u16, i16, impl_num_traits8, test8);
156generate!(ConstModulo16, u16, u32, i32, impl_num_traits16, test16);
157generate!(ConstModulo32, u32, u64, i64, impl_num_traits32, test32);
158generate!(ConstModulo64, u64, u128, i128, impl_num_traits64, test64);