Skip to main content

lib_modulo/
residue32.rs

1use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2
3/// Factory of [`Residue32`].
4///
5/// See documentation of [`Residue32`] for details.
6#[allow(clippy::derived_hash_with_manual_eq)]
7#[derive(Debug, Clone, Hash, Eq)]
8pub struct Modulus32 {
9    // n inv_n = 1 (mod 2^64)
10    n: u64,
11    inv_n: u64,
12    // 2^128 (mod n) * inv_n
13    init: u64,
14    // ceil(2^64 / n)
15    recip: u64,
16}
17
18impl Modulus32 {
19    /// Maximum available modulus.
20    pub const MAX: u32 = 2_654_435_769;
21
22    /// Creates new context for modular arithmetics.
23    ///
24    /// # Panics
25    ///
26    /// - modulus `n` should be an odd integer.
27    /// - modulus `n` should be no more than `2_654_435_769`,
28    ///   which is the floor of `2^32 / GOLDEN_RATIO`.
29    ///
30    /// # Example
31    ///
32    /// ```
33    /// use lib_modulo::Modulus32;
34    ///
35    /// // odd integer less than or equal to 2_654_435_769 is allowed.
36    /// let modulus = Modulus32::new(Modulus32::MAX);
37    /// let modulus = Modulus32::new(3);
38    ///
39    /// // modulus should be an odd integer!
40    /// assert!(std::panic::catch_unwind(|| { Modulus32::new(2); }).is_err())
41    /// ```
42    #[inline]
43    pub const fn new(n: u32) -> Self {
44        assert!(
45            n & 1 == 1,
46            "invalid modulus: modulus should be an odd integer."
47        );
48        assert!(
49            n <= Self::MAX,
50            "invalid modulus: modulus should be no more than 2_654_435_769."
51        );
52
53        let n = n as u64;
54
55        let inv_n = {
56            // 1 * 1 = 3 * 3 = 1 (mod 4)
57            let mut inv_n = n & 3;
58            // n inv_n = 1 (mod 2^k) => (n inv_n - 1)^2 = 0 (mod 2^{2k})
59            // => n inv_n (2 - n inv_n) = 1 (mod 2^{2k})
60            let mut i = u64::BITS.ilog2() - 1;
61            while i > 0 {
62                i -= 1;
63                inv_n = inv_n.wrapping_mul(2_u64.wrapping_sub(n.wrapping_mul(inv_n)));
64            }
65            debug_assert!(n.wrapping_mul(inv_n) == 1);
66
67            inv_n
68        };
69
70        let (div, rem) = {
71            let denom = n.wrapping_neg();
72            (denom / n, denom % n)
73        };
74        // 2^128 (mod n): magic number for converting integer to Plantard representation.
75        let init = rem * rem % n;
76        // ceil(2^64 / n): magic number for fast remainder algorithm
77        let recip = div.wrapping_add(if rem > 0 { 2 } else { 1 });
78
79        Self {
80            n,
81            inv_n,
82            init: init.wrapping_mul(inv_n),
83            recip,
84        }
85    }
86
87    /// Performs Plantard multiplication, i.e. `x, y -> x y / -2^64 (mod n)`.
88    ///
89    /// If `x y < self.n`, then returned value is less than `self.n`.
90    #[inline(always)]
91    const fn mul(&self, x: u64, y: u64) -> u64 {
92        // Plantard reduction: <https://thomas-plantard.github.io/pdf/Plantard21.pdf>
93        let z = self.inv_n.wrapping_mul(x).wrapping_mul(y) >> 32;
94        let z = ((z as u32).wrapping_add(1) as u64 * self.n) >> 32;
95        debug_assert!(z < self.n, "this is a bug in lib-modulo");
96        z
97    }
98
99    /// Calculates the residue of `x` modulo `self`.
100    ///
101    /// # Example
102    ///
103    /// ```
104    /// use lib_modulo::Modulus32;
105    ///
106    /// let modulus = Modulus32::new(5);
107    /// assert_eq!(modulus.residue(8).get(), 3)
108    /// ```
109    #[inline(always)]
110    pub const fn residue(&self, x: u32) -> Residue32<'_> {
111        // fast remainder algorithm
112        // See <https://onlinelibrary.wiley.com/doi/10.1002/spe.2689> for details
113        let x = {
114            let lo = self.recip.wrapping_mul(x as u64);
115            ((lo as u128 * self.n as u128) >> 64) as u64
116        };
117
118        let x = {
119            // multiplication by a constant
120            let x = self.init.wrapping_mul(x) >> 32;
121            ((x as u32).wrapping_add(1) as u64 * self.n) >> 32
122        };
123
124        Residue32 { x, modulus: self }
125    }
126
127    /// Checks whether `x` is divisible by `self`.
128    ///
129    /// # Example
130    ///
131    /// ```
132    /// use lib_modulo::Modulus32;
133    ///
134    /// let modulus = Modulus32::new(9);
135    /// assert!(modulus.can_divide(18));
136    /// assert!(!modulus.can_divide(19));
137    /// ```
138    #[inline(always)]
139    pub const fn can_divide(&self, x: u32) -> bool {
140        self.residue(x).is_zero()
141    }
142
143    /// Checks whether `self` is a prime number.
144    ///
145    /// # Time complexity
146    ///
147    /// *O*(log *self*)
148    ///
149    /// # Example
150    ///
151    /// ```
152    /// use lib_modulo::Modulus32;
153    ///
154    /// for p in [3, 5, 7, 11, 998_244_353, 1_000_000_007] {
155    ///     assert!(Modulus32::new(p).is_prime())
156    /// }
157    /// // Mersenne numbers (prime)
158    /// for d in [5, 7, 13, 17, 19, 31] {
159    ///     assert!(Modulus32::new((1 << d) - 1).is_prime())
160    /// }
161    ///
162    /// // composite numbers
163    /// for i in (3..).step_by(2).take(500) {
164    ///     assert!(!Modulus32::new(i * (i + 2)).is_prime())
165    /// }
166    /// ```
167    #[inline(always)]
168    pub const fn is_prime(&self) -> bool {
169        /// (SELF >> p) & 1 == 1 iff p is prime
170        const TEST_LT_64: u64 = 2891462833508853932;
171        /// (SELF >> n % 30) & 1 == 1 iff n is coprime to 2, 3, and 5
172        const TEST_2_3_5: u32 = 545925250;
173
174        if self.n < 64 {
175            return (TEST_LT_64 >> self.n) & 1 == 1;
176        } else if (TEST_2_3_5 >> (self.n % 30)) & 1 == 0 || self.n % 7 == 0 {
177            return false;
178        }
179
180        let one = self.residue(1).x;
181        let minus_one = self.n - one;
182        debug_assert!(one != 0 && minus_one != 0, "this is a bug in lib-modulo");
183
184        let (d, s) = {
185            let n = self.n - 1;
186            ((n >> n.trailing_zeros()) as u32, n.trailing_zeros() - 1)
187        };
188        let mut i = 0;
189        'test: while i < 3 {
190            let witness = [2, 7, 61][i];
191            i += 1;
192
193            let w = self.residue(witness);
194            if w.is_zero() {
195                continue;
196            }
197
198            let mut w = w.pow(d).x;
199            if w == minus_one || w == one {
200                continue;
201            }
202
203            let mut s = s;
204            while s > 0 {
205                s -= 1;
206                w = self.mul(w, w);
207                if w == minus_one {
208                    continue 'test;
209                }
210            }
211
212            return false;
213        }
214
215        true
216    }
217}
218
219impl PartialEq for Modulus32 {
220    fn eq(&self, other: &Self) -> bool {
221        // other fields depend on `n`
222        self.n == other.n
223    }
224}
225
226/// Residue with odd modulus which is no more than `2_654_435_769`.
227///
228/// # Fast modular multiplication
229///
230/// [`Residue32`] provides fast modular multiplication called [Plantard multiplication].
231/// This method saves one multiplication when either of two values of a multiplication is used multiple times.
232/// Therefore, [`Residue32::pow`] will be faster than that using [Montgomery multiplication].
233///
234/// [Plantard multiplication]: https://thomas-plantard.github.io/pdf/Plantard21.pdf
235/// [Montgomery multiplication]: https://doi.org/10.1090/s0025-5718-1985-0777282-x
236///
237/// # Usage
238///
239/// ```
240/// use lib_modulo::Modulus32;
241///
242/// // set modulus
243/// let modulus = Modulus32::new(3);
244///
245/// // performs modular arithmetics
246/// let one = modulus.residue(1);
247/// let two = modulus.residue(2);
248/// let five = modulus.residue(5);
249/// assert_eq!(two * five, one)
250/// ```
251///
252/// Two residues with different modulus can interact, but the result will be meaningless.
253/// It is highly recommended to use a block to ensure that [`Modulus32`], therefore [`Residue32`]s, are dropped.
254#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
255pub struct Residue32<'a> {
256    // compare modulus first
257    modulus: &'a Modulus32,
258    x: u64,
259}
260
261impl<'a> Residue32<'a> {
262    /// Checks whether `self` is `0`.
263    ///
264    /// # Example
265    ///
266    /// ```
267    /// use lib_modulo::Modulus32;
268    ///
269    /// let modulus = Modulus32::new(5);
270    /// assert!(modulus.residue(10).is_zero())
271    /// ```
272    #[inline(always)]
273    pub const fn is_zero(self) -> bool {
274        self.x == 0
275    }
276
277    /// Returns the residue.
278    ///
279    /// # Example
280    ///
281    /// ```
282    /// use lib_modulo::Modulus32;
283    ///
284    /// let modulus = Modulus32::new(7);
285    /// assert_eq!(modulus.residue(10).get(), 3)
286    /// ```
287    #[inline(always)]
288    pub const fn get(self) -> u64 {
289        self.modulus.mul(self.x, 1)
290    }
291
292    /// Returns the modulus.
293    ///
294    /// # Example
295    ///
296    /// ```
297    /// use lib_modulo::Modulus32;
298    ///
299    /// let modulus = Modulus32::new(11);
300    /// assert_eq!(modulus.residue(2).modulus(), 11);
301    /// ```
302    #[inline(always)]
303    pub const fn modulus(&self) -> u64 {
304        self.modulus.n
305    }
306
307    /// Raises `self` to the power of `exp`, using exponentiation by squaring.
308    ///
309    /// # Time complexity
310    ///
311    /// *Θ*(log `exp`)
312    ///
313    /// # Example
314    ///
315    /// ```
316    /// use lib_modulo::Modulus32;
317    ///
318    /// let modulus = Modulus32::new(1001);
319    /// let residue = modulus.residue(2);
320    /// for exp in 0..64 {
321    ///     assert_eq!(residue.pow(exp).get(), (1 << exp) % 1001)
322    /// }
323    /// ```
324    #[inline(always)]
325    pub const fn pow(self, mut exp: u32) -> Self {
326        let Self { mut x, modulus } = self;
327        // If `n = 1`, then `init = 0`. Otherwise, `n > 1`.
328        let mut prod = modulus.residue(1).x;
329
330        while exp > 1 {
331            if exp & 1 == 1 {
332                // インライン展開されると,掛け算を1回節約できる。
333                prod = modulus.mul(prod, x)
334            }
335
336            exp >>= 1;
337            x = modulus.mul(x, x); // skip last useless one
338        }
339        if exp != 0 {
340            prod = modulus.mul(prod, x);
341        }
342
343        Self { x: prod, modulus }
344    }
345
346    /// Calculates the modular inverse of `self`, using extended binary GCD algorithm.
347    ///
348    /// Modular inverse can be defined if and only if `self` and the modulus is coprime.
349    ///
350    /// - `Ok(x)` : `x` is the modular inverse.
351    /// - `Err(x)`: `x` is the GCD of `self` and the `modulus`,
352    ///   where `gcd(0, a)` is defined to be `a`.
353    ///
354    /// # Time complexity
355    ///
356    /// *O*(log `self`)
357    ///
358    /// # Example
359    ///
360    /// ```
361    /// use lib_modulo::Modulus32;
362    ///
363    /// let modulus = Modulus32::new(3 * 5);
364    ///
365    /// let residue = modulus.residue(2);
366    /// assert!(residue.try_inv().is_ok_and(|inv| (inv * residue).get() == 1));
367    ///
368    /// let residue = modulus.residue(6);
369    /// assert!(residue.try_inv().is_err_and(|gcd| gcd == 3));
370    /// ```
371    pub const fn try_inv(self) -> Result<Self, u64> {
372        // invariant: [a] x = a, [a] y = b (mod n), where [a] is initial value.
373        let mut a = self.get();
374        let mut b = self.modulus();
375        let Self { modulus, .. } = self;
376        let mut x = modulus.residue(1).x;
377        let mut y = 0;
378        let frac_1_2 = modulus.residue((modulus.n as u32).div_ceil(2));
379
380        while a > 0 {
381            x = modulus.mul(x, frac_1_2.pow(a.trailing_zeros()).x);
382            a >>= a.trailing_zeros();
383
384            if a < b {
385                (a, b) = (b, a);
386                (x, y) = (y, x);
387            }
388            a -= b;
389            let (z, b) = x.overflowing_sub(y);
390            x = if b { z.wrapping_add(modulus.n) } else { z };
391        }
392
393        // b = gcd([a], n)
394        if b == 1 {
395            Ok(Self { x: y, modulus })
396        } else {
397            Err(b)
398        }
399    }
400}
401
402impl<'a> Add for Residue32<'a> {
403    type Output = Self;
404
405    fn add(mut self, rhs: Self) -> Self::Output {
406        let (x, b) = self.x.overflowing_add(rhs.x);
407        self.x = if b || x >= self.modulus() {
408            x.wrapping_sub(self.modulus())
409        } else {
410            x
411        };
412
413        self
414    }
415}
416
417impl<'a> AddAssign for Residue32<'a> {
418    fn add_assign(&mut self, rhs: Self) {
419        *self = *self + rhs
420    }
421}
422
423impl<'a> Sub for Residue32<'a> {
424    type Output = Self;
425
426    fn sub(mut self, rhs: Self) -> Self::Output {
427        let (x, b) = self.x.overflowing_sub(rhs.x);
428        self.x = if b { x.wrapping_add(self.modulus()) } else { x };
429
430        self
431    }
432}
433
434impl<'a> SubAssign for Residue32<'a> {
435    fn sub_assign(&mut self, rhs: Self) {
436        *self = *self - rhs
437    }
438}
439
440impl<'a> Mul for Residue32<'a> {
441    type Output = Self;
442
443    fn mul(mut self, rhs: Self) -> Self::Output {
444        self.x = self.modulus.mul(self.x, rhs.x);
445        self
446    }
447}
448
449impl<'a> MulAssign for Residue32<'a> {
450    fn mul_assign(&mut self, rhs: Self) {
451        *self = *self * rhs
452    }
453}
454
455impl<'a> Neg for Residue32<'a> {
456    type Output = Self;
457
458    fn neg(mut self) -> Self::Output {
459        self.x = if self.x == 0 {
460            0
461        } else {
462            self.modulus() - self.x
463        };
464
465        self
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use super::*;
472
473    use proptest::prelude::*;
474
475    proptest! {
476        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
477        #[test]
478        fn mul(n in (0..=Modulus32::MAX).prop_map(|n| n | 1), x: u32) {
479            let modulus = Modulus32::new(n);
480
481            let res = modulus.residue(x);
482            assert_eq!(res.get() as u32, x % n)
483        }
484    }
485
486    proptest! {
487        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
488        #[test]
489        fn pow(n in (0..=Modulus32::MAX as u64).prop_map(|n| n | 1), x in 0u64..1 << 32) {
490            let modulus = Modulus32::new(n as u32);
491
492            let res = modulus.residue(x as u32);
493            let mut naive = 1;
494            for i in 0..100 {
495                assert_eq!(res.pow(i).get(), naive, "exp = {i}");
496                naive = naive * x % n
497            }
498        }
499    }
500
501    proptest! {
502        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
503        #[test]
504        fn divisible(n in (0..=Modulus32::MAX).prop_map(|n| n | 1), x: u32) {
505            let modulus = Modulus32::new(n);
506
507            assert_eq!(modulus.can_divide(x), x % n == 0);
508            for m in std::iter::successors(Some(n), |m| m.checked_add(n)).take(100) {
509                assert!(modulus.can_divide(m))
510            }
511        }
512    }
513
514    fn binary_gcd(mut a: u64, mut b: u64) -> u64 {
515        if b == 0 {
516            return a;
517        }
518
519        let shift = (a | b).trailing_zeros();
520        b >>= b.trailing_zeros();
521
522        while a != 0 {
523            a >>= a.trailing_zeros();
524
525            if a < b {
526                (a, b) = (b, a)
527            }
528            a -= b
529        }
530
531        b << shift
532    }
533
534    proptest! {
535        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
536        #[test]
537        fn try_inv(n in (0..=Modulus32::MAX).prop_map(|n| n | 1), x: u32) {
538            let modulus = Modulus32::new(n);
539            let res = modulus.residue(x);
540
541            match res.try_inv() {
542                Ok(inv) => assert_eq!((inv * res).get(), 1),
543                Err(gcd) => {
544                    assert!(res.get() % gcd == 0);
545                    assert!(res.modulus() % gcd == 0);
546                    assert_eq!(binary_gcd(res.get() / gcd, res.modulus() / gcd), 1);
547                }
548            }
549        }
550    }
551}