Skip to main content

lib_modulo/
residue32any.rs

1/// A modulus in `[2, 2^32)`, including even values.
2///
3/// # Fast modular multiplication
4///
5/// Provides fast modular multiplication using [Barrett multiplication].
6/// This works for any modulus (including even ones) and places no restrictions on the operands,
7///  but is generally slower than [`Residue32`](crate::Residue32).
8///
9/// Unlike Montgomery or Plantard methods, this operates directly on standard
10/// integer representations (i.e., no transformation is required).
11///
12/// [Barrett multiplication]: https://doi.org/10.1007/3-540-47721-7_24
13///
14/// # Usage
15///
16/// ```
17/// use lib_modulo::Modulus32Any;
18///
19/// let modulus = Modulus32Any::new(14).unwrap();
20/// assert_eq!(modulus.mul(3, 5), 1)
21/// ```
22#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
23pub struct Modulus32Any {
24    // Lemire's remainder algorithm (N = 64, L = 32)
25    n: u64,
26    // magic number for Barrett multiplication and Lemire's remainder algorithm
27    // ceil(2^64 / n)
28    magic: u64,
29}
30
31/// Invalid moduli of [`Modulus32Any`].
32#[derive(thiserror::Error, Debug, PartialEq, Eq, Hash, Clone, Copy)]
33pub enum InvalidModulus {
34    /// Modulus is 0. This is undefined.
35    #[error("modulo 0 is undefined")]
36    Zero,
37    /// Modulus is 1. This is meaningless.
38    #[error("modulo 1 is meaningless and not available for performance reason")]
39    One,
40}
41
42impl Modulus32Any {
43    /// Creates new instance with the given modulus.
44    ///
45    /// # Error
46    ///
47    /// Returns error if the given modulus is `0` or `1`.
48    ///
49    /// # Example
50    ///
51    /// ```
52    /// use lib_modulo::{Modulus32Any, InvalidModulus};
53    ///
54    /// // even number is available
55    /// let modulus = Modulus32Any::new(2).unwrap();
56    /// // odd number is also available
57    /// let modulus = Modulus32Any::new(3).unwrap();
58    /// // division by 0 is undefined
59    /// assert_eq!(Modulus32Any::new(0), Err(InvalidModulus::Zero));
60    /// // division by 1 is meaningless and NOT available for performance
61    /// assert_eq!(Modulus32Any::new(1), Err(InvalidModulus::One));
62    /// ```
63    pub const fn new(n: u32) -> Result<Self, InvalidModulus> {
64        match n {
65            0 => Err(InvalidModulus::Zero),
66            1 => Err(InvalidModulus::One),
67            n => {
68                let n = n as u64;
69                let magic = (u64::MAX / n).wrapping_add(1);
70
71                Ok(Self { n, magic })
72            }
73        }
74    }
75
76    /// Returns the modulus.
77    #[must_use]
78    pub const fn modulus(&self) -> u32 {
79        self.n as u32
80    }
81
82    /// Calculates residue of `x` modulo `self`.
83    #[must_use]
84    pub const fn residue32(&self, x: u32) -> u32 {
85        let lo = self.magic.wrapping_mul(x as u64);
86        ((lo as u128 * self.n as u128) >> 64) as u32
87    }
88
89    /// Calculates residue of `x` modulo `self`.
90    #[must_use]
91    pub const fn residue64(&self, x: u64) -> u64 {
92        let quot = ((x as u128 * self.magic as u128) >> 64) as u64;
93        let (rem, b) = x.overflowing_sub(quot * self.n);
94        if b {
95            rem.wrapping_add(self.n)
96        } else {
97            rem
98        }
99    }
100
101    /// Checks if `x` is divisible by `self`.
102    #[must_use]
103    pub const fn can_divide(&self, x: u32) -> bool {
104        // since `self.n` is not 1, `self.magic` never overflow
105        self.magic.wrapping_mul(x as u64) < self.magic
106    }
107
108    /// Performs `a * b` modulo `self`.
109    ///
110    /// # Example
111    ///
112    /// ```
113    /// use lib_modulo::Modulus32Any;
114    ///
115    /// // even number is available
116    /// let modulus = Modulus32Any::new(1 << 8).unwrap();
117    /// assert_eq!(modulus.mul(u32::MAX, u32::MAX), 1);
118    /// ```
119    #[must_use]
120    pub const fn mul(&self, a: u32, b: u32) -> u32 {
121        self.residue64(a as u64 * b as u64) as u32
122    }
123
124    /// Performs `a * b + c` modulo `self`.
125    ///
126    /// # Example
127    ///
128    /// ```
129    /// use lib_modulo::Modulus32Any;
130    ///
131    /// let modulus = Modulus32Any::new(2357).unwrap();
132    /// assert_eq!(
133    ///     modulus.carrying_mul(123, 456, 789),
134    ///     (123 * 456 + 789) % 2357
135    /// );
136    /// ```
137    #[must_use]
138    pub const fn carrying_mul(&self, a: u32, b: u32, c: u32) -> u32 {
139        self.residue64(a as u64 * b as u64 + c as u64) as u32
140    }
141
142    /// Performs `a * b + c + d` modulo `self`.
143    ///
144    /// # Example
145    ///
146    /// ```
147    /// use lib_modulo::Modulus32Any;
148    ///
149    /// // even number is available
150    /// let modulus = Modulus32Any::new(123_456).unwrap();
151    /// assert_eq!(
152    ///     modulus.carrying_mul_add(u32::MAX, u32::MAX, u32::MAX, u32::MAX),
153    ///     (u64::MAX % 123_456) as u32
154    /// );
155    /// ```
156    #[must_use]
157    pub const fn carrying_mul_add(&self, a: u32, b: u32, c: u32, d: u32) -> u32 {
158        self.residue64(a as u64 * b as u64 + c as u64 + d as u64) as u32
159    }
160
161    /// Raises `x` to the power of `exp`, using exponentiation by squaring.
162    ///
163    /// # Time Complexity
164    ///
165    /// *O*(log `x`)
166    ///
167    /// # Example
168    ///
169    /// ```
170    /// use lib_modulo::Modulus32Any;
171    ///
172    /// let modulus = Modulus32Any::new(123_456).unwrap();
173    ///
174    /// assert_eq!(modulus.pow(123_456 * 100 + 1, 1000), 1)
175    /// ```
176    #[must_use]
177    pub const fn pow(&self, mut x: u32, mut exp: u32) -> u32 {
178        let mut res = 1;
179        while exp > 0 {
180            if exp & 1 == 1 {
181                res = self.mul(res, x);
182            }
183            exp >>= 1;
184            x = self.mul(x, x);
185        }
186        res
187    }
188
189    /// Calculates the modular inverse of `x`, using extended gcd algorithm.
190    ///
191    /// Modular inverse can be defined if and only if `x` and the modulus is coprime.
192    ///
193    /// - `Ok(_)` : the modular inverse.
194    /// - `Err(_)`: the GCD of `x` and the `modulus`, where `gcd(0, a)` is defined to be `a`.
195    ///
196    /// # Time complexity
197    ///
198    /// *O*(log `x`)
199    ///
200    /// # Example
201    ///
202    /// ```
203    /// use lib_modulo::Modulus32Any;
204    ///
205    /// let modulus = Modulus32Any::new(3 * 5).unwrap();
206    /// for (a, inv_a) in [(1, 1), (2, 8), (4, 4), (7, 13), (11, 11), (14, 14)] {
207    ///     assert_eq!(modulus.mul(a, inv_a), 1);
208    ///     assert_eq!(modulus.inv(a), Ok(inv_a));
209    /// }
210    /// for a in [3, 6, 9, 12] {
211    ///     assert_eq!(modulus.inv(a), Err(3));
212    /// }
213    /// for a in [5, 10] {
214    ///     assert_eq!(modulus.inv(a), Err(5));
215    /// }
216    /// // gcd(0, a) is defined t be `a`
217    /// assert_eq!(modulus.inv(0), Err(15));
218    /// assert_eq!(modulus.inv(15 * 99), Err(15));
219    /// ```
220    pub const fn inv(&self, x: u32) -> Result<u32, u32> {
221        // invariant: a x0 = x, b x0 = y (mod [y])
222        let mut x = self.residue32(x) as i64;
223        let mut y = self.n as i64;
224        let [mut a, mut b] = [1, 0];
225
226        while x > 0 {
227            let (div, rem) = (y / x, y % x);
228
229            (x, y) = (rem, x);
230            (a, b) = (b - div * a, a);
231        }
232
233        // y = gcd(x0, y0) > 0
234        if y != 1 {
235            return Err(y as u32);
236        }
237        if b.is_negative() {
238            b += self.n as i64;
239        }
240
241        Ok(b as u32)
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use proptest::prelude::*;
248    use rand::{random_iter, rng};
249
250    use super::Modulus32Any;
251
252    proptest! {
253        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
254        #[test]
255        fn mul(n in 2..=u32::MAX, a: u32, b: u32) {
256            let modulus = Modulus32Any::new(n).unwrap();
257            assert_eq!(
258                modulus.mul(a, b),
259                (a as u64 * b as u64 % n as u64) as u32,
260                "{:?}", modulus
261            );
262        }
263    }
264
265    #[test]
266    fn mul_small() {
267        let mut rng = rng();
268        for n in 2..1 << 8 {
269            let modulus = Modulus32Any::new(n).unwrap();
270            for _ in 0..1 << 12 {
271                let a = rng.random();
272                let b = rng.random();
273                assert_eq!(modulus.mul(a, b) as u64, (a as u64 * b as u64 % n as u64),)
274            }
275        }
276    }
277
278    proptest! {
279        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
280        #[test]
281        fn residue32(n in 2..=u32::MAX, a: u32) {
282            let modulus = Modulus32Any::new(n).unwrap();
283            assert_eq!(modulus.residue32(a), a % n);
284        }
285    }
286
287    #[test]
288    fn residue32_small() {
289        for n in 2..1 << 8 {
290            let modulus = Modulus32Any::new(n).unwrap();
291            for a in random_iter().take(1 << 12) {
292                assert_eq!(modulus.residue32(a), a % n)
293            }
294        }
295    }
296
297    proptest! {
298        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
299        #[test]
300        fn residue64(n in 2..=u32::MAX, a: u64) {
301            let modulus = Modulus32Any::new(n).unwrap();
302            assert_eq!(modulus.residue64(a), a % n as u64);
303        }
304    }
305
306    #[test]
307    fn residue64_small() {
308        for n in 2..1 << 8 {
309            let modulus = Modulus32Any::new(n).unwrap();
310            for a in random_iter().take(1 << 12) {
311                assert_eq!(modulus.residue64(a), a % n as u64)
312            }
313        }
314    }
315
316    fn binary_gcd(mut a: u32, mut b: u32) -> u32 {
317        if b == 0 {
318            return a;
319        }
320
321        let shift = (a | b).trailing_zeros();
322        b >>= b.trailing_zeros();
323
324        while a != 0 {
325            a >>= a.trailing_zeros();
326
327            if a < b {
328                (a, b) = (b, a)
329            }
330            a -= b
331        }
332
333        b << shift
334    }
335
336    proptest! {
337        #![proptest_config(ProptestConfig::with_cases(1 << 15))]
338        #[test]
339        fn inv(n in 2..=u32::MAX, a: u32) {
340            let modulus = Modulus32Any::new(n).unwrap();
341            match modulus.inv(a) {
342                Ok(inv) => assert_eq!(modulus.mul(a, inv), 1, "!"),
343                Err(gcd) => assert_eq!(gcd, binary_gcd(a, n), "?")
344            }
345        }
346    }
347}