Skip to main content

lib_modulo/
residue32.rs

1use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2use std::{collections::HashMap, hash::BuildHasher};
3
4/// Factory of [`Residue32`].
5///
6/// See documentation of [`Residue32`] for details.
7#[allow(clippy::derived_hash_with_manual_eq)]
8#[derive(Debug, Clone, Hash, Eq)]
9pub struct Modulus32 {
10    // n inv_n = 1 (mod 2^64)
11    n: u64,
12    inv_n: u64,
13    // 2^128 (mod n) * inv_n
14    init: u64,
15    // ceil(2^64 / n)
16    recip: u64,
17}
18
19impl Modulus32 {
20    /// Maximum available modulus.
21    pub const MAX: u32 = 2_654_435_769;
22
23    /// Creates new context for modular arithmetics.
24    ///
25    /// # Panics
26    ///
27    /// - modulus `n` should be an odd integer.
28    /// - modulus `n` should be no more than `2_654_435_769`,
29    ///   which is the floor of `2^32 / GOLDEN_RATIO`.
30    ///
31    /// # Example
32    ///
33    /// ```
34    /// use lib_modulo::Modulus32;
35    ///
36    /// // odd integer less than or equal to 2_654_435_769 is allowed.
37    /// let modulus = Modulus32::new(Modulus32::MAX);
38    /// let modulus = Modulus32::new(3);
39    ///
40    /// // modulus should be an odd integer!
41    /// assert!(std::panic::catch_unwind(|| { Modulus32::new(2); }).is_err())
42    /// ```
43    #[inline]
44    pub const fn new(n: u32) -> Self {
45        assert!(
46            n & 1 == 1,
47            "invalid modulus: modulus should be an odd integer."
48        );
49        assert!(
50            n <= Self::MAX,
51            "invalid modulus: modulus should be no more than 2_654_435_769."
52        );
53
54        let n = n as u64;
55
56        let inv_n = {
57            // 1 * 1 = 3 * 3 = 1 (mod 4)
58            let mut inv_n = n & 3;
59            // n inv_n = 1 (mod 2^k) => (n inv_n - 1)^2 = 0 (mod 2^{2k})
60            // => n inv_n (2 - n inv_n) = 1 (mod 2^{2k})
61            let mut i = u64::BITS.ilog2() - 1;
62            while i > 0 {
63                i -= 1;
64                inv_n = inv_n.wrapping_mul(2_u64.wrapping_sub(n.wrapping_mul(inv_n)));
65            }
66            debug_assert!(n.wrapping_mul(inv_n) == 1);
67
68            inv_n
69        };
70
71        let (div, rem) = {
72            let denom = n.wrapping_neg();
73            (denom / n, denom % n)
74        };
75        // 2^128 (mod n): magic number for converting integer to Plantard representation.
76        let init = rem * rem % n;
77        // ceil(2^64 / n): magic number for fast remainder algorithm
78        let recip = div.wrapping_add(if rem > 0 { 2 } else { 1 });
79
80        Self {
81            n,
82            inv_n,
83            init: init.wrapping_mul(inv_n),
84            recip,
85        }
86    }
87
88    /// Performs Plantard multiplication, i.e. `x, y -> x y / -2^64 (mod n)`.
89    ///
90    /// If `x y < self.n`, then returned value is less than `self.n`.
91    #[inline(always)]
92    const fn mul(&self, x: u64, y: u64) -> u64 {
93        // Plantard reduction: <https://thomas-plantard.github.io/pdf/Plantard21.pdf>
94        let z = self.inv_n.wrapping_mul(x).wrapping_mul(y) >> 32;
95        let z = ((z as u32).wrapping_add(1) as u64 * self.n) >> 32;
96        debug_assert!(z < self.n, "this is a bug in lib-modulo");
97        z
98    }
99
100    /// Calculates the residue of `x` modulo `self`.
101    ///
102    /// # Example
103    ///
104    /// ```
105    /// use lib_modulo::Modulus32;
106    ///
107    /// let modulus = Modulus32::new(5);
108    /// assert_eq!(modulus.residue(8).get(), 3)
109    /// ```
110    #[inline(always)]
111    pub const fn residue(&self, x: u32) -> Residue32<'_> {
112        // fast remainder algorithm
113        // See <https://onlinelibrary.wiley.com/doi/10.1002/spe.2689> for details
114        let x = {
115            let lo = self.recip.wrapping_mul(x as u64);
116            ((lo as u128 * self.n as u128) >> 64) as u64
117        };
118
119        let x = {
120            // multiplication by a constant
121            let x = self.init.wrapping_mul(x) >> 32;
122            ((x as u32).wrapping_add(1) as u64 * self.n) >> 32
123        };
124
125        Residue32 { x, modulus: self }
126    }
127
128    /// Checks whether `x` is divisible by `self`.
129    ///
130    /// # Example
131    ///
132    /// ```
133    /// use lib_modulo::Modulus32;
134    ///
135    /// let modulus = Modulus32::new(9);
136    /// assert!(modulus.can_divide(18));
137    /// assert!(!modulus.can_divide(19));
138    /// ```
139    #[inline(always)]
140    pub const fn can_divide(&self, x: u32) -> bool {
141        self.recip.wrapping_mul(x as u64) <= self.recip.wrapping_sub(1)
142    }
143}
144
145impl PartialEq for Modulus32 {
146    fn eq(&self, other: &Self) -> bool {
147        // other fields depend on `n`
148        self.n == other.n
149    }
150}
151
152/// A residue with an odd modulus not exceeding `2_654_435_769`.
153///
154/// # Fast modular multiplication
155///
156/// [`Residue32`] provides fast modular multiplication using [Plantard multiplication].
157/// This method eliminates one multiplication when one of the operands is reused multiple times.
158/// As a result, [`Residue32::pow`] and other operations are typically
159/// faster than implementations based on [Montgomery multiplication].
160///
161/// [Plantard multiplication]: https://thomas-plantard.github.io/pdf/Plantard21.pdf
162/// [Montgomery multiplication]: https://doi.org/10.1090/s0025-5718-1985-0777282-x
163///
164/// # Usage
165///
166/// ```
167/// use lib_modulo::Modulus32;
168///
169/// // set modulus
170/// let modulus = Modulus32::new(3);
171///
172/// // performs modular arithmetics
173/// let one = modulus.residue(1);
174/// let two = modulus.residue(2);
175/// let five = modulus.residue(5);
176/// assert_eq!(two * five, one)
177/// ```
178///
179/// Two residues with different modulus can interact, but the result will be meaningless.
180/// It is highly recommended to use a block to ensure that [`Modulus32`], therefore [`Residue32`]s, are dropped.
181#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
182pub struct Residue32<'a> {
183    // compare modulus first
184    modulus: &'a Modulus32,
185    x: u64,
186}
187
188impl<'a> Residue32<'a> {
189    /// Extract the internal representation of `self`.
190    ///
191    /// ```
192    /// use lib_modulo::{Modulus32, Raw32};
193    ///
194    /// let modulus = Modulus32::new(1001);
195    /// // save memory
196    /// let residues: Vec<Raw32> = (1..=1000).map(|x| modulus.residue(x).into_raw()).collect();
197    /// ```
198    #[inline(always)]
199    pub const fn into_raw(self) -> Raw32 {
200        Raw32 { x: self.x }
201    }
202
203    /// Checks whether `self` is `0`.
204    ///
205    /// # Example
206    ///
207    /// ```
208    /// use lib_modulo::Modulus32;
209    ///
210    /// let modulus = Modulus32::new(5);
211    /// assert!(modulus.residue(10).is_zero())
212    /// ```
213    #[inline(always)]
214    pub const fn is_zero(self) -> bool {
215        self.x == 0
216    }
217
218    /// Returns the residue.
219    ///
220    /// # Example
221    ///
222    /// ```
223    /// use lib_modulo::Modulus32;
224    ///
225    /// let modulus = Modulus32::new(7);
226    /// assert_eq!(modulus.residue(10).get(), 3)
227    /// ```
228    #[inline(always)]
229    pub const fn get(self) -> u64 {
230        self.modulus.mul(self.x, 1)
231    }
232
233    /// Returns the modulus.
234    ///
235    /// # Example
236    ///
237    /// ```
238    /// use lib_modulo::Modulus32;
239    ///
240    /// let modulus = Modulus32::new(11);
241    /// assert_eq!(modulus.residue(2).modulus(), 11);
242    /// ```
243    #[inline(always)]
244    pub const fn modulus(&self) -> u64 {
245        self.modulus.n
246    }
247
248    /// Raises `self` to the power of `exp`, using exponentiation by squaring.
249    ///
250    /// # Time complexity
251    ///
252    /// *Θ*(log `exp`)
253    ///
254    /// # Example
255    ///
256    /// ```
257    /// use lib_modulo::Modulus32;
258    ///
259    /// let modulus = Modulus32::new(1001);
260    /// let residue = modulus.residue(2);
261    /// for exp in 0..64 {
262    ///     assert_eq!(residue.pow(exp).get(), (1 << exp) % 1001)
263    /// }
264    /// ```
265    #[inline(always)]
266    pub const fn pow(self, mut exp: u32) -> Self {
267        let Self { mut x, modulus } = self;
268        // If `n = 1`, then `init = 0`. Otherwise, `n > 1`.
269        let mut prod = modulus.residue(1).x;
270
271        while exp > 1 {
272            if exp & 1 == 1 {
273                // インライン展開されると,掛け算を1回節約できる。
274                prod = modulus.mul(prod, x)
275            }
276
277            exp >>= 1;
278            x = modulus.mul(x, x); // skip last useless one
279        }
280        if exp != 0 {
281            prod = modulus.mul(prod, x);
282        }
283
284        Self { x: prod, modulus }
285    }
286
287    /// Calculates the modular inverse of `self`, using extended binary GCD algorithm.
288    ///
289    /// Modular inverse can be defined if and only if `self` and the modulus is coprime.
290    ///
291    /// - `Ok(x)` : `x` is the modular inverse.
292    /// - `Err(x)`: `x` is the GCD of `self` and the `modulus`,
293    ///   where `gcd(0, a)` is defined to be `a`.
294    ///
295    /// # Time complexity
296    ///
297    /// *O*(log `self`)
298    ///
299    /// # Example
300    ///
301    /// ```
302    /// use lib_modulo::Modulus32;
303    ///
304    /// let modulus = Modulus32::new(3 * 5);
305    ///
306    /// let residue = modulus.residue(2);
307    /// assert!(residue.try_inv().is_ok_and(|inv| (inv * residue).get() == 1));
308    ///
309    /// let residue = modulus.residue(6);
310    /// assert!(residue.try_inv().is_err_and(|gcd| gcd == 3));
311    /// ```
312    pub const fn try_inv(self) -> Result<Self, u64> {
313        // invariant: [a] x = a, [a] y = b (mod n), where [a] is initial value.
314        let mut a = self.get();
315        let mut b = self.modulus();
316        let Self { modulus, .. } = self;
317        let mut x = modulus.residue(1).x;
318        let mut y = 0;
319        let frac_1_2 = modulus.residue((modulus.n as u32).div_ceil(2));
320
321        while a > 0 {
322            x = modulus.mul(x, frac_1_2.pow(a.trailing_zeros()).x);
323            a >>= a.trailing_zeros();
324
325            if a < b {
326                (a, b) = (b, a);
327                (x, y) = (y, x);
328            }
329            a -= b;
330            let (z, b) = x.overflowing_sub(y);
331            x = if b { z.wrapping_add(modulus.n) } else { z };
332        }
333
334        // b = gcd([a], n)
335        if b == 1 {
336            Ok(Self { x: y, modulus })
337        } else {
338            Err(b)
339        }
340    }
341
342    /// Solves discrete logarithm problem and returns the *smallest* solution.
343    ///
344    /// Consider using [`FxHashMap`] or other fast hash maps.
345    ///
346    /// [`FxHashMap`]: https://docs.rs/rustc-hash/latest/rustc_hash/type.FxHashMap.html
347    ///
348    /// # Time complexity
349    ///
350    /// *O*(√`modulus`)
351    ///
352    /// # Example
353    ///
354    /// ```
355    /// use lib_modulo::Modulus32;
356    /// use std::collections::HashMap;
357    ///
358    /// let modulus = Modulus32::new(2025);
359    /// let mut map = HashMap::new();
360    /// let mut offset = 0;
361    /// for d in 0..5000 {
362    ///     let pow2 = modulus.residue(2).pow(d).get() as u32;
363    ///     if pow2 == 1 {
364    ///         offset = d;
365    ///     }
366    ///     assert_eq!(modulus.residue(2).log(pow2, &mut map), Some(d - offset));
367    /// }
368    /// // Since `5 + 2025 i` is multiple of 5, it is not power of 2, 3, or 7
369    /// assert!(modulus.residue(2).log(5, &mut map).is_none());
370    /// assert!(modulus.residue(3).log(5, &mut map).is_none());
371    /// assert!(modulus.residue(7).log(5, &mut map).is_none());
372    /// ```
373    pub fn log<S>(self, rhs: u32, map: &mut HashMap<Raw32, u32, S>) -> Option<u32>
374    where
375        S: BuildHasher,
376    {
377        if rhs == 1 {
378            return Some(0);
379        } else if self.is_zero() {
380            return None;
381        }
382
383        let mut offset = 1;
384        let mut gcd = 1;
385        let mut factor = self;
386        // O(log n)
387        while let Err(g) = factor.try_inv().map_err(|g| g as u32) {
388            if g == gcd {
389                break;
390            }
391
392            offset += 1;
393            gcd = g;
394            factor *= self;
395        }
396
397        if rhs % gcd != 0 {
398            return None;
399        }
400
401        // solve `x^k = y (mod modulus)` by baby-step giant-step algorithm
402        let modulus = Modulus32::new(self.modulus() as u32 / gcd);
403        let x = modulus.residue(self.get() as u32);
404        let y = modulus.residue(rhs) * modulus.residue(factor.get() as u32).try_inv().unwrap();
405
406        let sqrt = (modulus.n as u32).isqrt() + 1;
407        map.clear();
408        map.reserve(sqrt as usize);
409
410        {
411            let mut lhs = modulus.residue(1);
412            map.insert(lhs.into(), offset);
413            for i in offset + 1..offset + sqrt {
414                lhs *= x;
415                // choose smaller
416                map.entry(lhs.into()).or_insert(i);
417            }
418        }
419        {
420            if let Some(i) = map.get(&y.into()) {
421                return Some(*i);
422            }
423
424            let mut rhs = y;
425            let inv = x.try_inv().unwrap().pow(sqrt);
426            for j in 1..sqrt {
427                rhs *= inv;
428                if let Some(i) = map.get(&rhs.into()) {
429                    return Some(j * sqrt + i);
430                }
431            }
432        }
433
434        None
435    }
436}
437
438impl<'a> Add for Residue32<'a> {
439    type Output = Self;
440
441    fn add(mut self, rhs: Self) -> Self::Output {
442        let (x, b) = self.x.overflowing_add(rhs.x);
443        self.x = if b || x >= self.modulus() {
444            x.wrapping_sub(self.modulus())
445        } else {
446            x
447        };
448
449        self
450    }
451}
452
453impl<'a> AddAssign for Residue32<'a> {
454    fn add_assign(&mut self, rhs: Self) {
455        *self = *self + rhs
456    }
457}
458
459impl<'a> Sub for Residue32<'a> {
460    type Output = Self;
461
462    fn sub(mut self, rhs: Self) -> Self::Output {
463        let (x, b) = self.x.overflowing_sub(rhs.x);
464        self.x = if b { x.wrapping_add(self.modulus()) } else { x };
465
466        self
467    }
468}
469
470impl<'a> SubAssign for Residue32<'a> {
471    fn sub_assign(&mut self, rhs: Self) {
472        *self = *self - rhs
473    }
474}
475
476impl<'a> Mul for Residue32<'a> {
477    type Output = Self;
478
479    fn mul(mut self, rhs: Self) -> Self::Output {
480        self.x = self.modulus.mul(self.x, rhs.x);
481        self
482    }
483}
484
485impl<'a> MulAssign for Residue32<'a> {
486    fn mul_assign(&mut self, rhs: Self) {
487        *self = *self * rhs
488    }
489}
490
491impl<'a> Neg for Residue32<'a> {
492    type Output = Self;
493
494    fn neg(mut self) -> Self::Output {
495        self.x = if self.x == 0 {
496            0
497        } else {
498            self.modulus() - self.x
499        };
500
501        self
502    }
503}
504
505/// An internal representation of [`Residue32`] without an associated [`Modulus32`].
506///
507/// Conceptually, [`Residue32`] = [`Raw32`] + [`Modulus32`].
508/// [`Raw32`] stores the value part alone, without holding a reference to its modulus.
509///
510/// This separation is useful for reducing the size of collections of [`Residue32`]
511/// and for avoiding self-referential structures when a type needs to contain both
512/// a residue and its modulus.
513#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
514pub struct Raw32 {
515    x: u64,
516}
517
518impl Raw32 {
519    /// Attaches a modulus and returns a [`Residue32`].
520    ///
521    /// # Caution
522    ///
523    /// This does not perform validation or reduction.
524    /// The caller must ensure the modulus is correct for this value.
525    #[inline(always)]
526    pub const fn into_residue<'a>(self, modulus: &'a Modulus32) -> Residue32<'a> {
527        Residue32 { modulus, x: self.x }
528    }
529}
530
531impl<'a> From<Residue32<'a>> for Raw32 {
532    #[inline(always)]
533    fn from(residue: Residue32<'a>) -> Self {
534        Self { x: residue.x }
535    }
536}
537#[cfg(test)]
538mod tests {
539    use super::*;
540
541    use proptest::prelude::*;
542
543    proptest! {
544        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
545        #[test]
546        fn mul(n in (0..=Modulus32::MAX).prop_map(|n| n | 1), x: u32) {
547            let modulus = Modulus32::new(n);
548
549            let res = modulus.residue(x);
550            assert_eq!(res.get() as u32, x % n)
551        }
552    }
553
554    proptest! {
555        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
556        #[test]
557        fn pow(n in (0..=Modulus32::MAX as u64).prop_map(|n| n | 1), x in 0u64..1 << 32) {
558            let modulus = Modulus32::new(n as u32);
559
560            let res = modulus.residue(x as u32);
561            let mut naive = 1;
562            for i in 0..100 {
563                assert_eq!(res.pow(i).get(), naive, "exp = {i}");
564                naive = naive * x % n
565            }
566        }
567    }
568
569    proptest! {
570        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
571        #[test]
572        fn divisible(n in (0..=Modulus32::MAX).prop_map(|n| n | 1), x: u32) {
573            let modulus = Modulus32::new(n);
574
575            assert_eq!(modulus.can_divide(x), x % n == 0);
576            for m in std::iter::successors(Some(n), |m| m.checked_add(n)).take(100) {
577                assert!(modulus.can_divide(m))
578            }
579        }
580    }
581
582    proptest! {
583        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
584        #[test]
585        fn divisible_by_1(x: u32) {
586            assert!(Modulus32::new(1).can_divide(x))
587        }
588    }
589
590    fn binary_gcd(mut a: u64, mut b: u64) -> u64 {
591        if b == 0 {
592            return a;
593        }
594
595        let shift = (a | b).trailing_zeros();
596        b >>= b.trailing_zeros();
597
598        while a != 0 {
599            a >>= a.trailing_zeros();
600
601            if a < b {
602                (a, b) = (b, a)
603            }
604            a -= b
605        }
606
607        b << shift
608    }
609
610    proptest! {
611        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
612        #[test]
613        fn try_inv(n in (0..=Modulus32::MAX).prop_map(|n| n | 1), x: u32) {
614            let modulus = Modulus32::new(n);
615            let res = modulus.residue(x);
616
617            match res.try_inv() {
618                Ok(inv) => assert_eq!((inv * res).get(), 1),
619                Err(gcd) => {
620                    assert!(res.get() % gcd == 0);
621                    assert!(res.modulus() % gcd == 0);
622                    assert_eq!(binary_gcd(res.get() / gcd, res.modulus() / gcd), 1);
623                }
624            }
625        }
626    }
627}
628
629mod primality_test {
630    use super::super::{COPRIME_2_3_5, PRIME_LT_64};
631    use super::Modulus32;
632
633    impl Modulus32 {
634        /// Checks whether `x` is a prime number.
635        ///
636        /// This may fail if `x` is larger than `2_654_435_769`.
637        /// Use 64-bit version.
638        ///
639        /// # Time complexity
640        ///
641        /// *O*(log `self`)
642        ///
643        /// # Example
644        ///
645        /// ```
646        /// use lib_modulo::Modulus32;
647        ///
648        /// // prime numbers
649        /// for p in [2, 3, 5, 7, 11, 998_244_353, 1_000_000_007, (1 << 31) - 1] {
650        ///     assert!(p <= Modulus32::MAX);
651        ///     assert_eq!(Modulus32::primality_test(p), Ok(true))
652        /// }
653        /// // composite numbers
654        /// for c in (2..).take(1 << 10) {
655        ///     assert!(c * (c + 1) <= Modulus32::MAX);
656        ///     assert_eq!(Modulus32::primality_test(c * (c + 1)), Ok(false));
657        /// }
658        ///
659        /// // may or may not fail for large integers
660        /// assert_eq!(Modulus32::primality_test(u32::MAX), Ok(false));
661        /// assert_eq!(Modulus32::primality_test(u32::MAX - 2), Err(()));
662        /// ```
663        #[allow(clippy::result_unit_err)]
664        pub const fn primality_test(x: u32) -> Result<bool, ()> {
665            if x < 64 {
666                return Ok((PRIME_LT_64 >> x) & 1 == 1);
667            } else if (COPRIME_2_3_5 >> (x % 30)) & 1 == 0 || x % 7 == 0 {
668                return Ok(false);
669            } else if x > Self::MAX {
670                return Err(());
671            }
672
673            let modulus = Self::new(x);
674            let one = modulus.residue(1).x;
675            let minus_one = modulus.n - one;
676            debug_assert!(one != 0 && minus_one != 0, "since x > 1");
677
678            let (d, s) = {
679                let n = modulus.n - 1;
680                ((n >> n.trailing_zeros()) as u32, n.trailing_zeros() - 1)
681            };
682            let mut i = 0;
683            'test: while i < 3 {
684                let witness = [2, 7, 61][i];
685                i += 1;
686
687                let w = modulus.residue(witness);
688                if w.is_zero() {
689                    continue;
690                }
691
692                let mut w = w.pow(d).x;
693                if w == minus_one || w == one {
694                    continue;
695                }
696
697                let mut s = s;
698                while s > 0 {
699                    s -= 1;
700                    w = modulus.mul(w, w);
701                    if w == minus_one {
702                        continue 'test;
703                    }
704                }
705
706                return Ok(false);
707            }
708
709            Ok(true)
710        }
711    }
712}