#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub struct Modulus32Any {
n: u64,
magic: u64,
}
#[derive(thiserror::Error, Debug, PartialEq, Eq, Hash, Clone, Copy)]
pub enum InvalidModulus {
#[error("modulo 0 is undefined")]
Zero,
#[error("modulo 1 is meaningless and not available for performance reason")]
One,
}
impl Modulus32Any {
pub const fn new(n: u32) -> Result<Self, InvalidModulus> {
match n {
0 => Err(InvalidModulus::Zero),
1 => Err(InvalidModulus::One),
n => {
let n = n as u64;
let magic = (u64::MAX / n).wrapping_add(1);
Ok(Self { n, magic })
}
}
}
#[must_use]
pub const fn modulus(&self) -> u32 {
self.n as u32
}
#[must_use]
pub const fn residue32(&self, x: u32) -> u32 {
let lo = self.magic.wrapping_mul(x as u64);
((lo as u128 * self.n as u128) >> 64) as u32
}
#[must_use]
pub const fn residue64(&self, x: u64) -> u64 {
let quot = ((x as u128 * self.magic as u128) >> 64) as u64;
let (rem, b) = x.overflowing_sub(quot * self.n);
if b {
rem.wrapping_add(self.n)
} else {
rem
}
}
#[must_use]
pub const fn can_divide(&self, x: u32) -> bool {
self.magic.wrapping_mul(x as u64) < self.magic
}
#[must_use]
pub const fn mul(&self, a: u32, b: u32) -> u32 {
self.residue64(a as u64 * b as u64) as u32
}
#[must_use]
pub const fn carrying_mul(&self, a: u32, b: u32, c: u32) -> u32 {
self.residue64(a as u64 * b as u64 + c as u64) as u32
}
#[must_use]
pub const fn carrying_mul_add(&self, a: u32, b: u32, c: u32, d: u32) -> u32 {
self.residue64(a as u64 * b as u64 + c as u64 + d as u64) as u32
}
#[must_use]
pub const fn pow(&self, mut x: u32, mut exp: u32) -> u32 {
let mut res = 1;
while exp > 0 {
if exp & 1 == 1 {
res = self.mul(res, x);
}
exp >>= 1;
x = self.mul(x, x);
}
res
}
pub const fn inv(&self, x: u32) -> Result<u32, u32> {
let mut x = self.residue32(x) as i64;
let mut y = self.n as i64;
let [mut a, mut b] = [1, 0];
while x > 0 {
let (div, rem) = (y / x, y % x);
(x, y) = (rem, x);
(a, b) = (b - div * a, a);
}
if y != 1 {
return Err(y as u32);
}
if b.is_negative() {
b += self.n as i64;
}
Ok(b as u32)
}
}
#[cfg(test)]
mod tests {
use proptest::prelude::*;
use rand::{random_iter, rng};
use super::Modulus32Any;
proptest! {
#![proptest_config(ProptestConfig::with_cases(1 << 15))]
#[test]
fn mul(n in 2..=u32::MAX, a: u32, b: u32) {
let modulus = Modulus32Any::new(n).unwrap();
assert_eq!(
modulus.mul(a, b),
(a as u64 * b as u64 % n as u64) as u32,
"{:?}", modulus
);
}
}
#[test]
fn mul_small() {
let mut rng = rng();
for n in 2..1 << 8 {
let modulus = Modulus32Any::new(n).unwrap();
for _ in 0..1 << 12 {
let a = rng.random();
let b = rng.random();
assert_eq!(modulus.mul(a, b) as u64, (a as u64 * b as u64 % n as u64),)
}
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(1 << 15))]
#[test]
fn residue32(n in 2..=u32::MAX, a: u32) {
let modulus = Modulus32Any::new(n).unwrap();
assert_eq!(modulus.residue32(a), a % n);
}
}
#[test]
fn residue32_small() {
for n in 2..1 << 8 {
let modulus = Modulus32Any::new(n).unwrap();
for a in random_iter().take(1 << 12) {
assert_eq!(modulus.residue32(a), a % n)
}
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(1 << 15))]
#[test]
fn residue64(n in 2..=u32::MAX, a: u64) {
let modulus = Modulus32Any::new(n).unwrap();
assert_eq!(modulus.residue64(a), a % n as u64);
}
}
#[test]
fn residue64_small() {
for n in 2..1 << 8 {
let modulus = Modulus32Any::new(n).unwrap();
for a in random_iter().take(1 << 12) {
assert_eq!(modulus.residue64(a), a % n as u64)
}
}
}
fn binary_gcd(mut a: u32, mut b: u32) -> u32 {
if b == 0 {
return a;
}
let shift = (a | b).trailing_zeros();
b >>= b.trailing_zeros();
while a != 0 {
a >>= a.trailing_zeros();
if a < b {
(a, b) = (b, a)
}
a -= b
}
b << shift
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(1 << 15))]
#[test]
fn inv(n in 2..=u32::MAX, a: u32) {
let modulus = Modulus32Any::new(n).unwrap();
match modulus.inv(a) {
Ok(inv) => assert_eq!(modulus.mul(a, inv), 1, "!"),
Err(gcd) => assert_eq!(gcd, binary_gcd(a, n), "?")
}
}
}
}