1use num_traits::{NumOps, One, Signed, Unsigned, Zero};
4
5use core::{
6 convert::{TryFrom, TryInto},
7 mem,
8};
9
10use crate::{arith::Arithmetic, error::ArithmeticError};
11
12pub trait DoubleWidth: Sized + Unsigned {
16 type Wide: Copy + From<Self> + TryInto<Self> + NumOps + Unsigned;
18 type SignedWide: Copy + From<Self> + TryInto<Self> + NumOps + Zero + One + Signed + PartialOrd;
20}
21
22impl DoubleWidth for u8 {
23 type Wide = u16;
24 type SignedWide = i16;
25}
26
27impl DoubleWidth for u16 {
28 type Wide = u32;
29 type SignedWide = i32;
30}
31
32impl DoubleWidth for u32 {
33 type Wide = u64;
34 type SignedWide = i64;
35}
36
37impl DoubleWidth for u64 {
38 type Wide = u128;
39 type SignedWide = i128;
40}
41
42#[derive(Debug, Clone, Copy)]
48pub struct ModularArithmetic<T> {
49 pub(super) modulus: T,
50}
51
52impl<T> ModularArithmetic<T>
53where
54 T: Clone + PartialEq + NumOps + Unsigned + Zero + One,
55{
56 pub fn new(modulus: T) -> Self {
62 assert!(!modulus.is_zero(), "Modulus cannot be 0");
63 assert!(!modulus.is_one(), "Modulus cannot be 1");
64 Self { modulus }
65 }
66
67 pub fn modulus(&self) -> &T {
69 &self.modulus
70 }
71}
72
73impl<T> ModularArithmetic<T>
74where
75 T: Copy + PartialEq + NumOps + Unsigned + Zero + One + DoubleWidth,
76{
77 #[inline]
78 fn mul_inner(self, x: T, y: T) -> T {
79 let wide = (<T::Wide>::from(x) * <T::Wide>::from(y)) % <T::Wide>::from(self.modulus);
80 wide.try_into().ok().unwrap() }
82
83 fn invert(self, value: T) -> Option<T> {
86 let value = value % self.modulus; let mut t = <T::SignedWide>::zero();
88 let mut new_t = <T::SignedWide>::one();
89
90 let modulus = <T::SignedWide>::from(self.modulus);
91 let mut r = modulus;
92 let mut new_r = <T::SignedWide>::from(value);
93
94 while !new_r.is_zero() {
95 let quotient = r / new_r;
96 t = t - quotient * new_t;
97 mem::swap(&mut new_t, &mut t);
98 r = r - quotient * new_r;
99 mem::swap(&mut new_r, &mut r);
100 }
101
102 if r > <T::SignedWide>::one() {
103 None } else {
105 if t.is_negative() {
106 t = t + modulus;
107 }
108 Some(t.try_into().ok().unwrap())
109 }
111 }
112
113 fn modular_exp(self, base: T, mut exp: usize) -> T {
114 if exp == 0 {
115 return T::one();
116 }
117
118 let wide_modulus = <T::Wide>::from(self.modulus);
119 let mut base = <T::Wide>::from(base % self.modulus);
120
121 while exp & 1 == 0 {
122 base = (base * base) % wide_modulus;
123 exp >>= 1;
124 }
125 if exp == 1 {
126 return base.try_into().ok().unwrap(); }
128
129 let mut acc = base;
130 while exp > 1 {
131 exp >>= 1;
132 base = (base * base) % wide_modulus;
133 if exp & 1 == 1 {
134 acc = (acc * base) % wide_modulus;
135 }
136 }
137 acc.try_into().ok().unwrap() }
139}
140
141impl<T> Arithmetic<T> for ModularArithmetic<T>
142where
143 T: Copy + PartialEq + NumOps + Zero + One + DoubleWidth,
144 usize: TryFrom<T>,
145{
146 #[inline]
147 fn add(&self, x: T, y: T) -> Result<T, ArithmeticError> {
148 let wide = (<T::Wide>::from(x) + <T::Wide>::from(y)) % <T::Wide>::from(self.modulus);
149 Ok(wide.try_into().ok().unwrap()) }
151
152 #[inline]
153 fn sub(&self, x: T, y: T) -> Result<T, ArithmeticError> {
154 let y = y % self.modulus; self.add(x, self.modulus - y)
156 }
157
158 #[inline]
159 fn mul(&self, x: T, y: T) -> Result<T, ArithmeticError> {
160 Ok(self.mul_inner(x, y))
161 }
162
163 #[inline]
164 fn div(&self, x: T, y: T) -> Result<T, ArithmeticError> {
165 if y.is_zero() {
166 Err(ArithmeticError::DivisionByZero)
167 } else {
168 let y_inv = self.invert(y).ok_or(ArithmeticError::NoInverse)?;
169 self.mul(x, y_inv)
170 }
171 }
172
173 #[inline]
174 #[allow(clippy::map_err_ignore)]
175 fn pow(&self, x: T, y: T) -> Result<T, ArithmeticError> {
176 let exp = usize::try_from(y).map_err(|_| ArithmeticError::InvalidExponent)?;
177 Ok(self.modular_exp(x, exp))
178 }
179
180 #[inline]
181 fn neg(&self, x: T) -> Result<T, ArithmeticError> {
182 let x = x % self.modulus; Ok(self.modulus - x)
184 }
185
186 #[inline]
187 fn eq(&self, x: &T, y: &T) -> bool {
188 *x % self.modulus == *y % self.modulus
189 }
190}
191
192#[cfg(test)]
193static_assertions::assert_impl_all!(ModularArithmetic<u8>: Arithmetic<u8>);
194#[cfg(test)]
195static_assertions::assert_impl_all!(ModularArithmetic<u16>: Arithmetic<u16>);
196#[cfg(test)]
197static_assertions::assert_impl_all!(ModularArithmetic<u32>: Arithmetic<u32>);
198#[cfg(test)]
199static_assertions::assert_impl_all!(ModularArithmetic<u64>: Arithmetic<u64>);
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204
205 use rand::{rngs::StdRng, Rng, SeedableRng};
206
207 #[test]
208 fn modular_arithmetic_basics() {
209 let arithmetic = ModularArithmetic::new(11_u32);
210 assert_eq!(arithmetic.add(1, 5).unwrap(), 6);
211 assert_eq!(arithmetic.add(2, 9).unwrap(), 0);
212 assert_eq!(arithmetic.add(5, 9).unwrap(), 3);
213 assert_eq!(arithmetic.add(5, 20).unwrap(), 3);
214
215 assert_eq!(arithmetic.sub(5, 9).unwrap(), 7);
216 assert_eq!(arithmetic.sub(5, 20).unwrap(), 7);
217
218 assert_eq!(arithmetic.mul(5, 4).unwrap(), 9);
219 assert_eq!(arithmetic.mul(11, 4).unwrap(), 0);
220
221 assert_eq!(u32::MAX % 11, 3);
223 assert_eq!(arithmetic.mul(u32::MAX, u32::MAX).unwrap(), 9);
224
225 assert_eq!(arithmetic.div(1, 4).unwrap(), 3); assert_eq!(arithmetic.div(2, 4).unwrap(), 6);
227 assert_eq!(arithmetic.div(1, 9).unwrap(), 5); assert_eq!(arithmetic.pow(2, 5).unwrap(), 10);
230 assert_eq!(arithmetic.pow(3, 10).unwrap(), 1); assert_eq!(arithmetic.pow(3, 4).unwrap(), 4);
232 assert_eq!(arithmetic.pow(7, 3).unwrap(), 2);
233 }
234
235 #[test]
236 fn modular_arithmetic_never_overflows() {
237 const MODULUS: u8 = 241;
238
239 let arithmetic = ModularArithmetic::new(MODULUS);
240 for x in 0..=u8::MAX {
241 for y in 0..=u8::MAX {
242 let expected = (u16::from(x) + u16::from(y)) % u16::from(MODULUS);
243 assert_eq!(u16::from(arithmetic.add(x, y).unwrap()), expected);
244
245 let mut expected = (i16::from(x) - i16::from(y)) % i16::from(MODULUS);
246 if expected < 0 {
247 expected += i16::from(MODULUS);
248 }
249 assert_eq!(i16::from(arithmetic.sub(x, y).unwrap()), expected);
250
251 let expected = (u16::from(x) * u16::from(y)) % u16::from(MODULUS);
252 assert_eq!(u16::from(arithmetic.mul(x, y).unwrap()), expected);
253 }
254 }
255
256 for x in 0..=u8::MAX {
257 let inv = arithmetic.invert(x);
258 if x % MODULUS == 0 {
259 assert!(inv.is_none());
260 } else {
261 let inv = u16::from(inv.unwrap());
262 assert_eq!((inv * u16::from(x)) % u16::from(MODULUS), 1);
263 }
264 }
265 }
266
267 const SAMPLE_COUNT: usize = 25_000;
269
270 fn mini_fuzz_for_prime_modulus(modulus: u64) {
271 let arithmetic = ModularArithmetic::new(modulus);
272 let unsigned_wide_mod = u128::from(modulus);
273 let signed_wide_mod = i128::from(modulus);
274 let mut rng = StdRng::seed_from_u64(modulus);
275
276 for (x, y) in (0..SAMPLE_COUNT).map(|_| rng.gen::<(u64, u64)>()) {
277 let expected = (u128::from(x) + u128::from(y)) % unsigned_wide_mod;
278 assert_eq!(u128::from(arithmetic.add(x, y).unwrap()), expected);
279
280 let mut expected = (i128::from(x) - i128::from(y)) % signed_wide_mod;
281 if expected < 0 {
282 expected += signed_wide_mod;
283 }
284 assert_eq!(i128::from(arithmetic.sub(x, y).unwrap()), expected);
285
286 let expected = (u128::from(x) * u128::from(y)) % unsigned_wide_mod;
287 assert_eq!(u128::from(arithmetic.mul(x, y).unwrap()), expected);
288 }
289
290 for x in (0..SAMPLE_COUNT).map(|_| rng.gen::<u64>()) {
291 let inv = arithmetic.invert(x);
292 if x % modulus == 0 {
293 assert!(inv.is_none());
295 } else {
296 let inv = u128::from(inv.unwrap());
297 assert_eq!((inv * u128::from(x)) % unsigned_wide_mod, 1);
298 }
299 }
300
301 for _ in 0..(SAMPLE_COUNT / 10) {
302 let x = rng.gen::<u64>();
303 let wide_x = u128::from(x);
304
305 let exp = rng.gen_range(1_u64..1_000);
307 let expected_pow = (0..exp).fold(1_u128, |acc, _| (acc * wide_x) % unsigned_wide_mod);
308 assert_eq!(u128::from(arithmetic.pow(x, exp).unwrap()), expected_pow);
309
310 if x % modulus != 0 {
311 let pow = arithmetic.pow(x, modulus - 1).unwrap();
313 assert_eq!(pow, 1);
314 }
315 }
316 }
317
318 #[test]
319 fn mini_fuzz_for_small_modulus() {
320 mini_fuzz_for_prime_modulus(3);
321 mini_fuzz_for_prime_modulus(7);
322 mini_fuzz_for_prime_modulus(23);
323 mini_fuzz_for_prime_modulus(61);
324 }
325
326 #[test]
327 fn mini_fuzz_for_u32_modulus() {
328 mini_fuzz_for_prime_modulus(3_000_000_019);
330 mini_fuzz_for_prime_modulus(3_500_000_011);
331 mini_fuzz_for_prime_modulus(4_000_000_007);
332 }
333
334 #[test]
335 fn mini_fuzz_for_large_u64_modulus() {
336 mini_fuzz_for_prime_modulus(2_594_642_710_891_962_701);
338 mini_fuzz_for_prime_modulus(5_647_618_287_156_850_721);
339 mini_fuzz_for_prime_modulus(9_223_372_036_854_775_837);
340 mini_fuzz_for_prime_modulus(10_902_486_311_044_492_273);
341 }
342}