use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
#[derive(Debug, Clone, Eq, Hash)]
pub struct Modulus64 {
pub(crate) n: u64,
pub(crate) inv_n: u64,
pub(crate) r2_mod_n: u64,
}
impl Modulus64 {
#[inline]
pub const fn new(n: u64) -> Self {
assert!(n & 1 == 1, "modulus should be an odd number");
let inv_n = {
const TABLE: u32 = {
let inv_n = [1, 11, 13, 7, 9, 3, 5, 15];
let mut table = 0;
let mut i = 0;
while i < 8 {
table |= inv_n[i] << (i * 4);
i += 1;
}
table
};
let mut inv_n = ((TABLE >> (n & 0b1110) * 2) & 0b1111) as u64;
let mut d = const { u64::BITS.ilog2() - 2 };
while d > 0 {
inv_n = inv_n.wrapping_mul((2 as u64).wrapping_sub(n.wrapping_mul(inv_n)));
d -= 1;
}
debug_assert!(n.wrapping_mul(inv_n) == 1);
inv_n
};
let r2_mod_n = ((n as u128).wrapping_neg() % (n as u128)) as u64;
Self { n, inv_n, r2_mod_n }
}
#[inline(always)]
pub const fn residue(&self, x: u64) -> Residue64<'_> {
let x = self.mul(x, self.r2_mod_n);
Residue64 { x, modulus: &self }
}
#[inline(always)]
pub(crate) const fn mul(&self, lhs: u64, rhs: u64) -> u64 {
self.mul_add(lhs, rhs, 0)
}
#[inline(always)]
pub(crate) const fn mul_add(&self, lhs: u64, rhs: u64, add: u64) -> u64 {
let (x_hi, x_lo) = {
let x = lhs as u128 * rhs as u128 + add as u128;
((x >> u64::BITS) as u64, x as u64)
};
let y_hi = ((x_lo.wrapping_mul(self.inv_n) as u128 * self.n as u128) >> u64::BITS) as u64;
let (z, b) = x_hi.overflowing_sub(y_hi);
if b {
z.wrapping_add(self.n)
} else {
z
}
}
#[inline]
pub const fn can_divide(&self, x: u64) -> bool {
self.residue(x).is_zero()
}
}
impl PartialEq for Modulus64 {
fn eq(&self, other: &Self) -> bool {
self.n == other.n
}
}
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub struct Residue64<'a> {
pub(crate) modulus: &'a Modulus64,
pub(crate) x: u64,
}
impl<'a> Residue64<'a> {
#[inline(always)]
pub const fn get(&self) -> u64 {
self.modulus.mul(self.x, 1)
}
#[inline(always)]
pub const fn modulus(&self) -> u64 {
self.modulus.n
}
#[inline(always)]
pub const fn is_zero(self) -> bool {
self.x == 0
}
#[inline]
pub const fn pow(mut self, mut exp: u64) -> Self {
let mut result = self.modulus.residue(1).x;
while exp > 0 {
if exp & 1 == 1 {
result = self.modulus.mul(result, self.x)
}
exp >>= 1;
self.x = self.modulus.mul(self.x, self.x)
}
self.x = result;
self
}
#[inline]
pub const fn try_inv(self) -> Result<Self, u64> {
let mut a = self.get();
let Self { modulus, .. } = self;
let mut b = modulus.n;
let mut x = modulus.residue(1).x; let mut y = 0; let frac_1_2 = modulus.residue(modulus.n.div_ceil(2));
while a > 0 {
x = modulus.mul(x, frac_1_2.pow(a.trailing_zeros() as u64).x);
a >>= a.trailing_zeros();
if a < b {
(a, b) = (b, a);
(x, y) = (y, x);
}
a -= b;
let (diff, b) = x.overflowing_sub(y);
x = if b {
diff.wrapping_add(modulus.n)
} else {
diff
};
}
if b == 1 {
Ok(Self { x: y, modulus })
} else {
Err(b)
}
}
}
impl<'a> Add for Residue64<'a> {
type Output = Self;
#[inline(always)]
fn add(mut self, rhs: Self) -> Self {
let (sum, b) = self.x.overflowing_add(rhs.x);
self.x = if b || sum >= self.modulus.n {
sum.wrapping_sub(self.modulus.n)
} else {
sum
};
self
}
}
impl<'a> AddAssign for Residue64<'a> {
#[inline(always)]
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs
}
}
impl<'a> Sub for Residue64<'a> {
type Output = Self;
#[inline(always)]
fn sub(mut self, rhs: Self) -> Self {
let (diff, b) = self.x.overflowing_sub(rhs.x);
self.x = if b {
diff.wrapping_add(self.modulus.n)
} else {
diff
};
self
}
}
impl<'a> SubAssign for Residue64<'a> {
#[inline(always)]
fn sub_assign(&mut self, rhs: Self) {
*self = *self - rhs
}
}
impl<'a> Mul for Residue64<'a> {
type Output = Self;
#[inline(always)]
fn mul(mut self, rhs: Self) -> Self {
self.x = self.modulus.mul(self.x, rhs.x);
self
}
}
impl<'a> MulAssign for Residue64<'a> {
#[inline(always)]
fn mul_assign(&mut self, rhs: Self) {
*self = *self * rhs
}
}
impl<'a> Neg for Residue64<'a> {
type Output = Self;
#[inline(always)]
fn neg(mut self) -> Self::Output {
self.x = if self.x == 0 {
self.x
} else {
self.modulus.n - self.x
};
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(1 << 15))]
#[test]
fn mul(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
let modulus = Modulus64::new(n);
let res = modulus.residue(x);
assert_eq!(res.get(), x % n)
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(1 << 15))]
#[test]
fn pow(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
let modulus = Modulus64::new(n);
let res = modulus.residue(x);
let mut naive = 1;
for i in 0..100 {
assert_eq!(res.pow(i).get(), naive, "exp = {i}");
naive = (naive as u128 * x as u128 % n as u128) as u64
}
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(1 << 15))]
#[test]
fn divisible(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
let modulus = Modulus64::new(n);
assert_eq!(modulus.can_divide(x), x % n == 0);
for m in std::iter::successors(Some(n), |m| m.checked_add(n)).take(100) {
assert!(modulus.can_divide(m))
}
}
}
fn binary_gcd(mut a: u64, mut b: u64) -> u64 {
if b == 0 {
return a;
}
let shift = (a | b).trailing_zeros();
b >>= b.trailing_zeros();
while a != 0 {
a >>= a.trailing_zeros();
if a < b {
(a, b) = (b, a)
}
a -= b
}
b << shift
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(1 << 15))]
#[test]
fn try_inv(n in (0..=u64::MAX).prop_map(|n| n | 1), x: u64) {
let modulus = Modulus64::new(n);
let res = modulus.residue(x);
match res.try_inv() {
Ok(inv) => assert_eq!((inv * res).get(), 1),
Err(gcd) => {
assert!(res.get() % gcd == 0);
assert!(res.modulus() % gcd == 0);
assert_eq!(binary_gcd(res.get() / gcd, res.modulus() / gcd), 1);
}
}
}
}
}