lib_modulo/residue64.rs
1use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2
3/// Factory of [`Residue64`].
4///
5/// See documentation of [`Residue64`] for details.
6#[allow(clippy::derived_hash_with_manual_eq)]
7#[derive(Debug, Clone, Eq, Hash)]
8pub struct Modulus64 {
9 // n inv_n = 1 (mod r = 2^32 or 2^64)
10 pub(crate) n: u64,
11 pub(crate) inv_n: u64,
12 pub(crate) r2_mod_n: u64,
13}
14
15impl Modulus64 {
16 /// Calculates some parameters for Montgomery multiplication.
17 ///
18 /// # Panics
19 ///
20 /// - modulus `n` should be an odd number.
21 #[inline]
22 pub const fn new(n: u64) -> Self {
23 assert!(n & 1 == 1, "modulus should be an odd number");
24
25 let inv_n = {
26 const TABLE: u32 = {
27 // | n | 1 | 3 | 5 | 7 | 9 | 11 | 13 | 15 |
28 // | inv_n | 1 | 11 | 13 | 7 | 9 | 3 | 5 | 15 | <- 4 bits * 8
29 let inv_n = [1, 11, 13, 7, 9, 3, 5, 15];
30
31 let mut table = 0;
32 let mut i = 0;
33 while i < 8 {
34 table |= inv_n[i] << (i * 4);
35 i += 1;
36 }
37
38 table
39 };
40 // n inv_n = 1 (mod 8)
41 let mut inv_n = ((TABLE >> ((n & 0b1110) * 2)) & 0b1111) as u64;
42
43 let mut d = const { u64::BITS.ilog2() - 2 };
44 while d > 0 {
45 inv_n = inv_n.wrapping_mul(2_u64.wrapping_sub(n.wrapping_mul(inv_n)));
46 d -= 1;
47 }
48 debug_assert!(n.wrapping_mul(inv_n) == 1);
49
50 inv_n
51 };
52 let r2_mod_n = ((n as u128).wrapping_neg() % (n as u128)) as u64;
53
54 Self { n, inv_n, r2_mod_n }
55 }
56
57 /// Calculates the residue of `x` modulo `self`.
58 ///
59 /// # Example
60 ///
61 /// ```
62 /// use lib_modulo::Modulus64;
63 ///
64 /// let modulus = Modulus64::new(5);
65 /// assert_eq!(modulus.residue(8).get(), 3)
66 /// ```
67 #[inline(always)]
68 pub const fn residue(&self, x: u64) -> Residue64<'_> {
69 // `x r2 < r n`
70 let x = self.mul(x, self.r2_mod_n);
71
72 Residue64 { x, modulus: self }
73 }
74
75 /// Performs Montgomery multiplication.
76 ///
77 /// if `lhs rhs < n r`, then `result < n`
78 #[inline(always)]
79 pub(crate) const fn mul(&self, lhs: u64, rhs: u64) -> u64 {
80 self.mul_add(lhs, rhs, 0)
81 }
82
83 /// Performs `lhs rhs + add`.
84 ///
85 /// If `lhs rhs + add < n r`, then the result is less than `n`.
86 #[inline(always)]
87 pub(crate) const fn mul_add(&self, lhs: u64, rhs: u64, add: u64) -> u64 {
88 // FIXME: use `a.widening_mul(b)`
89 let (x_hi, x_lo) = {
90 let x = lhs as u128 * rhs as u128 + add as u128;
91 ((x >> u64::BITS) as u64, x as u64)
92 };
93 // FIXME: use `mul_hi()`
94 // y = x n nn = x (mod r) => y_lo = x_lo
95 let y_hi = ((x_lo.wrapping_mul(self.inv_n) as u128 * self.n as u128) >> u64::BITS) as u64;
96 // x - y = 0 (mod r), x - y = x (mod n) => z = x inv_r (mod n)
97 let (z, b) = x_hi.overflowing_sub(y_hi);
98
99 // x < n r, y < n r => |z| < n
100 if b {
101 z.wrapping_add(self.n)
102 } else {
103 z
104 }
105 }
106
107 /// Checks whether `x` is multiple of `self`.
108 ///
109 /// # Example
110 ///
111 /// ```
112 /// use lib_modulo::Modulus64;
113 ///
114 /// for n in (1..1 << 10).step_by(2) {
115 /// let modulus = Modulus64::new(n);
116 ///
117 /// (0..1 << 10).for_each(|k| assert!(modulus.can_divide(n * k)));
118 /// }
119 /// ```
120 #[inline]
121 pub const fn can_divide(&self, x: u64) -> bool {
122 self.residue(x).is_zero()
123 }
124}
125
126impl PartialEq for Modulus64 {
127 fn eq(&self, other: &Self) -> bool {
128 // other parameters depend on `n`
129 self.n == other.n
130 }
131}
132
133/// A residue with an odd modulus that fits in `2^64`.
134///
135/// # Fast modular multiplication
136///
137/// [`Residue64`] provides fast modular multiplication using [Montgomery multiplication].
138/// Since this method provides modular multiplication without division,
139/// it is approximately twice as fast.
140///
141/// [Montgomery multiplication]: https://doi.org/10.1090/s0025-5718-1985-0777282-x
142///
143/// # Usage
144///
145/// ```
146/// use lib_modulo::Modulus64;
147///
148/// // runtime-specified *odd* modulus
149/// let modulus = 5;
150///
151/// let modulus = Modulus64::new(modulus); // slow
152/// let n = modulus.residue(2) * modulus.residue(3); // fast
153/// assert_eq!(n.get(), 1);
154/// ```
155///
156/// Two residues with different modulus can interact, but the result will be meaningless.
157/// It is highly recommended to use a block to ensure that [`Modulus64`], therefore [`Residue64`]s, are dropped.
158#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
159pub struct Residue64<'a> {
160 pub(crate) modulus: &'a Modulus64,
161 // x r (mod n)
162 pub(crate) x: u64,
163}
164
165impl<'a> Residue64<'a> {
166 /// Extract the internal representation of `self`.
167 ///
168 /// ```
169 /// use lib_modulo::{Modulus64, Raw64};
170 ///
171 /// let modulus = Modulus64::new(1001);
172 /// // save memory
173 /// let residues: Vec<Raw64> = (1..=1000).map(|x| modulus.residue(x).into_raw()).collect();
174 /// ```
175 #[inline(always)]
176 pub const fn into_raw(self) -> Raw64 {
177 Raw64 { x: self.x }
178 }
179
180 /// Returns the residue.
181 ///
182 /// # Example
183 ///
184 /// ```
185 /// use lib_modulo::Modulus64;
186 ///
187 /// let modulus = Modulus64::new(5);
188 /// let n = modulus.residue(7);
189 /// assert_eq!(n.get(), 2);
190 /// ```
191 #[inline(always)]
192 pub const fn get(&self) -> u64 {
193 self.modulus.mul(self.x, 1)
194 }
195
196 /// Returns the modulus.
197 ///
198 /// # Example
199 ///
200 /// ```
201 /// use lib_modulo::Modulus64;
202 ///
203 /// let modulus = Modulus64::new(5);
204 /// let n = modulus.residue(7);
205 /// assert_eq!(n.modulus(), 5);
206 /// ```
207 #[inline(always)]
208 pub const fn modulus(&self) -> u64 {
209 self.modulus.n
210 }
211
212 /// Checks whether `self` is `0`.
213 ///
214 /// # Example
215 ///
216 /// ```
217 /// use lib_modulo::Modulus64;
218 ///
219 /// let modulus = Modulus64::new(3);
220 /// assert_eq!(modulus.residue(6).get(), 0);
221 /// ```
222 #[inline(always)]
223 pub const fn is_zero(self) -> bool {
224 self.x == 0
225 }
226
227 /// Raises `self` to the power of `exp`, using exponentiation by squaring.
228 ///
229 /// # Time complexity
230 ///
231 /// *O*(log `exp`)
232 ///
233 /// # Example
234 ///
235 /// ```
236 /// use lib_modulo::Modulus64;
237 ///
238 /// let modulus = Modulus64::new(1001);
239 /// let residue = modulus.residue(2);
240 /// for exp in 0..64 {
241 /// assert_eq!(residue.pow(exp).get(), (1 << exp) % 1001)
242 /// }
243 /// ```
244 #[inline]
245 pub const fn pow(mut self, mut exp: u64) -> Self {
246 // r inv_r = 1 (mod n)
247 let mut result = self.modulus.residue(1).x;
248
249 while exp > 0 {
250 if exp & 1 == 1 {
251 // n < r
252 result = self.modulus.mul(result, self.x)
253 }
254
255 exp >>= 1;
256 // n < r
257 self.x = self.modulus.mul(self.x, self.x)
258 }
259 self.x = result;
260
261 self
262 }
263
264 /// Calculates the modular inverse of `self`, using extended binary GCD algorithm.
265 ///
266 /// Modular inverse can be defined if and only if `self` and the modulus is coprime.
267 ///
268 /// - `Ok(x)` : `x` is the modular inverse.
269 /// - `Err(x)`: `x` is the GCD of `self` and the `modulus`,
270 /// where `gcd(0, a) = gcd(a, 0)` is defined to be `a`.
271 ///
272 /// # Time complexity
273 ///
274 /// *O*(log `self`)
275 ///
276 /// # Example
277 ///
278 /// ```
279 /// use lib_modulo::Modulus64;
280 ///
281 /// // 998_244_353 is a prime number, so modular inverse of n exists iff n != 0 (mod 998_244_353)
282 /// let modulus = Modulus64::new(998_244_353);
283 ///
284 /// for n in 1..500_000 {
285 /// let n = modulus.residue(n);
286 /// assert!(n.try_inv().is_ok_and(|i| (i * n).get() == 1));
287 /// }
288 /// // 0 n = 0 != 1 for any integer n
289 /// assert!(modulus.residue(0).try_inv().is_err());
290 /// ```
291 #[inline]
292 pub const fn try_inv(self) -> Result<Self, u64> {
293 let mut a = self.get();
294 let Self { modulus, .. } = self;
295
296 // performs extended binary gcd
297 //
298 // invariants: a = [a] x, b = [a] y (mod n) where [a] is initial value
299 let mut b = modulus.n;
300 let mut x = modulus.residue(1).x; // 1 r mod n
301 let mut y = 0; // 0 r mod n
302 let frac_1_2 = modulus.residue(modulus.n.div_ceil(2));
303
304 while a > 0 {
305 x = modulus.mul(x, frac_1_2.pow(a.trailing_zeros() as u64).x);
306 a >>= a.trailing_zeros();
307
308 if a < b {
309 (a, b) = (b, a);
310 (x, y) = (y, x);
311 }
312 a -= b;
313 let (diff, b) = x.overflowing_sub(y);
314 x = if b {
315 diff.wrapping_add(modulus.n)
316 } else {
317 diff
318 };
319 }
320
321 // b = gcd([a], [b])
322 if b == 1 {
323 Ok(Self { x: y, modulus })
324 } else {
325 Err(b)
326 }
327 }
328}
329
330impl<'a> Add for Residue64<'a> {
331 type Output = Self;
332
333 #[inline(always)]
334 fn add(mut self, rhs: Self) -> Self {
335 let (sum, b) = self.x.overflowing_add(rhs.x);
336 self.x = if b || sum >= self.modulus.n {
337 sum.wrapping_sub(self.modulus.n)
338 } else {
339 sum
340 };
341
342 self
343 }
344}
345
346impl<'a> AddAssign for Residue64<'a> {
347 #[inline(always)]
348 fn add_assign(&mut self, rhs: Self) {
349 *self = *self + rhs
350 }
351}
352
353impl<'a> Sub for Residue64<'a> {
354 type Output = Self;
355
356 #[inline(always)]
357 fn sub(mut self, rhs: Self) -> Self {
358 let (diff, b) = self.x.overflowing_sub(rhs.x);
359 self.x = if b {
360 diff.wrapping_add(self.modulus.n)
361 } else {
362 diff
363 };
364
365 self
366 }
367}
368
369impl<'a> SubAssign for Residue64<'a> {
370 #[inline(always)]
371 fn sub_assign(&mut self, rhs: Self) {
372 *self = *self - rhs
373 }
374}
375
376impl<'a> Mul for Residue64<'a> {
377 type Output = Self;
378
379 #[inline(always)]
380 fn mul(mut self, rhs: Self) -> Self {
381 // n < r
382 self.x = self.modulus.mul(self.x, rhs.x);
383
384 self
385 }
386}
387
388impl<'a> MulAssign for Residue64<'a> {
389 #[inline(always)]
390 fn mul_assign(&mut self, rhs: Self) {
391 *self = *self * rhs
392 }
393}
394
395impl<'a> Neg for Residue64<'a> {
396 type Output = Self;
397
398 #[inline(always)]
399 fn neg(mut self) -> Self::Output {
400 // (x - x) r = 0 (mod n)
401 self.x = if self.x == 0 {
402 self.x
403 } else {
404 self.modulus.n - self.x
405 };
406
407 self
408 }
409}
410
411/// An internal representation of [`Residue64`] without an associated [`Modulus64`].
412///
413/// Conceptually, [`Residue64`] = [`Raw64`] + [`Modulus64`].
414/// [`Raw64`] stores the value part alone, without holding a reference to its modulus.
415///
416/// This separation is useful for reducing the size of collections of [`Residue64`]
417/// and for avoiding self-referential structures when a type needs to contain both
418/// a residue and its modulus.
419#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
420pub struct Raw64 {
421 x: u64,
422}
423
424impl Raw64 {
425 /// Attaches a modulus and returns a [`Residue64`].
426 ///
427 /// # Caution
428 ///
429 /// This does not perform validation or reduction.
430 /// The caller must ensure the modulus is correct for this value.
431 #[inline(always)]
432 pub const fn into_residue<'a>(self, modulus: &'a Modulus64) -> Residue64<'a> {
433 Residue64 { modulus, x: self.x }
434 }
435}
436
437impl<'a> From<Residue64<'a>> for Raw64 {
438 #[inline(always)]
439 fn from(residue: Residue64<'a>) -> Self {
440 Self { x: residue.x }
441 }
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447
448 use proptest::prelude::*;
449
450 proptest! {
451 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
452 #[test]
453 fn mul(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
454 let modulus = Modulus64::new(n);
455
456 let res = modulus.residue(x);
457 assert_eq!(res.get(), x % n)
458 }
459 }
460
461 proptest! {
462 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
463 #[test]
464 fn pow(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
465 let modulus = Modulus64::new(n);
466
467 let res = modulus.residue(x);
468 let mut naive = 1;
469 for i in 0..100 {
470 assert_eq!(res.pow(i).get(), naive, "exp = {i}");
471 naive = (naive as u128 * x as u128 % n as u128) as u64
472 }
473 }
474 }
475
476 proptest! {
477 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
478 #[test]
479 fn divisible(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
480 let modulus = Modulus64::new(n);
481
482 assert_eq!(modulus.can_divide(x), x % n == 0);
483 for m in std::iter::successors(Some(n), |m| m.checked_add(n)).take(100) {
484 assert!(modulus.can_divide(m))
485 }
486 }
487 }
488
489 fn binary_gcd(mut a: u64, mut b: u64) -> u64 {
490 if b == 0 {
491 return a;
492 }
493
494 let shift = (a | b).trailing_zeros();
495 b >>= b.trailing_zeros();
496
497 while a != 0 {
498 a >>= a.trailing_zeros();
499
500 if a < b {
501 (a, b) = (b, a)
502 }
503 a -= b
504 }
505
506 b << shift
507 }
508
509 proptest! {
510 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
511 #[test]
512 fn try_inv(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
513 let modulus = Modulus64::new(n);
514 let res = modulus.residue(x);
515
516 match res.try_inv() {
517 Ok(inv) => assert_eq!((inv * res).get(), 1),
518 Err(gcd) => {
519 assert!(res.get() % gcd == 0);
520 assert!(res.modulus() % gcd == 0);
521 assert_eq!(binary_gcd(res.get() / gcd, res.modulus() / gcd), 1);
522 }
523 }
524 }
525 }
526}