Skip to main content

lib_modulo/
residue32any.rs

1/// A modulus in `[2, 2^32)`, including even values.
2///
3/// # Fast modular multiplication
4///
5/// Provides fast modular multiplication using [Barrett multiplication].
6/// This works for any modulus (including even ones) and places no restrictions on the operands,
7///  but is generally slower than [`Residue32`](crate::Residue32).
8///
9/// Unlike Montgomery or Plantard methods, this operates directly on standard
10/// integer representations (i.e., no transformation is required).
11///
12/// [Barrett multiplication]: https://doi.org/10.1007/3-540-47721-7_24
13///
14/// # Usage
15///
16/// ```
17/// use lib_modulo::Modulus32Any;
18///
19/// let modulus = Modulus32Any::new(14);
20/// assert_eq!(modulus.mul(3, 5), 1)
21/// ```
22#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
23pub struct Modulus32Any {
24    // Lemire's remainder algorithm (N = 64, L = 32)
25    n: u64,
26    // magic number for Barrett multiplication and Lemire's remainder algorithm
27    // ceil(2^64 / n)
28    magic: u64,
29}
30
31impl Modulus32Any {
32    /// Creates new instance with the given modulus.
33    ///
34    /// # Panics
35    ///
36    /// Modulus `n` should be greater than 1.
37    ///
38    /// # Example
39    ///
40    /// ```
41    /// use lib_modulo::Modulus32Any;
42    ///
43    /// // even number is available
44    /// let modulus = Modulus32Any::new(2);
45    /// // odd number is also available
46    /// let modulus = Modulus32Any::new(3);
47    /// // division by 0 is undefined
48    /// assert!(std::panic::catch_unwind(|| Modulus32Any::new(0)).is_err());
49    /// // meaningless division by 1 is NOT available for performance
50    /// assert!(std::panic::catch_unwind(|| Modulus32Any::new(1)).is_err());
51    /// ```
52    #[inline(always)]
53    pub const fn new(n: u32) -> Self {
54        assert!(n > 1);
55
56        let n = n as u64;
57        let magic = (u64::MAX / n).wrapping_add(1);
58
59        Self { n, magic }
60    }
61
62    /// Returns the modulus.
63    #[inline(always)]
64    pub const fn modulus(&self) -> u32 {
65        self.n as u32
66    }
67
68    /// Calculates residue of `x` modulo `self`.
69    #[inline(always)]
70    pub const fn residue32(&self, x: u32) -> u32 {
71        let lo = self.magic.wrapping_mul(x as u64);
72        ((lo as u128 * self.n as u128) >> 64) as u32
73    }
74
75    /// Calculates residue of `x` modulo `self`.
76    #[inline(always)]
77    pub const fn residue64(&self, x: u64) -> u64 {
78        let quot = ((x as u128 * self.magic as u128) >> 64) as u64;
79        let (rem, b) = x.overflowing_sub(quot * self.n);
80        if b {
81            rem.wrapping_add(self.n)
82        } else {
83            rem
84        }
85    }
86
87    /// Checks if `x` is divisible by `self`.
88    #[inline(always)]
89    pub const fn can_divide(&self, x: u32) -> bool {
90        // since `self.n` is not 1, `self.magic` never overflow
91        self.magic.wrapping_mul(x as u64) < self.magic
92    }
93
94    /// Performs modular multiplication.
95    ///
96    /// # Example
97    ///
98    /// ```
99    /// use lib_modulo::Modulus32Any;
100    ///
101    /// // even number is available
102    /// let modulus = Modulus32Any::new(1 << 8);
103    /// assert_eq!(modulus.mul(u32::MAX, u32::MAX), 1);
104    /// ```
105    #[inline(always)]
106    pub const fn mul(&self, a: u32, b: u32) -> u32 {
107        self.residue64(a as u64 * b as u64) as u32
108    }
109
110    /// Raises `x` to the power of `exp`, using exponentiation by squaring.
111    ///
112    /// # Time Complexity
113    ///
114    /// *O*(log `x`)
115    ///
116    /// # Example
117    ///
118    /// ```
119    /// use lib_modulo::Modulus32Any;
120    ///
121    /// let modulus = Modulus32Any::new(123_456);
122    ///
123    /// assert_eq!(modulus.pow(123_456 * 100 + 1, 1000), 1)
124    /// ```
125    #[inline(always)]
126    pub const fn pow(&self, mut x: u32, mut exp: u32) -> u32 {
127        let mut res = 1;
128        while exp > 0 {
129            if exp & 1 == 1 {
130                res = self.mul(res, x);
131            }
132            exp >>= 1;
133            x = self.mul(x, x);
134        }
135        res
136    }
137
138    /// Calculates the modular inverse of `x`, using extended gcd algorithm.
139    ///
140    /// Modular inverse can be defined if and only if `x` and the modulus is coprime.
141    ///
142    /// - `Ok(_)` : the modular inverse.
143    /// - `Err(_)`: the GCD of `x` and the `modulus`, where `gcd(0, a)` is defined to be `a`.
144    ///
145    /// # Time complexity
146    ///
147    /// *O*(log `x`)
148    ///
149    /// # Example
150    ///
151    /// ```
152    /// use lib_modulo::Modulus32Any;
153    ///
154    /// let modulus = Modulus32Any::new(3 * 5);
155    /// for (a, inv_a) in [(1, 1), (2, 8), (4, 4), (7, 13), (11, 11), (14, 14)] {
156    ///     assert_eq!(modulus.mul(a, inv_a), 1);
157    ///     assert_eq!(modulus.inv(a), Ok(inv_a));
158    /// }
159    /// for a in [3, 6, 9, 12] {
160    ///     assert_eq!(modulus.inv(a), Err(3));
161    /// }
162    /// for a in [5, 10] {
163    ///     assert_eq!(modulus.inv(a), Err(5));
164    /// }
165    /// // gcd(0, a) is defined t be `a`
166    /// assert_eq!(modulus.inv(0), Err(15));
167    /// assert_eq!(modulus.inv(15 * 99), Err(15));
168    /// ```
169    pub fn inv(&self, x: u32) -> Result<u32, u32> {
170        // invariant: a x0 = x, b x0 = y (mod [y])
171        let mut x = self.residue32(x) as i64;
172        let mut y = self.n as i64;
173        let [mut a, mut b] = [1, 0];
174
175        while x > 0 {
176            let (div, rem) = (y / x, y % x);
177
178            (x, y) = (rem, x);
179            (a, b) = (b - div * a, a);
180        }
181
182        // y = gcd(x0, y0) > 0
183        if y != 1 {
184            return Err(y as u32);
185        }
186        if b.is_negative() {
187            b += self.n as i64;
188        }
189
190        Ok(b as u32)
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use proptest::prelude::*;
197    use rand::{random_iter, rng};
198
199    use super::Modulus32Any;
200
201    proptest! {
202        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
203        #[test]
204        fn mul(n in 2..=u32::MAX, a: u32, b: u32) {
205            let modulus = Modulus32Any::new(n);
206            assert_eq!(
207                modulus.mul(a, b),
208                (a as u64 * b as u64 % n as u64) as u32,
209                "{:?}", modulus
210            );
211        }
212    }
213
214    #[test]
215    fn mul_small() {
216        let mut rng = rng();
217        for n in 2..1 << 8 {
218            let modulus = Modulus32Any::new(n);
219            for _ in 0..1 << 12 {
220                let a = rng.random();
221                let b = rng.random();
222                assert_eq!(modulus.mul(a, b) as u64, (a as u64 * b as u64 % n as u64),)
223            }
224        }
225    }
226
227    proptest! {
228        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
229        #[test]
230        fn residue32(n in 2..=u32::MAX, a: u32) {
231            let modulus = Modulus32Any::new(n);
232            assert_eq!(modulus.residue32(a), a % n);
233        }
234    }
235
236    #[test]
237    fn residue32_small() {
238        for n in 2..1 << 8 {
239            let modulus = Modulus32Any::new(n);
240            for a in random_iter().take(1 << 12) {
241                assert_eq!(modulus.residue32(a), a % n)
242            }
243        }
244    }
245
246    proptest! {
247        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
248        #[test]
249        fn residue64(n in 2..=u32::MAX, a: u64) {
250            let modulus = Modulus32Any::new(n);
251            assert_eq!(modulus.residue64(a), a % n as u64);
252        }
253    }
254
255    #[test]
256    fn residue64_small() {
257        for n in 2..1 << 8 {
258            let modulus = Modulus32Any::new(n);
259            for a in random_iter().take(1 << 12) {
260                assert_eq!(modulus.residue64(a), a % n as u64)
261            }
262        }
263    }
264
265    fn binary_gcd(mut a: u32, mut b: u32) -> u32 {
266        if b == 0 {
267            return a;
268        }
269
270        let shift = (a | b).trailing_zeros();
271        b >>= b.trailing_zeros();
272
273        while a != 0 {
274            a >>= a.trailing_zeros();
275
276            if a < b {
277                (a, b) = (b, a)
278            }
279            a -= b
280        }
281
282        b << shift
283    }
284
285    proptest! {
286        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
287        #[test]
288        fn inv(n in 2..=u32::MAX, a: u32) {
289            let modulus = Modulus32Any::new(n);
290            match modulus.inv(a) {
291                Ok(inv) => assert_eq!(modulus.mul(a, inv), 1, "!"),
292                Err(gcd) => assert_eq!(gcd, binary_gcd(a, n), "?")
293            }
294        }
295    }
296}