lib_modulo/residue64.rs
1use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2
3/// Factory of [`Residue64`].
4///
5/// See documentation of [`Residue64`] for details.
6#[allow(clippy::derived_hash_with_manual_eq)]
7#[derive(Debug, Clone, Eq, Hash)]
8pub struct Modulus64 {
9 // n inv_n = 1 (mod r = 2^32 or 2^64)
10 n: u64,
11 inv_n: u64,
12 r2_mod_n: u64,
13}
14
15impl Modulus64 {
16 /// Creates new instance with the given modulus.
17 ///
18 /// # Panics
19 ///
20 /// - modulus `n` should be an odd number.
21 #[inline]
22 #[must_use]
23 pub const fn new(n: u64) -> Self {
24 assert!(n & 1 == 1, "modulus should be an odd number");
25
26 let inv_n = {
27 const TABLE: u32 = {
28 // | n | 1 | 3 | 5 | 7 | 9 | 11 | 13 | 15 |
29 // | inv_n | 1 | 11 | 13 | 7 | 9 | 3 | 5 | 15 | <- 4 bits * 8
30 let inv_n = [1, 11, 13, 7, 9, 3, 5, 15];
31
32 let mut table = 0;
33 let mut i = 0;
34 while i < 8 {
35 table |= inv_n[i] << (i * 4);
36 i += 1;
37 }
38
39 table
40 };
41 // n inv_n = 1 (mod 8)
42 let mut inv_n = ((TABLE >> ((n & 0b1110) * 2)) & 0b1111) as u64;
43
44 let mut d = const { u64::BITS.ilog2() - 2 };
45 while d > 0 {
46 inv_n = inv_n.wrapping_mul(2_u64.wrapping_sub(n.wrapping_mul(inv_n)));
47 d -= 1;
48 }
49 debug_assert!(n.wrapping_mul(inv_n) == 1);
50
51 inv_n
52 };
53 let r2_mod_n = ((n as u128).wrapping_neg() % (n as u128)) as u64;
54
55 Self { n, inv_n, r2_mod_n }
56 }
57
58 /// Calculates the residue of `x` modulo `self`.
59 ///
60 /// # Example
61 ///
62 /// ```
63 /// use lib_modulo::Modulus64;
64 ///
65 /// let modulus = Modulus64::new(5);
66 /// assert_eq!(modulus.residue(8).get(), 3)
67 /// ```
68 #[must_use]
69 pub const fn residue(&self, x: u64) -> Residue64<'_> {
70 // `x r2 < r n`
71 let x = self.mul(x, self.r2_mod_n);
72
73 Residue64 { x, modulus: self }
74 }
75
76 /// Performs Montgomery multiplication.
77 ///
78 /// if `lhs rhs < n r`, then `result < n`
79 const fn mul(&self, lhs: u64, rhs: u64) -> u64 {
80 self.mul_add(lhs, rhs, 0)
81 }
82
83 /// Performs `lhs rhs + add`.
84 ///
85 /// If `lhs rhs + add < n r`, then the result is less than `n`.
86 const fn mul_add(&self, lhs: u64, rhs: u64, add: u64) -> u64 {
87 // FIXME: use `a.widening_mul(b)`
88 let (x_hi, x_lo) = {
89 let x = lhs as u128 * rhs as u128 + add as u128;
90 ((x >> u64::BITS) as u64, x as u64)
91 };
92 // FIXME: use `mul_hi()`
93 // y = x n nn = x (mod r) => y_lo = x_lo
94 let y_hi = ((x_lo.wrapping_mul(self.inv_n) as u128 * self.n as u128) >> u64::BITS) as u64;
95 // x - y = 0 (mod r), x - y = x (mod n) => z = x inv_r (mod n)
96 let (z, b) = x_hi.overflowing_sub(y_hi);
97
98 // x < n r, y < n r => |z| < n
99 if b {
100 z.wrapping_add(self.n)
101 } else {
102 z
103 }
104 }
105
106 /// Checks whether `x` is multiple of `self`.
107 ///
108 /// # Example
109 ///
110 /// ```
111 /// use lib_modulo::Modulus64;
112 ///
113 /// for n in (1..1 << 10).step_by(2) {
114 /// let modulus = Modulus64::new(n);
115 ///
116 /// (0..1 << 10).for_each(|k| assert!(modulus.can_divide(n * k)));
117 /// }
118 /// ```
119 #[inline]
120 #[must_use]
121 pub const fn can_divide(&self, x: u64) -> bool {
122 self.residue(x).is_zero()
123 }
124}
125
126impl PartialEq for Modulus64 {
127 fn eq(&self, other: &Self) -> bool {
128 // other parameters depend on `n`
129 self.n == other.n
130 }
131}
132
133/// A residue with an odd modulus that fits in `2^64`.
134///
135/// # Fast modular multiplication
136///
137/// [`Residue64`] provides fast modular multiplication using [Montgomery multiplication].
138/// Since this method provides modular multiplication without division,
139/// it is approximately twice as fast.
140///
141/// [Montgomery multiplication]: https://doi.org/10.1090/s0025-5718-1985-0777282-x
142///
143/// # Usage
144///
145/// ```
146/// use lib_modulo::Modulus64;
147///
148/// // runtime-specified *odd* modulus
149/// let modulus = 5;
150///
151/// let modulus = Modulus64::new(modulus); // slow
152/// let n = modulus.residue(2) * modulus.residue(3); // fast
153/// assert_eq!(n.get(), 1);
154/// ```
155///
156/// Two residues with different modulus can interact, but the result will be meaningless.
157/// It is highly recommended to use a block to ensure that [`Modulus64`], therefore [`Residue64`]s, are dropped.
158#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
159pub struct Residue64<'a> {
160 modulus: &'a Modulus64,
161 // x r (mod n)
162 x: u64,
163}
164
165impl Residue64<'_> {
166 /// Extract the internal representation of `self`.
167 ///
168 /// ```
169 /// use lib_modulo::{Modulus64, Raw64};
170 ///
171 /// let modulus = Modulus64::new(1001);
172 /// // save memory
173 /// let residues: Vec<Raw64> = (1..=1000).map(|x| modulus.residue(x).into_raw()).collect();
174 ///
175 /// // `Residue64` and `raw64` can interact.
176 /// // The caller must ensure that both operands shares the same modulus.
177 /// let double_sum = residues.into_iter().fold(modulus.residue(0), |sum, r| r + sum + r);
178 /// assert_eq!(double_sum, modulus.residue((1 + 1000) * 1000));
179 /// ```
180 #[must_use]
181 pub const fn into_raw(self) -> Raw64 {
182 Raw64 { x: self.x }
183 }
184
185 /// Returns the residue.
186 ///
187 /// # Example
188 ///
189 /// ```
190 /// use lib_modulo::Modulus64;
191 ///
192 /// let modulus = Modulus64::new(5);
193 /// let n = modulus.residue(7);
194 /// assert_eq!(n.get(), 2);
195 /// ```
196 #[must_use]
197 pub const fn get(&self) -> u64 {
198 self.modulus.mul(self.x, 1)
199 }
200
201 /// Returns the modulus.
202 ///
203 /// # Example
204 ///
205 /// ```
206 /// use lib_modulo::Modulus64;
207 ///
208 /// let modulus = Modulus64::new(5);
209 /// let n = modulus.residue(7);
210 /// assert_eq!(n.modulus(), 5);
211 /// ```
212 #[must_use]
213 pub const fn modulus(&self) -> u64 {
214 self.modulus.n
215 }
216
217 /// Checks whether `self` is `0`.
218 ///
219 /// # Example
220 ///
221 /// ```
222 /// use lib_modulo::Modulus64;
223 ///
224 /// let modulus = Modulus64::new(3);
225 /// assert_eq!(modulus.residue(6).get(), 0);
226 /// ```
227 #[must_use]
228 pub const fn is_zero(self) -> bool {
229 self.x == 0
230 }
231
232 /// Raises `self` to the power of `exp`, using exponentiation by squaring.
233 ///
234 /// # Time complexity
235 ///
236 /// *O*(log `exp`)
237 ///
238 /// # Example
239 ///
240 /// ```
241 /// use lib_modulo::Modulus64;
242 ///
243 /// let modulus = Modulus64::new(1001);
244 /// let residue = modulus.residue(2);
245 /// for exp in 0..64 {
246 /// assert_eq!(residue.pow(exp).get(), (1 << exp) % 1001)
247 /// }
248 /// ```
249 #[inline]
250 #[must_use]
251 pub const fn pow(mut self, mut exp: u64) -> Self {
252 // r inv_r = 1 (mod n)
253 let mut result = self.modulus.residue(1).x;
254
255 while exp > 0 {
256 if exp & 1 == 1 {
257 // n < r
258 result = self.modulus.mul(result, self.x);
259 }
260
261 exp >>= 1;
262 // n < r
263 self.x = self.modulus.mul(self.x, self.x);
264 }
265 self.x = result;
266
267 self
268 }
269
270 /// Calculates the modular inverse of `self`, using extended binary GCD algorithm.
271 ///
272 /// Modular inverse can be defined if and only if `self` and the modulus is coprime.
273 ///
274 /// - `Ok(x)` : `x` is the modular inverse.
275 /// - `Err(x)`: `x` is the GCD of `self` and the `modulus`,
276 /// where `gcd(0, a) = gcd(a, 0)` is defined to be `a`.
277 ///
278 /// # Time complexity
279 ///
280 /// *O*(log `self`)
281 ///
282 /// # Example
283 ///
284 /// ```
285 /// use lib_modulo::Modulus64;
286 ///
287 /// // 998_244_353 is a prime number, so modular inverse of n exists iff n != 0 (mod 998_244_353)
288 /// let modulus = Modulus64::new(998_244_353);
289 ///
290 /// for n in 1..500_000 {
291 /// let n = modulus.residue(n);
292 /// assert!(n.inv().is_ok_and(|i| (i * n).get() == 1));
293 /// }
294 /// // 0 n = 0 != 1 for any integer n
295 /// assert!(modulus.residue(0).inv().is_err());
296 /// ```
297 #[inline]
298 pub const fn inv(self) -> Result<Self, u64> {
299 let mut a = self.get();
300 let Self { modulus, .. } = self;
301
302 // performs extended binary gcd
303 //
304 // invariants: a = [a] x, b = [a] y (mod n) where [a] is initial value
305 let mut b = modulus.n;
306 let mut x = modulus.residue(1).x; // 1 r mod n
307 let mut y = 0; // 0 r mod n
308 let frac_1_2 = modulus.residue(modulus.n.div_ceil(2));
309
310 while a > 0 {
311 x = modulus.mul(x, frac_1_2.pow(a.trailing_zeros() as u64).x);
312 a >>= a.trailing_zeros();
313
314 if a < b {
315 (a, b) = (b, a);
316 (x, y) = (y, x);
317 }
318 a -= b;
319 let (diff, b) = x.overflowing_sub(y);
320 x = if b {
321 diff.wrapping_add(modulus.n)
322 } else {
323 diff
324 };
325 }
326
327 // b = gcd([a], [b])
328 if b == 1 {
329 Ok(Self { x: y, modulus })
330 } else {
331 Err(b)
332 }
333 }
334}
335
336impl Add for Residue64<'_> {
337 type Output = Self;
338
339 fn add(self, rhs: Self) -> Self {
340 self + rhs.into_raw()
341 }
342}
343
344impl AddAssign for Residue64<'_> {
345 fn add_assign(&mut self, rhs: Self) {
346 *self = *self + rhs;
347 }
348}
349
350impl Sub for Residue64<'_> {
351 type Output = Self;
352
353 fn sub(self, rhs: Self) -> Self {
354 self - rhs.into_raw()
355 }
356}
357
358impl SubAssign for Residue64<'_> {
359 fn sub_assign(&mut self, rhs: Self) {
360 *self = *self - rhs;
361 }
362}
363
364impl Mul for Residue64<'_> {
365 type Output = Self;
366
367 fn mul(self, rhs: Self) -> Self {
368 self * rhs.into_raw()
369 }
370}
371
372impl MulAssign for Residue64<'_> {
373 fn mul_assign(&mut self, rhs: Self) {
374 *self = *self * rhs;
375 }
376}
377
378impl Neg for Residue64<'_> {
379 type Output = Self;
380
381 fn neg(mut self) -> Self::Output {
382 // (x - x) r = 0 (mod n)
383 self.x = if self.x == 0 {
384 self.x
385 } else {
386 self.modulus.n - self.x
387 };
388
389 self
390 }
391}
392
393/// An internal representation of [`Residue64`] without an associated [`Modulus64`].
394///
395/// Conceptually, [`Residue64`] = [`Raw64`] + [`Modulus64`].
396/// [`Raw64`] stores the value part alone, without holding a reference to its modulus.
397///
398/// This separation is useful for reducing the size of collections of [`Residue64`]
399/// and for avoiding self-referential structures when a type needs to contain both
400/// a residue and its modulus.
401#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
402pub struct Raw64 {
403 x: u64,
404}
405
406impl Raw64 {
407 /// Attaches a modulus and returns a [`Residue64`].
408 ///
409 /// Typically, this only needs to be called once per computation
410 /// because `Raw64` and `Residue64` can interact.
411 ///
412 /// # Caution
413 ///
414 /// This does not perform validation or reduction.
415 /// The caller must ensure the modulus is correct for this value.
416 #[must_use]
417 pub const fn into_residue(self, modulus: &Modulus64) -> Residue64<'_> {
418 Residue64 { modulus, x: self.x }
419 }
420}
421
422impl<'a> From<Residue64<'a>> for Raw64 {
423 fn from(residue: Residue64<'a>) -> Self {
424 Self { x: residue.x }
425 }
426}
427
428impl<'a> Add<Raw64> for Residue64<'a> {
429 type Output = Residue64<'a>;
430
431 /// Performs the `+` operation.
432 ///
433 /// # Caution
434 ///
435 /// The caller must ensure that both operands shares the same modulus.
436 fn add(mut self, rhs: Raw64) -> Self::Output {
437 let (sum, b) = self.x.overflowing_add(rhs.x);
438 self.x = if b || sum >= self.modulus.n {
439 sum.wrapping_sub(self.modulus.n)
440 } else {
441 sum
442 };
443
444 self
445 }
446}
447
448impl<'a> Add<Residue64<'a>> for Raw64 {
449 type Output = Residue64<'a>;
450
451 /// Performs the `+` operation.
452 ///
453 /// # Caution
454 ///
455 /// The caller must ensure that both operands shares the same modulus.
456 fn add(self, rhs: Residue64<'a>) -> Self::Output {
457 rhs + self
458 }
459}
460
461impl AddAssign<Raw64> for Residue64<'_> {
462 /// Performs the `+=` operation.
463 ///
464 /// # Caution
465 ///
466 /// The caller must ensure that both operands shares the same modulus.
467 fn add_assign(&mut self, rhs: Raw64) {
468 *self = *self + rhs;
469 }
470}
471
472impl<'a> Sub<Raw64> for Residue64<'a> {
473 type Output = Residue64<'a>;
474
475 /// Performs the `-` operation.
476 ///
477 /// # Caution
478 ///
479 /// The caller must ensure that both operands shares the same modulus.
480 fn sub(mut self, rhs: Raw64) -> Self::Output {
481 let (diff, b) = self.x.overflowing_sub(rhs.x);
482 self.x = if b {
483 diff.wrapping_add(self.modulus.n)
484 } else {
485 diff
486 };
487
488 self
489 }
490}
491
492impl<'a> Sub<Residue64<'a>> for Raw64 {
493 type Output = Residue64<'a>;
494
495 /// Performs the `-` operation.
496 ///
497 /// # Caution
498 ///
499 /// The caller must ensure that both operands shares the same modulus.
500 fn sub(self, mut rhs: Residue64<'a>) -> Self::Output {
501 let (diff, b) = self.x.overflowing_sub(rhs.x);
502 rhs.x = if b {
503 diff.wrapping_add(rhs.modulus.n)
504 } else {
505 diff
506 };
507
508 rhs
509 }
510}
511
512impl SubAssign<Raw64> for Residue64<'_> {
513 /// Performs the `-=` operation.
514 ///
515 /// # Caution
516 ///
517 /// The caller must ensure that both operands shares the same modulus.
518 fn sub_assign(&mut self, rhs: Raw64) {
519 *self = *self - rhs;
520 }
521}
522
523impl<'a> Mul<Raw64> for Residue64<'a> {
524 type Output = Residue64<'a>;
525
526 /// Performs the `*` operation.
527 ///
528 /// # Caution
529 ///
530 /// The caller must ensure that both operands shares the same modulus.
531 fn mul(mut self, rhs: Raw64) -> Self::Output {
532 // n < r
533 self.x = self.modulus.mul(self.x, rhs.x);
534
535 self
536 }
537}
538
539impl<'a> Mul<Residue64<'a>> for Raw64 {
540 type Output = Residue64<'a>;
541
542 /// Performs the `*` operation.
543 ///
544 /// # Caution
545 ///
546 /// The caller must ensure that both operands shares the same modulus.
547 fn mul(self, rhs: Residue64<'a>) -> Self::Output {
548 rhs * self
549 }
550}
551
552impl MulAssign<Raw64> for Residue64<'_> {
553 /// Performs the `*=` operation.
554 ///
555 /// # Caution
556 ///
557 /// The caller must ensure that both operands shares the same modulus.
558 fn mul_assign(&mut self, rhs: Raw64) {
559 *self = *self * rhs;
560 }
561}
562
563#[cfg(test)]
564mod tests {
565 use super::*;
566
567 use proptest::prelude::*;
568
569 proptest! {
570 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
571 #[test]
572 fn mul(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
573 let modulus = Modulus64::new(n);
574
575 let res = modulus.residue(x);
576 assert_eq!(res.get(), x % n)
577 }
578 }
579
580 proptest! {
581 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
582 #[test]
583 fn pow(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
584 let modulus = Modulus64::new(n);
585
586 let res = modulus.residue(x);
587 let mut naive = 1;
588 for i in 0..100 {
589 assert_eq!(res.pow(i).get(), naive, "exp = {i}");
590 naive = (naive as u128 * x as u128 % n as u128) as u64
591 }
592 }
593 }
594
595 proptest! {
596 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
597 #[test]
598 fn divisible(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
599 let modulus = Modulus64::new(n);
600
601 assert_eq!(modulus.can_divide(x), x % n == 0);
602 }
603 }
604
605 fn binary_gcd(mut a: u64, mut b: u64) -> u64 {
606 if b == 0 {
607 return a;
608 }
609
610 let shift = (a | b).trailing_zeros();
611 b >>= b.trailing_zeros();
612
613 while a != 0 {
614 a >>= a.trailing_zeros();
615
616 if a < b {
617 (a, b) = (b, a)
618 }
619 a -= b
620 }
621
622 b << shift
623 }
624
625 proptest! {
626 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
627 #[test]
628 fn inv(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
629 let modulus = Modulus64::new(n);
630 let res = modulus.residue(x);
631
632 match res.inv() {
633 Ok(inv) => assert_eq!((inv * res).get(), 1),
634 Err(gcd) => {
635 assert!(res.get() % gcd == 0);
636 assert!(res.modulus() % gcd == 0);
637 assert_eq!(binary_gcd(res.get() / gcd, res.modulus() / gcd), 1);
638 }
639 }
640 }
641 }
642}