Skip to main content

lib_modulo/
residue64.rs

1use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2
3/// Factory of [`Residue64`].
4///
5/// See documentation of [`Residue64`] for details.
6#[allow(clippy::derived_hash_with_manual_eq)]
7#[derive(Debug, Clone, Eq, Hash)]
8pub struct Modulus64 {
9    // n inv_n = 1 (mod r = 2^32 or 2^64)
10    n: u64,
11    inv_n: u64,
12    r2_mod_n: u64,
13}
14
15impl Modulus64 {
16    /// Creates new instance with the given modulus.
17    ///
18    /// # Panics
19    ///
20    /// - modulus `n` should be an odd number.
21    #[inline]
22    pub const fn new(n: u64) -> Self {
23        assert!(n & 1 == 1, "modulus should be an odd number");
24
25        let inv_n = {
26            const TABLE: u32 = {
27                // | n     | 1 | 3  | 5  | 7 | 9 | 11 | 13 | 15 |
28                // | inv_n | 1 | 11 | 13 | 7 | 9 | 3  | 5  | 15 | <- 4 bits * 8
29                let inv_n = [1, 11, 13, 7, 9, 3, 5, 15];
30
31                let mut table = 0;
32                let mut i = 0;
33                while i < 8 {
34                    table |= inv_n[i] << (i * 4);
35                    i += 1;
36                }
37
38                table
39            };
40            // n inv_n = 1 (mod 8)
41            let mut inv_n = ((TABLE >> ((n & 0b1110) * 2)) & 0b1111) as u64;
42
43            let mut d = const { u64::BITS.ilog2() - 2 };
44            while d > 0 {
45                inv_n = inv_n.wrapping_mul(2_u64.wrapping_sub(n.wrapping_mul(inv_n)));
46                d -= 1;
47            }
48            debug_assert!(n.wrapping_mul(inv_n) == 1);
49
50            inv_n
51        };
52        let r2_mod_n = ((n as u128).wrapping_neg() % (n as u128)) as u64;
53
54        Self { n, inv_n, r2_mod_n }
55    }
56
57    /// Calculates the residue of `x` modulo `self`.
58    ///
59    /// # Example
60    ///
61    /// ```
62    /// use lib_modulo::Modulus64;
63    ///
64    /// let modulus = Modulus64::new(5);
65    /// assert_eq!(modulus.residue(8).get(), 3)
66    /// ```
67    #[inline(always)]
68    pub const fn residue(&self, x: u64) -> Residue64<'_> {
69        // `x r2 < r n`
70        let x = self.mul(x, self.r2_mod_n);
71
72        Residue64 { x, modulus: self }
73    }
74
75    /// Performs Montgomery multiplication.
76    ///
77    /// if `lhs rhs < n r`, then `result < n`
78    #[inline(always)]
79    const fn mul(&self, lhs: u64, rhs: u64) -> u64 {
80        self.mul_add(lhs, rhs, 0)
81    }
82
83    /// Performs `lhs rhs + add`.
84    ///
85    /// If `lhs rhs + add < n r`, then the result is less than `n`.
86    #[inline(always)]
87    const fn mul_add(&self, lhs: u64, rhs: u64, add: u64) -> u64 {
88        // FIXME: use `a.widening_mul(b)`
89        let (x_hi, x_lo) = {
90            let x = lhs as u128 * rhs as u128 + add as u128;
91            ((x >> u64::BITS) as u64, x as u64)
92        };
93        // FIXME: use `mul_hi()`
94        // y = x n nn = x (mod r) => y_lo = x_lo
95        let y_hi = ((x_lo.wrapping_mul(self.inv_n) as u128 * self.n as u128) >> u64::BITS) as u64;
96        // x - y = 0 (mod r), x - y = x (mod n) => z = x inv_r (mod n)
97        let (z, b) = x_hi.overflowing_sub(y_hi);
98
99        // x < n r, y < n r => |z| < n
100        if b {
101            z.wrapping_add(self.n)
102        } else {
103            z
104        }
105    }
106
107    /// Checks whether `x` is multiple of `self`.
108    ///
109    /// # Example
110    ///
111    /// ```
112    /// use lib_modulo::Modulus64;
113    ///
114    /// for n in (1..1 << 10).step_by(2) {
115    ///     let modulus  = Modulus64::new(n);
116    ///
117    ///     (0..1 << 10).for_each(|k| assert!(modulus.can_divide(n * k)));
118    /// }
119    /// ```
120    #[inline]
121    pub const fn can_divide(&self, x: u64) -> bool {
122        self.residue(x).is_zero()
123    }
124}
125
126impl PartialEq for Modulus64 {
127    fn eq(&self, other: &Self) -> bool {
128        // other parameters depend on `n`
129        self.n == other.n
130    }
131}
132
133/// A residue with an odd modulus that fits in `2^64`.
134///
135/// # Fast modular multiplication
136///
137/// [`Residue64`] provides fast modular multiplication using [Montgomery multiplication].
138/// Since this method provides modular multiplication without division,
139/// it is approximately twice as fast.
140///
141/// [Montgomery multiplication]: https://doi.org/10.1090/s0025-5718-1985-0777282-x
142///
143/// # Usage
144///
145/// ```
146/// use lib_modulo::Modulus64;
147///
148/// // runtime-specified *odd* modulus
149/// let modulus = 5;
150///
151/// let modulus = Modulus64::new(modulus); // slow
152/// let n = modulus.residue(2) * modulus.residue(3); // fast
153/// assert_eq!(n.get(), 1);
154/// ```
155///
156/// Two residues with different modulus can interact, but the result will be meaningless.
157/// It is highly recommended to use a block to ensure that [`Modulus64`], therefore [`Residue64`]s, are dropped.
158#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
159pub struct Residue64<'a> {
160    modulus: &'a Modulus64,
161    // x r (mod n)
162    x: u64,
163}
164
165impl<'a> Residue64<'a> {
166    /// Extract the internal representation of `self`.
167    ///
168    /// ```
169    /// use lib_modulo::{Modulus64, Raw64};
170    ///
171    /// let modulus = Modulus64::new(1001);
172    /// // save memory
173    /// let residues: Vec<Raw64> = (1..=1000).map(|x| modulus.residue(x).into_raw()).collect();
174    /// ```
175    #[inline(always)]
176    pub const fn into_raw(self) -> Raw64 {
177        Raw64 { x: self.x }
178    }
179
180    /// Returns the residue.
181    ///
182    /// # Example
183    ///
184    /// ```
185    /// use lib_modulo::Modulus64;
186    ///
187    /// let modulus  = Modulus64::new(5);
188    /// let n = modulus.residue(7);
189    /// assert_eq!(n.get(), 2);
190    /// ```
191    #[inline(always)]
192    pub const fn get(&self) -> u64 {
193        self.modulus.mul(self.x, 1)
194    }
195
196    /// Returns the modulus.
197    ///
198    /// # Example
199    ///
200    /// ```
201    /// use lib_modulo::Modulus64;
202    ///
203    /// let modulus  = Modulus64::new(5);
204    /// let n = modulus.residue(7);
205    /// assert_eq!(n.modulus(), 5);
206    /// ```
207    #[inline(always)]
208    pub const fn modulus(&self) -> u64 {
209        self.modulus.n
210    }
211
212    /// Checks whether `self` is `0`.
213    ///
214    /// # Example
215    ///
216    /// ```
217    /// use lib_modulo::Modulus64;
218    ///
219    /// let modulus  = Modulus64::new(3);
220    /// assert_eq!(modulus.residue(6).get(), 0);
221    /// ```
222    #[inline(always)]
223    pub const fn is_zero(self) -> bool {
224        self.x == 0
225    }
226
227    /// Raises `self` to the power of `exp`, using exponentiation by squaring.
228    ///
229    /// # Time complexity
230    ///
231    /// *O*(log `exp`)
232    ///
233    /// # Example
234    ///
235    /// ```
236    /// use lib_modulo::Modulus64;
237    ///
238    /// let modulus = Modulus64::new(1001);
239    /// let residue = modulus.residue(2);
240    /// for exp in 0..64 {
241    ///     assert_eq!(residue.pow(exp).get(), (1 << exp) % 1001)
242    /// }
243    /// ```
244    #[inline]
245    pub const fn pow(mut self, mut exp: u64) -> Self {
246        // r inv_r = 1 (mod n)
247        let mut result = self.modulus.residue(1).x;
248
249        while exp > 0 {
250            if exp & 1 == 1 {
251                // n < r
252                result = self.modulus.mul(result, self.x)
253            }
254
255            exp >>= 1;
256            // n < r
257            self.x = self.modulus.mul(self.x, self.x)
258        }
259        self.x = result;
260
261        self
262    }
263
264    /// Calculates the modular inverse of `self`, using extended binary GCD algorithm.
265    ///
266    /// Modular inverse can be defined if and only if `self` and the modulus is coprime.
267    ///
268    /// - `Ok(x)` : `x` is the modular inverse.
269    /// - `Err(x)`: `x` is the GCD of `self` and the `modulus`,
270    ///   where `gcd(0, a) = gcd(a, 0)` is defined to be `a`.
271    ///
272    /// # Time complexity
273    ///
274    /// *O*(log `self`)
275    ///
276    /// # Example
277    ///
278    /// ```
279    /// use lib_modulo::Modulus64;
280    ///
281    /// // 998_244_353 is a prime number, so modular inverse of n exists iff n != 0 (mod 998_244_353)
282    /// let modulus = Modulus64::new(998_244_353);
283    ///
284    /// for n in 1..500_000 {
285    ///     let n = modulus.residue(n);
286    ///     assert!(n.inv().is_ok_and(|i| (i * n).get() == 1));
287    /// }
288    /// // 0 n = 0 != 1 for any integer n
289    /// assert!(modulus.residue(0).inv().is_err());
290    /// ```
291    #[inline]
292    pub const fn inv(self) -> Result<Self, u64> {
293        let mut a = self.get();
294        let Self { modulus, .. } = self;
295
296        // performs extended binary gcd
297        //
298        // invariants: a = [a] x,  b = [a] y (mod n) where [a] is initial value
299        let mut b = modulus.n;
300        let mut x = modulus.residue(1).x; // 1 r mod n
301        let mut y = 0; // 0 r mod n
302        let frac_1_2 = modulus.residue(modulus.n.div_ceil(2));
303
304        while a > 0 {
305            x = modulus.mul(x, frac_1_2.pow(a.trailing_zeros() as u64).x);
306            a >>= a.trailing_zeros();
307
308            if a < b {
309                (a, b) = (b, a);
310                (x, y) = (y, x);
311            }
312            a -= b;
313            let (diff, b) = x.overflowing_sub(y);
314            x = if b {
315                diff.wrapping_add(modulus.n)
316            } else {
317                diff
318            };
319        }
320
321        // b = gcd([a], [b])
322        if b == 1 {
323            Ok(Self { x: y, modulus })
324        } else {
325            Err(b)
326        }
327    }
328}
329
330impl<'a> Add for Residue64<'a> {
331    type Output = Self;
332
333    #[inline(always)]
334    fn add(mut self, rhs: Self) -> Self {
335        let (sum, b) = self.x.overflowing_add(rhs.x);
336        self.x = if b || sum >= self.modulus.n {
337            sum.wrapping_sub(self.modulus.n)
338        } else {
339            sum
340        };
341
342        self
343    }
344}
345
346impl<'a> AddAssign for Residue64<'a> {
347    #[inline(always)]
348    fn add_assign(&mut self, rhs: Self) {
349        *self = *self + rhs
350    }
351}
352
353impl<'a> Sub for Residue64<'a> {
354    type Output = Self;
355
356    #[inline(always)]
357    fn sub(mut self, rhs: Self) -> Self {
358        let (diff, b) = self.x.overflowing_sub(rhs.x);
359        self.x = if b {
360            diff.wrapping_add(self.modulus.n)
361        } else {
362            diff
363        };
364
365        self
366    }
367}
368
369impl<'a> SubAssign for Residue64<'a> {
370    #[inline(always)]
371    fn sub_assign(&mut self, rhs: Self) {
372        *self = *self - rhs
373    }
374}
375
376impl<'a> Mul for Residue64<'a> {
377    type Output = Self;
378
379    #[inline(always)]
380    fn mul(mut self, rhs: Self) -> Self {
381        // n < r
382        self.x = self.modulus.mul(self.x, rhs.x);
383
384        self
385    }
386}
387
388impl<'a> MulAssign for Residue64<'a> {
389    #[inline(always)]
390    fn mul_assign(&mut self, rhs: Self) {
391        *self = *self * rhs
392    }
393}
394
395impl<'a> Neg for Residue64<'a> {
396    type Output = Self;
397
398    #[inline(always)]
399    fn neg(mut self) -> Self::Output {
400        // (x - x) r = 0 (mod n)
401        self.x = if self.x == 0 {
402            self.x
403        } else {
404            self.modulus.n - self.x
405        };
406
407        self
408    }
409}
410
411/// An internal representation of [`Residue64`] without an associated [`Modulus64`].
412///
413/// Conceptually, [`Residue64`] = [`Raw64`] + [`Modulus64`].
414/// [`Raw64`] stores the value part alone, without holding a reference to its modulus.
415///
416/// This separation is useful for reducing the size of collections of [`Residue64`]
417/// and for avoiding self-referential structures when a type needs to contain both
418/// a residue and its modulus.
419#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
420pub struct Raw64 {
421    x: u64,
422}
423
424impl Raw64 {
425    /// Attaches a modulus and returns a [`Residue64`].
426    ///
427    /// # Caution
428    ///
429    /// This does not perform validation or reduction.
430    /// The caller must ensure the modulus is correct for this value.
431    #[inline(always)]
432    pub const fn into_residue<'a>(self, modulus: &'a Modulus64) -> Residue64<'a> {
433        Residue64 { modulus, x: self.x }
434    }
435}
436
437impl<'a> From<Residue64<'a>> for Raw64 {
438    #[inline(always)]
439    fn from(residue: Residue64<'a>) -> Self {
440        Self { x: residue.x }
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447
448    use proptest::prelude::*;
449
450    proptest! {
451        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
452        #[test]
453        fn mul(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
454            let modulus = Modulus64::new(n);
455
456            let res = modulus.residue(x);
457            assert_eq!(res.get(), x % n)
458        }
459    }
460
461    proptest! {
462        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
463        #[test]
464        fn pow(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
465            let modulus = Modulus64::new(n);
466
467            let res = modulus.residue(x);
468            let mut naive = 1;
469            for i in 0..100 {
470                assert_eq!(res.pow(i).get(), naive, "exp = {i}");
471                naive = (naive as u128 * x as u128 % n as u128) as u64
472            }
473        }
474    }
475
476    proptest! {
477        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
478        #[test]
479        fn divisible(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
480            let modulus = Modulus64::new(n);
481
482            assert_eq!(modulus.can_divide(x), x % n == 0);
483        }
484    }
485
486    fn binary_gcd(mut a: u64, mut b: u64) -> u64 {
487        if b == 0 {
488            return a;
489        }
490
491        let shift = (a | b).trailing_zeros();
492        b >>= b.trailing_zeros();
493
494        while a != 0 {
495            a >>= a.trailing_zeros();
496
497            if a < b {
498                (a, b) = (b, a)
499            }
500            a -= b
501        }
502
503        b << shift
504    }
505
506    proptest! {
507        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
508        #[test]
509        fn inv(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
510            let modulus = Modulus64::new(n);
511            let res = modulus.residue(x);
512
513            match res.inv() {
514                Ok(inv) => assert_eq!((inv * res).get(), 1),
515                Err(gcd) => {
516                    assert!(res.get() % gcd == 0);
517                    assert!(res.modulus() % gcd == 0);
518                    assert_eq!(binary_gcd(res.get() / gcd, res.modulus() / gcd), 1);
519                }
520            }
521        }
522    }
523}