#[inline]
pub fn mod_mul(a: u64, b: u64, m: u64) -> u64 {
((a as u128 * b as u128) % m as u128) as u64
}
#[inline]
pub fn mod_pow(mut base: u64, mut exp: u64, m: u64) -> u64 {
if m == 1 {
return 0;
}
let mut result: u64 = 1;
base %= m;
while exp > 0 {
if exp & 1 == 1 {
result = mod_mul(result, base, m);
}
exp >>= 1;
base = mod_mul(base, base, m);
}
result
}
pub fn mod_inv(a: u64, m: u64) -> Option<u64> {
if m == 0 {
return None;
}
if m == 1 {
return Some(0);
}
let (mut old_r, mut r) = (a as i128, m as i128);
let (mut old_s, mut s) = (1i128, 0i128);
while r != 0 {
let quotient = old_r / r;
let temp_r = r;
r = old_r - quotient * r;
old_r = temp_r;
let temp_s = s;
s = old_s - quotient * s;
old_s = temp_s;
}
if old_r != 1 {
return None;
}
let result = ((old_s % m as i128) + m as i128) % m as i128;
Some(result as u64)
}
pub fn is_prime(n: u64) -> bool {
if n < 2 {
return false;
}
const SMALL_PRIMES: [u64; 12] = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37];
for &p in &SMALL_PRIMES {
if n == p {
return true;
}
if n.is_multiple_of(p) {
return false;
}
}
let mut d = n - 1;
let mut r = 0u32;
while d & 1 == 0 {
d >>= 1;
r += 1;
}
'witness: for &a in &SMALL_PRIMES {
let mut x = mod_pow(a, d, n);
if x == 1 || x == n - 1 {
continue;
}
for _ in 0..r - 1 {
x = mod_mul(x, x, n);
if x == n - 1 {
continue 'witness;
}
}
return false;
}
true
}
pub fn primitive_root(p: u64) -> Option<u64> {
if p < 2 {
return None;
}
if p == 2 {
return Some(1);
}
if !is_prime(p) {
return None;
}
let phi = p - 1;
let factors = factorize(phi);
'candidate: for g in 2..p {
for &f in &factors {
if mod_pow(g, phi / f, p) == 1 {
continue 'candidate;
}
}
return Some(g);
}
None }
fn factorize(mut n: u64) -> crate::prelude::Vec<u64> {
let mut factors = crate::prelude::Vec::new();
let mut d = 2u64;
while d * d <= n {
if n.is_multiple_of(d) {
factors.push(d);
while n.is_multiple_of(d) {
n /= d;
}
}
d += 1;
}
if n > 1 {
factors.push(n);
}
factors
}
pub fn two_adic_valuation(n: u64) -> Option<u32> {
if n == 0 {
return None;
}
Some(n.trailing_zeros())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mod_mul() {
assert_eq!(mod_mul(0, 0, 7), 0);
assert_eq!(mod_mul(3, 4, 7), 5); assert_eq!(mod_mul(6, 6, 7), 1); let big = (1u64 << 62) - 1;
let m = 998_244_353;
let result = mod_mul(big, big, m);
assert!(result < m);
}
#[test]
fn test_mod_pow() {
assert_eq!(mod_pow(2, 10, 1000), 24); assert_eq!(mod_pow(3, 0, 7), 1);
assert_eq!(mod_pow(0, 0, 7), 1); assert_eq!(mod_pow(5, 3, 13), 8); assert_eq!(mod_pow(2, 23, 998_244_353), 8_388_608);
}
#[test]
fn test_mod_pow_modulus_one() {
assert_eq!(mod_pow(5, 3, 1), 0);
}
#[test]
fn test_mod_inv() {
assert_eq!(mod_inv(3, 7), Some(5));
assert_eq!(mod_inv(0, 7), None);
assert_eq!(mod_inv(4, 8), None);
assert_eq!(mod_inv(1, 998_244_353), Some(1));
let inv = mod_inv(42, 998_244_353);
assert!(inv.is_some());
assert_eq!(mod_mul(42, inv.expect("just checked"), 998_244_353), 1);
}
#[test]
fn test_mod_inv_edge() {
assert_eq!(mod_inv(1, 0), None);
assert_eq!(mod_inv(5, 1), Some(0));
}
#[test]
fn test_is_prime() {
assert!(!is_prime(0));
assert!(!is_prime(1));
assert!(is_prime(2));
assert!(is_prime(3));
assert!(!is_prime(4));
assert!(is_prime(5));
assert!(is_prime(998_244_353));
assert!(is_prime(469_762_049));
assert!(is_prime(167_772_161));
assert!(!is_prime(998_244_354));
assert!(is_prime(1_000_000_007));
}
#[test]
fn test_primitive_root() {
assert_eq!(primitive_root(998_244_353), Some(3));
assert_eq!(primitive_root(469_762_049), Some(3));
assert_eq!(primitive_root(167_772_161), Some(3));
assert_eq!(primitive_root(2), Some(1));
assert_eq!(primitive_root(7), Some(3));
assert_eq!(primitive_root(4), None);
assert_eq!(primitive_root(0), None);
assert_eq!(primitive_root(1), None);
}
#[test]
fn test_two_adic_valuation() {
assert_eq!(two_adic_valuation(0), None);
assert_eq!(two_adic_valuation(1), Some(0));
assert_eq!(two_adic_valuation(8), Some(3));
assert_eq!(two_adic_valuation(998_244_352), Some(23));
}
#[test]
fn test_factorize() {
assert_eq!(factorize(1), vec![]);
assert_eq!(factorize(12), vec![2, 3]);
assert_eq!(factorize(998_244_352), vec![2, 7, 17]);
}
}