lib_modulo/residue32any.rs
1/// A modulus in `[2, 2^32)`, including even values.
2///
3/// # Fast modular multiplication
4///
5/// Provides fast modular multiplication using [Barrett multiplication].
6/// This works for any modulus (including even ones) and places no restrictions on the operands,
7/// but is generally slower than [`Residue32`](crate::Residue32).
8///
9/// Unlike Montgomery or Plantard methods, this operates directly on standard
10/// integer representations (i.e., no transformation is required).
11///
12/// [Barrett multiplication]: https://doi.org/10.1007/3-540-47721-7_24
13///
14/// # Usage
15///
16/// ```
17/// use lib_modulo::Modulus32Any;
18///
19/// let modulus = Modulus32Any::new(14).unwrap();
20/// assert_eq!(modulus.mul(3, 5), 1)
21/// ```
22#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
23pub struct Modulus32Any {
24 // Lemire's remainder algorithm (N = 64, L = 32)
25 n: u64,
26 // magic number for Barrett multiplication and Lemire's remainder algorithm
27 // ceil(2^64 / n)
28 magic: u64,
29}
30
31/// Invalid moduli of [`Modulus32Any`].
32#[derive(thiserror::Error, Debug, PartialEq, Eq, Hash, Clone, Copy)]
33pub enum InvalidModulus {
34 /// Modulus is 0. This is undefined.
35 #[error("modulo 0 is undefined")]
36 Zero,
37 /// Modulus is 1. This is meaningless.
38 #[error("modulo 1 is meaningless and not available for performance reason")]
39 One,
40}
41
42impl Modulus32Any {
43 /// Creates new instance with the given modulus.
44 ///
45 /// # Error
46 ///
47 /// Returns error if the given modulus is `0` or `1`.
48 ///
49 /// # Example
50 ///
51 /// ```
52 /// use lib_modulo::{Modulus32Any, InvalidModulus};
53 ///
54 /// // even number is available
55 /// let modulus = Modulus32Any::new(2).unwrap();
56 /// // odd number is also available
57 /// let modulus = Modulus32Any::new(3).unwrap();
58 /// // division by 0 is undefined
59 /// assert_eq!(Modulus32Any::new(0), Err(InvalidModulus::Zero));
60 /// // division by 1 is meaningless and NOT available for performance
61 /// assert_eq!(Modulus32Any::new(1), Err(InvalidModulus::One));
62 /// ```
63 pub const fn new(n: u32) -> Result<Self, InvalidModulus> {
64 match n {
65 0 => Err(InvalidModulus::Zero),
66 1 => Err(InvalidModulus::One),
67 n => {
68 let n = n as u64;
69 let magic = (u64::MAX / n).wrapping_add(1);
70
71 Ok(Self { n, magic })
72 }
73 }
74 }
75
76 /// Returns the modulus.
77 #[must_use]
78 pub const fn modulus(&self) -> u32 {
79 self.n as u32
80 }
81
82 /// Calculates residue of `x` modulo `self`.
83 #[must_use]
84 pub const fn residue32(&self, x: u32) -> u32 {
85 let lo = self.magic.wrapping_mul(x as u64);
86 ((lo as u128 * self.n as u128) >> 64) as u32
87 }
88
89 /// Calculates residue of `x` modulo `self`.
90 #[must_use]
91 pub const fn residue64(&self, x: u64) -> u64 {
92 let quot = ((x as u128 * self.magic as u128) >> 64) as u64;
93 let (rem, b) = x.overflowing_sub(quot * self.n);
94 if b {
95 rem.wrapping_add(self.n)
96 } else {
97 rem
98 }
99 }
100
101 /// Checks if `x` is divisible by `self`.
102 #[must_use]
103 pub const fn can_divide(&self, x: u32) -> bool {
104 // since `self.n` is not 1, `self.magic` never overflow
105 self.magic.wrapping_mul(x as u64) < self.magic
106 }
107
108 /// Performs `a * b` modulo `self`.
109 ///
110 /// # Example
111 ///
112 /// ```
113 /// use lib_modulo::Modulus32Any;
114 ///
115 /// // even number is available
116 /// let modulus = Modulus32Any::new(1 << 8).unwrap();
117 /// assert_eq!(modulus.mul(u32::MAX, u32::MAX), 1);
118 /// ```
119 #[must_use]
120 pub const fn mul(&self, a: u32, b: u32) -> u32 {
121 self.residue64(a as u64 * b as u64) as u32
122 }
123
124 /// Performs `a * b + c` modulo `self`.
125 ///
126 /// # Example
127 ///
128 /// ```
129 /// use lib_modulo::Modulus32Any;
130 ///
131 /// let modulus = Modulus32Any::new(2357).unwrap();
132 /// assert_eq!(
133 /// modulus.carrying_mul(123, 456, 789),
134 /// (123 * 456 + 789) % 2357
135 /// );
136 /// ```
137 #[must_use]
138 pub const fn carrying_mul(&self, a: u32, b: u32, c: u32) -> u32 {
139 self.residue64(a as u64 * b as u64 + c as u64) as u32
140 }
141
142 /// Performs `a * b + c + d` modulo `self`.
143 ///
144 /// # Example
145 ///
146 /// ```
147 /// use lib_modulo::Modulus32Any;
148 ///
149 /// // even number is available
150 /// let modulus = Modulus32Any::new(123_456).unwrap();
151 /// assert_eq!(
152 /// modulus.carrying_mul_add(u32::MAX, u32::MAX, u32::MAX, u32::MAX),
153 /// (u64::MAX % 123_456) as u32
154 /// );
155 /// ```
156 #[must_use]
157 pub const fn carrying_mul_add(&self, a: u32, b: u32, c: u32, d: u32) -> u32 {
158 self.residue64(a as u64 * b as u64 + c as u64 + d as u64) as u32
159 }
160
161 /// Raises `x` to the power of `exp`, using exponentiation by squaring.
162 ///
163 /// # Time Complexity
164 ///
165 /// *O*(log `x`)
166 ///
167 /// # Example
168 ///
169 /// ```
170 /// use lib_modulo::Modulus32Any;
171 ///
172 /// let modulus = Modulus32Any::new(123_456).unwrap();
173 ///
174 /// assert_eq!(modulus.pow(123_456 * 100 + 1, 1000), 1)
175 /// ```
176 #[must_use]
177 pub const fn pow(&self, mut x: u32, mut exp: u32) -> u32 {
178 let mut res = 1;
179 while exp > 0 {
180 if exp & 1 == 1 {
181 res = self.mul(res, x);
182 }
183 exp >>= 1;
184 x = self.mul(x, x);
185 }
186 res
187 }
188
189 /// Calculates the modular inverse of `x`, using extended gcd algorithm.
190 ///
191 /// Modular inverse can be defined if and only if `x` and the modulus is coprime.
192 ///
193 /// - `Ok(_)` : the modular inverse.
194 /// - `Err(_)`: the GCD of `x` and the `modulus`, where `gcd(0, a)` is defined to be `a`.
195 ///
196 /// # Time complexity
197 ///
198 /// *O*(log `x`)
199 ///
200 /// # Example
201 ///
202 /// ```
203 /// use lib_modulo::Modulus32Any;
204 ///
205 /// let modulus = Modulus32Any::new(3 * 5).unwrap();
206 /// for (a, inv_a) in [(1, 1), (2, 8), (4, 4), (7, 13), (11, 11), (14, 14)] {
207 /// assert_eq!(modulus.mul(a, inv_a), 1);
208 /// assert_eq!(modulus.inv(a), Ok(inv_a));
209 /// }
210 /// for a in [3, 6, 9, 12] {
211 /// assert_eq!(modulus.inv(a), Err(3));
212 /// }
213 /// for a in [5, 10] {
214 /// assert_eq!(modulus.inv(a), Err(5));
215 /// }
216 /// // gcd(0, a) is defined t be `a`
217 /// assert_eq!(modulus.inv(0), Err(15));
218 /// assert_eq!(modulus.inv(15 * 99), Err(15));
219 /// ```
220 pub const fn inv(&self, x: u32) -> Result<u32, u32> {
221 // invariant: a x0 = x, b x0 = y (mod [y])
222 let mut x = self.residue32(x) as i64;
223 let mut y = self.n as i64;
224 let [mut a, mut b] = [1, 0];
225
226 while x > 0 {
227 let (div, rem) = (y / x, y % x);
228
229 (x, y) = (rem, x);
230 (a, b) = (b - div * a, a);
231 }
232
233 // y = gcd(x0, y0) > 0
234 if y != 1 {
235 return Err(y as u32);
236 }
237 if b.is_negative() {
238 b += self.n as i64;
239 }
240
241 Ok(b as u32)
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use proptest::prelude::*;
248 use rand::{random_iter, rng};
249
250 use super::Modulus32Any;
251
252 proptest! {
253 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
254 #[test]
255 fn mul(n in 2..=u32::MAX, a: u32, b: u32) {
256 let modulus = Modulus32Any::new(n).unwrap();
257 assert_eq!(
258 modulus.mul(a, b),
259 (a as u64 * b as u64 % n as u64) as u32,
260 "{:?}", modulus
261 );
262 }
263 }
264
265 #[test]
266 fn mul_small() {
267 let mut rng = rng();
268 for n in 2..1 << 8 {
269 let modulus = Modulus32Any::new(n).unwrap();
270 for _ in 0..1 << 12 {
271 let a = rng.random();
272 let b = rng.random();
273 assert_eq!(modulus.mul(a, b) as u64, (a as u64 * b as u64 % n as u64),)
274 }
275 }
276 }
277
278 proptest! {
279 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
280 #[test]
281 fn residue32(n in 2..=u32::MAX, a: u32) {
282 let modulus = Modulus32Any::new(n).unwrap();
283 assert_eq!(modulus.residue32(a), a % n);
284 }
285 }
286
287 #[test]
288 fn residue32_small() {
289 for n in 2..1 << 8 {
290 let modulus = Modulus32Any::new(n).unwrap();
291 for a in random_iter().take(1 << 12) {
292 assert_eq!(modulus.residue32(a), a % n)
293 }
294 }
295 }
296
297 proptest! {
298 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
299 #[test]
300 fn residue64(n in 2..=u32::MAX, a: u64) {
301 let modulus = Modulus32Any::new(n).unwrap();
302 assert_eq!(modulus.residue64(a), a % n as u64);
303 }
304 }
305
306 #[test]
307 fn residue64_small() {
308 for n in 2..1 << 8 {
309 let modulus = Modulus32Any::new(n).unwrap();
310 for a in random_iter().take(1 << 12) {
311 assert_eq!(modulus.residue64(a), a % n as u64)
312 }
313 }
314 }
315
316 fn binary_gcd(mut a: u32, mut b: u32) -> u32 {
317 if b == 0 {
318 return a;
319 }
320
321 let shift = (a | b).trailing_zeros();
322 b >>= b.trailing_zeros();
323
324 while a != 0 {
325 a >>= a.trailing_zeros();
326
327 if a < b {
328 (a, b) = (b, a)
329 }
330 a -= b
331 }
332
333 b << shift
334 }
335
336 proptest! {
337 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
338 #[test]
339 fn inv(n in 2..=u32::MAX, a: u32) {
340 let modulus = Modulus32Any::new(n).unwrap();
341 match modulus.inv(a) {
342 Ok(inv) => assert_eq!(modulus.mul(a, inv), 1, "!"),
343 Err(gcd) => assert_eq!(gcd, binary_gcd(a, n), "?")
344 }
345 }
346 }
347}