Skip to main content

lib_modulo/
residue64.rs

1use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2
3/// Factory of [`Residue64`].
4///
5/// See documentation of [`Residue64`] for details.
6#[derive(Debug, Clone, Eq, Hash)]
7pub struct Modulus64 {
8    // n inv_n = 1 (mod r = 2^32 or 2^64)
9    pub(crate) n: u64,
10    pub(crate) inv_n: u64,
11    pub(crate) r2_mod_n: u64,
12}
13
14impl Modulus64 {
15    /// Calculates some parameters for Montgomery multiplication.
16    ///
17    /// # Panics
18    ///
19    /// - modulus `n` should be an odd number.
20    #[inline]
21    pub const fn new(n: u64) -> Self {
22        assert!(n & 1 == 1, "modulus should be an odd number");
23
24        let inv_n = {
25            const TABLE: u32 = {
26                // | n     | 1 | 3  | 5  | 7 | 9 | 11 | 13 | 15 |
27                // | inv_n | 1 | 11 | 13 | 7 | 9 | 3  | 5  | 15 | <- 4 bits * 8
28                let inv_n = [1, 11, 13, 7, 9, 3, 5, 15];
29
30                let mut table = 0;
31                let mut i = 0;
32                while i < 8 {
33                    table |= inv_n[i] << (i * 4);
34                    i += 1;
35                }
36
37                table
38            };
39            // n inv_n = 1 (mod 8)
40            let mut inv_n = ((TABLE >> (n & 0b1110) * 2) & 0b1111) as u64;
41
42            let mut d = const { u64::BITS.ilog2() - 2 };
43            while d > 0 {
44                inv_n = inv_n.wrapping_mul((2 as u64).wrapping_sub(n.wrapping_mul(inv_n)));
45                d -= 1;
46            }
47            debug_assert!(n.wrapping_mul(inv_n) == 1);
48
49            inv_n
50        };
51        let r2_mod_n = ((n as u128).wrapping_neg() % (n as u128)) as u64;
52
53        Self { n, inv_n, r2_mod_n }
54    }
55
56    #[inline(always)]
57    pub const fn residue(&self, x: u64) -> Residue64<'_> {
58        // `x r2 < r n`
59        let x = self.mul(x, self.r2_mod_n);
60
61        Residue64 { x, modulus: &self }
62    }
63
64    /// Performs Montgomery multiplication.
65    ///
66    /// if `lhs rhs < n r`, then `result < n`
67    #[inline(always)]
68    pub(crate) const fn mul(&self, lhs: u64, rhs: u64) -> u64 {
69        self.mul_add(lhs, rhs, 0)
70    }
71
72    /// Performs `lhs rhs + add`.
73    ///
74    /// If `lhs rhs + add < n r`, then the result is less than `n`.
75    #[inline(always)]
76    pub(crate) const fn mul_add(&self, lhs: u64, rhs: u64, add: u64) -> u64 {
77        // FIXME: use `a.widening_mul(b)`
78        let (x_hi, x_lo) = {
79            let x = lhs as u128 * rhs as u128 + add as u128;
80            ((x >> u64::BITS) as u64, x as u64)
81        };
82        // FIXME: use `mul_hi()`
83        // y = x n nn = x (mod r) => y_lo = x_lo
84        let y_hi = ((x_lo.wrapping_mul(self.inv_n) as u128 * self.n as u128) >> u64::BITS) as u64;
85        // x - y = 0 (mod r), x - y = x (mod n) => z = x inv_r (mod n)
86        let (z, b) = x_hi.overflowing_sub(y_hi);
87
88        // x < n r, y < n r => |z| < n
89        if b {
90            z.wrapping_add(self.n)
91        } else {
92            z
93        }
94    }
95
96    /// Checks whether `x` is multiple of `self`.
97    ///
98    /// # Example
99    ///
100    /// ```
101    /// use lib_modulo::Modulus64;
102    ///
103    /// for n in (1..1 << 10).step_by(2) {
104    ///     let modulus  = Modulus64::new(n);
105    ///
106    ///     (0..1 << 10).for_each(|k| assert!(modulus.can_divide(n * k)));
107    /// }
108    /// ```
109    #[inline]
110    pub const fn can_divide(&self, x: u64) -> bool {
111        self.residue(x).is_zero()
112    }
113}
114
115impl PartialEq for Modulus64 {
116    fn eq(&self, other: &Self) -> bool {
117        self.n == other.n
118    }
119}
120
121/// Residue with odd modulus which is less than 2^64.
122///
123/// # Fast modular multiplication
124///
125/// [`Residue64`] provides fast modular multiplication called [Montgomery multiplication].
126/// Since this method provides modular multiplication without trial division,
127/// it is approximately twice as fast.
128///
129/// [Montgomery multiplication]: https://doi.org/10.1090/s0025-5718-1985-0777282-x
130///
131/// # Usage
132///
133/// ```
134/// use lib_modulo::Modulus64;
135///
136/// // runtime-specified *odd* modulus
137/// let modulus = 5;
138///
139/// let modulus = Modulus64::new(modulus); // slow
140/// let n = modulus.residue(2) * modulus.residue(3); // fast
141/// assert_eq!(n.get(), 1);
142/// ```
143/// Two residues with different modulus can interact, but the result will be meaningless.
144/// It is highly recommended to use a block to ensure that [`Modulus64`], therefore [`Residue64`]s, are dropped.
145#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
146pub struct Residue64<'a> {
147    pub(crate) modulus: &'a Modulus64,
148    // x r (mod n)
149    pub(crate) x: u64,
150}
151
152impl<'a> Residue64<'a> {
153    /// Returns the residue.
154    ///
155    /// # Example
156    ///
157    /// ```
158    /// use lib_modulo::Modulus64;
159    ///
160    /// let modulus  = Modulus64::new(5);
161    /// let n = modulus.residue(7);
162    /// assert_eq!(n.get(), 2);
163    /// ```
164    #[inline(always)]
165    pub const fn get(&self) -> u64 {
166        self.modulus.mul(self.x, 1)
167    }
168
169    /// Returns the modulus.
170    ///
171    /// # Example
172    ///
173    /// ```
174    /// use lib_modulo::Modulus64;
175    ///
176    /// let modulus  = Modulus64::new(5);
177    /// let n = modulus.residue(7);
178    /// assert_eq!(n.modulus(), 5);
179    /// ```
180    #[inline(always)]
181    pub const fn modulus(&self) -> u64 {
182        self.modulus.n
183    }
184
185    /// Checks whether `self` is `0`.
186    ///
187    /// # Example
188    ///
189    /// ```
190    /// use lib_modulo::Modulus64;
191    ///
192    /// let modulus  = Modulus64::new(3);
193    /// let n = modulus.residue(6);
194    /// assert_eq!(n.modulus(), 0);
195    /// ```
196    #[inline(always)]
197    pub const fn is_zero(self) -> bool {
198        self.x == 0
199    }
200
201    /// Raises `self` to the power of `exp`, using exponentiation by squaring.
202    ///
203    /// # Time complexity
204    ///
205    /// *O*(log `exp`)
206    ///
207    /// # Example
208    ///
209    /// ```
210    /// use lib_modulo::Modulus64;
211    ///
212    /// let modulus = Modulus64::new(1001);
213    /// let residue = modulus.residue(2);
214    /// for exp in 0..64 {
215    ///     assert_eq!(residue.pow(exp).get(), (1 << exp) % 1001)
216    /// }
217    /// ```
218    #[inline]
219    pub const fn pow(mut self, mut exp: u64) -> Self {
220        // r inv_r = 1 (mod n)
221        let mut result = self.modulus.residue(1).x;
222
223        while exp > 0 {
224            if exp & 1 == 1 {
225                // n < r
226                result = self.modulus.mul(result, self.x)
227            }
228
229            exp >>= 1;
230            // n < r
231            self.x = self.modulus.mul(self.x, self.x)
232        }
233        self.x = result;
234
235        self
236    }
237
238    /// Calculates the modular inverse of `self`, using extended binary GCD algorithm.
239    ///
240    /// Modular inverse can be defined if and only if `self` and the modulus is coprime.
241    ///
242    /// - `Ok(x)` : `x` is the modular inverse.
243    /// - `Err(x)`: `x` is the GCD of `self` and the `modulus`,
244    /// where `gcd(0, a) = gcd(a, 0)` is defined to be `a`.
245    ///
246    /// # Time complexity
247    ///
248    /// *O*(log `self`)
249    ///
250    /// # Example
251    ///
252    /// ```
253    /// use lib_modulo::Modulus64;
254    ///
255    /// // 998_244_353 is a prime number, so modular inverse of n exists iff n != 0 (mod 998_244_353)
256    /// let modulus = Modulus64::new(998_244_353);
257    ///
258    /// for n in 1..500_000 {
259    ///     let n = modulus.residue(n);
260    ///     assert!(n.try_inv().is_ok_and(|i| (i * n).get() == 1));
261    /// }
262    /// // 0 n = 0 != 1 for any integer n
263    /// assert!(modulus.residue(0).try_inv().is_err());
264    /// ```
265    #[inline]
266    pub const fn try_inv(self) -> Result<Self, u64> {
267        let mut a = self.get();
268        let Self { modulus, .. } = self;
269
270        // performs extended binary gcd
271        //
272        // invariants: a = [a] x,  b = [a] y (mod n) where [a] is initial value
273        let mut b = modulus.n;
274        let mut x = modulus.residue(1).x; // 1 r mod n
275        let mut y = 0; // 0 r mod n
276        let frac_1_2 = modulus.residue(modulus.n.div_ceil(2));
277
278        while a > 0 {
279            x = modulus.mul(x, frac_1_2.pow(a.trailing_zeros() as u64).x);
280            a >>= a.trailing_zeros();
281
282            if a < b {
283                (a, b) = (b, a);
284                (x, y) = (y, x);
285            }
286            a -= b;
287            let (diff, b) = x.overflowing_sub(y);
288            x = if b {
289                diff.wrapping_add(modulus.n)
290            } else {
291                diff
292            };
293        }
294
295        // b = gcd([a], [b])
296        if b == 1 {
297            Ok(Self { x: y, modulus })
298        } else {
299            Err(b)
300        }
301    }
302}
303
304impl<'a> Add for Residue64<'a> {
305    type Output = Self;
306
307    #[inline(always)]
308    fn add(mut self, rhs: Self) -> Self {
309        let (sum, b) = self.x.overflowing_add(rhs.x);
310        self.x = if b || sum >= self.modulus.n {
311            sum.wrapping_sub(self.modulus.n)
312        } else {
313            sum
314        };
315
316        self
317    }
318}
319
320impl<'a> AddAssign for Residue64<'a> {
321    #[inline(always)]
322    fn add_assign(&mut self, rhs: Self) {
323        *self = *self + rhs
324    }
325}
326
327impl<'a> Sub for Residue64<'a> {
328    type Output = Self;
329
330    #[inline(always)]
331    fn sub(mut self, rhs: Self) -> Self {
332        let (diff, b) = self.x.overflowing_sub(rhs.x);
333        self.x = if b {
334            diff.wrapping_add(self.modulus.n)
335        } else {
336            diff
337        };
338
339        self
340    }
341}
342
343impl<'a> SubAssign for Residue64<'a> {
344    #[inline(always)]
345    fn sub_assign(&mut self, rhs: Self) {
346        *self = *self - rhs
347    }
348}
349
350impl<'a> Mul for Residue64<'a> {
351    type Output = Self;
352
353    #[inline(always)]
354    fn mul(mut self, rhs: Self) -> Self {
355        // n < r
356        self.x = self.modulus.mul(self.x, rhs.x);
357
358        self
359    }
360}
361
362impl<'a> MulAssign for Residue64<'a> {
363    #[inline(always)]
364    fn mul_assign(&mut self, rhs: Self) {
365        *self = *self * rhs
366    }
367}
368
369impl<'a> Neg for Residue64<'a> {
370    type Output = Self;
371
372    #[inline(always)]
373    fn neg(mut self) -> Self::Output {
374        // (x - x) r = 0 (mod n)
375        self.x = if self.x == 0 {
376            self.x
377        } else {
378            self.modulus.n - self.x
379        };
380
381        self
382    }
383}
384
385#[cfg(test)]
386mod tests {
387    use super::*;
388
389    use proptest::prelude::*;
390
391    proptest! {
392        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
393        #[test]
394        fn mul(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
395            let modulus = Modulus64::new(n);
396
397            let res = modulus.residue(x);
398            assert_eq!(res.get(), x % n)
399        }
400    }
401
402    proptest! {
403        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
404        #[test]
405        fn pow(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
406            let modulus = Modulus64::new(n);
407
408            let res = modulus.residue(x);
409            let mut naive = 1;
410            for i in 0..100 {
411                assert_eq!(res.pow(i).get(), naive, "exp = {i}");
412                naive = (naive as u128 * x as u128 % n as u128) as u64
413            }
414        }
415    }
416
417    proptest! {
418        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
419        #[test]
420        fn divisible(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
421            let modulus = Modulus64::new(n);
422
423            assert_eq!(modulus.can_divide(x), x % n == 0);
424            for m in std::iter::successors(Some(n), |m| m.checked_add(n)).take(100) {
425                assert!(modulus.can_divide(m))
426            }
427        }
428    }
429
430    fn binary_gcd(mut a: u64, mut b: u64) -> u64 {
431        if b == 0 {
432            return a;
433        }
434
435        let shift = (a | b).trailing_zeros();
436        b >>= b.trailing_zeros();
437
438        while a != 0 {
439            a >>= a.trailing_zeros();
440
441            if a < b {
442                (a, b) = (b, a)
443            }
444            a -= b
445        }
446
447        b << shift
448    }
449
450    proptest! {
451        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
452        #[test]
453        fn try_inv(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
454            let modulus = Modulus64::new(n);
455            let res = modulus.residue(x);
456
457            match res.try_inv() {
458                Ok(inv) => assert_eq!((inv * res).get(), 1),
459                Err(gcd) => {
460                    assert!(res.get() % gcd == 0);
461                    assert!(res.modulus() % gcd == 0);
462                    assert_eq!(binary_gcd(res.get() / gcd, res.modulus() / gcd), 1);
463                }
464            }
465        }
466    }
467}