Skip to main content

lib_modulo/
residue64.rs

1use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2
3/// Factory of [`Residue64`].
4///
5/// See documentation of [`Residue64`] for details.
6#[allow(clippy::derived_hash_with_manual_eq)]
7#[derive(Debug, Clone, Eq, Hash)]
8pub struct Modulus64 {
9    // n inv_n = 1 (mod r = 2^32 or 2^64)
10    n: u64,
11    inv_n: u64,
12    r2_mod_n: u64,
13}
14
15impl Modulus64 {
16    /// Creates new instance with the given modulus.
17    ///
18    /// # Panics
19    ///
20    /// - modulus `n` should be an odd number.
21    #[inline]
22    #[must_use]
23    pub const fn new(n: u64) -> Self {
24        assert!(n & 1 == 1, "modulus should be an odd number");
25
26        let inv_n = {
27            const TABLE: u32 = {
28                // | n     | 1 | 3  | 5  | 7 | 9 | 11 | 13 | 15 |
29                // | inv_n | 1 | 11 | 13 | 7 | 9 | 3  | 5  | 15 | <- 4 bits * 8
30                let inv_n = [1, 11, 13, 7, 9, 3, 5, 15];
31
32                let mut table = 0;
33                let mut i = 0;
34                while i < 8 {
35                    table |= inv_n[i] << (i * 4);
36                    i += 1;
37                }
38
39                table
40            };
41            // n inv_n = 1 (mod 8)
42            let mut inv_n = ((TABLE >> ((n & 0b1110) * 2)) & 0b1111) as u64;
43
44            let mut d = const { u64::BITS.ilog2() - 2 };
45            while d > 0 {
46                inv_n = inv_n.wrapping_mul(2_u64.wrapping_sub(n.wrapping_mul(inv_n)));
47                d -= 1;
48            }
49            debug_assert!(n.wrapping_mul(inv_n) == 1);
50
51            inv_n
52        };
53        let r2_mod_n = ((n as u128).wrapping_neg() % (n as u128)) as u64;
54
55        Self { n, inv_n, r2_mod_n }
56    }
57
58    /// Calculates the residue of `x` modulo `self`.
59    ///
60    /// # Example
61    ///
62    /// ```
63    /// use lib_modulo::Modulus64;
64    ///
65    /// let modulus = Modulus64::new(5);
66    /// assert_eq!(modulus.residue(8).get(), 3)
67    /// ```
68    #[must_use]
69    pub const fn residue(&self, x: u64) -> Residue64<'_> {
70        // `x r2 < r n`
71        let x = self.mul(x, self.r2_mod_n);
72
73        Residue64 { x, modulus: self }
74    }
75
76    /// Performs Montgomery multiplication.
77    ///
78    /// if `lhs rhs < n r`, then `result < n`
79    const fn mul(&self, lhs: u64, rhs: u64) -> u64 {
80        self.mul_add(lhs, rhs, 0)
81    }
82
83    /// Performs `lhs rhs + add`.
84    ///
85    /// If `lhs rhs + add < n r`, then the result is less than `n`.
86    const fn mul_add(&self, lhs: u64, rhs: u64, add: u64) -> u64 {
87        // FIXME: use `a.widening_mul(b)`
88        let (x_hi, x_lo) = {
89            let x = lhs as u128 * rhs as u128 + add as u128;
90            ((x >> u64::BITS) as u64, x as u64)
91        };
92        // FIXME: use `mul_hi()`
93        // y = x n nn = x (mod r) => y_lo = x_lo
94        let y_hi = ((x_lo.wrapping_mul(self.inv_n) as u128 * self.n as u128) >> u64::BITS) as u64;
95        // x - y = 0 (mod r), x - y = x (mod n) => z = x inv_r (mod n)
96        let (z, b) = x_hi.overflowing_sub(y_hi);
97
98        // x < n r, y < n r => |z| < n
99        if b {
100            z.wrapping_add(self.n)
101        } else {
102            z
103        }
104    }
105
106    /// Checks whether `x` is multiple of `self`.
107    ///
108    /// # Example
109    ///
110    /// ```
111    /// use lib_modulo::Modulus64;
112    ///
113    /// for n in (1..1 << 10).step_by(2) {
114    ///     let modulus  = Modulus64::new(n);
115    ///
116    ///     (0..1 << 10).for_each(|k| assert!(modulus.can_divide(n * k)));
117    /// }
118    /// ```
119    #[inline]
120    #[must_use]
121    pub const fn can_divide(&self, x: u64) -> bool {
122        self.residue(x).is_zero()
123    }
124}
125
126impl PartialEq for Modulus64 {
127    fn eq(&self, other: &Self) -> bool {
128        // other parameters depend on `n`
129        self.n == other.n
130    }
131}
132
133/// A residue with an odd modulus that fits in `2^64`.
134///
135/// # Fast modular multiplication
136///
137/// [`Residue64`] provides fast modular multiplication using [Montgomery multiplication].
138/// Since this method provides modular multiplication without division,
139/// it is approximately twice as fast.
140///
141/// [Montgomery multiplication]: https://doi.org/10.1090/s0025-5718-1985-0777282-x
142///
143/// # Usage
144///
145/// ```
146/// use lib_modulo::Modulus64;
147///
148/// // runtime-specified *odd* modulus
149/// let modulus = 5;
150///
151/// let modulus = Modulus64::new(modulus); // slow
152/// let n = modulus.residue(2) * modulus.residue(3); // fast
153/// assert_eq!(n.get(), 1);
154/// ```
155///
156/// Two residues with different modulus can interact, but the result will be meaningless.
157/// It is highly recommended to use a block to ensure that [`Modulus64`], therefore [`Residue64`]s, are dropped.
158#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
159pub struct Residue64<'a> {
160    modulus: &'a Modulus64,
161    // x r (mod n)
162    x: u64,
163}
164
165impl Residue64<'_> {
166    /// Extract the internal representation of `self`.
167    ///
168    /// ```
169    /// use lib_modulo::{Modulus64, Raw64};
170    ///
171    /// let modulus = Modulus64::new(1001);
172    /// // save memory
173    /// let residues: Vec<Raw64> = (1..=1000).map(|x| modulus.residue(x).into_raw()).collect();
174    ///
175    /// // `Residue64` and `raw64` can interact.
176    /// // The caller must ensure that both operands shares the same modulus.
177    /// let double_sum = residues.into_iter().fold(modulus.residue(0), |sum, r| r + sum + r);
178    /// assert_eq!(double_sum, modulus.residue((1 + 1000) * 1000));
179    /// ```
180    #[must_use]
181    pub const fn into_raw(self) -> Raw64 {
182        Raw64 { x: self.x }
183    }
184
185    /// Returns the residue.
186    ///
187    /// # Example
188    ///
189    /// ```
190    /// use lib_modulo::Modulus64;
191    ///
192    /// let modulus  = Modulus64::new(5);
193    /// let n = modulus.residue(7);
194    /// assert_eq!(n.get(), 2);
195    /// ```
196    #[must_use]
197    pub const fn get(&self) -> u64 {
198        self.modulus.mul(self.x, 1)
199    }
200
201    /// Returns the modulus.
202    ///
203    /// # Example
204    ///
205    /// ```
206    /// use lib_modulo::Modulus64;
207    ///
208    /// let modulus  = Modulus64::new(5);
209    /// let n = modulus.residue(7);
210    /// assert_eq!(n.modulus(), 5);
211    /// ```
212    #[must_use]
213    pub const fn modulus(&self) -> u64 {
214        self.modulus.n
215    }
216
217    /// Checks whether `self` is `0`.
218    ///
219    /// # Example
220    ///
221    /// ```
222    /// use lib_modulo::Modulus64;
223    ///
224    /// let modulus  = Modulus64::new(3);
225    /// assert_eq!(modulus.residue(6).get(), 0);
226    /// ```
227    #[must_use]
228    pub const fn is_zero(self) -> bool {
229        self.x == 0
230    }
231
232    /// Raises `self` to the power of `exp`, using exponentiation by squaring.
233    ///
234    /// # Time complexity
235    ///
236    /// *O*(log `exp`)
237    ///
238    /// # Example
239    ///
240    /// ```
241    /// use lib_modulo::Modulus64;
242    ///
243    /// let modulus = Modulus64::new(1001);
244    /// let residue = modulus.residue(2);
245    /// for exp in 0..64 {
246    ///     assert_eq!(residue.pow(exp).get(), (1 << exp) % 1001)
247    /// }
248    /// ```
249    #[inline]
250    #[must_use]
251    pub const fn pow(mut self, mut exp: u64) -> Self {
252        // r inv_r = 1 (mod n)
253        let mut result = self.modulus.residue(1).x;
254
255        while exp > 0 {
256            if exp & 1 == 1 {
257                // n < r
258                result = self.modulus.mul(result, self.x);
259            }
260
261            exp >>= 1;
262            // n < r
263            self.x = self.modulus.mul(self.x, self.x);
264        }
265        self.x = result;
266
267        self
268    }
269
270    /// Calculates the modular inverse of `self`, using extended binary GCD algorithm.
271    ///
272    /// Modular inverse can be defined if and only if `self` and the modulus is coprime.
273    ///
274    /// - `Ok(x)` : `x` is the modular inverse.
275    /// - `Err(x)`: `x` is the GCD of `self` and the `modulus`,
276    ///   where `gcd(0, a) = gcd(a, 0)` is defined to be `a`.
277    ///
278    /// # Time complexity
279    ///
280    /// *O*(log `self`)
281    ///
282    /// # Example
283    ///
284    /// ```
285    /// use lib_modulo::Modulus64;
286    ///
287    /// // 998_244_353 is a prime number, so modular inverse of n exists iff n != 0 (mod 998_244_353)
288    /// let modulus = Modulus64::new(998_244_353);
289    ///
290    /// for n in 1..500_000 {
291    ///     let n = modulus.residue(n);
292    ///     assert!(n.inv().is_ok_and(|i| (i * n).get() == 1));
293    /// }
294    /// // 0 n = 0 != 1 for any integer n
295    /// assert!(modulus.residue(0).inv().is_err());
296    /// ```
297    #[inline]
298    pub const fn inv(self) -> Result<Self, u64> {
299        let mut a = self.get();
300        let Self { modulus, .. } = self;
301
302        // performs extended binary gcd
303        //
304        // invariants: a = [a] x,  b = [a] y (mod n) where [a] is initial value
305        let mut b = modulus.n;
306        let mut x = modulus.residue(1).x; // 1 r mod n
307        let mut y = 0; // 0 r mod n
308        let frac_1_2 = modulus.residue(modulus.n.div_ceil(2));
309
310        while a > 0 {
311            x = modulus.mul(x, frac_1_2.pow(a.trailing_zeros() as u64).x);
312            a >>= a.trailing_zeros();
313
314            if a < b {
315                (a, b) = (b, a);
316                (x, y) = (y, x);
317            }
318            a -= b;
319            let (diff, b) = x.overflowing_sub(y);
320            x = if b {
321                diff.wrapping_add(modulus.n)
322            } else {
323                diff
324            };
325        }
326
327        // b = gcd([a], [b])
328        if b == 1 {
329            Ok(Self { x: y, modulus })
330        } else {
331            Err(b)
332        }
333    }
334}
335
336impl Add for Residue64<'_> {
337    type Output = Self;
338
339    fn add(self, rhs: Self) -> Self {
340        self + rhs.into_raw()
341    }
342}
343
344impl AddAssign for Residue64<'_> {
345    fn add_assign(&mut self, rhs: Self) {
346        *self = *self + rhs;
347    }
348}
349
350impl Sub for Residue64<'_> {
351    type Output = Self;
352
353    fn sub(self, rhs: Self) -> Self {
354        self - rhs.into_raw()
355    }
356}
357
358impl SubAssign for Residue64<'_> {
359    fn sub_assign(&mut self, rhs: Self) {
360        *self = *self - rhs;
361    }
362}
363
364impl Mul for Residue64<'_> {
365    type Output = Self;
366
367    fn mul(self, rhs: Self) -> Self {
368        self * rhs.into_raw()
369    }
370}
371
372impl MulAssign for Residue64<'_> {
373    fn mul_assign(&mut self, rhs: Self) {
374        *self = *self * rhs;
375    }
376}
377
378impl Neg for Residue64<'_> {
379    type Output = Self;
380
381    fn neg(mut self) -> Self::Output {
382        // (x - x) r = 0 (mod n)
383        self.x = if self.x == 0 {
384            self.x
385        } else {
386            self.modulus.n - self.x
387        };
388
389        self
390    }
391}
392
393/// An internal representation of [`Residue64`] without an associated [`Modulus64`].
394///
395/// Conceptually, [`Residue64`] = [`Raw64`] + [`Modulus64`].
396/// [`Raw64`] stores the value part alone, without holding a reference to its modulus.
397///
398/// This separation is useful for reducing the size of collections of [`Residue64`]
399/// and for avoiding self-referential structures when a type needs to contain both
400/// a residue and its modulus.
401#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
402pub struct Raw64 {
403    x: u64,
404}
405
406impl Raw64 {
407    /// Attaches a modulus and returns a [`Residue64`].
408    ///
409    /// Typically, this only needs to be called once per computation
410    /// because `Raw64` and `Residue64` can interact.
411    ///
412    /// # Caution
413    ///
414    /// This does not perform validation or reduction.
415    /// The caller must ensure the modulus is correct for this value.
416    #[must_use]
417    pub const fn into_residue(self, modulus: &Modulus64) -> Residue64<'_> {
418        Residue64 { modulus, x: self.x }
419    }
420}
421
422impl<'a> From<Residue64<'a>> for Raw64 {
423    fn from(residue: Residue64<'a>) -> Self {
424        Self { x: residue.x }
425    }
426}
427
428impl<'a> Add<Raw64> for Residue64<'a> {
429    type Output = Residue64<'a>;
430
431    /// Performs the `+` operation.
432    ///
433    /// # Caution
434    ///
435    /// The caller must ensure that both operands shares the same modulus.
436    fn add(mut self, rhs: Raw64) -> Self::Output {
437        let (sum, b) = self.x.overflowing_add(rhs.x);
438        self.x = if b || sum >= self.modulus.n {
439            sum.wrapping_sub(self.modulus.n)
440        } else {
441            sum
442        };
443
444        self
445    }
446}
447
448impl<'a> Add<Residue64<'a>> for Raw64 {
449    type Output = Residue64<'a>;
450
451    /// Performs the `+` operation.
452    ///
453    /// # Caution
454    ///
455    /// The caller must ensure that both operands shares the same modulus.
456    fn add(self, rhs: Residue64<'a>) -> Self::Output {
457        rhs + self
458    }
459}
460
461impl AddAssign<Raw64> for Residue64<'_> {
462    /// Performs the `+=` operation.
463    ///
464    /// # Caution
465    ///
466    /// The caller must ensure that both operands shares the same modulus.
467    fn add_assign(&mut self, rhs: Raw64) {
468        *self = *self + rhs;
469    }
470}
471
472impl<'a> Sub<Raw64> for Residue64<'a> {
473    type Output = Residue64<'a>;
474
475    /// Performs the `-` operation.
476    ///
477    /// # Caution
478    ///
479    /// The caller must ensure that both operands shares the same modulus.
480    fn sub(mut self, rhs: Raw64) -> Self::Output {
481        let (diff, b) = self.x.overflowing_sub(rhs.x);
482        self.x = if b {
483            diff.wrapping_add(self.modulus.n)
484        } else {
485            diff
486        };
487
488        self
489    }
490}
491
492impl<'a> Sub<Residue64<'a>> for Raw64 {
493    type Output = Residue64<'a>;
494
495    /// Performs the `-` operation.
496    ///
497    /// # Caution
498    ///
499    /// The caller must ensure that both operands shares the same modulus.
500    fn sub(self, mut rhs: Residue64<'a>) -> Self::Output {
501        let (diff, b) = self.x.overflowing_sub(rhs.x);
502        rhs.x = if b {
503            diff.wrapping_add(rhs.modulus.n)
504        } else {
505            diff
506        };
507
508        rhs
509    }
510}
511
512impl SubAssign<Raw64> for Residue64<'_> {
513    /// Performs the `-=` operation.
514    ///
515    /// # Caution
516    ///
517    /// The caller must ensure that both operands shares the same modulus.
518    fn sub_assign(&mut self, rhs: Raw64) {
519        *self = *self - rhs;
520    }
521}
522
523impl<'a> Mul<Raw64> for Residue64<'a> {
524    type Output = Residue64<'a>;
525
526    /// Performs the `*` operation.
527    ///
528    /// # Caution
529    ///
530    /// The caller must ensure that both operands shares the same modulus.
531    fn mul(mut self, rhs: Raw64) -> Self::Output {
532        // n < r
533        self.x = self.modulus.mul(self.x, rhs.x);
534
535        self
536    }
537}
538
539impl<'a> Mul<Residue64<'a>> for Raw64 {
540    type Output = Residue64<'a>;
541
542    /// Performs the `*` operation.
543    ///
544    /// # Caution
545    ///
546    /// The caller must ensure that both operands shares the same modulus.
547    fn mul(self, rhs: Residue64<'a>) -> Self::Output {
548        rhs * self
549    }
550}
551
552impl MulAssign<Raw64> for Residue64<'_> {
553    /// Performs the `*=` operation.
554    ///
555    /// # Caution
556    ///
557    /// The caller must ensure that both operands shares the same modulus.
558    fn mul_assign(&mut self, rhs: Raw64) {
559        *self = *self * rhs;
560    }
561}
562
563#[cfg(test)]
564mod tests {
565    use super::*;
566
567    use proptest::prelude::*;
568
569    proptest! {
570        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
571        #[test]
572        fn mul(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
573            let modulus = Modulus64::new(n);
574
575            let res = modulus.residue(x);
576            assert_eq!(res.get(), x % n)
577        }
578    }
579
580    proptest! {
581        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
582        #[test]
583        fn pow(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
584            let modulus = Modulus64::new(n);
585
586            let res = modulus.residue(x);
587            let mut naive = 1;
588            for i in 0..100 {
589                assert_eq!(res.pow(i).get(), naive, "exp = {i}");
590                naive = (naive as u128 * x as u128 % n as u128) as u64
591            }
592        }
593    }
594
595    proptest! {
596        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
597        #[test]
598        fn divisible(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
599            let modulus = Modulus64::new(n);
600
601            assert_eq!(modulus.can_divide(x), x % n == 0);
602        }
603    }
604
605    fn binary_gcd(mut a: u64, mut b: u64) -> u64 {
606        if b == 0 {
607            return a;
608        }
609
610        let shift = (a | b).trailing_zeros();
611        b >>= b.trailing_zeros();
612
613        while a != 0 {
614            a >>= a.trailing_zeros();
615
616            if a < b {
617                (a, b) = (b, a)
618            }
619            a -= b
620        }
621
622        b << shift
623    }
624
625    proptest! {
626        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
627        #[test]
628        fn inv(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
629            let modulus = Modulus64::new(n);
630            let res = modulus.residue(x);
631
632            match res.inv() {
633                Ok(inv) => assert_eq!((inv * res).get(), 1),
634                Err(gcd) => {
635                    assert!(res.get() % gcd == 0);
636                    assert!(res.modulus() % gcd == 0);
637                    assert_eq!(binary_gcd(res.get() / gcd, res.modulus() / gcd), 1);
638                }
639            }
640        }
641    }
642}