Skip to main content

lib_modulo/
residue64.rs

1use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2
3/// Factory of [`Residue64`].
4///
5/// See documentation of [`Residue64`] for details.
6#[allow(clippy::derived_hash_with_manual_eq)]
7#[derive(Debug, Clone, Eq, Hash)]
8pub struct Modulus64 {
9    // n inv_n = 1 (mod r = 2^32 or 2^64)
10    pub(crate) n: u64,
11    pub(crate) inv_n: u64,
12    pub(crate) r2_mod_n: u64,
13}
14
15impl Modulus64 {
16    /// Calculates some parameters for Montgomery multiplication.
17    ///
18    /// # Panics
19    ///
20    /// - modulus `n` should be an odd number.
21    #[inline]
22    pub const fn new(n: u64) -> Self {
23        assert!(n & 1 == 1, "modulus should be an odd number");
24
25        let inv_n = {
26            const TABLE: u32 = {
27                // | n     | 1 | 3  | 5  | 7 | 9 | 11 | 13 | 15 |
28                // | inv_n | 1 | 11 | 13 | 7 | 9 | 3  | 5  | 15 | <- 4 bits * 8
29                let inv_n = [1, 11, 13, 7, 9, 3, 5, 15];
30
31                let mut table = 0;
32                let mut i = 0;
33                while i < 8 {
34                    table |= inv_n[i] << (i * 4);
35                    i += 1;
36                }
37
38                table
39            };
40            // n inv_n = 1 (mod 8)
41            let mut inv_n = ((TABLE >> ((n & 0b1110) * 2)) & 0b1111) as u64;
42
43            let mut d = const { u64::BITS.ilog2() - 2 };
44            while d > 0 {
45                inv_n = inv_n.wrapping_mul(2_u64.wrapping_sub(n.wrapping_mul(inv_n)));
46                d -= 1;
47            }
48            debug_assert!(n.wrapping_mul(inv_n) == 1);
49
50            inv_n
51        };
52        let r2_mod_n = ((n as u128).wrapping_neg() % (n as u128)) as u64;
53
54        Self { n, inv_n, r2_mod_n }
55    }
56
57    #[inline(always)]
58    pub const fn residue(&self, x: u64) -> Residue64<'_> {
59        // `x r2 < r n`
60        let x = self.mul(x, self.r2_mod_n);
61
62        Residue64 { x, modulus: self }
63    }
64
65    /// Performs Montgomery multiplication.
66    ///
67    /// if `lhs rhs < n r`, then `result < n`
68    #[inline(always)]
69    pub(crate) const fn mul(&self, lhs: u64, rhs: u64) -> u64 {
70        self.mul_add(lhs, rhs, 0)
71    }
72
73    /// Performs `lhs rhs + add`.
74    ///
75    /// If `lhs rhs + add < n r`, then the result is less than `n`.
76    #[inline(always)]
77    pub(crate) const fn mul_add(&self, lhs: u64, rhs: u64, add: u64) -> u64 {
78        // FIXME: use `a.widening_mul(b)`
79        let (x_hi, x_lo) = {
80            let x = lhs as u128 * rhs as u128 + add as u128;
81            ((x >> u64::BITS) as u64, x as u64)
82        };
83        // FIXME: use `mul_hi()`
84        // y = x n nn = x (mod r) => y_lo = x_lo
85        let y_hi = ((x_lo.wrapping_mul(self.inv_n) as u128 * self.n as u128) >> u64::BITS) as u64;
86        // x - y = 0 (mod r), x - y = x (mod n) => z = x inv_r (mod n)
87        let (z, b) = x_hi.overflowing_sub(y_hi);
88
89        // x < n r, y < n r => |z| < n
90        if b {
91            z.wrapping_add(self.n)
92        } else {
93            z
94        }
95    }
96
97    /// Checks whether `x` is multiple of `self`.
98    ///
99    /// # Example
100    ///
101    /// ```
102    /// use lib_modulo::Modulus64;
103    ///
104    /// for n in (1..1 << 10).step_by(2) {
105    ///     let modulus  = Modulus64::new(n);
106    ///
107    ///     (0..1 << 10).for_each(|k| assert!(modulus.can_divide(n * k)));
108    /// }
109    /// ```
110    #[inline]
111    pub const fn can_divide(&self, x: u64) -> bool {
112        self.residue(x).is_zero()
113    }
114}
115
116impl PartialEq for Modulus64 {
117    fn eq(&self, other: &Self) -> bool {
118        // other parameters depend on `n`
119        self.n == other.n
120    }
121}
122
123/// Residue with odd modulus which is less than 2^64.
124///
125/// # Fast modular multiplication
126///
127/// [`Residue64`] provides fast modular multiplication called [Montgomery multiplication].
128/// Since this method provides modular multiplication without trial division,
129/// it is approximately twice as fast.
130///
131/// [Montgomery multiplication]: https://doi.org/10.1090/s0025-5718-1985-0777282-x
132///
133/// # Usage
134///
135/// ```
136/// use lib_modulo::Modulus64;
137///
138/// // runtime-specified *odd* modulus
139/// let modulus = 5;
140///
141/// let modulus = Modulus64::new(modulus); // slow
142/// let n = modulus.residue(2) * modulus.residue(3); // fast
143/// assert_eq!(n.get(), 1);
144/// ```
145/// Two residues with different modulus can interact, but the result will be meaningless.
146/// It is highly recommended to use a block to ensure that [`Modulus64`], therefore [`Residue64`]s, are dropped.
147#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
148pub struct Residue64<'a> {
149    pub(crate) modulus: &'a Modulus64,
150    // x r (mod n)
151    pub(crate) x: u64,
152}
153
154impl<'a> Residue64<'a> {
155    /// Returns the residue.
156    ///
157    /// # Example
158    ///
159    /// ```
160    /// use lib_modulo::Modulus64;
161    ///
162    /// let modulus  = Modulus64::new(5);
163    /// let n = modulus.residue(7);
164    /// assert_eq!(n.get(), 2);
165    /// ```
166    #[inline(always)]
167    pub const fn get(&self) -> u64 {
168        self.modulus.mul(self.x, 1)
169    }
170
171    /// Returns the modulus.
172    ///
173    /// # Example
174    ///
175    /// ```
176    /// use lib_modulo::Modulus64;
177    ///
178    /// let modulus  = Modulus64::new(5);
179    /// let n = modulus.residue(7);
180    /// assert_eq!(n.modulus(), 5);
181    /// ```
182    #[inline(always)]
183    pub const fn modulus(&self) -> u64 {
184        self.modulus.n
185    }
186
187    /// Checks whether `self` is `0`.
188    ///
189    /// # Example
190    ///
191    /// ```
192    /// use lib_modulo::Modulus64;
193    ///
194    /// let modulus  = Modulus64::new(3);
195    /// assert_eq!(modulus.residue(6).get(), 0);
196    /// ```
197    #[inline(always)]
198    pub const fn is_zero(self) -> bool {
199        self.x == 0
200    }
201
202    /// Raises `self` to the power of `exp`, using exponentiation by squaring.
203    ///
204    /// # Time complexity
205    ///
206    /// *O*(log `exp`)
207    ///
208    /// # Example
209    ///
210    /// ```
211    /// use lib_modulo::Modulus64;
212    ///
213    /// let modulus = Modulus64::new(1001);
214    /// let residue = modulus.residue(2);
215    /// for exp in 0..64 {
216    ///     assert_eq!(residue.pow(exp).get(), (1 << exp) % 1001)
217    /// }
218    /// ```
219    #[inline]
220    pub const fn pow(mut self, mut exp: u64) -> Self {
221        // r inv_r = 1 (mod n)
222        let mut result = self.modulus.residue(1).x;
223
224        while exp > 0 {
225            if exp & 1 == 1 {
226                // n < r
227                result = self.modulus.mul(result, self.x)
228            }
229
230            exp >>= 1;
231            // n < r
232            self.x = self.modulus.mul(self.x, self.x)
233        }
234        self.x = result;
235
236        self
237    }
238
239    /// Calculates the modular inverse of `self`, using extended binary GCD algorithm.
240    ///
241    /// Modular inverse can be defined if and only if `self` and the modulus is coprime.
242    ///
243    /// - `Ok(x)` : `x` is the modular inverse.
244    /// - `Err(x)`: `x` is the GCD of `self` and the `modulus`,
245    ///   where `gcd(0, a) = gcd(a, 0)` is defined to be `a`.
246    ///
247    /// # Time complexity
248    ///
249    /// *O*(log `self`)
250    ///
251    /// # Example
252    ///
253    /// ```
254    /// use lib_modulo::Modulus64;
255    ///
256    /// // 998_244_353 is a prime number, so modular inverse of n exists iff n != 0 (mod 998_244_353)
257    /// let modulus = Modulus64::new(998_244_353);
258    ///
259    /// for n in 1..500_000 {
260    ///     let n = modulus.residue(n);
261    ///     assert!(n.try_inv().is_ok_and(|i| (i * n).get() == 1));
262    /// }
263    /// // 0 n = 0 != 1 for any integer n
264    /// assert!(modulus.residue(0).try_inv().is_err());
265    /// ```
266    #[inline]
267    pub const fn try_inv(self) -> Result<Self, u64> {
268        let mut a = self.get();
269        let Self { modulus, .. } = self;
270
271        // performs extended binary gcd
272        //
273        // invariants: a = [a] x,  b = [a] y (mod n) where [a] is initial value
274        let mut b = modulus.n;
275        let mut x = modulus.residue(1).x; // 1 r mod n
276        let mut y = 0; // 0 r mod n
277        let frac_1_2 = modulus.residue(modulus.n.div_ceil(2));
278
279        while a > 0 {
280            x = modulus.mul(x, frac_1_2.pow(a.trailing_zeros() as u64).x);
281            a >>= a.trailing_zeros();
282
283            if a < b {
284                (a, b) = (b, a);
285                (x, y) = (y, x);
286            }
287            a -= b;
288            let (diff, b) = x.overflowing_sub(y);
289            x = if b {
290                diff.wrapping_add(modulus.n)
291            } else {
292                diff
293            };
294        }
295
296        // b = gcd([a], [b])
297        if b == 1 {
298            Ok(Self { x: y, modulus })
299        } else {
300            Err(b)
301        }
302    }
303}
304
305impl<'a> Add for Residue64<'a> {
306    type Output = Self;
307
308    #[inline(always)]
309    fn add(mut self, rhs: Self) -> Self {
310        let (sum, b) = self.x.overflowing_add(rhs.x);
311        self.x = if b || sum >= self.modulus.n {
312            sum.wrapping_sub(self.modulus.n)
313        } else {
314            sum
315        };
316
317        self
318    }
319}
320
321impl<'a> AddAssign for Residue64<'a> {
322    #[inline(always)]
323    fn add_assign(&mut self, rhs: Self) {
324        *self = *self + rhs
325    }
326}
327
328impl<'a> Sub for Residue64<'a> {
329    type Output = Self;
330
331    #[inline(always)]
332    fn sub(mut self, rhs: Self) -> Self {
333        let (diff, b) = self.x.overflowing_sub(rhs.x);
334        self.x = if b {
335            diff.wrapping_add(self.modulus.n)
336        } else {
337            diff
338        };
339
340        self
341    }
342}
343
344impl<'a> SubAssign for Residue64<'a> {
345    #[inline(always)]
346    fn sub_assign(&mut self, rhs: Self) {
347        *self = *self - rhs
348    }
349}
350
351impl<'a> Mul for Residue64<'a> {
352    type Output = Self;
353
354    #[inline(always)]
355    fn mul(mut self, rhs: Self) -> Self {
356        // n < r
357        self.x = self.modulus.mul(self.x, rhs.x);
358
359        self
360    }
361}
362
363impl<'a> MulAssign for Residue64<'a> {
364    #[inline(always)]
365    fn mul_assign(&mut self, rhs: Self) {
366        *self = *self * rhs
367    }
368}
369
370impl<'a> Neg for Residue64<'a> {
371    type Output = Self;
372
373    #[inline(always)]
374    fn neg(mut self) -> Self::Output {
375        // (x - x) r = 0 (mod n)
376        self.x = if self.x == 0 {
377            self.x
378        } else {
379            self.modulus.n - self.x
380        };
381
382        self
383    }
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389
390    use proptest::prelude::*;
391
392    proptest! {
393        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
394        #[test]
395        fn mul(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
396            let modulus = Modulus64::new(n);
397
398            let res = modulus.residue(x);
399            assert_eq!(res.get(), x % n)
400        }
401    }
402
403    proptest! {
404        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
405        #[test]
406        fn pow(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
407            let modulus = Modulus64::new(n);
408
409            let res = modulus.residue(x);
410            let mut naive = 1;
411            for i in 0..100 {
412                assert_eq!(res.pow(i).get(), naive, "exp = {i}");
413                naive = (naive as u128 * x as u128 % n as u128) as u64
414            }
415        }
416    }
417
418    proptest! {
419        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
420        #[test]
421        fn divisible(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
422            let modulus = Modulus64::new(n);
423
424            assert_eq!(modulus.can_divide(x), x % n == 0);
425            for m in std::iter::successors(Some(n), |m| m.checked_add(n)).take(100) {
426                assert!(modulus.can_divide(m))
427            }
428        }
429    }
430
431    fn binary_gcd(mut a: u64, mut b: u64) -> u64 {
432        if b == 0 {
433            return a;
434        }
435
436        let shift = (a | b).trailing_zeros();
437        b >>= b.trailing_zeros();
438
439        while a != 0 {
440            a >>= a.trailing_zeros();
441
442            if a < b {
443                (a, b) = (b, a)
444            }
445            a -= b
446        }
447
448        b << shift
449    }
450
451    proptest! {
452        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
453        #[test]
454        fn try_inv(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
455            let modulus = Modulus64::new(n);
456            let res = modulus.residue(x);
457
458            match res.try_inv() {
459                Ok(inv) => assert_eq!((inv * res).get(), 1),
460                Err(gcd) => {
461                    assert!(res.get() % gcd == 0);
462                    assert!(res.modulus() % gcd == 0);
463                    assert_eq!(binary_gcd(res.get() / gcd, res.modulus() / gcd), 1);
464                }
465            }
466        }
467    }
468}