use crate::error::{SpecialError, SpecialResult};
use std::collections::HashMap;
pub fn gcd(mut a: u64, mut b: u64) -> u64 {
while b != 0 {
let t = b;
b = a % b;
a = t;
}
a
}
pub fn lcm(a: u64, b: u64) -> u64 {
if a == 0 || b == 0 {
return 0;
}
(a / gcd(a, b)).saturating_mul(b)
}
pub fn extended_gcd(a: i64, b: i64) -> (i64, i64, i64) {
if b == 0 {
if a >= 0 {
return (a, 1, 0);
} else {
return (-a, -1, 0);
}
}
let (g, x1, y1) = extended_gcd(b, a % b);
let x = y1;
let y = x1 - (a / b) * y1;
(g, x, y)
}
pub fn pow_mod(mut base: u64, mut exp: u64, m: u64) -> u64 {
if m == 1 {
return 0;
}
let mut result = 1u64;
base %= m;
while exp > 0 {
if exp & 1 == 1 {
result = mul_mod(result, base, m);
}
base = mul_mod(base, base, m);
exp >>= 1;
}
result
}
#[inline]
fn mul_mod(a: u64, b: u64, m: u64) -> u64 {
((a as u128 * b as u128) % m as u128) as u64
}
pub fn mod_inverse(a: u64, m: u64) -> SpecialResult<u64> {
if m == 0 {
return Err(SpecialError::ValueError(
"mod_inverse: modulus must be > 0".to_string(),
));
}
if m == 1 {
return Ok(0);
}
let (g, x, _) = extended_gcd(a as i64, m as i64);
if g != 1 {
return Err(SpecialError::DomainError(format!(
"mod_inverse: gcd({a}, {m}) = {g} ≠ 1; inverse does not exist"
)));
}
Ok(((x % m as i64 + m as i64) as u64) % m)
}
pub fn is_prime(n: u64) -> bool {
match n {
0 | 1 => return false,
2 | 3 | 5 | 7 => return true,
_ if n % 2 == 0 || n % 3 == 0 || n % 5 == 0 => return false,
_ => {}
}
let mut d = n - 1;
let mut r = 0u32;
while d % 2 == 0 {
d /= 2;
r += 1;
}
const WITNESSES: &[u64] = &[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37];
'witness: for &a in WITNESSES {
if a >= n {
continue;
}
let mut x = pow_mod(a, d, n);
if x == 1 || x == n - 1 {
continue;
}
for _ in 0..r - 1 {
x = mul_mod(x, x, n);
if x == n - 1 {
continue 'witness;
}
}
return false;
}
true
}
pub fn factorize(n: u64) -> Vec<u64> {
if n <= 1 {
return Vec::new();
}
let mut factors = Vec::new();
factorize_into(n, &mut factors);
factors.sort_unstable();
factors
}
fn factorize_into(mut n: u64, out: &mut Vec<u64>) {
if n == 1 {
return;
}
for p in [2u64, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37] {
while n % p == 0 {
out.push(p);
n /= p;
}
if n == 1 {
return;
}
}
if n < 1369 {
out.push(n);
return;
}
if is_prime(n) {
out.push(n);
return;
}
let d = pollard_rho(n);
factorize_into(d, out);
factorize_into(n / d, out);
}
fn pollard_rho(n: u64) -> u64 {
if n % 2 == 0 {
return 2;
}
for seed in 1u64.. {
let mut x = seed % n + 1;
let mut y = x;
let mut c = seed % (n - 1) + 1;
let mut d = 1u64;
while d == 1 {
x = (mul_mod(x, x, n) + c) % n;
y = (mul_mod(y, y, n) + c) % n;
y = (mul_mod(y, y, n) + c) % n;
let diff = if x > y { x - y } else { y - x };
d = gcd(diff, n);
}
if d != n {
return d;
}
}
n
}
pub fn chinese_remainder_theorem(rems: &[i64], mods: &[i64]) -> SpecialResult<i64> {
if rems.len() != mods.len() {
return Err(SpecialError::ValueError(
"crt: rems and mods must have the same length".to_string(),
));
}
if rems.is_empty() {
return Ok(0);
}
for &m in mods {
if m <= 0 {
return Err(SpecialError::ValueError(format!(
"crt: modulus {m} must be positive"
)));
}
}
let mut x = 0i64;
let mut step = 1i64;
for (&r, &m) in rems.iter().zip(mods.iter()) {
let g = gcd(step.unsigned_abs(), m.unsigned_abs()) as i64;
let diff = ((r - x) % m + m) % m;
if diff % g != 0 {
return Err(SpecialError::ValueError(
"crt: moduli are not pairwise coprime (no solution exists)".to_string(),
));
}
let m_reduced = m / g;
let step_reduced = (step / g).rem_euclid(m_reduced);
let inv = mod_inverse(step_reduced.unsigned_abs(), m_reduced.unsigned_abs())
.map_err(|_| {
SpecialError::ValueError("crt: step inverse does not exist".to_string())
})?;
let t = ((diff / g % m_reduced + m_reduced) % m_reduced
* inv as i64
% m_reduced
+ m_reduced)
% m_reduced;
x += t * step;
step = step
.checked_mul(m / g)
.ok_or_else(|| SpecialError::OverflowError("crt: lcm overflows i64".to_string()))?;
}
Ok(x.rem_euclid(step))
}
pub fn next_prime(n: u64) -> u64 {
if n < 2 {
return 2;
}
let mut candidate = n + 1;
loop {
if is_prime(candidate) {
return candidate;
}
candidate += 1;
}
}
pub fn prime_factors(n: u64) -> Vec<(u64, u32)> {
if n <= 1 {
return Vec::new();
}
let flat = factorize(n);
let mut result: Vec<(u64, u32)> = Vec::new();
for p in flat {
match result.last_mut() {
Some(last) if last.0 == p => last.1 += 1,
_ => result.push((p, 1)),
}
}
result
}
pub fn euler_totient(n: u64) -> u64 {
if n == 0 {
return 0;
}
if n == 1 {
return 1;
}
let mut result = n;
let mut x = n;
let mut p = 2u64;
while p * p <= x {
if x % p == 0 {
while x % p == 0 {
x /= p;
}
result -= result / p;
}
p += 1;
}
if x > 1 {
result -= result / x;
}
result
}
pub fn mobius(n: u64) -> i32 {
if n == 0 {
return 0;
}
if n == 1 {
return 1;
}
let mut x = n;
let mut k = 0i32;
let mut p = 2u64;
while p * p <= x {
if x % p == 0 {
x /= p;
if x % p == 0 {
return 0; }
k += 1;
}
p += 1;
}
if x > 1 {
k += 1;
}
if k % 2 == 0 {
1
} else {
-1
}
}
pub fn divisor_sum(n: u64, k: u64) -> u64 {
if n == 0 {
return 0;
}
let factors = prime_factors(n);
if factors.is_empty() {
return 1; }
let mut result = 1u64;
for (p, e) in factors {
let term = if k == 0 {
u64::from(e) + 1
} else {
let mut s = 1u64;
let mut pk = 1u64;
for _ in 0..e {
pk = pk.saturating_mul(p.saturating_pow(k as u32));
s = s.saturating_add(pk);
}
s
};
result = result.saturating_mul(term);
}
result
}
pub fn jacobi_symbol(a: i64, n: u64) -> SpecialResult<i32> {
if n == 0 || n % 2 == 0 {
return Err(SpecialError::ValueError(format!(
"jacobi_symbol: n = {n} must be a positive odd integer"
)));
}
let mut a = ((a % n as i64 + n as i64) as u64) % n;
let mut n = n;
let mut result = 1i32;
loop {
if a == 0 {
return Ok(if n == 1 { result } else { 0 });
}
let v = a.trailing_zeros();
a >>= v;
if v % 2 == 1 {
let n_mod8 = n % 8;
if n_mod8 == 3 || n_mod8 == 5 {
result = -result;
}
}
if a % 4 == 3 && n % 4 == 3 {
result = -result;
}
let tmp = a;
a = n % a;
n = tmp;
if n == 1 {
return Ok(result);
}
}
}
pub fn legendre_symbol(a: i64, p: u64) -> SpecialResult<i32> {
if p < 2 || p % 2 == 0 {
return Err(SpecialError::ValueError(format!(
"legendre_symbol: p = {p} must be an odd prime"
)));
}
if !is_prime(p) {
return Err(SpecialError::ValueError(format!(
"legendre_symbol: p = {p} is not prime"
)));
}
jacobi_symbol(a, p)
}
pub fn discrete_log(g: u64, h: u64, p: u64) -> SpecialResult<u64> {
if p == 0 {
return Err(SpecialError::ValueError(
"discrete_log: modulus p must be > 0".to_string(),
));
}
if p == 1 {
return Ok(0);
}
let m = (p as f64).sqrt().ceil() as u64 + 1;
let mut table: HashMap<u64, u64> = HashMap::with_capacity(m as usize);
let mut baby = 1u64;
for j in 0..m {
table.insert(baby, j);
baby = mul_mod(baby, g, p);
}
let gm = pow_mod(g, m, p);
let gm_inv = mod_inverse(gm, p).map_err(|_| {
SpecialError::DomainError(format!(
"discrete_log: g^m = {gm} has no inverse mod {p}; g and p may not be coprime"
))
})?;
let mut giant = h % p;
for i in 0..m {
if let Some(&j) = table.get(&giant) {
let x = i * m + j;
if pow_mod(g, x, p) == h % p {
return Ok(x);
}
}
giant = mul_mod(giant, gm_inv, p);
}
Err(SpecialError::DomainError(format!(
"discrete_log: no solution found for g={g}, h={h}, p={p}"
)))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gcd() {
assert_eq!(gcd(0, 0), 0);
assert_eq!(gcd(0, 5), 5);
assert_eq!(gcd(12, 8), 4);
assert_eq!(gcd(35, 14), 7);
assert_eq!(gcd(7, 13), 1); }
#[test]
fn test_lcm() {
assert_eq!(lcm(4, 6), 12);
assert_eq!(lcm(7, 5), 35);
assert_eq!(lcm(0, 10), 0);
assert_eq!(lcm(1, 100), 100);
}
#[test]
fn test_extended_gcd() {
let (g, x, y) = extended_gcd(35, 15);
assert_eq!(g, 5);
assert_eq!(35 * x + 15 * y, g);
let (g2, x2, y2) = extended_gcd(12, 8);
assert_eq!(g2, 4);
assert_eq!(12 * x2 + 8 * y2, g2);
let (g3, x3, y3) = extended_gcd(7, 13);
assert_eq!(g3, 1);
assert_eq!(7 * x3 + 13 * y3, g3);
}
#[test]
fn test_pow_mod() {
assert_eq!(pow_mod(2, 10, 1000), 24);
assert_eq!(pow_mod(3, 0, 7), 1);
assert_eq!(pow_mod(0, 5, 7), 0);
assert_eq!(pow_mod(2, 3, 5), 3); }
#[test]
fn test_mod_inverse() {
assert_eq!(mod_inverse(3, 7).expect("ok"), 5); assert_eq!(mod_inverse(2, 5).expect("ok"), 3); assert!(mod_inverse(2, 4).is_err()); assert!(mod_inverse(0, 7).is_err()); }
#[test]
fn test_is_prime() {
assert!(!is_prime(0));
assert!(!is_prime(1));
assert!(is_prime(2));
assert!(is_prime(3));
assert!(!is_prime(4));
assert!(is_prime(5));
assert!(!is_prime(9));
assert!(is_prime(97));
assert!(is_prime(104729));
assert!(!is_prime(104730));
assert!(!is_prime(561));
}
#[test]
fn test_factorize() {
assert_eq!(factorize(1), Vec::<u64>::new());
assert_eq!(factorize(2), vec![2]);
assert_eq!(factorize(4), vec![2, 2]);
assert_eq!(factorize(12), vec![2, 2, 3]);
assert_eq!(factorize(30), vec![2, 3, 5]);
assert_eq!(factorize(97), vec![97]);
let f = factorize(9_007_199_254_740_997);
assert!(f.len() >= 2 || (f.len() == 1 && is_prime(f[0])));
}
#[test]
fn test_legendre_symbol() {
assert_eq!(legendre_symbol(0, 5).expect("ok"), 0);
assert_eq!(legendre_symbol(1, 5).expect("ok"), 1);
assert_eq!(legendre_symbol(2, 5).expect("ok"), -1);
assert_eq!(legendre_symbol(3, 5).expect("ok"), -1);
assert_eq!(legendre_symbol(4, 5).expect("ok"), 1); assert_eq!(legendre_symbol(-1, 5).expect("ok"), 1); assert!(legendre_symbol(1, 4).is_err());
}
#[test]
fn test_crt_two_congruences() {
let x = chinese_remainder_theorem(&[2, 3], &[3, 5]).expect("ok");
assert_eq!(x, 8);
assert_eq!(x % 3, 2);
assert_eq!(x % 5, 3);
}
#[test]
fn test_crt_three_congruences() {
let x = chinese_remainder_theorem(&[0, 3, 4], &[3, 4, 5]).expect("ok");
assert_eq!(x % 3, 0);
assert_eq!(x % 4, 3);
assert_eq!(x % 5, 4);
}
#[test]
fn test_crt_error_length_mismatch() {
assert!(chinese_remainder_theorem(&[1, 2], &[3]).is_err());
}
#[test]
fn test_discrete_log_basic() {
let x = discrete_log(2, 8, 11).expect("ok");
assert_eq!(pow_mod(2, x, 11), 8 % 11);
let x0 = discrete_log(3, 1, 7).expect("ok");
assert_eq!(pow_mod(3, x0, 7), 1);
}
#[test]
fn test_discrete_log_no_solution() {
let result = discrete_log(2, 3, 7);
assert!(result.is_err());
}
#[test]
fn test_next_prime() {
assert_eq!(next_prime(0), 2);
assert_eq!(next_prime(1), 2);
assert_eq!(next_prime(2), 3);
assert_eq!(next_prime(3), 5);
assert_eq!(next_prime(10), 11);
assert_eq!(next_prime(12), 13);
assert_eq!(next_prime(100), 101);
}
#[test]
fn test_prime_factors_with_exponents() {
assert_eq!(prime_factors(1), Vec::<(u64, u32)>::new());
assert_eq!(prime_factors(2), vec![(2, 1)]);
assert_eq!(prime_factors(4), vec![(2, 2)]);
assert_eq!(prime_factors(12), vec![(2, 2), (3, 1)]);
assert_eq!(prime_factors(360), vec![(2, 3), (3, 2), (5, 1)]);
assert_eq!(prime_factors(97), vec![(97, 1)]);
let n = 360u64;
let f = prime_factors(n);
let recovered: u64 = f.iter().map(|(p, e)| p.pow(*e)).product();
assert_eq!(recovered, n);
}
#[test]
fn test_euler_totient_nt() {
assert_eq!(euler_totient(1), 1);
assert_eq!(euler_totient(2), 1);
assert_eq!(euler_totient(6), 2);
assert_eq!(euler_totient(9), 6);
assert_eq!(euler_totient(12), 4);
assert_eq!(euler_totient(0), 0);
}
#[test]
fn test_mobius_nt() {
assert_eq!(mobius(1), 1);
assert_eq!(mobius(2), -1);
assert_eq!(mobius(4), 0); assert_eq!(mobius(6), 1); assert_eq!(mobius(30), -1); assert_eq!(mobius(0), 0);
}
#[test]
fn test_divisor_sum() {
assert_eq!(divisor_sum(1, 0), 1);
assert_eq!(divisor_sum(6, 0), 4); assert_eq!(divisor_sum(12, 0), 6); assert_eq!(divisor_sum(1, 1), 1);
assert_eq!(divisor_sum(6, 1), 12); assert_eq!(divisor_sum(12, 1), 28); assert_eq!(divisor_sum(6, 2), 50);
}
#[test]
fn test_jacobi_symbol() {
assert_eq!(jacobi_symbol(0, 5).expect("ok"), 0);
assert_eq!(jacobi_symbol(1, 5).expect("ok"), 1);
assert_eq!(jacobi_symbol(2, 5).expect("ok"), -1);
assert_eq!(jacobi_symbol(4, 5).expect("ok"), 1);
assert_eq!(jacobi_symbol(5, 15).expect("ok"), 0); assert_eq!(jacobi_symbol(1, 15).expect("ok"), 1);
assert!(jacobi_symbol(1, 4).is_err());
assert!(jacobi_symbol(1, 0).is_err());
}
}
#[inline]
pub fn euler_phi(n: u64) -> u64 {
euler_totient(n)
}
pub fn von_mangoldt(n: u64) -> f64 {
if n <= 1 {
return 0.0;
}
let factors = prime_factors(n);
if factors.len() == 1 {
(factors[0].0 as f64).ln()
} else {
0.0
}
}
pub fn ramanujan_sum(k: u64, n: u64) -> i64 {
if k == 0 {
return 0;
}
if k == 1 {
return 1;
}
let g = gcd(k, n);
let k_over_g = k / g;
let mu_val = mobius(k_over_g) as i64;
if mu_val == 0 {
return 0;
}
let phi_k = euler_totient(k) as i64;
let phi_k_over_g = euler_totient(k_over_g) as i64;
if phi_k_over_g == 0 {
return 0;
}
mu_val * phi_k / phi_k_over_g
}
pub fn jordan_totient(n: u64, k: u32) -> u64 {
if n == 0 {
return 0;
}
if k == 0 {
return 1;
}
let factors = prime_factors(n);
if factors.is_empty() {
return 1; }
let mut result: u64 = 1;
for (p, e) in factors {
let pk = p.saturating_pow(k); let pke = pk.saturating_pow(e); let pk_e_minus_1 = pk.saturating_pow(e - 1); let term = pk_e_minus_1.saturating_mul(pk.saturating_sub(1));
result = result.saturating_mul(term);
}
result
}
pub fn arithmetic_derivative(n: u64) -> u64 {
if n <= 1 {
return 0;
}
let factors = prime_factors(n);
if factors.is_empty() {
return 0;
}
let mut result: u64 = 0;
for (p, e) in &factors {
let n_div_p = n / p;
let contribution = n_div_p.saturating_mul(*e as u64);
result = result.saturating_add(contribution);
}
result
}
pub fn prime_zeta(s: f64, n_primes: usize) -> f64 {
if n_primes == 0 {
return 0.0;
}
let mut sum = 0.0f64;
let mut p = 2u64;
let mut count = 0usize;
while count < n_primes {
sum += (p as f64).powf(-s);
p = next_prime(p);
count += 1;
}
sum
}
pub fn euler_product(s: f64, n_primes: usize) -> f64 {
if n_primes == 0 {
return 1.0;
}
let mut product = 1.0f64;
let mut p = 2u64;
let mut count = 0usize;
while count < n_primes {
let p_s = (p as f64).powf(s);
let factor = p_s / (p_s - 1.0);
product *= factor;
p = next_prime(p);
count += 1;
}
product
}
#[cfg(test)]
mod advanced_tests {
use super::*;
#[test]
fn test_euler_phi_is_alias() {
for n in 0..=50u64 {
assert_eq!(euler_phi(n), euler_totient(n), "n = {n}");
}
}
#[test]
fn test_von_mangoldt() {
assert_eq!(von_mangoldt(1), 0.0);
for p in [2u64, 3, 5, 7, 11, 13, 97] {
let val = von_mangoldt(p);
assert!(
(val - (p as f64).ln()).abs() < 1e-14,
"p = {p}: {val}"
);
}
for (p, p2) in [(2u64, 4u64), (3, 9), (5, 25)] {
let val = von_mangoldt(p2);
assert!(
(val - (p as f64).ln()).abs() < 1e-14,
"p^2 = {p2}: {val}"
);
}
for c in [6u64, 10, 12, 15, 30] {
assert_eq!(von_mangoldt(c), 0.0, "composite {c}");
}
}
#[test]
fn test_ramanujan_sum_special_cases() {
for n in 0..=10u64 {
assert_eq!(ramanujan_sum(1, n), 1, "c_1({n})");
}
for k in 1..=10u64 {
assert_eq!(
ramanujan_sum(k, 0),
euler_totient(k) as i64,
"c_{k}(0)"
);
}
for k in 1..=10u64 {
assert_eq!(
ramanujan_sum(k, 1),
mobius(k) as i64,
"c_{k}(1)"
);
}
assert_eq!(ramanujan_sum(5, 1), -1);
assert_eq!(ramanujan_sum(6, 1), 1);
assert_eq!(ramanujan_sum(4, 2), -2);
}
#[test]
fn test_jordan_totient_k1() {
for n in 1..=20u64 {
assert_eq!(
jordan_totient(n, 1),
euler_totient(n),
"J_1({n})"
);
}
}
#[test]
fn test_jordan_totient_k0() {
for n in 1..=10u64 {
assert_eq!(jordan_totient(n, 0), 1, "J_0({n})");
}
}
#[test]
fn test_jordan_totient_known() {
assert_eq!(jordan_totient(6, 2), 24);
}
#[test]
fn test_arithmetic_derivative() {
assert_eq!(arithmetic_derivative(0), 0);
assert_eq!(arithmetic_derivative(1), 0);
assert_eq!(arithmetic_derivative(2), 1);
assert_eq!(arithmetic_derivative(3), 1);
assert_eq!(arithmetic_derivative(4), 4); assert_eq!(arithmetic_derivative(6), 5); assert_eq!(arithmetic_derivative(9), 6); let a = 5u64;
let b = 7u64;
let ab_d = arithmetic_derivative(a * b);
let manual = arithmetic_derivative(a) * b + a * arithmetic_derivative(b);
assert_eq!(ab_d, manual, "(5·7)' = {ab_d}, manual = {manual}");
}
#[test]
fn test_prime_zeta_s2() {
let val = prime_zeta(2.0, 10000);
assert!((val - 0.45224742).abs() < 1e-4, "P(2) = {val}");
}
#[test]
fn test_euler_product_s2() {
let expected = std::f64::consts::PI.powi(2) / 6.0;
let val = euler_product(2.0, 5000);
assert!((val - expected).abs() < 0.01, "ζ(2) ≈ {val}, expected {expected}");
}
#[test]
fn test_euler_product_s4() {
let expected = std::f64::consts::PI.powi(4) / 90.0;
let val = euler_product(4.0, 5000);
assert!((val - expected).abs() < 0.001, "ζ(4) ≈ {val}, expected {expected}");
}
}