1use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2
3#[allow(clippy::derived_hash_with_manual_eq)]
7#[derive(Debug, Clone, Eq, Hash)]
8pub struct Modulus64 {
9 pub(crate) n: u64,
11 pub(crate) inv_n: u64,
12 pub(crate) r2_mod_n: u64,
13}
14
15impl Modulus64 {
16 #[inline]
22 pub const fn new(n: u64) -> Self {
23 assert!(n & 1 == 1, "modulus should be an odd number");
24
25 let inv_n = {
26 const TABLE: u32 = {
27 let inv_n = [1, 11, 13, 7, 9, 3, 5, 15];
30
31 let mut table = 0;
32 let mut i = 0;
33 while i < 8 {
34 table |= inv_n[i] << (i * 4);
35 i += 1;
36 }
37
38 table
39 };
40 let mut inv_n = ((TABLE >> ((n & 0b1110) * 2)) & 0b1111) as u64;
42
43 let mut d = const { u64::BITS.ilog2() - 2 };
44 while d > 0 {
45 inv_n = inv_n.wrapping_mul(2_u64.wrapping_sub(n.wrapping_mul(inv_n)));
46 d -= 1;
47 }
48 debug_assert!(n.wrapping_mul(inv_n) == 1);
49
50 inv_n
51 };
52 let r2_mod_n = ((n as u128).wrapping_neg() % (n as u128)) as u64;
53
54 Self { n, inv_n, r2_mod_n }
55 }
56
57 #[inline(always)]
58 pub const fn residue(&self, x: u64) -> Residue64<'_> {
59 let x = self.mul(x, self.r2_mod_n);
61
62 Residue64 { x, modulus: self }
63 }
64
65 #[inline(always)]
69 pub(crate) const fn mul(&self, lhs: u64, rhs: u64) -> u64 {
70 self.mul_add(lhs, rhs, 0)
71 }
72
73 #[inline(always)]
77 pub(crate) const fn mul_add(&self, lhs: u64, rhs: u64, add: u64) -> u64 {
78 let (x_hi, x_lo) = {
80 let x = lhs as u128 * rhs as u128 + add as u128;
81 ((x >> u64::BITS) as u64, x as u64)
82 };
83 let y_hi = ((x_lo.wrapping_mul(self.inv_n) as u128 * self.n as u128) >> u64::BITS) as u64;
86 let (z, b) = x_hi.overflowing_sub(y_hi);
88
89 if b {
91 z.wrapping_add(self.n)
92 } else {
93 z
94 }
95 }
96
97 #[inline]
111 pub const fn can_divide(&self, x: u64) -> bool {
112 self.residue(x).is_zero()
113 }
114}
115
116impl PartialEq for Modulus64 {
117 fn eq(&self, other: &Self) -> bool {
118 self.n == other.n
120 }
121}
122
123#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
148pub struct Residue64<'a> {
149 pub(crate) modulus: &'a Modulus64,
150 pub(crate) x: u64,
152}
153
154impl<'a> Residue64<'a> {
155 #[inline(always)]
167 pub const fn get(&self) -> u64 {
168 self.modulus.mul(self.x, 1)
169 }
170
171 #[inline(always)]
183 pub const fn modulus(&self) -> u64 {
184 self.modulus.n
185 }
186
187 #[inline(always)]
198 pub const fn is_zero(self) -> bool {
199 self.x == 0
200 }
201
202 #[inline]
220 pub const fn pow(mut self, mut exp: u64) -> Self {
221 let mut result = self.modulus.residue(1).x;
223
224 while exp > 0 {
225 if exp & 1 == 1 {
226 result = self.modulus.mul(result, self.x)
228 }
229
230 exp >>= 1;
231 self.x = self.modulus.mul(self.x, self.x)
233 }
234 self.x = result;
235
236 self
237 }
238
239 #[inline]
267 pub const fn try_inv(self) -> Result<Self, u64> {
268 let mut a = self.get();
269 let Self { modulus, .. } = self;
270
271 let mut b = modulus.n;
275 let mut x = modulus.residue(1).x; let mut y = 0; let frac_1_2 = modulus.residue(modulus.n.div_ceil(2));
278
279 while a > 0 {
280 x = modulus.mul(x, frac_1_2.pow(a.trailing_zeros() as u64).x);
281 a >>= a.trailing_zeros();
282
283 if a < b {
284 (a, b) = (b, a);
285 (x, y) = (y, x);
286 }
287 a -= b;
288 let (diff, b) = x.overflowing_sub(y);
289 x = if b {
290 diff.wrapping_add(modulus.n)
291 } else {
292 diff
293 };
294 }
295
296 if b == 1 {
298 Ok(Self { x: y, modulus })
299 } else {
300 Err(b)
301 }
302 }
303}
304
305impl<'a> Add for Residue64<'a> {
306 type Output = Self;
307
308 #[inline(always)]
309 fn add(mut self, rhs: Self) -> Self {
310 let (sum, b) = self.x.overflowing_add(rhs.x);
311 self.x = if b || sum >= self.modulus.n {
312 sum.wrapping_sub(self.modulus.n)
313 } else {
314 sum
315 };
316
317 self
318 }
319}
320
321impl<'a> AddAssign for Residue64<'a> {
322 #[inline(always)]
323 fn add_assign(&mut self, rhs: Self) {
324 *self = *self + rhs
325 }
326}
327
328impl<'a> Sub for Residue64<'a> {
329 type Output = Self;
330
331 #[inline(always)]
332 fn sub(mut self, rhs: Self) -> Self {
333 let (diff, b) = self.x.overflowing_sub(rhs.x);
334 self.x = if b {
335 diff.wrapping_add(self.modulus.n)
336 } else {
337 diff
338 };
339
340 self
341 }
342}
343
344impl<'a> SubAssign for Residue64<'a> {
345 #[inline(always)]
346 fn sub_assign(&mut self, rhs: Self) {
347 *self = *self - rhs
348 }
349}
350
351impl<'a> Mul for Residue64<'a> {
352 type Output = Self;
353
354 #[inline(always)]
355 fn mul(mut self, rhs: Self) -> Self {
356 self.x = self.modulus.mul(self.x, rhs.x);
358
359 self
360 }
361}
362
363impl<'a> MulAssign for Residue64<'a> {
364 #[inline(always)]
365 fn mul_assign(&mut self, rhs: Self) {
366 *self = *self * rhs
367 }
368}
369
370impl<'a> Neg for Residue64<'a> {
371 type Output = Self;
372
373 #[inline(always)]
374 fn neg(mut self) -> Self::Output {
375 self.x = if self.x == 0 {
377 self.x
378 } else {
379 self.modulus.n - self.x
380 };
381
382 self
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389
390 use proptest::prelude::*;
391
392 proptest! {
393 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
394 #[test]
395 fn mul(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
396 let modulus = Modulus64::new(n);
397
398 let res = modulus.residue(x);
399 assert_eq!(res.get(), x % n)
400 }
401 }
402
403 proptest! {
404 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
405 #[test]
406 fn pow(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
407 let modulus = Modulus64::new(n);
408
409 let res = modulus.residue(x);
410 let mut naive = 1;
411 for i in 0..100 {
412 assert_eq!(res.pow(i).get(), naive, "exp = {i}");
413 naive = (naive as u128 * x as u128 % n as u128) as u64
414 }
415 }
416 }
417
418 proptest! {
419 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
420 #[test]
421 fn divisible(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
422 let modulus = Modulus64::new(n);
423
424 assert_eq!(modulus.can_divide(x), x % n == 0);
425 for m in std::iter::successors(Some(n), |m| m.checked_add(n)).take(100) {
426 assert!(modulus.can_divide(m))
427 }
428 }
429 }
430
431 fn binary_gcd(mut a: u64, mut b: u64) -> u64 {
432 if b == 0 {
433 return a;
434 }
435
436 let shift = (a | b).trailing_zeros();
437 b >>= b.trailing_zeros();
438
439 while a != 0 {
440 a >>= a.trailing_zeros();
441
442 if a < b {
443 (a, b) = (b, a)
444 }
445 a -= b
446 }
447
448 b << shift
449 }
450
451 proptest! {
452 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
453 #[test]
454 fn try_inv(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
455 let modulus = Modulus64::new(n);
456 let res = modulus.residue(x);
457
458 match res.try_inv() {
459 Ok(inv) => assert_eq!((inv * res).get(), 1),
460 Err(gcd) => {
461 assert!(res.get() % gcd == 0);
462 assert!(res.modulus() % gcd == 0);
463 assert_eq!(binary_gcd(res.get() / gcd, res.modulus() / gcd), 1);
464 }
465 }
466 }
467 }
468}