#![no_std]
pub mod reference {
#[inline]
pub fn mulmod_u64(a: u64, b: u64, m: u64) -> u64 {
let aa = a as u128;
let bb = b as u128;
let mm = m as u128;
let rr = aa * bb % mm;
rr as _
}
#[inline]
pub fn mod_u128u64(a: u128, m: u64) -> u64 {
let mm = m as u128;
let rr = a % mm;
rr as _
}
#[inline]
pub fn powmod_u64(mut a: u64, mut p: u64, m: u64) -> u64 {
let mut y = 1;
while p > 0 {
if p % 2 == 1 {
y = mulmod_u64(y, a, m);
}
a = mulmod_u64(a, a, m);
p /= 2;
}
y
}
}
mod internal {
use core::arch::asm;
#[inline]
#[cfg(target_arch = "x86_64")]
pub fn mulmod_u64(a: u64, mut b: u64, m: u64) -> u64 {
unsafe {
asm!(
"mul rdx",
"div {}",
in(reg) m,
inout("rax") a => _,
inout("rdx") b,
);
}
b
}
#[inline]
#[cfg(not(target_arch = "x86_64"))]
pub fn mulmod_u64(a: u64, b: u64, m: u64) -> u64 {
super::reference::mulmod_u64(a, b, m)
}
#[inline]
#[cfg(target_arch = "x86_64")]
pub fn mod_u128u64(a: u128, m: u64) -> u64 {
let qb = 65 + m.leading_zeros() - a.leading_zeros();
if qb > 64 {
let s = qb - 64;
let mask = (1 << s) - 1;
let r = mod_u128u64_unchecked(a >> s, m) as u128;
let na = (r << s) + (a & mask);
mod_u128u64_unchecked(na, m)
} else {
mod_u128u64_unchecked(a, m)
}
}
#[inline]
#[cfg(not(target_arch = "x86_64"))]
pub fn mod_u128u64(a: u128, m: u64) -> u64 {
super::reference::mod_u128u64(a, m)
}
#[inline]
#[cfg(target_arch = "x86_64")]
pub fn mod_u128u64_unchecked(a: u128, m: u64) -> u64 {
let hi = (a >> 64) as u64;
let lo = (a & 0xFFFFFFFFFFFFFFFF) as u64;
let r;
unsafe {
asm!(
"div {}",
in(reg) m,
inout("rax") lo => _,
inout("rdx") hi => r,
);
}
r
}
#[inline]
#[cfg(not(target_arch = "x86_64"))]
pub fn mod_u128u64_unchecked(a: u128, m: u64) -> u64 {
super::reference::mod_u128u64(a, m)
}
}
#[inline]
pub fn mulmod_u64(a: u64, b: u64, m: u64) -> u64 {
internal::mulmod_u64(a, b, m)
}
#[inline]
pub fn mod_u128u64(a: u128, m: u64) -> u64 {
internal::mod_u128u64(a, m)
}
#[inline]
pub fn mod_u128u64_unchecked(a: u128, m: u64) -> u64 {
internal::mod_u128u64_unchecked(a, m)
}
#[inline]
pub fn powmod_u64(mut a: u64, mut p: u64, m: u64) -> u64 {
let mut y = 1;
while p > 0 {
if p % 2 == 1 {
y = mulmod_u64(y, a, m);
}
a = mulmod_u64(a, a, m);
p /= 2;
}
y
}
#[cfg(test)]
mod tests {
use crate::*;
#[test]
fn modulo_u64_mul() {
use rand::prelude::*;
let mut rng = rand::thread_rng();
for _ in 0..1_000_000 {
let m = rng.gen();
let a = rng.gen::<u64>() % m;
let b = rng.gen::<u64>() % m;
assert_eq!(reference::mulmod_u64(a, b, m), mulmod_u64(a, b, m));
}
}
#[test]
fn modulo_u128u64() {
use rand::prelude::*;
let mut rng = rand::thread_rng();
for _ in 0..1_000_000 {
let m = rng.gen();
let a = rng.gen();
assert_eq!(reference::mod_u128u64(a, m), mod_u128u64(a, m));
}
}
#[test]
fn modulo_u128u64_small_divisor() {
use rand::prelude::*;
let mut rng = rand::thread_rng();
for _ in 0..1_000_000 {
let m = rng.gen::<u32>().into();
let a = rng.gen();
assert_eq!(reference::mod_u128u64(a, m), mod_u128u64(a, m));
}
}
#[test]
fn modulo_u64_pow() {
use rand::prelude::*;
let mut rng = rand::thread_rng();
for _ in 0..1_000_000 {
let m = rng.gen();
let a = rng.gen::<u64>() % m;
let p = rng.gen::<u64>();
assert_eq!(reference::powmod_u64(a, p, m), powmod_u64(a, p, m));
}
}
}