Skip to main content

lib_modulo/
residue32.rs

1use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2use std::{collections::HashMap, hash::BuildHasher};
3
4/// Factory of [`Residue32`].
5///
6/// See documentation of [`Residue32`] for details.
7#[allow(clippy::derived_hash_with_manual_eq)]
8#[derive(Debug, Clone, Hash, Eq)]
9pub struct Modulus32 {
10    // n inv_n = 1 (mod 2^64)
11    n: u64,
12    inv_n: u64,
13    // 2^128 (mod n) * inv_n
14    init: u64,
15    // ceil(2^64 / n)
16    recip: u64,
17}
18
19impl Modulus32 {
20    /// Maximum available modulus.
21    pub const MAX: u32 = 2_654_435_769;
22
23    /// Creates new context for modular arithmetics.
24    ///
25    /// # Panics
26    ///
27    /// - modulus `n` should be an odd integer.
28    /// - modulus `n` should be no more than `2_654_435_769`,
29    ///   which is the floor of `2^32 / GOLDEN_RATIO`.
30    ///
31    /// # Example
32    ///
33    /// ```
34    /// use lib_modulo::Modulus32;
35    ///
36    /// // odd integer less than or equal to 2_654_435_769 is allowed.
37    /// let modulus = Modulus32::new(Modulus32::MAX);
38    /// let modulus = Modulus32::new(3);
39    ///
40    /// // modulus should be an odd integer!
41    /// assert!(std::panic::catch_unwind(|| { Modulus32::new(2); }).is_err())
42    /// ```
43    #[inline]
44    pub const fn new(n: u32) -> Self {
45        assert!(
46            n & 1 == 1,
47            "invalid modulus: modulus should be an odd integer."
48        );
49        assert!(
50            n <= Self::MAX,
51            "invalid modulus: modulus should be no more than 2_654_435_769."
52        );
53
54        let n = n as u64;
55
56        let inv_n = {
57            // 1 * 1 = 3 * 3 = 1 (mod 4)
58            let mut inv_n = n & 3;
59            // n inv_n = 1 (mod 2^k) => (n inv_n - 1)^2 = 0 (mod 2^{2k})
60            // => n inv_n (2 - n inv_n) = 1 (mod 2^{2k})
61            let mut i = u64::BITS.ilog2() - 1;
62            while i > 0 {
63                i -= 1;
64                inv_n = inv_n.wrapping_mul(2_u64.wrapping_sub(n.wrapping_mul(inv_n)));
65            }
66            debug_assert!(n.wrapping_mul(inv_n) == 1);
67
68            inv_n
69        };
70
71        let (div, rem) = {
72            let denom = n.wrapping_neg();
73            (denom / n, denom % n)
74        };
75        // 2^128 (mod n): magic number for converting integer to Plantard representation.
76        let init = rem * rem % n;
77        // ceil(2^64 / n): magic number for fast remainder algorithm
78        let recip = div.wrapping_add(if rem > 0 { 2 } else { 1 });
79
80        Self {
81            n,
82            inv_n,
83            init: init.wrapping_mul(inv_n),
84            recip,
85        }
86    }
87
88    /// Performs Plantard multiplication, i.e. `x, y -> x y / -2^64 (mod n)`.
89    ///
90    /// If `x y < self.n`, then returned value is less than `self.n`.
91    #[inline(always)]
92    const fn mul(&self, x: u64, y: u64) -> u64 {
93        // Plantard reduction: <https://thomas-plantard.github.io/pdf/Plantard21.pdf>
94        let z = self.inv_n.wrapping_mul(x).wrapping_mul(y) >> 32;
95        let z = ((z as u32).wrapping_add(1) as u64 * self.n) >> 32;
96        debug_assert!(z < self.n, "this is a bug in lib-modulo");
97        z
98    }
99
100    /// Calculates the residue of `x` modulo `self`.
101    ///
102    /// # Example
103    ///
104    /// ```
105    /// use lib_modulo::Modulus32;
106    ///
107    /// let modulus = Modulus32::new(5);
108    /// assert_eq!(modulus.residue(8).get(), 3)
109    /// ```
110    #[inline(always)]
111    pub const fn residue(&self, x: u32) -> Residue32<'_> {
112        // fast remainder algorithm
113        // See <https://onlinelibrary.wiley.com/doi/10.1002/spe.2689> for details
114        let x = {
115            let lo = self.recip.wrapping_mul(x as u64);
116            ((lo as u128 * self.n as u128) >> 64) as u64
117        };
118
119        let x = {
120            // multiplication by a constant
121            let x = self.init.wrapping_mul(x) >> 32;
122            ((x as u32).wrapping_add(1) as u64 * self.n) >> 32
123        };
124
125        Residue32 { x, modulus: self }
126    }
127
128    /// Checks whether `x` is divisible by `self`.
129    ///
130    /// # Example
131    ///
132    /// ```
133    /// use lib_modulo::Modulus32;
134    ///
135    /// let modulus = Modulus32::new(9);
136    /// assert!(modulus.can_divide(18));
137    /// assert!(!modulus.can_divide(19));
138    /// ```
139    #[inline(always)]
140    pub const fn can_divide(&self, x: u32) -> bool {
141        self.recip.wrapping_mul(x as u64) <= self.recip.wrapping_sub(1)
142    }
143
144    /// Checks whether `x` is a prime number.
145    ///
146    /// This may fail if `x` is larger than `2_654_435_769`.
147    /// Use 64-bit version.
148    ///
149    /// # Time complexity
150    ///
151    /// *O*(log `self`)
152    ///
153    /// # Example
154    ///
155    /// ```
156    /// use lib_modulo::Modulus32;
157    ///
158    /// for p in [2, 3, 5, 7, 11, 998_244_353, 1_000_000_007] {
159    ///     assert!(Modulus32::primality_test(p).unwrap())
160    /// }
161    /// // Mersenne numbers (prime)
162    /// for d in [5, 7, 13, 17, 19, 31] {
163    ///     assert!(Modulus32::primality_test((1 << d) - 1).unwrap())
164    /// }
165    ///
166    /// // composite numbers
167    /// for i in (2..).take(1 << 10) {
168    ///     assert!(!Modulus32::primality_test(i * (i + 1)).unwrap())
169    /// }
170    /// ```
171    #[allow(clippy::result_unit_err)]
172    #[inline(always)]
173    pub const fn primality_test(x: u32) -> Result<bool, ()> {
174        /// (SELF >> p) & 1 == 1 iff p is prime
175        const TEST_LT_64: u64 = 2891462833508853932;
176        /// (SELF >> n % 30) & 1 == 1 iff n is coprime to 2, 3, and 5
177        const TEST_2_3_5: u32 = 545925250;
178
179        if x < 64 {
180            return Ok((TEST_LT_64 >> x) & 1 == 1);
181        } else if (TEST_2_3_5 >> (x % 30)) & 1 == 0 || x % 7 == 0 {
182            return Ok(false);
183        } else if x > Self::MAX {
184            return Err(());
185        }
186
187        let modulus = Self::new(x);
188        let one = modulus.residue(1).x;
189        let minus_one = modulus.n - one;
190        debug_assert!(one != 0 && minus_one != 0, "since x > 1");
191
192        let (d, s) = {
193            let n = modulus.n - 1;
194            ((n >> n.trailing_zeros()) as u32, n.trailing_zeros() - 1)
195        };
196        let mut i = 0;
197        'test: while i < 3 {
198            let witness = [2, 7, 61][i];
199            i += 1;
200
201            let w = modulus.residue(witness);
202            if w.is_zero() {
203                continue;
204            }
205
206            let mut w = w.pow(d).x;
207            if w == minus_one || w == one {
208                continue;
209            }
210
211            let mut s = s;
212            while s > 0 {
213                s -= 1;
214                w = modulus.mul(w, w);
215                if w == minus_one {
216                    continue 'test;
217                }
218            }
219
220            return Ok(false);
221        }
222
223        Ok(true)
224    }
225}
226
227impl PartialEq for Modulus32 {
228    fn eq(&self, other: &Self) -> bool {
229        // other fields depend on `n`
230        self.n == other.n
231    }
232}
233
234/// A residue with an odd modulus not exceeding `2_654_435_769`.
235///
236/// # Fast modular multiplication
237///
238/// [`Residue32`] provides fast modular multiplication using [Plantard multiplication].
239/// This method eliminates one multiplication when one of the operands is reused multiple times.
240/// As a result, [`Residue32::pow`] and other operations are typically
241/// faster than implementations based on [Montgomery multiplication].
242///
243/// [Plantard multiplication]: https://thomas-plantard.github.io/pdf/Plantard21.pdf
244/// [Montgomery multiplication]: https://doi.org/10.1090/s0025-5718-1985-0777282-x
245///
246/// # Usage
247///
248/// ```
249/// use lib_modulo::Modulus32;
250///
251/// // set modulus
252/// let modulus = Modulus32::new(3);
253///
254/// // performs modular arithmetics
255/// let one = modulus.residue(1);
256/// let two = modulus.residue(2);
257/// let five = modulus.residue(5);
258/// assert_eq!(two * five, one)
259/// ```
260///
261/// Two residues with different modulus can interact, but the result will be meaningless.
262/// It is highly recommended to use a block to ensure that [`Modulus32`], therefore [`Residue32`]s, are dropped.
263#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
264pub struct Residue32<'a> {
265    // compare modulus first
266    modulus: &'a Modulus32,
267    x: u64,
268}
269
270impl<'a> Residue32<'a> {
271    /// Checks whether `self` is `0`.
272    ///
273    /// # Example
274    ///
275    /// ```
276    /// use lib_modulo::Modulus32;
277    ///
278    /// let modulus = Modulus32::new(5);
279    /// assert!(modulus.residue(10).is_zero())
280    /// ```
281    #[inline(always)]
282    pub const fn is_zero(self) -> bool {
283        self.x == 0
284    }
285
286    /// Returns the residue.
287    ///
288    /// # Example
289    ///
290    /// ```
291    /// use lib_modulo::Modulus32;
292    ///
293    /// let modulus = Modulus32::new(7);
294    /// assert_eq!(modulus.residue(10).get(), 3)
295    /// ```
296    #[inline(always)]
297    pub const fn get(self) -> u64 {
298        self.modulus.mul(self.x, 1)
299    }
300
301    /// Returns the modulus.
302    ///
303    /// # Example
304    ///
305    /// ```
306    /// use lib_modulo::Modulus32;
307    ///
308    /// let modulus = Modulus32::new(11);
309    /// assert_eq!(modulus.residue(2).modulus(), 11);
310    /// ```
311    #[inline(always)]
312    pub const fn modulus(&self) -> u64 {
313        self.modulus.n
314    }
315
316    /// Raises `self` to the power of `exp`, using exponentiation by squaring.
317    ///
318    /// # Time complexity
319    ///
320    /// *Θ*(log `exp`)
321    ///
322    /// # Example
323    ///
324    /// ```
325    /// use lib_modulo::Modulus32;
326    ///
327    /// let modulus = Modulus32::new(1001);
328    /// let residue = modulus.residue(2);
329    /// for exp in 0..64 {
330    ///     assert_eq!(residue.pow(exp).get(), (1 << exp) % 1001)
331    /// }
332    /// ```
333    #[inline(always)]
334    pub const fn pow(self, mut exp: u32) -> Self {
335        let Self { mut x, modulus } = self;
336        // If `n = 1`, then `init = 0`. Otherwise, `n > 1`.
337        let mut prod = modulus.residue(1).x;
338
339        while exp > 1 {
340            if exp & 1 == 1 {
341                // インライン展開されると,掛け算を1回節約できる。
342                prod = modulus.mul(prod, x)
343            }
344
345            exp >>= 1;
346            x = modulus.mul(x, x); // skip last useless one
347        }
348        if exp != 0 {
349            prod = modulus.mul(prod, x);
350        }
351
352        Self { x: prod, modulus }
353    }
354
355    /// Calculates the modular inverse of `self`, using extended binary GCD algorithm.
356    ///
357    /// Modular inverse can be defined if and only if `self` and the modulus is coprime.
358    ///
359    /// - `Ok(x)` : `x` is the modular inverse.
360    /// - `Err(x)`: `x` is the GCD of `self` and the `modulus`,
361    ///   where `gcd(0, a)` is defined to be `a`.
362    ///
363    /// # Time complexity
364    ///
365    /// *O*(log `self`)
366    ///
367    /// # Example
368    ///
369    /// ```
370    /// use lib_modulo::Modulus32;
371    ///
372    /// let modulus = Modulus32::new(3 * 5);
373    ///
374    /// let residue = modulus.residue(2);
375    /// assert!(residue.try_inv().is_ok_and(|inv| (inv * residue).get() == 1));
376    ///
377    /// let residue = modulus.residue(6);
378    /// assert!(residue.try_inv().is_err_and(|gcd| gcd == 3));
379    /// ```
380    pub const fn try_inv(self) -> Result<Self, u64> {
381        // invariant: [a] x = a, [a] y = b (mod n), where [a] is initial value.
382        let mut a = self.get();
383        let mut b = self.modulus();
384        let Self { modulus, .. } = self;
385        let mut x = modulus.residue(1).x;
386        let mut y = 0;
387        let frac_1_2 = modulus.residue((modulus.n as u32).div_ceil(2));
388
389        while a > 0 {
390            x = modulus.mul(x, frac_1_2.pow(a.trailing_zeros()).x);
391            a >>= a.trailing_zeros();
392
393            if a < b {
394                (a, b) = (b, a);
395                (x, y) = (y, x);
396            }
397            a -= b;
398            let (z, b) = x.overflowing_sub(y);
399            x = if b { z.wrapping_add(modulus.n) } else { z };
400        }
401
402        // b = gcd([a], n)
403        if b == 1 {
404            Ok(Self { x: y, modulus })
405        } else {
406            Err(b)
407        }
408    }
409
410    /// Solves discrete logarithm problem and returns the *smallest* solution.
411    ///
412    /// Consider using [`FxHashMap`] or other fast hash maps.
413    ///
414    /// [`FxHashMap`]: https://docs.rs/rustc-hash/latest/rustc_hash/type.FxHashMap.html
415    ///
416    /// # Time complexity
417    ///
418    /// *O*(√`modulus`)
419    ///
420    /// # Example
421    ///
422    /// ```
423    /// use lib_modulo::Modulus32;
424    /// use std::collections::HashMap;
425    ///
426    /// let modulus = Modulus32::new(2025);
427    /// let mut map = HashMap::new();
428    /// let mut offset = 0;
429    /// for d in 0..5000 {
430    ///     let pow2 = modulus.residue(2).pow(d).get() as u32;
431    ///     if pow2 == 1 {
432    ///         offset = d;
433    ///     }
434    ///     assert_eq!(modulus.residue(2).log(pow2, &mut map), Some(d - offset));
435    /// }
436    /// // Since `5 + 2025 i` is multiple of 5, it is not power of 2, 3, or 7
437    /// assert!(modulus.residue(2).log(5, &mut map).is_none());
438    /// assert!(modulus.residue(3).log(5, &mut map).is_none());
439    /// assert!(modulus.residue(7).log(5, &mut map).is_none());
440    /// ```
441    pub fn log<S>(self, rhs: u32, map: &mut HashMap<u64, u32, S>) -> Option<u32>
442    where
443        S: BuildHasher,
444    {
445        if rhs == 1 {
446            return Some(0);
447        } else if self.is_zero() {
448            return None;
449        }
450
451        let mut offset = 1;
452        let mut gcd = 1;
453        let mut factor = self;
454        // O(log n)
455        while let Err(g) = factor.try_inv().map_err(|g| g as u32) {
456            if g == gcd {
457                break;
458            }
459
460            offset += 1;
461            gcd = g;
462            factor *= self;
463        }
464
465        if rhs % gcd != 0 {
466            return None;
467        }
468
469        // solve `x^k = y (mod modulus)` by baby-step giant-step algorithm
470        let modulus = Modulus32::new(self.modulus() as u32 / gcd);
471        let x = modulus.residue(self.get() as u32);
472        let y = modulus.residue(rhs) * modulus.residue(factor.get() as u32).try_inv().unwrap();
473
474        let sqrt = (modulus.n as u32).isqrt() + 1;
475        map.clear();
476        map.reserve(sqrt as usize);
477
478        {
479            let mut lhs = modulus.residue(1);
480            map.insert(lhs.x, offset);
481            for i in offset + 1..offset + sqrt {
482                lhs *= x;
483                // choose smaller
484                map.entry(lhs.x).or_insert(i);
485            }
486        }
487        {
488            if let Some(i) = map.get(&y.x) {
489                return Some(*i);
490            }
491
492            let mut rhs = y;
493            let inv = x.try_inv().unwrap().pow(sqrt);
494            for j in 1..sqrt {
495                rhs *= inv;
496                if let Some(i) = map.get(&rhs.x) {
497                    return Some(j * sqrt + i);
498                }
499            }
500        }
501
502        None
503    }
504}
505
506impl<'a> Add for Residue32<'a> {
507    type Output = Self;
508
509    fn add(mut self, rhs: Self) -> Self::Output {
510        let (x, b) = self.x.overflowing_add(rhs.x);
511        self.x = if b || x >= self.modulus() {
512            x.wrapping_sub(self.modulus())
513        } else {
514            x
515        };
516
517        self
518    }
519}
520
521impl<'a> AddAssign for Residue32<'a> {
522    fn add_assign(&mut self, rhs: Self) {
523        *self = *self + rhs
524    }
525}
526
527impl<'a> Sub for Residue32<'a> {
528    type Output = Self;
529
530    fn sub(mut self, rhs: Self) -> Self::Output {
531        let (x, b) = self.x.overflowing_sub(rhs.x);
532        self.x = if b { x.wrapping_add(self.modulus()) } else { x };
533
534        self
535    }
536}
537
538impl<'a> SubAssign for Residue32<'a> {
539    fn sub_assign(&mut self, rhs: Self) {
540        *self = *self - rhs
541    }
542}
543
544impl<'a> Mul for Residue32<'a> {
545    type Output = Self;
546
547    fn mul(mut self, rhs: Self) -> Self::Output {
548        self.x = self.modulus.mul(self.x, rhs.x);
549        self
550    }
551}
552
553impl<'a> MulAssign for Residue32<'a> {
554    fn mul_assign(&mut self, rhs: Self) {
555        *self = *self * rhs
556    }
557}
558
559impl<'a> Neg for Residue32<'a> {
560    type Output = Self;
561
562    fn neg(mut self) -> Self::Output {
563        self.x = if self.x == 0 {
564            0
565        } else {
566            self.modulus() - self.x
567        };
568
569        self
570    }
571}
572
573#[cfg(test)]
574mod tests {
575    use super::*;
576
577    use proptest::prelude::*;
578
579    proptest! {
580        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
581        #[test]
582        fn mul(n in (0..=Modulus32::MAX).prop_map(|n| n | 1), x: u32) {
583            let modulus = Modulus32::new(n);
584
585            let res = modulus.residue(x);
586            assert_eq!(res.get() as u32, x % n)
587        }
588    }
589
590    proptest! {
591        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
592        #[test]
593        fn pow(n in (0..=Modulus32::MAX as u64).prop_map(|n| n | 1), x in 0u64..1 << 32) {
594            let modulus = Modulus32::new(n as u32);
595
596            let res = modulus.residue(x as u32);
597            let mut naive = 1;
598            for i in 0..100 {
599                assert_eq!(res.pow(i).get(), naive, "exp = {i}");
600                naive = naive * x % n
601            }
602        }
603    }
604
605    proptest! {
606        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
607        #[test]
608        fn divisible(n in (0..=Modulus32::MAX).prop_map(|n| n | 1), x: u32) {
609            let modulus = Modulus32::new(n);
610
611            assert_eq!(modulus.can_divide(x), x % n == 0);
612            for m in std::iter::successors(Some(n), |m| m.checked_add(n)).take(100) {
613                assert!(modulus.can_divide(m))
614            }
615        }
616    }
617
618    proptest! {
619        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
620        #[test]
621        fn divisible_by_1(x: u32) {
622            assert!(Modulus32::new(1).can_divide(x))
623        }
624    }
625
626    fn binary_gcd(mut a: u64, mut b: u64) -> u64 {
627        if b == 0 {
628            return a;
629        }
630
631        let shift = (a | b).trailing_zeros();
632        b >>= b.trailing_zeros();
633
634        while a != 0 {
635            a >>= a.trailing_zeros();
636
637            if a < b {
638                (a, b) = (b, a)
639            }
640            a -= b
641        }
642
643        b << shift
644    }
645
646    proptest! {
647        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
648        #[test]
649        fn try_inv(n in (0..=Modulus32::MAX).prop_map(|n| n | 1), x: u32) {
650            let modulus = Modulus32::new(n);
651            let res = modulus.residue(x);
652
653            match res.try_inv() {
654                Ok(inv) => assert_eq!((inv * res).get(), 1),
655                Err(gcd) => {
656                    assert!(res.get() % gcd == 0);
657                    assert!(res.modulus() % gcd == 0);
658                    assert_eq!(binary_gcd(res.get() / gcd, res.modulus() / gcd), 1);
659                }
660            }
661        }
662    }
663}