arithmetic_eval/arith/
modular.rs

1//! Modular arithmetic.
2
3use num_traits::{NumOps, One, Signed, Unsigned, Zero};
4
5use core::{
6    convert::{TryFrom, TryInto},
7    mem,
8};
9
10use crate::{arith::Arithmetic, error::ArithmeticError};
11
12/// Encapsulates extension of an unsigned integer type into signed and unsigned double-width types.
13/// This allows performing certain operations (e.g., multiplication) without a possibility of
14/// integer overflow.
15pub trait DoubleWidth: Sized + Unsigned {
16    /// Unsigned double-width extension type.
17    type Wide: Copy + From<Self> + TryInto<Self> + NumOps + Unsigned;
18    /// Signed double-width extension type.
19    type SignedWide: Copy + From<Self> + TryInto<Self> + NumOps + Zero + One + Signed + PartialOrd;
20}
21
22impl DoubleWidth for u8 {
23    type Wide = u16;
24    type SignedWide = i16;
25}
26
27impl DoubleWidth for u16 {
28    type Wide = u32;
29    type SignedWide = i32;
30}
31
32impl DoubleWidth for u32 {
33    type Wide = u64;
34    type SignedWide = i64;
35}
36
37impl DoubleWidth for u64 {
38    type Wide = u128;
39    type SignedWide = i128;
40}
41
42/// Modular arithmetic on integers.
43///
44/// As an example, `ModularArithmetic<T>` implements `Arithmetic<T>` if `T` is one of unsigned
45/// built-in integer types (`u8`, `u16`, `u32`, `u64`; `u128` **is excluded** because it cannot be
46/// extended to double width).
47#[derive(Debug, Clone, Copy)]
48pub struct ModularArithmetic<T> {
49    pub(super) modulus: T,
50}
51
52impl<T> ModularArithmetic<T>
53where
54    T: Clone + PartialEq + NumOps + Unsigned + Zero + One,
55{
56    /// Creates a new arithmetic with the specified `modulus`.
57    ///
58    /// # Panics
59    ///
60    /// - Panics if modulus is 0 or 1.
61    pub fn new(modulus: T) -> Self {
62        assert!(!modulus.is_zero(), "Modulus cannot be 0");
63        assert!(!modulus.is_one(), "Modulus cannot be 1");
64        Self { modulus }
65    }
66
67    /// Returns the modulus for this arithmetic.
68    pub fn modulus(&self) -> &T {
69        &self.modulus
70    }
71}
72
73impl<T> ModularArithmetic<T>
74where
75    T: Copy + PartialEq + NumOps + Unsigned + Zero + One + DoubleWidth,
76{
77    #[inline]
78    fn mul_inner(self, x: T, y: T) -> T {
79        let wide = (<T::Wide>::from(x) * <T::Wide>::from(y)) % <T::Wide>::from(self.modulus);
80        wide.try_into().ok().unwrap() // `unwrap` is safe by construction
81    }
82
83    /// Computes the multiplicative inverse of `value` using the extended Euclid algorithm.
84    /// Care is taken to not overflow anywhere.
85    fn invert(self, value: T) -> Option<T> {
86        let value = value % self.modulus; // Reduce value since this influences speed.
87        let mut t = <T::SignedWide>::zero();
88        let mut new_t = <T::SignedWide>::one();
89
90        let modulus = <T::SignedWide>::from(self.modulus);
91        let mut r = modulus;
92        let mut new_r = <T::SignedWide>::from(value);
93
94        while !new_r.is_zero() {
95            let quotient = r / new_r;
96            t = t - quotient * new_t;
97            mem::swap(&mut new_t, &mut t);
98            r = r - quotient * new_r;
99            mem::swap(&mut new_r, &mut r);
100        }
101
102        if r > <T::SignedWide>::one() {
103            None // r = gcd(self.modulus, value) > 1
104        } else {
105            if t.is_negative() {
106                t = t + modulus;
107            }
108            Some(t.try_into().ok().unwrap())
109            // ^-- `unwrap` is safe by construction
110        }
111    }
112
113    fn modular_exp(self, base: T, mut exp: usize) -> T {
114        if exp == 0 {
115            return T::one();
116        }
117
118        let wide_modulus = <T::Wide>::from(self.modulus);
119        let mut base = <T::Wide>::from(base % self.modulus);
120
121        while exp & 1 == 0 {
122            base = (base * base) % wide_modulus;
123            exp >>= 1;
124        }
125        if exp == 1 {
126            return base.try_into().ok().unwrap(); // `unwrap` is safe by construction
127        }
128
129        let mut acc = base;
130        while exp > 1 {
131            exp >>= 1;
132            base = (base * base) % wide_modulus;
133            if exp & 1 == 1 {
134                acc = (acc * base) % wide_modulus;
135            }
136        }
137        acc.try_into().ok().unwrap() // `unwrap` is safe by construction
138    }
139}
140
141impl<T> Arithmetic<T> for ModularArithmetic<T>
142where
143    T: Copy + PartialEq + NumOps + Zero + One + DoubleWidth,
144    usize: TryFrom<T>,
145{
146    #[inline]
147    fn add(&self, x: T, y: T) -> Result<T, ArithmeticError> {
148        let wide = (<T::Wide>::from(x) + <T::Wide>::from(y)) % <T::Wide>::from(self.modulus);
149        Ok(wide.try_into().ok().unwrap()) // `unwrap` is safe by construction
150    }
151
152    #[inline]
153    fn sub(&self, x: T, y: T) -> Result<T, ArithmeticError> {
154        let y = y % self.modulus; // Prevent possible overflow in the following subtraction
155        self.add(x, self.modulus - y)
156    }
157
158    #[inline]
159    fn mul(&self, x: T, y: T) -> Result<T, ArithmeticError> {
160        Ok(self.mul_inner(x, y))
161    }
162
163    #[inline]
164    fn div(&self, x: T, y: T) -> Result<T, ArithmeticError> {
165        if y.is_zero() {
166            Err(ArithmeticError::DivisionByZero)
167        } else {
168            let y_inv = self.invert(y).ok_or(ArithmeticError::NoInverse)?;
169            self.mul(x, y_inv)
170        }
171    }
172
173    #[inline]
174    #[allow(clippy::map_err_ignore)]
175    fn pow(&self, x: T, y: T) -> Result<T, ArithmeticError> {
176        let exp = usize::try_from(y).map_err(|_| ArithmeticError::InvalidExponent)?;
177        Ok(self.modular_exp(x, exp))
178    }
179
180    #[inline]
181    fn neg(&self, x: T) -> Result<T, ArithmeticError> {
182        let x = x % self.modulus; // Prevent possible overflow in the following subtraction
183        Ok(self.modulus - x)
184    }
185
186    #[inline]
187    fn eq(&self, x: &T, y: &T) -> bool {
188        *x % self.modulus == *y % self.modulus
189    }
190}
191
192#[cfg(test)]
193static_assertions::assert_impl_all!(ModularArithmetic<u8>: Arithmetic<u8>);
194#[cfg(test)]
195static_assertions::assert_impl_all!(ModularArithmetic<u16>: Arithmetic<u16>);
196#[cfg(test)]
197static_assertions::assert_impl_all!(ModularArithmetic<u32>: Arithmetic<u32>);
198#[cfg(test)]
199static_assertions::assert_impl_all!(ModularArithmetic<u64>: Arithmetic<u64>);
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204
205    use rand::{rngs::StdRng, Rng, SeedableRng};
206
207    #[test]
208    fn modular_arithmetic_basics() {
209        let arithmetic = ModularArithmetic::new(11_u32);
210        assert_eq!(arithmetic.add(1, 5).unwrap(), 6);
211        assert_eq!(arithmetic.add(2, 9).unwrap(), 0);
212        assert_eq!(arithmetic.add(5, 9).unwrap(), 3);
213        assert_eq!(arithmetic.add(5, 20).unwrap(), 3);
214
215        assert_eq!(arithmetic.sub(5, 9).unwrap(), 7);
216        assert_eq!(arithmetic.sub(5, 20).unwrap(), 7);
217
218        assert_eq!(arithmetic.mul(5, 4).unwrap(), 9);
219        assert_eq!(arithmetic.mul(11, 4).unwrap(), 0);
220
221        // Check overflows.
222        assert_eq!(u32::MAX % 11, 3);
223        assert_eq!(arithmetic.mul(u32::MAX, u32::MAX).unwrap(), 9);
224
225        assert_eq!(arithmetic.div(1, 4).unwrap(), 3); // 4 * 3 = 12 = 1 (mod 11)
226        assert_eq!(arithmetic.div(2, 4).unwrap(), 6);
227        assert_eq!(arithmetic.div(1, 9).unwrap(), 5); // 9 * 5 = 45 = 1 (mod 11)
228
229        assert_eq!(arithmetic.pow(2, 5).unwrap(), 10);
230        assert_eq!(arithmetic.pow(3, 10).unwrap(), 1); // by Fermat theorem
231        assert_eq!(arithmetic.pow(3, 4).unwrap(), 4);
232        assert_eq!(arithmetic.pow(7, 3).unwrap(), 2);
233    }
234
235    #[test]
236    fn modular_arithmetic_never_overflows() {
237        const MODULUS: u8 = 241;
238
239        let arithmetic = ModularArithmetic::new(MODULUS);
240        for x in 0..=u8::MAX {
241            for y in 0..=u8::MAX {
242                let expected = (u16::from(x) + u16::from(y)) % u16::from(MODULUS);
243                assert_eq!(u16::from(arithmetic.add(x, y).unwrap()), expected);
244
245                let mut expected = (i16::from(x) - i16::from(y)) % i16::from(MODULUS);
246                if expected < 0 {
247                    expected += i16::from(MODULUS);
248                }
249                assert_eq!(i16::from(arithmetic.sub(x, y).unwrap()), expected);
250
251                let expected = (u16::from(x) * u16::from(y)) % u16::from(MODULUS);
252                assert_eq!(u16::from(arithmetic.mul(x, y).unwrap()), expected);
253            }
254        }
255
256        for x in 0..=u8::MAX {
257            let inv = arithmetic.invert(x);
258            if x % MODULUS == 0 {
259                assert!(inv.is_none());
260            } else {
261                let inv = u16::from(inv.unwrap());
262                assert_eq!((inv * u16::from(x)) % u16::from(MODULUS), 1);
263            }
264        }
265    }
266
267    // Takes ~1s in the debug mode.
268    const SAMPLE_COUNT: usize = 25_000;
269
270    fn mini_fuzz_for_prime_modulus(modulus: u64) {
271        let arithmetic = ModularArithmetic::new(modulus);
272        let unsigned_wide_mod = u128::from(modulus);
273        let signed_wide_mod = i128::from(modulus);
274        let mut rng = StdRng::seed_from_u64(modulus);
275
276        for (x, y) in (0..SAMPLE_COUNT).map(|_| rng.gen::<(u64, u64)>()) {
277            let expected = (u128::from(x) + u128::from(y)) % unsigned_wide_mod;
278            assert_eq!(u128::from(arithmetic.add(x, y).unwrap()), expected);
279
280            let mut expected = (i128::from(x) - i128::from(y)) % signed_wide_mod;
281            if expected < 0 {
282                expected += signed_wide_mod;
283            }
284            assert_eq!(i128::from(arithmetic.sub(x, y).unwrap()), expected);
285
286            let expected = (u128::from(x) * u128::from(y)) % unsigned_wide_mod;
287            assert_eq!(u128::from(arithmetic.mul(x, y).unwrap()), expected);
288        }
289
290        for x in (0..SAMPLE_COUNT).map(|_| rng.gen::<u64>()) {
291            let inv = arithmetic.invert(x);
292            if x % modulus == 0 {
293                // Quite unlikely, but better be safe than sorry.
294                assert!(inv.is_none());
295            } else {
296                let inv = u128::from(inv.unwrap());
297                assert_eq!((inv * u128::from(x)) % unsigned_wide_mod, 1);
298            }
299        }
300
301        for _ in 0..(SAMPLE_COUNT / 10) {
302            let x = rng.gen::<u64>();
303            let wide_x = u128::from(x);
304
305            // Check a random small exponent.
306            let exp = rng.gen_range(1_u64..1_000);
307            let expected_pow = (0..exp).fold(1_u128, |acc, _| (acc * wide_x) % unsigned_wide_mod);
308            assert_eq!(u128::from(arithmetic.pow(x, exp).unwrap()), expected_pow);
309
310            if x % modulus != 0 {
311                // Check Fermat's little theorem.
312                let pow = arithmetic.pow(x, modulus - 1).unwrap();
313                assert_eq!(pow, 1);
314            }
315        }
316    }
317
318    #[test]
319    fn mini_fuzz_for_small_modulus() {
320        mini_fuzz_for_prime_modulus(3);
321        mini_fuzz_for_prime_modulus(7);
322        mini_fuzz_for_prime_modulus(23);
323        mini_fuzz_for_prime_modulus(61);
324    }
325
326    #[test]
327    fn mini_fuzz_for_u32_modulus() {
328        // Primes taken from https://www.numberempire.com/primenumbers.php
329        mini_fuzz_for_prime_modulus(3_000_000_019);
330        mini_fuzz_for_prime_modulus(3_500_000_011);
331        mini_fuzz_for_prime_modulus(4_000_000_007);
332    }
333
334    #[test]
335    fn mini_fuzz_for_large_u64_modulus() {
336        // Primes taken from https://bigprimes.org/
337        mini_fuzz_for_prime_modulus(2_594_642_710_891_962_701);
338        mini_fuzz_for_prime_modulus(5_647_618_287_156_850_721);
339        mini_fuzz_for_prime_modulus(9_223_372_036_854_775_837);
340        mini_fuzz_for_prime_modulus(10_902_486_311_044_492_273);
341    }
342}