Skip to main content

lib_modulo/
residue32.rs

1use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2
3/// Factory of [`Residue32`].
4///
5/// See documentation of [`Residue32`] for details.
6#[allow(clippy::derived_hash_with_manual_eq)]
7#[derive(Debug, Clone, Hash, Eq)]
8pub struct Modulus32 {
9    // n inv_n = 1 (mod 2^64)
10    n: u64,
11    inv_n: u64,
12    // 2^128 (mod n) * inv_n
13    init: u64,
14    // ceil(2^64 / n)
15    recip: u64,
16}
17
18impl Modulus32 {
19    /// Maximum available modulus.
20    pub const MAX: u32 = 2_654_435_769;
21
22    /// Creates new instance with the given modulus.
23    ///
24    /// # Panics
25    ///
26    /// - modulus `n` should be an odd integer.
27    /// - modulus `n` should be no more than `2_654_435_769`,
28    ///   which is the floor of `2^32 / GOLDEN_RATIO`.
29    ///
30    /// # Example
31    ///
32    /// ```
33    /// use lib_modulo::Modulus32;
34    ///
35    /// // odd integer less than or equal to 2_654_435_769 is allowed.
36    /// let modulus = Modulus32::new(Modulus32::MAX);
37    /// let modulus = Modulus32::new(3);
38    ///
39    /// // modulus should be an odd integer!
40    /// assert!(std::panic::catch_unwind(|| { Modulus32::new(2); }).is_err())
41    /// ```
42    #[inline]
43    #[must_use]
44    pub const fn new(n: u32) -> Self {
45        assert!(
46            n & 1 == 1,
47            "invalid modulus: modulus should be an odd integer."
48        );
49        assert!(
50            n <= Self::MAX,
51            "invalid modulus: modulus should be no more than 2_654_435_769."
52        );
53
54        let n = n as u64;
55
56        let inv_n = {
57            // 1 * 1 = 3 * 3 = 1 (mod 4)
58            let mut inv_n = n & 3;
59            // n inv_n = 1 (mod 2^k) => (n inv_n - 1)^2 = 0 (mod 2^{2k})
60            // => n inv_n (2 - n inv_n) = 1 (mod 2^{2k})
61            let mut i = u64::BITS.ilog2() - 1;
62            while i > 0 {
63                i -= 1;
64                inv_n = inv_n.wrapping_mul(2_u64.wrapping_sub(n.wrapping_mul(inv_n)));
65            }
66            debug_assert!(n.wrapping_mul(inv_n) == 1);
67
68            inv_n
69        };
70
71        let (div, rem) = {
72            let denom = n.wrapping_neg();
73            (denom / n, denom % n)
74        };
75        // 2^128 (mod n): magic number for converting integer to Plantard representation.
76        let init = rem * rem % n;
77        // ceil(2^64 / n): magic number for fast remainder algorithm
78        let recip = div.wrapping_add(if rem > 0 { 2 } else { 1 });
79
80        Self {
81            n,
82            inv_n,
83            init: init.wrapping_mul(inv_n),
84            recip,
85        }
86    }
87
88    /// Performs Plantard multiplication, i.e. `x, y -> x y / -2^64 (mod n)`.
89    ///
90    /// If `x y < self.n`, then returned value is less than `self.n`.
91    #[inline(always)]
92    const fn mul(&self, x: u64, y: u64) -> u64 {
93        // Plantard reduction: <https://thomas-plantard.github.io/pdf/Plantard21.pdf>
94        let z = self.inv_n.wrapping_mul(x).wrapping_mul(y) >> 32;
95        let z = ((z as u32).wrapping_add(1) as u64 * self.n) >> 32;
96        debug_assert!(z < self.n, "this is a bug in lib-modulo");
97        z
98    }
99
100    /// Calculates the residue of `x` modulo `self`.
101    ///
102    /// # Example
103    ///
104    /// ```
105    /// use lib_modulo::Modulus32;
106    ///
107    /// let modulus = Modulus32::new(5);
108    /// assert_eq!(modulus.residue(8).get(), 3)
109    /// ```
110    #[must_use]
111    pub const fn residue(&self, x: u32) -> Residue32<'_> {
112        // fast remainder algorithm
113        // See <https://onlinelibrary.wiley.com/doi/10.1002/spe.2689> for details
114        let x = {
115            let lo = self.recip.wrapping_mul(x as u64);
116            ((lo as u128 * self.n as u128) >> 64) as u64
117        };
118
119        let x = {
120            // multiplication by a constant
121            let x = self.init.wrapping_mul(x) >> 32;
122            ((x as u32).wrapping_add(1) as u64 * self.n) >> 32
123        };
124
125        Residue32 { x, modulus: self }
126    }
127
128    /// Checks whether `x` is divisible by `self`.
129    ///
130    /// # Example
131    ///
132    /// ```
133    /// use lib_modulo::Modulus32;
134    ///
135    /// let modulus = Modulus32::new(9);
136    /// assert!(modulus.can_divide(18));
137    /// assert!(!modulus.can_divide(19));
138    /// ```
139    #[must_use]
140    pub const fn can_divide(&self, x: u32) -> bool {
141        self.recip.wrapping_mul(x as u64) <= self.recip.wrapping_sub(1)
142    }
143}
144
145impl PartialEq for Modulus32 {
146    fn eq(&self, other: &Self) -> bool {
147        // other fields depend on `n`
148        self.n == other.n
149    }
150}
151
152/// A residue with an odd modulus not exceeding `2_654_435_769`.
153///
154/// # Fast modular multiplication
155///
156/// [`Residue32`] provides fast modular multiplication using [Plantard multiplication].
157/// This method eliminates one multiplication when one of the operands is reused multiple times.
158/// As a result, [`Residue32::pow`] and other operations are typically
159/// faster than implementations based on [Montgomery multiplication].
160///
161/// [Plantard multiplication]: https://thomas-plantard.github.io/pdf/Plantard21.pdf
162/// [Montgomery multiplication]: https://doi.org/10.1090/s0025-5718-1985-0777282-x
163///
164/// # Usage
165///
166/// ```
167/// use lib_modulo::Modulus32;
168///
169/// // set modulus
170/// let modulus = Modulus32::new(3);
171///
172/// // performs modular arithmetic
173/// let one = modulus.residue(1);
174/// let two = modulus.residue(2);
175/// let five = modulus.residue(5);
176/// assert_eq!(two * five, one)
177/// ```
178///
179/// Two residues with different modulus can interact, but the result will be meaningless.
180/// It is highly recommended to use a block to ensure that [`Modulus32`], therefore [`Residue32`]s, are dropped.
181#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
182pub struct Residue32<'a> {
183    // compare modulus first
184    modulus: &'a Modulus32,
185    x: u64,
186}
187
188impl Residue32<'_> {
189    /// Extract the internal representation of `self`.
190    ///
191    /// ```
192    /// use lib_modulo::{Modulus32, Raw32};
193    ///
194    /// let modulus = Modulus32::new(1001);
195    /// // save memory
196    /// let residues: Vec<Raw32> = (1..=1000).map(|x| modulus.residue(x).into_raw()).collect();
197    ///
198    /// // `Residue32` and `raw32` can interact.
199    /// // The caller must ensure that both operands shares the same modulus.
200    /// let double_sum = residues.into_iter().fold(modulus.residue(0), |sum, r| r + sum + r);
201    /// assert_eq!(double_sum, modulus.residue((1 + 1000) * 1000));
202    /// ```
203    #[must_use]
204    pub const fn into_raw(self) -> Raw32 {
205        Raw32 { x: self.x }
206    }
207
208    /// Checks whether `self` is `0`.
209    ///
210    /// # Example
211    ///
212    /// ```
213    /// use lib_modulo::Modulus32;
214    ///
215    /// let modulus = Modulus32::new(5);
216    /// assert!(modulus.residue(10).is_zero())
217    /// ```
218    #[must_use]
219    pub const fn is_zero(self) -> bool {
220        self.x == 0
221    }
222
223    /// Returns the residue.
224    ///
225    /// # Example
226    ///
227    /// ```
228    /// use lib_modulo::Modulus32;
229    ///
230    /// let modulus = Modulus32::new(7);
231    /// assert_eq!(modulus.residue(10).get(), 3)
232    /// ```
233    #[must_use]
234    pub const fn get(self) -> u64 {
235        self.modulus.mul(self.x, 1)
236    }
237
238    /// Returns the modulus.
239    ///
240    /// # Example
241    ///
242    /// ```
243    /// use lib_modulo::Modulus32;
244    ///
245    /// let modulus = Modulus32::new(11);
246    /// assert_eq!(modulus.residue(2).modulus(), 11);
247    /// ```
248    #[must_use]
249    pub const fn modulus(&self) -> u64 {
250        self.modulus.n
251    }
252
253    /// Raises `self` to the power of `exp`, using exponentiation by squaring.
254    ///
255    /// # Time complexity
256    ///
257    /// *Θ*(log `exp`)
258    ///
259    /// # Example
260    ///
261    /// ```
262    /// use lib_modulo::Modulus32;
263    ///
264    /// let modulus = Modulus32::new(1001);
265    /// let residue = modulus.residue(2);
266    /// for exp in 0..64 {
267    ///     assert_eq!(residue.pow(exp).get(), (1 << exp) % 1001)
268    /// }
269    /// ```
270    #[must_use]
271    pub const fn pow(self, mut exp: u32) -> Self {
272        let Self { mut x, modulus } = self;
273        // If `n = 1`, then `init = 0`. Otherwise, `n > 1`.
274        let mut prod = modulus.residue(1).x;
275
276        while exp > 1 {
277            if exp & 1 == 1 {
278                // インライン展開されると,掛け算を1回節約できる。
279                prod = modulus.mul(prod, x);
280            }
281
282            exp >>= 1;
283            x = modulus.mul(x, x); // skip last useless one
284        }
285        if exp != 0 {
286            prod = modulus.mul(prod, x);
287        }
288
289        Self { x: prod, modulus }
290    }
291
292    /// Calculates the modular inverse of `self`, using extended binary GCD algorithm.
293    ///
294    /// Modular inverse can be defined if and only if `self` and the modulus is coprime.
295    ///
296    /// - `Ok(x)` : `x` is the modular inverse.
297    /// - `Err(x)`: `x` is the GCD of `self` and the `modulus`,
298    ///   where `gcd(0, a)` is defined to be `a`.
299    ///
300    /// # Time complexity
301    ///
302    /// *O*(log `self`)
303    ///
304    /// # Example
305    ///
306    /// ```
307    /// use lib_modulo::Modulus32;
308    ///
309    /// let modulus = Modulus32::new(3 * 5);
310    ///
311    /// let residue = modulus.residue(2);
312    /// assert!(residue.inv().is_ok_and(|inv| (inv * residue).get() == 1));
313    ///
314    /// let residue = modulus.residue(6);
315    /// assert!(residue.inv().is_err_and(|gcd| gcd == 3));
316    /// ```
317    pub const fn inv(self) -> Result<Self, u64> {
318        // invariant: [a] x = a, [a] y = b (mod n), where [a] is initial value.
319        let mut a = self.get();
320        let mut b = self.modulus();
321        let Self { modulus, .. } = self;
322        let mut x = modulus.residue(1).x;
323        let mut y = 0;
324        let frac_1_2 = modulus.residue((modulus.n as u32).div_ceil(2));
325
326        while a > 0 {
327            x = modulus.mul(x, frac_1_2.pow(a.trailing_zeros()).x);
328            a >>= a.trailing_zeros();
329
330            if a < b {
331                (a, b) = (b, a);
332                (x, y) = (y, x);
333            }
334            a -= b;
335            let (z, b) = x.overflowing_sub(y);
336            x = if b { z.wrapping_add(modulus.n) } else { z };
337        }
338
339        // b = gcd([a], n)
340        if b == 1 {
341            Ok(Self { x: y, modulus })
342        } else {
343            Err(b)
344        }
345    }
346}
347
348impl Add for Residue32<'_> {
349    type Output = Self;
350
351    fn add(self, rhs: Self) -> Self::Output {
352        self + rhs.into_raw()
353    }
354}
355
356impl AddAssign for Residue32<'_> {
357    fn add_assign(&mut self, rhs: Self) {
358        *self = *self + rhs;
359    }
360}
361
362impl Sub for Residue32<'_> {
363    type Output = Self;
364
365    fn sub(self, rhs: Self) -> Self::Output {
366        self - rhs.into_raw()
367    }
368}
369
370impl SubAssign for Residue32<'_> {
371    fn sub_assign(&mut self, rhs: Self) {
372        *self = *self - rhs;
373    }
374}
375
376impl Mul for Residue32<'_> {
377    type Output = Self;
378
379    fn mul(self, rhs: Self) -> Self::Output {
380        self * rhs.into_raw()
381    }
382}
383
384impl MulAssign for Residue32<'_> {
385    fn mul_assign(&mut self, rhs: Self) {
386        *self = *self * rhs;
387    }
388}
389
390impl Neg for Residue32<'_> {
391    type Output = Self;
392
393    fn neg(mut self) -> Self::Output {
394        self.x = if self.x == 0 {
395            0
396        } else {
397            self.modulus() - self.x
398        };
399
400        self
401    }
402}
403
404/// An internal representation of [`Residue32`] without an associated [`Modulus32`].
405///
406/// Conceptually, [`Residue32`] = [`Raw32`] + [`Modulus32`].
407/// [`Raw32`] stores the value part alone, without holding a reference to its modulus.
408///
409/// This separation is useful for reducing the size of collections of [`Residue32`]
410/// and for avoiding self-referential structures when a type needs to contain both
411/// a residue and its modulus.
412#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
413pub struct Raw32 {
414    x: u64,
415}
416
417impl Raw32 {
418    /// Attaches a modulus and returns a [`Residue32`].
419    ///
420    /// Typically, this only needs to be called once per computation
421    /// because `Raw32` and `Residue32` can interact.
422    ///
423    /// # Caution
424    ///
425    /// This does not perform validation or reduction.
426    /// The caller must ensure the modulus is correct for this value.
427    #[must_use]
428    pub const fn into_residue(self, modulus: &Modulus32) -> Residue32<'_> {
429        Residue32 { modulus, x: self.x }
430    }
431}
432
433impl<'a> From<Residue32<'a>> for Raw32 {
434    fn from(residue: Residue32<'a>) -> Self {
435        Self { x: residue.x }
436    }
437}
438
439impl<'a> Add<Raw32> for Residue32<'a> {
440    type Output = Residue32<'a>;
441
442    /// Performs the `+` operation.
443    ///
444    /// # Caution
445    ///
446    /// The caller must ensure that both operands shares the same modulus.
447    fn add(mut self, rhs: Raw32) -> Self::Output {
448        let (sum, b) = self.x.overflowing_add(rhs.x);
449        self.x = if b || sum >= self.modulus.n {
450            sum.wrapping_sub(self.modulus.n)
451        } else {
452            sum
453        };
454
455        self
456    }
457}
458
459impl<'a> Add<Residue32<'a>> for Raw32 {
460    type Output = Residue32<'a>;
461
462    /// Performs the `+` operation.
463    ///
464    /// # Caution
465    ///
466    /// The caller must ensure that both operands shares the same modulus.
467    fn add(self, rhs: Residue32<'a>) -> Self::Output {
468        rhs + self
469    }
470}
471
472impl AddAssign<Raw32> for Residue32<'_> {
473    /// Performs the `+=` operation.
474    ///
475    /// # Caution
476    ///
477    /// The caller must ensure that both operands shares the same modulus.
478    fn add_assign(&mut self, rhs: Raw32) {
479        *self = *self + rhs;
480    }
481}
482
483impl<'a> Sub<Raw32> for Residue32<'a> {
484    type Output = Residue32<'a>;
485
486    /// Performs the `-` operation.
487    ///
488    /// # Caution
489    ///
490    /// The caller must ensure that both operands shares the same modulus.
491    fn sub(mut self, rhs: Raw32) -> Self::Output {
492        let (diff, b) = self.x.overflowing_sub(rhs.x);
493        self.x = if b {
494            diff.wrapping_add(self.modulus.n)
495        } else {
496            diff
497        };
498
499        self
500    }
501}
502
503impl<'a> Sub<Residue32<'a>> for Raw32 {
504    type Output = Residue32<'a>;
505
506    /// Performs the `-` operation.
507    ///
508    /// # Caution
509    ///
510    /// The caller must ensure that both operands shares the same modulus.
511    fn sub(self, mut rhs: Residue32<'a>) -> Self::Output {
512        let (diff, b) = self.x.overflowing_sub(rhs.x);
513        rhs.x = if b {
514            diff.wrapping_add(rhs.modulus.n)
515        } else {
516            diff
517        };
518
519        rhs
520    }
521}
522
523impl SubAssign<Raw32> for Residue32<'_> {
524    /// Performs the `-=` operation.
525    ///
526    /// # Caution
527    ///
528    /// The caller must ensure that both operands shares the same modulus.
529    fn sub_assign(&mut self, rhs: Raw32) {
530        *self = *self - rhs;
531    }
532}
533
534impl<'a> Mul<Raw32> for Residue32<'a> {
535    type Output = Residue32<'a>;
536
537    /// Performs the `*` operation.
538    ///
539    /// # Caution
540    ///
541    /// The caller must ensure that both operands shares the same modulus.
542    fn mul(mut self, rhs: Raw32) -> Self::Output {
543        // n < r
544        self.x = self.modulus.mul(self.x, rhs.x);
545
546        self
547    }
548}
549
550impl<'a> Mul<Residue32<'a>> for Raw32 {
551    type Output = Residue32<'a>;
552
553    /// Performs the `*` operation.
554    ///
555    /// # Caution
556    ///
557    /// The caller must ensure that both operands shares the same modulus.
558    fn mul(self, rhs: Residue32<'a>) -> Self::Output {
559        rhs * self
560    }
561}
562
563impl MulAssign<Raw32> for Residue32<'_> {
564    /// Performs the `*=` operation.
565    ///
566    /// # Caution
567    ///
568    /// The caller must ensure that both operands shares the same modulus.
569    fn mul_assign(&mut self, rhs: Raw32) {
570        *self = *self * rhs;
571    }
572}
573
574#[cfg(test)]
575mod tests {
576    use super::*;
577
578    use proptest::prelude::*;
579
580    proptest! {
581        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
582        #[test]
583        fn mul(n in (0..=Modulus32::MAX).prop_map(|n| n | 1), x: u32) {
584            let modulus = Modulus32::new(n);
585
586            let res = modulus.residue(x);
587            assert_eq!(res.get() as u32, x % n)
588        }
589    }
590
591    proptest! {
592        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
593        #[test]
594        fn pow(n in (0..=Modulus32::MAX as u64).prop_map(|n| n | 1), x in 0u64..1 << 32) {
595            let modulus = Modulus32::new(n as u32);
596
597            let res = modulus.residue(x as u32);
598            let mut naive = 1;
599            for i in 0..100 {
600                assert_eq!(res.pow(i).get(), naive, "exp = {i}");
601                naive = naive * x % n
602            }
603        }
604    }
605
606    proptest! {
607        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
608        #[test]
609        fn divisible(n in (0..=Modulus32::MAX).prop_map(|n| n | 1), x: u32) {
610            let modulus = Modulus32::new(n);
611
612            assert_eq!(modulus.can_divide(x), x % n == 0);
613        }
614    }
615
616    proptest! {
617        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
618        #[test]
619        fn divisible_by_1(x: u32) {
620            assert!(Modulus32::new(1).can_divide(x))
621        }
622    }
623
624    fn binary_gcd(mut a: u64, mut b: u64) -> u64 {
625        if b == 0 {
626            return a;
627        }
628
629        let shift = (a | b).trailing_zeros();
630        b >>= b.trailing_zeros();
631
632        while a != 0 {
633            a >>= a.trailing_zeros();
634
635            if a < b {
636                (a, b) = (b, a)
637            }
638            a -= b
639        }
640
641        b << shift
642    }
643
644    proptest! {
645        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
646        #[test]
647        fn inv(n in (0..=Modulus32::MAX).prop_map(|n| n | 1), x: u32) {
648            let modulus = Modulus32::new(n);
649            let res = modulus.residue(x);
650
651            match res.inv() {
652                Ok(inv) => assert_eq!((inv * res).get(), 1),
653                Err(gcd) => {
654                    assert!(res.get() % gcd == 0);
655                    assert!(res.modulus() % gcd == 0);
656                    assert_eq!(binary_gcd(res.get() / gcd, res.modulus() / gcd), 1);
657                }
658            }
659        }
660    }
661}