Skip to main content

lib_modulo/
residue32.rs

1use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2
3/// Factory of [`Residue32`].
4///
5/// See documentation of [`Residue32`] for details.
6#[allow(clippy::derived_hash_with_manual_eq)]
7#[derive(Debug, Clone, Hash, Eq)]
8pub struct Modulus32 {
9    // n inv_n = 1 (mod 2^64)
10    n: u64,
11    inv_n: u64,
12    // 2^128 (mod n) * inv_n
13    init: u64,
14    // ceil(2^64 / n)
15    recip: u64,
16}
17
18impl Modulus32 {
19    /// Maximum available modulus.
20    pub const MAX: u32 = 2_654_435_769;
21
22    /// Creates new instance with the given modulus.
23    ///
24    /// # Panics
25    ///
26    /// - modulus `n` should be an odd integer.
27    /// - modulus `n` should be no more than `2_654_435_769`,
28    ///   which is the floor of `2^32 / GOLDEN_RATIO`.
29    ///
30    /// # Example
31    ///
32    /// ```
33    /// use lib_modulo::Modulus32;
34    ///
35    /// // odd integer less than or equal to 2_654_435_769 is allowed.
36    /// let modulus = Modulus32::new(Modulus32::MAX);
37    /// let modulus = Modulus32::new(3);
38    ///
39    /// // modulus should be an odd integer!
40    /// assert!(std::panic::catch_unwind(|| { Modulus32::new(2); }).is_err())
41    /// ```
42    #[inline]
43    pub const fn new(n: u32) -> Self {
44        assert!(
45            n & 1 == 1,
46            "invalid modulus: modulus should be an odd integer."
47        );
48        assert!(
49            n <= Self::MAX,
50            "invalid modulus: modulus should be no more than 2_654_435_769."
51        );
52
53        let n = n as u64;
54
55        let inv_n = {
56            // 1 * 1 = 3 * 3 = 1 (mod 4)
57            let mut inv_n = n & 3;
58            // n inv_n = 1 (mod 2^k) => (n inv_n - 1)^2 = 0 (mod 2^{2k})
59            // => n inv_n (2 - n inv_n) = 1 (mod 2^{2k})
60            let mut i = u64::BITS.ilog2() - 1;
61            while i > 0 {
62                i -= 1;
63                inv_n = inv_n.wrapping_mul(2_u64.wrapping_sub(n.wrapping_mul(inv_n)));
64            }
65            debug_assert!(n.wrapping_mul(inv_n) == 1);
66
67            inv_n
68        };
69
70        let (div, rem) = {
71            let denom = n.wrapping_neg();
72            (denom / n, denom % n)
73        };
74        // 2^128 (mod n): magic number for converting integer to Plantard representation.
75        let init = rem * rem % n;
76        // ceil(2^64 / n): magic number for fast remainder algorithm
77        let recip = div.wrapping_add(if rem > 0 { 2 } else { 1 });
78
79        Self {
80            n,
81            inv_n,
82            init: init.wrapping_mul(inv_n),
83            recip,
84        }
85    }
86
87    /// Performs Plantard multiplication, i.e. `x, y -> x y / -2^64 (mod n)`.
88    ///
89    /// If `x y < self.n`, then returned value is less than `self.n`.
90    #[inline(always)]
91    const fn mul(&self, x: u64, y: u64) -> u64 {
92        // Plantard reduction: <https://thomas-plantard.github.io/pdf/Plantard21.pdf>
93        let z = self.inv_n.wrapping_mul(x).wrapping_mul(y) >> 32;
94        let z = ((z as u32).wrapping_add(1) as u64 * self.n) >> 32;
95        debug_assert!(z < self.n, "this is a bug in lib-modulo");
96        z
97    }
98
99    /// Calculates the residue of `x` modulo `self`.
100    ///
101    /// # Example
102    ///
103    /// ```
104    /// use lib_modulo::Modulus32;
105    ///
106    /// let modulus = Modulus32::new(5);
107    /// assert_eq!(modulus.residue(8).get(), 3)
108    /// ```
109    #[inline(always)]
110    pub const fn residue(&self, x: u32) -> Residue32<'_> {
111        // fast remainder algorithm
112        // See <https://onlinelibrary.wiley.com/doi/10.1002/spe.2689> for details
113        let x = {
114            let lo = self.recip.wrapping_mul(x as u64);
115            ((lo as u128 * self.n as u128) >> 64) as u64
116        };
117
118        let x = {
119            // multiplication by a constant
120            let x = self.init.wrapping_mul(x) >> 32;
121            ((x as u32).wrapping_add(1) as u64 * self.n) >> 32
122        };
123
124        Residue32 { x, modulus: self }
125    }
126
127    /// Checks whether `x` is divisible by `self`.
128    ///
129    /// # Example
130    ///
131    /// ```
132    /// use lib_modulo::Modulus32;
133    ///
134    /// let modulus = Modulus32::new(9);
135    /// assert!(modulus.can_divide(18));
136    /// assert!(!modulus.can_divide(19));
137    /// ```
138    #[inline(always)]
139    pub const fn can_divide(&self, x: u32) -> bool {
140        self.recip.wrapping_mul(x as u64) <= self.recip.wrapping_sub(1)
141    }
142}
143
144impl PartialEq for Modulus32 {
145    fn eq(&self, other: &Self) -> bool {
146        // other fields depend on `n`
147        self.n == other.n
148    }
149}
150
151/// A residue with an odd modulus not exceeding `2_654_435_769`.
152///
153/// # Fast modular multiplication
154///
155/// [`Residue32`] provides fast modular multiplication using [Plantard multiplication].
156/// This method eliminates one multiplication when one of the operands is reused multiple times.
157/// As a result, [`Residue32::pow`] and other operations are typically
158/// faster than implementations based on [Montgomery multiplication].
159///
160/// [Plantard multiplication]: https://thomas-plantard.github.io/pdf/Plantard21.pdf
161/// [Montgomery multiplication]: https://doi.org/10.1090/s0025-5718-1985-0777282-x
162///
163/// # Usage
164///
165/// ```
166/// use lib_modulo::Modulus32;
167///
168/// // set modulus
169/// let modulus = Modulus32::new(3);
170///
171/// // performs modular arithmetic
172/// let one = modulus.residue(1);
173/// let two = modulus.residue(2);
174/// let five = modulus.residue(5);
175/// assert_eq!(two * five, one)
176/// ```
177///
178/// Two residues with different modulus can interact, but the result will be meaningless.
179/// It is highly recommended to use a block to ensure that [`Modulus32`], therefore [`Residue32`]s, are dropped.
180#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
181pub struct Residue32<'a> {
182    // compare modulus first
183    modulus: &'a Modulus32,
184    x: u64,
185}
186
187impl<'a> Residue32<'a> {
188    /// Extract the internal representation of `self`.
189    ///
190    /// ```
191    /// use lib_modulo::{Modulus32, Raw32};
192    ///
193    /// let modulus = Modulus32::new(1001);
194    /// // save memory
195    /// let residues: Vec<Raw32> = (1..=1000).map(|x| modulus.residue(x).into_raw()).collect();
196    /// ```
197    #[inline(always)]
198    pub const fn into_raw(self) -> Raw32 {
199        Raw32 { x: self.x }
200    }
201
202    /// Checks whether `self` is `0`.
203    ///
204    /// # Example
205    ///
206    /// ```
207    /// use lib_modulo::Modulus32;
208    ///
209    /// let modulus = Modulus32::new(5);
210    /// assert!(modulus.residue(10).is_zero())
211    /// ```
212    #[inline(always)]
213    pub const fn is_zero(self) -> bool {
214        self.x == 0
215    }
216
217    /// Returns the residue.
218    ///
219    /// # Example
220    ///
221    /// ```
222    /// use lib_modulo::Modulus32;
223    ///
224    /// let modulus = Modulus32::new(7);
225    /// assert_eq!(modulus.residue(10).get(), 3)
226    /// ```
227    #[inline(always)]
228    pub const fn get(self) -> u64 {
229        self.modulus.mul(self.x, 1)
230    }
231
232    /// Returns the modulus.
233    ///
234    /// # Example
235    ///
236    /// ```
237    /// use lib_modulo::Modulus32;
238    ///
239    /// let modulus = Modulus32::new(11);
240    /// assert_eq!(modulus.residue(2).modulus(), 11);
241    /// ```
242    #[inline(always)]
243    pub const fn modulus(&self) -> u64 {
244        self.modulus.n
245    }
246
247    /// Raises `self` to the power of `exp`, using exponentiation by squaring.
248    ///
249    /// # Time complexity
250    ///
251    /// *Θ*(log `exp`)
252    ///
253    /// # Example
254    ///
255    /// ```
256    /// use lib_modulo::Modulus32;
257    ///
258    /// let modulus = Modulus32::new(1001);
259    /// let residue = modulus.residue(2);
260    /// for exp in 0..64 {
261    ///     assert_eq!(residue.pow(exp).get(), (1 << exp) % 1001)
262    /// }
263    /// ```
264    #[inline(always)]
265    pub const fn pow(self, mut exp: u32) -> Self {
266        let Self { mut x, modulus } = self;
267        // If `n = 1`, then `init = 0`. Otherwise, `n > 1`.
268        let mut prod = modulus.residue(1).x;
269
270        while exp > 1 {
271            if exp & 1 == 1 {
272                // インライン展開されると,掛け算を1回節約できる。
273                prod = modulus.mul(prod, x)
274            }
275
276            exp >>= 1;
277            x = modulus.mul(x, x); // skip last useless one
278        }
279        if exp != 0 {
280            prod = modulus.mul(prod, x);
281        }
282
283        Self { x: prod, modulus }
284    }
285
286    /// Calculates the modular inverse of `self`, using extended binary GCD algorithm.
287    ///
288    /// Modular inverse can be defined if and only if `self` and the modulus is coprime.
289    ///
290    /// - `Ok(x)` : `x` is the modular inverse.
291    /// - `Err(x)`: `x` is the GCD of `self` and the `modulus`,
292    ///   where `gcd(0, a)` is defined to be `a`.
293    ///
294    /// # Time complexity
295    ///
296    /// *O*(log `self`)
297    ///
298    /// # Example
299    ///
300    /// ```
301    /// use lib_modulo::Modulus32;
302    ///
303    /// let modulus = Modulus32::new(3 * 5);
304    ///
305    /// let residue = modulus.residue(2);
306    /// assert!(residue.inv().is_ok_and(|inv| (inv * residue).get() == 1));
307    ///
308    /// let residue = modulus.residue(6);
309    /// assert!(residue.inv().is_err_and(|gcd| gcd == 3));
310    /// ```
311    pub const fn inv(self) -> Result<Self, u64> {
312        // invariant: [a] x = a, [a] y = b (mod n), where [a] is initial value.
313        let mut a = self.get();
314        let mut b = self.modulus();
315        let Self { modulus, .. } = self;
316        let mut x = modulus.residue(1).x;
317        let mut y = 0;
318        let frac_1_2 = modulus.residue((modulus.n as u32).div_ceil(2));
319
320        while a > 0 {
321            x = modulus.mul(x, frac_1_2.pow(a.trailing_zeros()).x);
322            a >>= a.trailing_zeros();
323
324            if a < b {
325                (a, b) = (b, a);
326                (x, y) = (y, x);
327            }
328            a -= b;
329            let (z, b) = x.overflowing_sub(y);
330            x = if b { z.wrapping_add(modulus.n) } else { z };
331        }
332
333        // b = gcd([a], n)
334        if b == 1 {
335            Ok(Self { x: y, modulus })
336        } else {
337            Err(b)
338        }
339    }
340}
341
342impl<'a> Add for Residue32<'a> {
343    type Output = Self;
344
345    fn add(mut self, rhs: Self) -> Self::Output {
346        let (x, b) = self.x.overflowing_add(rhs.x);
347        self.x = if b || x >= self.modulus() {
348            x.wrapping_sub(self.modulus())
349        } else {
350            x
351        };
352
353        self
354    }
355}
356
357impl<'a> AddAssign for Residue32<'a> {
358    fn add_assign(&mut self, rhs: Self) {
359        *self = *self + rhs
360    }
361}
362
363impl<'a> Sub for Residue32<'a> {
364    type Output = Self;
365
366    fn sub(mut self, rhs: Self) -> Self::Output {
367        let (x, b) = self.x.overflowing_sub(rhs.x);
368        self.x = if b { x.wrapping_add(self.modulus()) } else { x };
369
370        self
371    }
372}
373
374impl<'a> SubAssign for Residue32<'a> {
375    fn sub_assign(&mut self, rhs: Self) {
376        *self = *self - rhs
377    }
378}
379
380impl<'a> Mul for Residue32<'a> {
381    type Output = Self;
382
383    fn mul(mut self, rhs: Self) -> Self::Output {
384        self.x = self.modulus.mul(self.x, rhs.x);
385        self
386    }
387}
388
389impl<'a> MulAssign for Residue32<'a> {
390    fn mul_assign(&mut self, rhs: Self) {
391        *self = *self * rhs
392    }
393}
394
395impl<'a> Neg for Residue32<'a> {
396    type Output = Self;
397
398    fn neg(mut self) -> Self::Output {
399        self.x = if self.x == 0 {
400            0
401        } else {
402            self.modulus() - self.x
403        };
404
405        self
406    }
407}
408
409/// An internal representation of [`Residue32`] without an associated [`Modulus32`].
410///
411/// Conceptually, [`Residue32`] = [`Raw32`] + [`Modulus32`].
412/// [`Raw32`] stores the value part alone, without holding a reference to its modulus.
413///
414/// This separation is useful for reducing the size of collections of [`Residue32`]
415/// and for avoiding self-referential structures when a type needs to contain both
416/// a residue and its modulus.
417#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
418pub struct Raw32 {
419    x: u64,
420}
421
422impl Raw32 {
423    /// Attaches a modulus and returns a [`Residue32`].
424    ///
425    /// # Caution
426    ///
427    /// This does not perform validation or reduction.
428    /// The caller must ensure the modulus is correct for this value.
429    #[inline(always)]
430    pub const fn into_residue<'a>(self, modulus: &'a Modulus32) -> Residue32<'a> {
431        Residue32 { modulus, x: self.x }
432    }
433}
434
435impl<'a> From<Residue32<'a>> for Raw32 {
436    #[inline(always)]
437    fn from(residue: Residue32<'a>) -> Self {
438        Self { x: residue.x }
439    }
440}
441#[cfg(test)]
442mod tests {
443    use super::*;
444
445    use proptest::prelude::*;
446
447    proptest! {
448        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
449        #[test]
450        fn mul(n in (0..=Modulus32::MAX).prop_map(|n| n | 1), x: u32) {
451            let modulus = Modulus32::new(n);
452
453            let res = modulus.residue(x);
454            assert_eq!(res.get() as u32, x % n)
455        }
456    }
457
458    proptest! {
459        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
460        #[test]
461        fn pow(n in (0..=Modulus32::MAX as u64).prop_map(|n| n | 1), x in 0u64..1 << 32) {
462            let modulus = Modulus32::new(n as u32);
463
464            let res = modulus.residue(x as u32);
465            let mut naive = 1;
466            for i in 0..100 {
467                assert_eq!(res.pow(i).get(), naive, "exp = {i}");
468                naive = naive * x % n
469            }
470        }
471    }
472
473    proptest! {
474        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
475        #[test]
476        fn divisible(n in (0..=Modulus32::MAX).prop_map(|n| n | 1), x: u32) {
477            let modulus = Modulus32::new(n);
478
479            assert_eq!(modulus.can_divide(x), x % n == 0);
480        }
481    }
482
483    proptest! {
484        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
485        #[test]
486        fn divisible_by_1(x: u32) {
487            assert!(Modulus32::new(1).can_divide(x))
488        }
489    }
490
491    fn binary_gcd(mut a: u64, mut b: u64) -> u64 {
492        if b == 0 {
493            return a;
494        }
495
496        let shift = (a | b).trailing_zeros();
497        b >>= b.trailing_zeros();
498
499        while a != 0 {
500            a >>= a.trailing_zeros();
501
502            if a < b {
503                (a, b) = (b, a)
504            }
505            a -= b
506        }
507
508        b << shift
509    }
510
511    proptest! {
512        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
513        #[test]
514        fn inv(n in (0..=Modulus32::MAX).prop_map(|n| n | 1), x: u32) {
515            let modulus = Modulus32::new(n);
516            let res = modulus.residue(x);
517
518            match res.inv() {
519                Ok(inv) => assert_eq!((inv * res).get(), 1),
520                Err(gcd) => {
521                    assert!(res.get() % gcd == 0);
522                    assert!(res.modulus() % gcd == 0);
523                    assert_eq!(binary_gcd(res.get() / gcd, res.modulus() / gcd), 1);
524                }
525            }
526        }
527    }
528}