use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
#[allow(clippy::derived_hash_with_manual_eq)]
#[derive(Debug, Clone, Hash, Eq)]
pub struct Modulus32 {
n: u64,
inv_n: u64,
init: u64,
recip: u64,
}
impl Modulus32 {
pub const MAX: u32 = 2_654_435_769;
#[inline]
#[must_use]
pub const fn new(n: u32) -> Self {
assert!(
n & 1 == 1,
"invalid modulus: modulus should be an odd integer."
);
assert!(
n <= Self::MAX,
"invalid modulus: modulus should be no more than 2_654_435_769."
);
let n = n as u64;
let inv_n = {
let mut inv_n = n & 3;
let mut i = u64::BITS.ilog2() - 1;
while i > 0 {
i -= 1;
inv_n = inv_n.wrapping_mul(2_u64.wrapping_sub(n.wrapping_mul(inv_n)));
}
debug_assert!(n.wrapping_mul(inv_n) == 1);
inv_n
};
let (div, rem) = {
let denom = n.wrapping_neg();
(denom / n, denom % n)
};
let init = rem * rem % n;
let recip = div.wrapping_add(if rem > 0 { 2 } else { 1 });
Self {
n,
inv_n,
init: init.wrapping_mul(inv_n),
recip,
}
}
#[inline(always)]
const fn mul(&self, x: u64, y: u64) -> u64 {
let z = self.inv_n.wrapping_mul(x).wrapping_mul(y) >> 32;
let z = ((z as u32).wrapping_add(1) as u64 * self.n) >> 32;
debug_assert!(z < self.n, "this is a bug in lib-modulo");
z
}
#[must_use]
pub const fn residue(&self, x: u32) -> Residue32<'_> {
let x = {
let lo = self.recip.wrapping_mul(x as u64);
((lo as u128 * self.n as u128) >> 64) as u64
};
let x = {
let x = self.init.wrapping_mul(x) >> 32;
((x as u32).wrapping_add(1) as u64 * self.n) >> 32
};
Residue32 { x, modulus: self }
}
#[must_use]
pub const fn can_divide(&self, x: u32) -> bool {
self.recip.wrapping_mul(x as u64) <= self.recip.wrapping_sub(1)
}
}
impl PartialEq for Modulus32 {
fn eq(&self, other: &Self) -> bool {
self.n == other.n
}
}
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub struct Residue32<'a> {
modulus: &'a Modulus32,
x: u64,
}
impl Residue32<'_> {
#[must_use]
pub const fn into_raw(self) -> Raw32 {
Raw32 { x: self.x }
}
#[must_use]
pub const fn is_zero(self) -> bool {
self.x == 0
}
#[must_use]
pub const fn get(self) -> u64 {
self.modulus.mul(self.x, 1)
}
#[must_use]
pub const fn modulus(&self) -> u64 {
self.modulus.n
}
#[must_use]
pub const fn pow(self, mut exp: u32) -> Self {
let Self { mut x, modulus } = self;
let mut prod = modulus.residue(1).x;
while exp > 1 {
if exp & 1 == 1 {
prod = modulus.mul(prod, x);
}
exp >>= 1;
x = modulus.mul(x, x); }
if exp != 0 {
prod = modulus.mul(prod, x);
}
Self { x: prod, modulus }
}
pub const fn inv(self) -> Result<Self, u64> {
let mut a = self.get();
let mut b = self.modulus();
let Self { modulus, .. } = self;
let mut x = modulus.residue(1).x;
let mut y = 0;
let frac_1_2 = modulus.residue((modulus.n as u32).div_ceil(2));
while a > 0 {
x = modulus.mul(x, frac_1_2.pow(a.trailing_zeros()).x);
a >>= a.trailing_zeros();
if a < b {
(a, b) = (b, a);
(x, y) = (y, x);
}
a -= b;
let (z, b) = x.overflowing_sub(y);
x = if b { z.wrapping_add(modulus.n) } else { z };
}
if b == 1 {
Ok(Self { x: y, modulus })
} else {
Err(b)
}
}
}
impl Add for Residue32<'_> {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
self + rhs.into_raw()
}
}
impl AddAssign for Residue32<'_> {
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs;
}
}
impl Sub for Residue32<'_> {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
self - rhs.into_raw()
}
}
impl SubAssign for Residue32<'_> {
fn sub_assign(&mut self, rhs: Self) {
*self = *self - rhs;
}
}
impl Mul for Residue32<'_> {
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
self * rhs.into_raw()
}
}
impl MulAssign for Residue32<'_> {
fn mul_assign(&mut self, rhs: Self) {
*self = *self * rhs;
}
}
impl Neg for Residue32<'_> {
type Output = Self;
fn neg(mut self) -> Self::Output {
self.x = if self.x == 0 {
0
} else {
self.modulus() - self.x
};
self
}
}
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub struct Raw32 {
x: u64,
}
impl Raw32 {
#[must_use]
pub const fn into_residue(self, modulus: &Modulus32) -> Residue32<'_> {
Residue32 { modulus, x: self.x }
}
}
impl<'a> From<Residue32<'a>> for Raw32 {
fn from(residue: Residue32<'a>) -> Self {
Self { x: residue.x }
}
}
impl<'a> Add<Raw32> for Residue32<'a> {
type Output = Residue32<'a>;
fn add(mut self, rhs: Raw32) -> Self::Output {
let (sum, b) = self.x.overflowing_add(rhs.x);
self.x = if b || sum >= self.modulus.n {
sum.wrapping_sub(self.modulus.n)
} else {
sum
};
self
}
}
impl<'a> Add<Residue32<'a>> for Raw32 {
type Output = Residue32<'a>;
fn add(self, rhs: Residue32<'a>) -> Self::Output {
rhs + self
}
}
impl AddAssign<Raw32> for Residue32<'_> {
fn add_assign(&mut self, rhs: Raw32) {
*self = *self + rhs;
}
}
impl<'a> Sub<Raw32> for Residue32<'a> {
type Output = Residue32<'a>;
fn sub(mut self, rhs: Raw32) -> Self::Output {
let (diff, b) = self.x.overflowing_sub(rhs.x);
self.x = if b {
diff.wrapping_add(self.modulus.n)
} else {
diff
};
self
}
}
impl<'a> Sub<Residue32<'a>> for Raw32 {
type Output = Residue32<'a>;
fn sub(self, mut rhs: Residue32<'a>) -> Self::Output {
let (diff, b) = self.x.overflowing_sub(rhs.x);
rhs.x = if b {
diff.wrapping_add(rhs.modulus.n)
} else {
diff
};
rhs
}
}
impl SubAssign<Raw32> for Residue32<'_> {
fn sub_assign(&mut self, rhs: Raw32) {
*self = *self - rhs;
}
}
impl<'a> Mul<Raw32> for Residue32<'a> {
type Output = Residue32<'a>;
fn mul(mut self, rhs: Raw32) -> Self::Output {
self.x = self.modulus.mul(self.x, rhs.x);
self
}
}
impl<'a> Mul<Residue32<'a>> for Raw32 {
type Output = Residue32<'a>;
fn mul(self, rhs: Residue32<'a>) -> Self::Output {
rhs * self
}
}
impl MulAssign<Raw32> for Residue32<'_> {
fn mul_assign(&mut self, rhs: Raw32) {
*self = *self * rhs;
}
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(1 << 15))]
#[test]
fn mul(n in (0..=Modulus32::MAX).prop_map(|n| n | 1), x: u32) {
let modulus = Modulus32::new(n);
let res = modulus.residue(x);
assert_eq!(res.get() as u32, x % n)
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(1 << 15))]
#[test]
fn pow(n in (0..=Modulus32::MAX as u64).prop_map(|n| n | 1), x in 0u64..1 << 32) {
let modulus = Modulus32::new(n as u32);
let res = modulus.residue(x as u32);
let mut naive = 1;
for i in 0..100 {
assert_eq!(res.pow(i).get(), naive, "exp = {i}");
naive = naive * x % n
}
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(1 << 15))]
#[test]
fn divisible(n in (0..=Modulus32::MAX).prop_map(|n| n | 1), x: u32) {
let modulus = Modulus32::new(n);
assert_eq!(modulus.can_divide(x), x % n == 0);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(1 << 15))]
#[test]
fn divisible_by_1(x: u32) {
assert!(Modulus32::new(1).can_divide(x))
}
}
fn binary_gcd(mut a: u64, mut b: u64) -> u64 {
if b == 0 {
return a;
}
let shift = (a | b).trailing_zeros();
b >>= b.trailing_zeros();
while a != 0 {
a >>= a.trailing_zeros();
if a < b {
(a, b) = (b, a)
}
a -= b
}
b << shift
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(1 << 15))]
#[test]
fn inv(n in (0..=Modulus32::MAX).prop_map(|n| n | 1), x: u32) {
let modulus = Modulus32::new(n);
let res = modulus.residue(x);
match res.inv() {
Ok(inv) => assert_eq!((inv * res).get(), 1),
Err(gcd) => {
assert!(res.get() % gcd == 0);
assert!(res.modulus() % gcd == 0);
assert_eq!(binary_gcd(res.get() / gcd, res.modulus() / gcd), 1);
}
}
}
}
}