lib_modulo/residue32.rs
1use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2
3/// Factory of [`Residue32`].
4///
5/// See documentation of [`Residue32`] for details.
6#[allow(clippy::derived_hash_with_manual_eq)]
7#[derive(Debug, Clone, Hash, Eq)]
8pub struct Modulus32 {
9 // n inv_n = 1 (mod 2^64)
10 n: u64,
11 inv_n: u64,
12 // 2^128 (mod n) * inv_n
13 init: u64,
14 // ceil(2^64 / n)
15 recip: u64,
16}
17
18impl Modulus32 {
19 /// Maximum available modulus.
20 pub const MAX: u32 = 2_654_435_769;
21
22 /// Creates new context for modular arithmetics.
23 ///
24 /// # Panics
25 ///
26 /// - modulus `n` should be an odd integer.
27 /// - modulus `n` should be no more than `2_654_435_769`,
28 /// which is the floor of `2^32 / GOLDEN_RATIO`.
29 ///
30 /// # Example
31 ///
32 /// ```
33 /// use lib_modulo::Modulus32;
34 ///
35 /// // odd integer less than or equal to 2_654_435_769 is allowed.
36 /// let modulus = Modulus32::new(Modulus32::MAX);
37 /// let modulus = Modulus32::new(3);
38 ///
39 /// // modulus should be an odd integer!
40 /// assert!(std::panic::catch_unwind(|| { Modulus32::new(2); }).is_err())
41 /// ```
42 #[inline]
43 pub const fn new(n: u32) -> Self {
44 assert!(
45 n & 1 == 1,
46 "invalid modulus: modulus should be an odd integer."
47 );
48 assert!(
49 n <= Self::MAX,
50 "invalid modulus: modulus should be no more than 2_654_435_769."
51 );
52
53 let n = n as u64;
54
55 let inv_n = {
56 // 1 * 1 = 3 * 3 = 1 (mod 4)
57 let mut inv_n = n & 3;
58 // n inv_n = 1 (mod 2^k) => (n inv_n - 1)^2 = 0 (mod 2^{2k})
59 // => n inv_n (2 - n inv_n) = 1 (mod 2^{2k})
60 let mut i = u64::BITS.ilog2() - 1;
61 while i > 0 {
62 i -= 1;
63 inv_n = inv_n.wrapping_mul(2_u64.wrapping_sub(n.wrapping_mul(inv_n)));
64 }
65 debug_assert!(n.wrapping_mul(inv_n) == 1);
66
67 inv_n
68 };
69
70 let (div, rem) = {
71 let denom = n.wrapping_neg();
72 (denom / n, denom % n)
73 };
74 // 2^128 (mod n): magic number for converting integer to Plantard representation.
75 let init = rem * rem % n;
76 // ceil(2^64 / n): magic number for fast remainder algorithm
77 let recip = div.wrapping_add(if rem > 0 { 2 } else { 1 });
78
79 Self {
80 n,
81 inv_n,
82 init: init.wrapping_mul(inv_n),
83 recip,
84 }
85 }
86
87 /// Performs Plantard multiplication, i.e. `x, y -> x y / -2^64 (mod n)`.
88 ///
89 /// If `x y < self.n`, then returned value is less than `self.n`.
90 #[inline(always)]
91 const fn mul(&self, x: u64, y: u64) -> u64 {
92 // Plantard reduction: <https://thomas-plantard.github.io/pdf/Plantard21.pdf>
93 let z = self.inv_n.wrapping_mul(x).wrapping_mul(y) >> 32;
94 let z = ((z as u32).wrapping_add(1) as u64 * self.n) >> 32;
95 debug_assert!(z < self.n, "this is a bug in lib-modulo");
96 z
97 }
98
99 /// Calculates the residue of `x` modulo `self`.
100 ///
101 /// # Example
102 ///
103 /// ```
104 /// use lib_modulo::Modulus32;
105 ///
106 /// let modulus = Modulus32::new(5);
107 /// assert_eq!(modulus.residue(8).get(), 3)
108 /// ```
109 #[inline(always)]
110 pub const fn residue(&self, x: u32) -> Residue32<'_> {
111 // fast remainder algorithm
112 // See <https://onlinelibrary.wiley.com/doi/10.1002/spe.2689> for details
113 let x = {
114 let lo = self.recip.wrapping_mul(x as u64);
115 ((lo as u128 * self.n as u128) >> 64) as u64
116 };
117
118 let x = {
119 // multiplication by a constant
120 let x = self.init.wrapping_mul(x) >> 32;
121 ((x as u32).wrapping_add(1) as u64 * self.n) >> 32
122 };
123
124 Residue32 { x, modulus: self }
125 }
126
127 /// Checks whether `x` is divisible by `self`.
128 ///
129 /// # Example
130 ///
131 /// ```
132 /// use lib_modulo::Modulus32;
133 ///
134 /// let modulus = Modulus32::new(9);
135 /// assert!(modulus.can_divide(18));
136 /// assert!(!modulus.can_divide(19));
137 /// ```
138 #[inline(always)]
139 pub const fn can_divide(&self, x: u32) -> bool {
140 self.residue(x).is_zero()
141 }
142
143 /// Checks whether `self` is a prime number.
144 ///
145 /// # Time complexity
146 ///
147 /// *O*(log *self*)
148 ///
149 /// # Example
150 ///
151 /// ```
152 /// use lib_modulo::Modulus32;
153 ///
154 /// for p in [3, 5, 7, 11, 998_244_353, 1_000_000_007] {
155 /// assert!(Modulus32::new(p).is_prime())
156 /// }
157 /// // Mersenne numbers (prime)
158 /// for d in [5, 7, 13, 17, 19, 31] {
159 /// assert!(Modulus32::new((1 << d) - 1).is_prime())
160 /// }
161 ///
162 /// // composite numbers
163 /// for i in (3..).step_by(2).take(500) {
164 /// assert!(!Modulus32::new(i * (i + 2)).is_prime())
165 /// }
166 /// ```
167 #[inline(always)]
168 pub const fn is_prime(&self) -> bool {
169 /// (SELF >> p) & 1 == 1 iff p is prime
170 const TEST_LT_64: u64 = 2891462833508853932;
171 /// (SELF >> n % 30) & 1 == 1 iff n is coprime to 2, 3, and 5
172 const TEST_2_3_5: u32 = 545925250;
173
174 if self.n < 64 {
175 return (TEST_LT_64 >> self.n) & 1 == 1;
176 } else if (TEST_2_3_5 >> (self.n % 30)) & 1 == 0 || self.n % 7 == 0 {
177 return false;
178 }
179
180 let one = self.residue(1).x;
181 let minus_one = self.n - one;
182 debug_assert!(one != 0 && minus_one != 0, "this is a bug in lib-modulo");
183
184 let (d, s) = {
185 let n = self.n - 1;
186 ((n >> n.trailing_zeros()) as u32, n.trailing_zeros() - 1)
187 };
188 let mut i = 0;
189 'test: while i < 3 {
190 let witness = [2, 7, 61][i];
191 i += 1;
192
193 let w = self.residue(witness);
194 if w.is_zero() {
195 continue;
196 }
197
198 let mut w = w.pow(d).x;
199 if w == minus_one || w == one {
200 continue;
201 }
202
203 let mut s = s;
204 while s > 0 {
205 s -= 1;
206 w = self.mul(w, w);
207 if w == minus_one {
208 continue 'test;
209 }
210 }
211
212 return false;
213 }
214
215 true
216 }
217}
218
219impl PartialEq for Modulus32 {
220 fn eq(&self, other: &Self) -> bool {
221 // other fields depend on `n`
222 self.n == other.n
223 }
224}
225
226/// Residue with odd modulus which is no more than `2_654_435_769`.
227///
228/// # Fast modular multiplication
229///
230/// [`Residue32`] provides fast modular multiplication called [Plantard multiplication].
231/// This method saves one multiplication when either of two values of a multiplication is used multiple times.
232/// Therefore, [`Residue32::pow`] will be faster than that using [Montgomery multiplication].
233///
234/// [Plantard multiplication]: https://thomas-plantard.github.io/pdf/Plantard21.pdf
235/// [Montgomery multiplication]: https://doi.org/10.1090/s0025-5718-1985-0777282-x
236///
237/// # Usage
238///
239/// ```
240/// use lib_modulo::Modulus32;
241///
242/// // set modulus
243/// let modulus = Modulus32::new(3);
244///
245/// // performs modular arithmetics
246/// let one = modulus.residue(1);
247/// let two = modulus.residue(2);
248/// let five = modulus.residue(5);
249/// assert_eq!(two * five, one)
250/// ```
251///
252/// Two residues with different modulus can interact, but the result will be meaningless.
253/// It is highly recommended to use a block to ensure that [`Modulus32`], therefore [`Residue32`]s, are dropped.
254#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
255pub struct Residue32<'a> {
256 // compare modulus first
257 modulus: &'a Modulus32,
258 x: u64,
259}
260
261impl<'a> Residue32<'a> {
262 /// Checks whether `self` is `0`.
263 ///
264 /// # Example
265 ///
266 /// ```
267 /// use lib_modulo::Modulus32;
268 ///
269 /// let modulus = Modulus32::new(5);
270 /// assert!(modulus.residue(10).is_zero())
271 /// ```
272 #[inline(always)]
273 pub const fn is_zero(self) -> bool {
274 self.x == 0
275 }
276
277 /// Returns the residue.
278 ///
279 /// # Example
280 ///
281 /// ```
282 /// use lib_modulo::Modulus32;
283 ///
284 /// let modulus = Modulus32::new(7);
285 /// assert_eq!(modulus.residue(10).get(), 3)
286 /// ```
287 #[inline(always)]
288 pub const fn get(self) -> u64 {
289 self.modulus.mul(self.x, 1)
290 }
291
292 /// Returns the modulus.
293 ///
294 /// # Example
295 ///
296 /// ```
297 /// use lib_modulo::Modulus32;
298 ///
299 /// let modulus = Modulus32::new(11);
300 /// assert_eq!(modulus.residue(2).modulus(), 11);
301 /// ```
302 #[inline(always)]
303 pub const fn modulus(&self) -> u64 {
304 self.modulus.n
305 }
306
307 /// Raises `self` to the power of `exp`, using exponentiation by squaring.
308 ///
309 /// # Time complexity
310 ///
311 /// *Θ*(log `exp`)
312 ///
313 /// # Example
314 ///
315 /// ```
316 /// use lib_modulo::Modulus32;
317 ///
318 /// let modulus = Modulus32::new(1001);
319 /// let residue = modulus.residue(2);
320 /// for exp in 0..64 {
321 /// assert_eq!(residue.pow(exp).get(), (1 << exp) % 1001)
322 /// }
323 /// ```
324 #[inline(always)]
325 pub const fn pow(self, mut exp: u32) -> Self {
326 let Self { mut x, modulus } = self;
327 // If `n = 1`, then `init = 0`. Otherwise, `n > 1`.
328 let mut prod = modulus.residue(1).x;
329
330 while exp > 1 {
331 if exp & 1 == 1 {
332 // インライン展開されると,掛け算を1回節約できる。
333 prod = modulus.mul(prod, x)
334 }
335
336 exp >>= 1;
337 x = modulus.mul(x, x); // skip last useless one
338 }
339 if exp != 0 {
340 prod = modulus.mul(prod, x);
341 }
342
343 Self { x: prod, modulus }
344 }
345
346 /// Calculates the modular inverse of `self`, using extended binary GCD algorithm.
347 ///
348 /// Modular inverse can be defined if and only if `self` and the modulus is coprime.
349 ///
350 /// - `Ok(x)` : `x` is the modular inverse.
351 /// - `Err(x)`: `x` is the GCD of `self` and the `modulus`,
352 /// where `gcd(0, a)` is defined to be `a`.
353 ///
354 /// # Time complexity
355 ///
356 /// *O*(log `self`)
357 ///
358 /// # Example
359 ///
360 /// ```
361 /// use lib_modulo::Modulus32;
362 ///
363 /// let modulus = Modulus32::new(3 * 5);
364 ///
365 /// let residue = modulus.residue(2);
366 /// assert!(residue.try_inv().is_ok_and(|inv| (inv * residue).get() == 1));
367 ///
368 /// let residue = modulus.residue(6);
369 /// assert!(residue.try_inv().is_err_and(|gcd| gcd == 3));
370 /// ```
371 pub const fn try_inv(self) -> Result<Self, u64> {
372 // invariant: [a] x = a, [a] y = b (mod n), where [a] is initial value.
373 let mut a = self.get();
374 let mut b = self.modulus();
375 let Self { modulus, .. } = self;
376 let mut x = modulus.residue(1).x;
377 let mut y = 0;
378 let frac_1_2 = modulus.residue((modulus.n as u32).div_ceil(2));
379
380 while a > 0 {
381 x = modulus.mul(x, frac_1_2.pow(a.trailing_zeros()).x);
382 a >>= a.trailing_zeros();
383
384 if a < b {
385 (a, b) = (b, a);
386 (x, y) = (y, x);
387 }
388 a -= b;
389 let (z, b) = x.overflowing_sub(y);
390 x = if b { z.wrapping_add(modulus.n) } else { z };
391 }
392
393 // b = gcd([a], n)
394 if b == 1 {
395 Ok(Self { x: y, modulus })
396 } else {
397 Err(b)
398 }
399 }
400}
401
402impl<'a> Add for Residue32<'a> {
403 type Output = Self;
404
405 fn add(mut self, rhs: Self) -> Self::Output {
406 let (x, b) = self.x.overflowing_add(rhs.x);
407 self.x = if b || x >= self.modulus() {
408 x.wrapping_sub(self.modulus())
409 } else {
410 x
411 };
412
413 self
414 }
415}
416
417impl<'a> AddAssign for Residue32<'a> {
418 fn add_assign(&mut self, rhs: Self) {
419 *self = *self + rhs
420 }
421}
422
423impl<'a> Sub for Residue32<'a> {
424 type Output = Self;
425
426 fn sub(mut self, rhs: Self) -> Self::Output {
427 let (x, b) = self.x.overflowing_sub(rhs.x);
428 self.x = if b { x.wrapping_add(self.modulus()) } else { x };
429
430 self
431 }
432}
433
434impl<'a> SubAssign for Residue32<'a> {
435 fn sub_assign(&mut self, rhs: Self) {
436 *self = *self - rhs
437 }
438}
439
440impl<'a> Mul for Residue32<'a> {
441 type Output = Self;
442
443 fn mul(mut self, rhs: Self) -> Self::Output {
444 self.x = self.modulus.mul(self.x, rhs.x);
445 self
446 }
447}
448
449impl<'a> MulAssign for Residue32<'a> {
450 fn mul_assign(&mut self, rhs: Self) {
451 *self = *self * rhs
452 }
453}
454
455impl<'a> Neg for Residue32<'a> {
456 type Output = Self;
457
458 fn neg(mut self) -> Self::Output {
459 self.x = if self.x == 0 {
460 0
461 } else {
462 self.modulus() - self.x
463 };
464
465 self
466 }
467}
468
469#[cfg(test)]
470mod tests {
471 use super::*;
472
473 use proptest::prelude::*;
474
475 proptest! {
476 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
477 #[test]
478 fn mul(n in (0..=Modulus32::MAX).prop_map(|n| n | 1), x: u32) {
479 let modulus = Modulus32::new(n);
480
481 let res = modulus.residue(x);
482 assert_eq!(res.get() as u32, x % n)
483 }
484 }
485
486 proptest! {
487 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
488 #[test]
489 fn pow(n in (0..=Modulus32::MAX as u64).prop_map(|n| n | 1), x in 0u64..1 << 32) {
490 let modulus = Modulus32::new(n as u32);
491
492 let res = modulus.residue(x as u32);
493 let mut naive = 1;
494 for i in 0..100 {
495 assert_eq!(res.pow(i).get(), naive, "exp = {i}");
496 naive = naive * x % n
497 }
498 }
499 }
500
501 proptest! {
502 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
503 #[test]
504 fn divisible(n in (0..=Modulus32::MAX).prop_map(|n| n | 1), x: u32) {
505 let modulus = Modulus32::new(n);
506
507 assert_eq!(modulus.can_divide(x), x % n == 0);
508 for m in std::iter::successors(Some(n), |m| m.checked_add(n)).take(100) {
509 assert!(modulus.can_divide(m))
510 }
511 }
512 }
513
514 fn binary_gcd(mut a: u64, mut b: u64) -> u64 {
515 if b == 0 {
516 return a;
517 }
518
519 let shift = (a | b).trailing_zeros();
520 b >>= b.trailing_zeros();
521
522 while a != 0 {
523 a >>= a.trailing_zeros();
524
525 if a < b {
526 (a, b) = (b, a)
527 }
528 a -= b
529 }
530
531 b << shift
532 }
533
534 proptest! {
535 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
536 #[test]
537 fn try_inv(n in (0..=Modulus32::MAX).prop_map(|n| n | 1), x: u32) {
538 let modulus = Modulus32::new(n);
539 let res = modulus.residue(x);
540
541 match res.try_inv() {
542 Ok(inv) => assert_eq!((inv * res).get(), 1),
543 Err(gcd) => {
544 assert!(res.get() % gcd == 0);
545 assert!(res.modulus() % gcd == 0);
546 assert_eq!(binary_gcd(res.get() / gcd, res.modulus() / gcd), 1);
547 }
548 }
549 }
550 }
551}