1use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2
3#[derive(Debug, Clone, Eq, Hash)]
7pub struct Modulus64 {
8 pub(crate) n: u64,
10 pub(crate) inv_n: u64,
11 pub(crate) r2_mod_n: u64,
12}
13
14impl Modulus64 {
15 #[inline]
21 pub const fn new(n: u64) -> Self {
22 assert!(n & 1 == 1, "modulus should be an odd number");
23
24 let inv_n = {
25 const TABLE: u32 = {
26 let inv_n = [1, 11, 13, 7, 9, 3, 5, 15];
29
30 let mut table = 0;
31 let mut i = 0;
32 while i < 8 {
33 table |= inv_n[i] << (i * 4);
34 i += 1;
35 }
36
37 table
38 };
39 let mut inv_n = ((TABLE >> (n & 0b1110) * 2) & 0b1111) as u64;
41
42 let mut d = const { u64::BITS.ilog2() - 2 };
43 while d > 0 {
44 inv_n = inv_n.wrapping_mul((2 as u64).wrapping_sub(n.wrapping_mul(inv_n)));
45 d -= 1;
46 }
47 debug_assert!(n.wrapping_mul(inv_n) == 1);
48
49 inv_n
50 };
51 let r2_mod_n = ((n as u128).wrapping_neg() % (n as u128)) as u64;
52
53 Self { n, inv_n, r2_mod_n }
54 }
55
56 #[inline(always)]
57 pub const fn residue(&self, x: u64) -> Residue64<'_> {
58 let x = self.mul(x, self.r2_mod_n);
60
61 Residue64 { x, modulus: &self }
62 }
63
64 #[inline(always)]
68 pub(crate) const fn mul(&self, lhs: u64, rhs: u64) -> u64 {
69 self.mul_add(lhs, rhs, 0)
70 }
71
72 #[inline(always)]
76 pub(crate) const fn mul_add(&self, lhs: u64, rhs: u64, add: u64) -> u64 {
77 let (x_hi, x_lo) = {
79 let x = lhs as u128 * rhs as u128 + add as u128;
80 ((x >> u64::BITS) as u64, x as u64)
81 };
82 let y_hi = ((x_lo.wrapping_mul(self.inv_n) as u128 * self.n as u128) >> u64::BITS) as u64;
85 let (z, b) = x_hi.overflowing_sub(y_hi);
87
88 if b {
90 z.wrapping_add(self.n)
91 } else {
92 z
93 }
94 }
95
96 #[inline]
110 pub const fn can_divide(&self, x: u64) -> bool {
111 self.residue(x).is_zero()
112 }
113}
114
115impl PartialEq for Modulus64 {
116 fn eq(&self, other: &Self) -> bool {
117 self.n == other.n
118 }
119}
120
121#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
146pub struct Residue64<'a> {
147 pub(crate) modulus: &'a Modulus64,
148 pub(crate) x: u64,
150}
151
152impl<'a> Residue64<'a> {
153 #[inline(always)]
165 pub const fn get(&self) -> u64 {
166 self.modulus.mul(self.x, 1)
167 }
168
169 #[inline(always)]
181 pub const fn modulus(&self) -> u64 {
182 self.modulus.n
183 }
184
185 #[inline(always)]
197 pub const fn is_zero(self) -> bool {
198 self.x == 0
199 }
200
201 #[inline]
219 pub const fn pow(mut self, mut exp: u64) -> Self {
220 let mut result = self.modulus.residue(1).x;
222
223 while exp > 0 {
224 if exp & 1 == 1 {
225 result = self.modulus.mul(result, self.x)
227 }
228
229 exp >>= 1;
230 self.x = self.modulus.mul(self.x, self.x)
232 }
233 self.x = result;
234
235 self
236 }
237
238 #[inline]
266 pub const fn try_inv(self) -> Result<Self, u64> {
267 let mut a = self.get();
268 let Self { modulus, .. } = self;
269
270 let mut b = modulus.n;
274 let mut x = modulus.residue(1).x; let mut y = 0; let frac_1_2 = modulus.residue(modulus.n.div_ceil(2));
277
278 while a > 0 {
279 x = modulus.mul(x, frac_1_2.pow(a.trailing_zeros() as u64).x);
280 a >>= a.trailing_zeros();
281
282 if a < b {
283 (a, b) = (b, a);
284 (x, y) = (y, x);
285 }
286 a -= b;
287 let (diff, b) = x.overflowing_sub(y);
288 x = if b {
289 diff.wrapping_add(modulus.n)
290 } else {
291 diff
292 };
293 }
294
295 if b == 1 {
297 Ok(Self { x: y, modulus })
298 } else {
299 Err(b)
300 }
301 }
302}
303
304impl<'a> Add for Residue64<'a> {
305 type Output = Self;
306
307 #[inline(always)]
308 fn add(mut self, rhs: Self) -> Self {
309 let (sum, b) = self.x.overflowing_add(rhs.x);
310 self.x = if b || sum >= self.modulus.n {
311 sum.wrapping_sub(self.modulus.n)
312 } else {
313 sum
314 };
315
316 self
317 }
318}
319
320impl<'a> AddAssign for Residue64<'a> {
321 #[inline(always)]
322 fn add_assign(&mut self, rhs: Self) {
323 *self = *self + rhs
324 }
325}
326
327impl<'a> Sub for Residue64<'a> {
328 type Output = Self;
329
330 #[inline(always)]
331 fn sub(mut self, rhs: Self) -> Self {
332 let (diff, b) = self.x.overflowing_sub(rhs.x);
333 self.x = if b {
334 diff.wrapping_add(self.modulus.n)
335 } else {
336 diff
337 };
338
339 self
340 }
341}
342
343impl<'a> SubAssign for Residue64<'a> {
344 #[inline(always)]
345 fn sub_assign(&mut self, rhs: Self) {
346 *self = *self - rhs
347 }
348}
349
350impl<'a> Mul for Residue64<'a> {
351 type Output = Self;
352
353 #[inline(always)]
354 fn mul(mut self, rhs: Self) -> Self {
355 self.x = self.modulus.mul(self.x, rhs.x);
357
358 self
359 }
360}
361
362impl<'a> MulAssign for Residue64<'a> {
363 #[inline(always)]
364 fn mul_assign(&mut self, rhs: Self) {
365 *self = *self * rhs
366 }
367}
368
369impl<'a> Neg for Residue64<'a> {
370 type Output = Self;
371
372 #[inline(always)]
373 fn neg(mut self) -> Self::Output {
374 self.x = if self.x == 0 {
376 self.x
377 } else {
378 self.modulus.n - self.x
379 };
380
381 self
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388
389 use proptest::prelude::*;
390
391 proptest! {
392 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
393 #[test]
394 fn mul(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
395 let modulus = Modulus64::new(n);
396
397 let res = modulus.residue(x);
398 assert_eq!(res.get(), x % n)
399 }
400 }
401
402 proptest! {
403 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
404 #[test]
405 fn pow(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
406 let modulus = Modulus64::new(n);
407
408 let res = modulus.residue(x);
409 let mut naive = 1;
410 for i in 0..100 {
411 assert_eq!(res.pow(i).get(), naive, "exp = {i}");
412 naive = (naive as u128 * x as u128 % n as u128) as u64
413 }
414 }
415 }
416
417 proptest! {
418 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
419 #[test]
420 fn divisible(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
421 let modulus = Modulus64::new(n);
422
423 assert_eq!(modulus.can_divide(x), x % n == 0);
424 for m in std::iter::successors(Some(n), |m| m.checked_add(n)).take(100) {
425 assert!(modulus.can_divide(m))
426 }
427 }
428 }
429
430 fn binary_gcd(mut a: u64, mut b: u64) -> u64 {
431 if b == 0 {
432 return a;
433 }
434
435 let shift = (a | b).trailing_zeros();
436 b >>= b.trailing_zeros();
437
438 while a != 0 {
439 a >>= a.trailing_zeros();
440
441 if a < b {
442 (a, b) = (b, a)
443 }
444 a -= b
445 }
446
447 b << shift
448 }
449
450 proptest! {
451 #![proptest_config(ProptestConfig::with_cases(1 << 15))]
452 #[test]
453 fn try_inv(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
454 let modulus = Modulus64::new(n);
455 let res = modulus.residue(x);
456
457 match res.try_inv() {
458 Ok(inv) => assert_eq!((inv * res).get(), 1),
459 Err(gcd) => {
460 assert!(res.get() % gcd == 0);
461 assert!(res.modulus() % gcd == 0);
462 assert_eq!(binary_gcd(res.get() / gcd, res.modulus() / gcd), 1);
463 }
464 }
465 }
466 }
467}