Skip to main content

lib_modulo/
residue32.rs

1use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2
3/// Factory of [`Residue32`].
4///
5/// See documentation of [`Residue32`] for details.
6#[derive(Debug, Clone, Hash, Eq)]
7pub struct Modulus32 {
8    // n inv_n = 1 (mod 2^64)
9    n: u64,
10    inv_n: u64,
11    // 2^128 (mod n)
12    init: u64,
13}
14
15impl Modulus32 {
16    /// Maximum available modulus.
17    pub const MAX_MODULUS: u32 = 2_654_435_769;
18
19    /// Creates new context for modular arithmetics.
20    ///
21    /// # Panics
22    ///
23    /// - modulus `n` should be an odd integer.
24    /// - modulus `n` should be no more than `2_654_435_769`,
25    /// which is the floor of `2^32 / GOLDEN_RATIO`.
26    ///
27    /// # Example
28    ///
29    /// ```
30    /// use lib_modulo::Modulus32;
31    ///
32    /// // odd integer less than or equal to 2_654_435_769 is allowed.
33    /// let modulus = Modulus32::new(Modulus32::MAX_MODULUS);
34    /// let modulus = Modulus32::new(3);
35    ///
36    /// // modulus should be an odd integer!
37    /// assert!(std::panic::catch_unwind(|| { Modulus32::new(2); }).is_err())
38    /// ```
39    #[inline]
40    pub const fn new(n: u32) -> Self {
41        assert!(
42            n & 1 == 1,
43            "invalid modulus: modulus should be an odd integer."
44        );
45        assert!(
46            n <= Self::MAX_MODULUS,
47            "invalid modulus: modulus should be no more than 2_654_435_769."
48        );
49
50        let n = n as u64;
51
52        let inv_n = {
53            // 1 * 1 = 3 * 3 = 1 (mod 4)
54            let mut inv_n = n & 3;
55            // n inv_n = 1 (mod 2^k) => (n inv_n - 1)^2 = 0 (mod 2^{2k})
56            // => n inv_n (2 - n inv_n) = 1 (mod 2^{2k})
57            let mut i = u64::BITS.ilog2() - 1;
58            while i > 0 {
59                i -= 1;
60                inv_n = inv_n.wrapping_mul(2_u64.wrapping_sub(n.wrapping_mul(inv_n)));
61            }
62            debug_assert!(n.wrapping_mul(inv_n) == 1);
63
64            inv_n
65        };
66
67        let init = if std::mem::size_of::<usize>() >= std::mem::size_of::<u64>() {
68            ((n as u128).wrapping_neg() % n as u128) as u64
69        } else {
70            // 128 bit division on 32-bit machine will be slow (no experiment)
71            let pow_2_64 = n.wrapping_neg() % n;
72            let pow_2_128 = pow_2_64 * pow_2_64 % n;
73            pow_2_128
74        };
75
76        Self { n, inv_n, init }
77    }
78
79    /// Performs Plantard multiplication, i.e. `x, y -> x y / -2^64 (mod n)`.
80    ///
81    /// If `x y < self.n`, then returned value is less than `self.n`.
82    #[inline(always)]
83    const fn mul(&self, x: u64, y: u64) -> u64 {
84        // Plantard reduction: <https://thomas-plantard.github.io/pdf/Plantard21.pdf>
85        // let z = x.wrapping_mul(y).wrapping_mul(self.inv_n) >> 32;
86        // let z = (z + 1).wrapping_mul(self.n) >> 32;
87
88        // if z == self.n {
89        //     0
90        // } else {
91        //     z
92        // }
93
94        // modified Plantard multiplication
95        // z = (Q1 2^32 + (2^32 - 1)) P >> 64 (notation is the same as the paper)
96        let z = self.inv_n.wrapping_mul(x).wrapping_mul(y) | u32::MAX as u64;
97        let z = ((z as u128 * self.n as u128) >> 64) as u64;
98        debug_assert!(z < self.n, "this is a bug in lib-modulo");
99        z
100    }
101
102    /// Calculates the residue of `x` modulo `self`.
103    ///
104    /// # Example
105    ///
106    /// ```
107    /// use lib_modulo::Modulus32;
108    ///
109    /// let modulus = Modulus32::new(5);
110    /// assert_eq!(modulus.residue(8).get(), 3)
111    /// ```
112    #[inline(always)]
113    pub const fn residue(&self, x: u32) -> Residue32<'_> {
114        let mut x = x as u64;
115
116        if x * self.init > !(self.n << 32) {
117            x %= self.n
118        }
119
120        Residue32 {
121            x: self.mul(x, self.init),
122            modulus: &self,
123        }
124    }
125
126    /// Checks whether `x` is divisible by `self`.
127    ///
128    /// # Example
129    ///
130    /// ```
131    /// use lib_modulo::Modulus32;
132    ///
133    /// let modulus = Modulus32::new(9);
134    /// assert!(modulus.can_divide(18));
135    /// assert!(!modulus.can_divide(19));
136    /// ```
137    #[inline(always)]
138    pub const fn can_divide(&self, x: u32) -> bool {
139        self.residue(x).is_zero()
140    }
141}
142
143impl PartialEq for Modulus32 {
144    fn eq(&self, other: &Self) -> bool {
145        self.n == other.n
146    }
147}
148
149/// Residue with odd modulus which is no more than `2_654_435_769`.
150///
151/// # Fast modular multiplication
152///
153/// [`Residue32`] provides fast modular multiplication called [Plantard multiplication].
154/// This method saves one multiplication when either of two values of a multiplication is used multiple times.
155/// Therefore, [`Residue32::pow`] will be faster than that using [Montgomery multiplication].
156///
157/// [Plantard multiplication]: https://thomas-plantard.github.io/pdf/Plantard21.pdf
158/// [Montgomery multiplication]: https://doi.org/10.1090/s0025-5718-1985-0777282-x
159///
160/// # Usage
161///
162/// ```
163/// use lib_modulo::Modulus32;
164///
165/// // set modulus
166/// let modulus = Modulus32::new(3);
167///
168/// // performs modular arithmetics
169/// let one = modulus.residue(1);
170/// let two = modulus.residue(2);
171/// let five = modulus.residue(5);
172/// assert_eq!(two * five, one)
173/// ```
174///
175/// Two residues with different modulus can interact, but the result will be meaningless.
176/// It is highly recommended to use a block to ensure that [`Modulus32`], therefore [`Residue32`]s, are dropped.
177#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
178pub struct Residue32<'a> {
179    // compare modulus first
180    modulus: &'a Modulus32,
181    x: u64,
182}
183
184impl<'a> Residue32<'a> {
185    /// Checks whether `self` is `0`.
186    ///
187    /// # Example
188    ///
189    /// ```
190    /// use lib_modulo::Modulus32;
191    ///
192    /// let modulus = Modulus32::new(5);
193    /// assert!(modulus.residue(10).is_zero())
194    /// ```
195    #[inline(always)]
196    pub const fn is_zero(self) -> bool {
197        self.x == 0
198    }
199
200    /// Returns the residue.
201    ///
202    /// # Example
203    ///
204    /// ```
205    /// use lib_modulo::Modulus32;
206    ///
207    /// let modulus = Modulus32::new(7);
208    /// assert_eq!(modulus.residue(10).get(), 3)
209    /// ```
210    #[inline(always)]
211    pub const fn get(self) -> u64 {
212        self.modulus.mul(self.x, 1)
213    }
214
215    /// Returns the modulus.
216    ///
217    /// # Example
218    ///
219    /// ```
220    /// use lib_modulo::Modulus32;
221    ///
222    /// let modulus = Modulus32::new(11);
223    /// assert_eq!(modulus.residue(2).modulus(), 11);
224    /// ```
225    #[inline(always)]
226    pub const fn modulus(&self) -> u64 {
227        self.modulus.n
228    }
229
230    /// Raises `self` to the power of `exp`, using exponentiation by squaring.
231    ///
232    /// # Time complexity
233    ///
234    /// *Θ*(log `exp`)
235    ///
236    /// # Example
237    ///
238    /// ```
239    /// use lib_modulo::Modulus32;
240    ///
241    /// let modulus = Modulus32::new(1001);
242    /// let residue = modulus.residue(2);
243    /// for exp in 0..64 {
244    ///     assert_eq!(residue.pow(exp).get(), (1 << exp) % 1001)
245    /// }
246    /// ```
247    #[inline(always)]
248    pub const fn pow(self, mut exp: u32) -> Self {
249        let Self { mut x, modulus } = self;
250        // If `n = 1`, then `init = 0`. Otherwise, `n > 1`.
251        let mut prod = modulus.mul(1, modulus.init);
252
253        while exp > 1 {
254            if exp & 1 == 1 {
255                // インライン展開されると,掛け算を1回節約できる。
256                prod = modulus.mul(prod, x)
257            }
258
259            exp >>= 1;
260            x = modulus.mul(x, x); // skip last useless one
261        }
262        if exp != 0 {
263            prod = modulus.mul(prod, x);
264        }
265
266        Self { x: prod, modulus }
267    }
268
269    /// Calculates the modular inverse of `self`, using extended binary GCD algorithm.
270    ///
271    /// Modular inverse can be defined if and only if `self` and the modulus is coprime.
272    ///
273    /// - `Ok(x)` : `x` is the modular inverse.
274    /// - `Err(x)`: `x` is the GCD of `self` and the `modulus`,
275    /// where `gcd(0, a)` is defined to be `a`.
276    ///
277    /// # Time complexity
278    ///
279    /// *O*(log `self`)
280    ///
281    /// # Example
282    ///
283    /// ```
284    /// use lib_modulo::Modulus32;
285    ///
286    /// let modulus = Modulus32::new(3 * 5);
287    ///
288    /// let residue = modulus.residue(2);
289    /// assert!(residue.try_inv().is_ok_and(|inv| (inv * residue).get() == 1));
290    ///
291    /// let residue = modulus.residue(6);
292    /// assert!(residue.try_inv().is_err_and(|gcd| gcd == 3));
293    /// ```
294    pub const fn try_inv(self) -> Result<Self, u64> {
295        // invariant: [a] x = a, [a] y = b (mod n), where [a] is initial value.
296        let mut a = self.get();
297        let mut b = self.modulus();
298        let Self { modulus, .. } = self;
299        let mut x = modulus.residue(1).x;
300        let mut y = 0;
301        let frac_1_2 = modulus.residue((modulus.n as u32).div_ceil(2));
302
303        while a > 0 {
304            x = modulus.mul(x, frac_1_2.pow(a.trailing_zeros()).x);
305            a >>= a.trailing_zeros();
306
307            if a < b {
308                (a, b) = (b, a);
309                (x, y) = (y, x);
310            }
311            a -= b;
312            let (z, b) = x.overflowing_sub(y);
313            x = if b { z.wrapping_add(modulus.n) } else { z };
314        }
315
316        // b = gcd([a], n)
317        if b == 1 {
318            Ok(Self { x: y, modulus })
319        } else {
320            Err(b)
321        }
322    }
323}
324
325impl<'a> Add for Residue32<'a> {
326    type Output = Self;
327
328    fn add(mut self, rhs: Self) -> Self::Output {
329        let (x, b) = self.x.overflowing_add(rhs.x);
330        self.x = if b || x >= self.modulus() {
331            x.wrapping_sub(self.modulus())
332        } else {
333            x
334        };
335
336        self
337    }
338}
339
340impl<'a> AddAssign for Residue32<'a> {
341    fn add_assign(&mut self, rhs: Self) {
342        *self = *self + rhs
343    }
344}
345
346impl<'a> Sub for Residue32<'a> {
347    type Output = Self;
348
349    fn sub(mut self, rhs: Self) -> Self::Output {
350        let (x, b) = self.x.overflowing_sub(rhs.x);
351        self.x = if b { x.wrapping_add(self.modulus()) } else { x };
352
353        self
354    }
355}
356
357impl<'a> SubAssign for Residue32<'a> {
358    fn sub_assign(&mut self, rhs: Self) {
359        *self = *self - rhs
360    }
361}
362
363impl<'a> Mul for Residue32<'a> {
364    type Output = Self;
365
366    fn mul(mut self, rhs: Self) -> Self::Output {
367        self.x = self.modulus.mul(self.x, rhs.x);
368        self
369    }
370}
371
372impl<'a> MulAssign for Residue32<'a> {
373    fn mul_assign(&mut self, rhs: Self) {
374        *self = *self * rhs
375    }
376}
377
378impl<'a> Neg for Residue32<'a> {
379    type Output = Self;
380
381    fn neg(mut self) -> Self::Output {
382        self.x = if self.x == 0 {
383            0
384        } else {
385            self.modulus() - self.x
386        };
387
388        self
389    }
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395
396    use proptest::prelude::*;
397
398    proptest! {
399        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
400        #[test]
401        fn mul(n in (0..=2_654_435_769_u32).prop_map(|n| n | 1), x: u32) {
402            let modulus = Modulus32::new(n);
403
404            let res = modulus.residue(x);
405            assert_eq!(res.get() as u32, x % n)
406        }
407    }
408
409    proptest! {
410        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
411        #[test]
412        fn pow(n in (0..=2_654_435_769_u64).prop_map(|n| n | 1), x in 0u64..1 << 32) {
413            let modulus = Modulus32::new(n as u32);
414
415            let res = modulus.residue(x as u32);
416            let mut naive = 1;
417            for i in 0..100 {
418                assert_eq!(res.pow(i).get(), naive, "exp = {i}");
419                naive = naive * x % n
420            }
421        }
422    }
423
424    proptest! {
425        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
426        #[test]
427        fn divisible(n in (0..=2_654_435_769_u32).prop_map(|n| n | 1), x: u32) {
428            let modulus = Modulus32::new(n);
429
430            assert_eq!(modulus.can_divide(x), x % n == 0);
431            for m in std::iter::successors(Some(n), |m| m.checked_add(n)).take(100) {
432                assert!(modulus.can_divide(m))
433            }
434        }
435    }
436
437    fn binary_gcd(mut a: u64, mut b: u64) -> u64 {
438        if b == 0 {
439            return a;
440        }
441
442        let shift = (a | b).trailing_zeros();
443        b >>= b.trailing_zeros();
444
445        while a != 0 {
446            a >>= a.trailing_zeros();
447
448            if a < b {
449                (a, b) = (b, a)
450            }
451            a -= b
452        }
453
454        b << shift
455    }
456
457    proptest! {
458        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
459        #[test]
460        fn try_inv(n in (0..=2_654_435_769_u32).prop_map(|n| n | 1), x: u32) {
461            let modulus = Modulus32::new(n);
462            let res = modulus.residue(x);
463
464            match res.try_inv() {
465                Ok(inv) => assert_eq!((inv * res).get(), 1),
466                Err(gcd) => {
467                    assert!(res.get() % gcd == 0);
468                    assert!(res.modulus() % gcd == 0);
469                    assert_eq!(binary_gcd(res.get() / gcd, res.modulus() / gcd), 1);
470                }
471            }
472        }
473    }
474}