use crate::error::{LinalgError, LinalgResult};
pub fn is_prime(n: u64) -> bool {
if n < 2 {
return false;
}
if n < 4 {
return true;
}
if n % 2 == 0 || n % 3 == 0 {
return false;
}
let witnesses: &[u64] = &[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37];
let mut d = n - 1;
let mut r = 0u32;
while d % 2 == 0 {
d /= 2;
r += 1;
}
'outer: for &a in witnesses {
if a >= n {
continue;
}
let mut x = mod_pow(a, d, n);
if x == 1 || x == n - 1 {
continue;
}
for _ in 0..r - 1 {
x = mul_mod(x, x, n);
if x == n - 1 {
continue 'outer;
}
}
return false;
}
true
}
pub fn primes_up_to(limit: usize) -> Vec<usize> {
if limit < 2 {
return vec![];
}
let mut is_composite = vec![false; limit + 1];
is_composite[0] = true;
is_composite[1] = true;
let mut p = 2;
while p * p <= limit {
if !is_composite[p] {
let mut multiple = p * p;
while multiple <= limit {
is_composite[multiple] = true;
multiple += p;
}
}
p += 1;
}
(2..=limit).filter(|&i| !is_composite[i]).collect()
}
pub fn gcd(mut a: u64, mut b: u64) -> u64 {
while b != 0 {
let t = b;
b = a % b;
a = t;
}
a
}
pub fn lcm(a: u64, b: u64) -> u64 {
if a == 0 || b == 0 {
return 0;
}
a / gcd(a, b) * b
}
pub fn extended_gcd(a: i64, b: i64) -> (i64, i64, i64) {
if b == 0 {
if a < 0 {
return (-a, -1, 0);
}
return (a, 1, 0);
}
let (g, x1, y1) = extended_gcd(b, a % b);
(g, y1, x1 - (a / b) * y1)
}
pub fn mod_pow(mut base: u64, mut exp: u64, modulus: u64) -> u64 {
if modulus == 1 {
return 0;
}
let mut result: u64 = 1;
base %= modulus;
while exp > 0 {
if exp % 2 == 1 {
result = mul_mod(result, base, modulus);
}
exp /= 2;
base = mul_mod(base, base, modulus);
}
result
}
pub fn mod_inverse(a: i64, m: i64) -> Option<i64> {
if m <= 1 {
return None;
}
let (g, x, _) = extended_gcd(a.rem_euclid(m), m);
if g != 1 {
return None;
}
Some(x.rem_euclid(m))
}
pub fn crt(remainders: &[i64], moduli: &[i64]) -> Option<i64> {
if remainders.len() != moduli.len() || remainders.is_empty() {
return None;
}
let m: i64 = moduli.iter().product();
let mut x: i64 = 0;
for (&r, &mi) in remainders.iter().zip(moduli.iter()) {
let big_m = m / mi;
let inv = mod_inverse(big_m % mi, mi)?;
x = (x as i128 + r as i128 * big_m as i128 * inv as i128).rem_euclid(m as i128) as i64;
}
Some(x)
}
pub fn euler_totient(mut n: u64) -> u64 {
if n == 0 {
return 0;
}
let original = n;
let mut result = n;
let mut p = 2u64;
while p * p <= n {
if n % p == 0 {
while n % p == 0 {
n /= p;
}
result -= result / p;
}
p += 1;
}
if n > 1 {
result -= result / n;
}
let _ = original;
result
}
pub fn prime_factorization(mut n: u64) -> Vec<(u64, u32)> {
let mut factors = Vec::new();
let mut p = 2u64;
while p * p <= n {
if n % p == 0 {
let mut exp = 0u32;
while n % p == 0 {
n /= p;
exp += 1;
}
factors.push((p, exp));
}
p += 1;
}
if n > 1 {
factors.push((n, 1));
}
factors
}
pub fn legendre_symbol(a: i64, p: i64) -> i32 {
if p <= 1 {
return 0;
}
let a_mod = a.rem_euclid(p) as u64;
let p_u64 = p as u64;
if a_mod == 0 {
return 0;
}
let val = mod_pow(a_mod, (p_u64 - 1) / 2, p_u64);
if val == 1 {
1
} else {
-1
}
}
pub fn is_quadratic_residue(a: u64, p: u64) -> bool {
if p == 2 {
return true;
}
let a_mod = a % p;
if a_mod == 0 {
return true;
}
mod_pow(a_mod, (p - 1) / 2, p) == 1
}
pub fn sqrt_mod_prime(n: u64, p: u64) -> Option<u64> {
let n = n % p;
if n == 0 {
return Some(0);
}
if !is_quadratic_residue(n, p) {
return None;
}
if p % 4 == 3 {
return Some(mod_pow(n, (p + 1) / 4, p));
}
let mut q = p - 1;
let mut s = 0u32;
while q % 2 == 0 {
q /= 2;
s += 1;
}
let z = (2..p).find(|&z| !is_quadratic_residue(z, p)).unwrap_or(2);
let mut m = s;
let mut c = mod_pow(z, q, p);
let mut t = mod_pow(n, q, p);
let mut r = mod_pow(n, (q + 1) / 2, p);
loop {
if t == 1 {
return Some(r);
}
let mut i = 1u32;
let mut tmp = mul_mod(t, t, p);
while tmp != 1 && i < m {
tmp = mul_mod(tmp, tmp, p);
i += 1;
}
if i == m {
return None; }
let b = mod_pow(c, mod_pow(2, (m - i - 1) as u64, p - 1), p);
m = i;
c = mul_mod(b, b, p);
t = mul_mod(t, c, p);
r = mul_mod(r, b, p);
}
}
pub fn ntt(a: &mut Vec<i64>, invert: bool, modulus: i64, primitive_root: i64) {
let n = a.len();
let mut j = 0usize;
for i in 1..n {
let mut bit = n >> 1;
while j & bit != 0 {
j ^= bit;
bit >>= 1;
}
j ^= bit;
if i < j {
a.swap(i, j);
}
}
let mut len = 2usize;
while len <= n {
let w = if invert {
mod_pow_i64(primitive_root, modulus - 1 - (modulus - 1) / len as i64, modulus)
} else {
mod_pow_i64(primitive_root, (modulus - 1) / len as i64, modulus)
};
let mut i = 0;
while i < n {
let mut wn = 1i64;
for jj in 0..len / 2 {
let u = a[i + jj];
let v = (a[i + jj + len / 2] as i128 * wn as i128).rem_euclid(modulus as i128) as i64;
a[i + jj] = (u + v).rem_euclid(modulus);
a[i + jj + len / 2] = (u - v).rem_euclid(modulus);
wn = (wn as i128 * w as i128).rem_euclid(modulus as i128) as i64;
}
i += len;
}
len <<= 1;
}
if invert {
let n_inv = mod_pow_i64(n as i64, modulus - 2, modulus);
for x in a.iter_mut() {
*x = (*x as i128 * n_inv as i128).rem_euclid(modulus as i128) as i64;
}
}
}
pub fn ntt_multiply(a: &[i64], b: &[i64], modulus: i64, primitive_root: i64) -> Vec<i64> {
let result_len = a.len() + b.len() - 1;
let n = result_len.next_power_of_two();
let mut fa: Vec<i64> = a.iter().map(|&x| x.rem_euclid(modulus)).collect();
fa.resize(n, 0);
let mut fb: Vec<i64> = b.iter().map(|&x| x.rem_euclid(modulus)).collect();
fb.resize(n, 0);
ntt(&mut fa, false, modulus, primitive_root);
ntt(&mut fb, false, modulus, primitive_root);
for i in 0..n {
fa[i] = (fa[i] as i128 * fb[i] as i128).rem_euclid(modulus as i128) as i64;
}
ntt(&mut fa, true, modulus, primitive_root);
fa.truncate(result_len);
fa
}
pub fn lll_reduce(basis: &[Vec<f64>], delta: f64) -> Vec<Vec<f64>> {
if basis.is_empty() {
return vec![];
}
let n = basis.len();
let d = basis[0].len();
let mut b: Vec<Vec<f64>> = basis.to_vec();
let mut b_star: Vec<Vec<f64>> = vec![vec![0.0; d]; n];
let mut mu: Vec<Vec<f64>> = vec![vec![0.0; n]; n];
let dot = |u: &[f64], v: &[f64]| -> f64 { u.iter().zip(v).map(|(a, b)| a * b).sum() };
let norm_sq = |u: &[f64]| -> f64 { u.iter().map(|x| x * x).sum() };
let gram_schmidt = |b: &[Vec<f64>], b_star: &mut Vec<Vec<f64>>, mu: &mut Vec<Vec<f64>>| {
let n = b.len();
let d = b[0].len();
for i in 0..n {
b_star[i] = b[i].clone();
for j in 0..i {
let mu_ij = dot(&b[i], &b_star[j]) / dot(&b_star[j], &b_star[j]).max(1e-300);
mu[i][j] = mu_ij;
for k in 0..d {
b_star[i][k] -= mu_ij * b_star[j][k];
}
}
}
};
gram_schmidt(&b, &mut b_star, &mut mu);
let mut k = 1usize;
while k < n {
for j in (0..k).rev() {
let mu_kj = mu[k][j];
if mu_kj.abs() > 0.5 {
let rounded = mu_kj.round();
for l in 0..d {
let bj_l = b[j][l];
b[k][l] -= rounded * bj_l;
}
gram_schmidt(&b, &mut b_star, &mut mu);
}
}
let lhs = norm_sq(&b_star[k]);
let rhs = (delta - mu[k][k - 1].powi(2)) * norm_sq(&b_star[k - 1]);
if lhs >= rhs {
k += 1;
} else {
b.swap(k, k - 1);
gram_schmidt(&b, &mut b_star, &mut mu);
if k > 1 {
k -= 1;
}
}
}
b
}
#[inline(always)]
fn mul_mod(a: u64, b: u64, m: u64) -> u64 {
((a as u128 * b as u128) % m as u128) as u64
}
fn mod_pow_i64(base: i64, exp: i64, modulus: i64) -> i64 {
let b = base.rem_euclid(modulus) as u64;
let e = exp.rem_euclid(modulus - 1) as u64; mod_pow(b, e, modulus as u64) as i64
}
pub fn crt_result(remainders: &[i64], moduli: &[i64]) -> LinalgResult<i64> {
crt(remainders, moduli).ok_or_else(|| {
LinalgError::ComputationError("CRT: no solution exists (moduli not pairwise coprime?)".into())
})
}
pub fn sqrt_mod_prime_result(n: u64, p: u64) -> LinalgResult<u64> {
sqrt_mod_prime(n, p).ok_or_else(|| {
LinalgError::ComputationError(format!("{} is not a quadratic residue mod {}", n, p))
})
}
pub fn mod_inverse_result(a: i64, m: i64) -> LinalgResult<i64> {
mod_inverse(a, m).ok_or_else(|| {
LinalgError::ComputationError(format!("No modular inverse for {} mod {}", a, m))
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_prime() {
assert!(!is_prime(0));
assert!(!is_prime(1));
assert!(is_prime(2));
assert!(is_prime(3));
assert!(!is_prime(4));
assert!(is_prime(7_919));
assert!(!is_prime(7_921)); assert!(is_prime(998_244_353)); assert!(is_prime(1_000_000_007));
}
#[test]
fn test_primes_up_to() {
assert_eq!(primes_up_to(0), Vec::<usize>::new());
assert_eq!(primes_up_to(1), Vec::<usize>::new());
assert_eq!(primes_up_to(2), vec![2]);
assert_eq!(primes_up_to(10), vec![2, 3, 5, 7]);
assert_eq!(primes_up_to(20), vec![2, 3, 5, 7, 11, 13, 17, 19]);
}
#[test]
fn test_gcd_lcm() {
assert_eq!(gcd(0, 5), 5);
assert_eq!(gcd(48, 18), 6);
assert_eq!(lcm(4, 6), 12);
assert_eq!(lcm(0, 5), 0);
}
#[test]
fn test_extended_gcd() {
let (g, x, y) = extended_gcd(30, 12);
assert_eq!(g, 6);
assert_eq!(30 * x + 12 * y, g);
let (g2, x2, y2) = extended_gcd(7, 3);
assert_eq!(g2, 1);
assert_eq!(7 * x2 + 3 * y2, 1);
}
#[test]
fn test_mod_pow() {
assert_eq!(mod_pow(2, 10, 1000), 24);
assert_eq!(mod_pow(3, 0, 5), 1);
assert_eq!(mod_pow(2, 62, 1_000_000_007), 145_586_002);
}
#[test]
fn test_mod_inverse() {
assert_eq!(mod_inverse(3, 7), Some(5));
assert_eq!(mod_inverse(2, 4), None);
assert_eq!((3 * 5) % 7, 1);
}
#[test]
fn test_crt() {
let x = crt(&[2, 3, 2], &[3, 5, 7]).expect("failed to create x");
assert_eq!(x % 3, 2);
assert_eq!(x % 5, 3);
assert_eq!(x % 7, 2);
}
#[test]
fn test_euler_totient() {
assert_eq!(euler_totient(1), 1);
assert_eq!(euler_totient(6), 2);
assert_eq!(euler_totient(7), 6);
assert_eq!(euler_totient(12), 4);
}
#[test]
fn test_prime_factorization() {
assert_eq!(prime_factorization(1), Vec::<(u64, u32)>::new());
assert_eq!(prime_factorization(12), vec![(2, 2), (3, 1)]);
assert_eq!(prime_factorization(2), vec![(2, 1)]);
assert_eq!(prime_factorization(360), vec![(2, 3), (3, 2), (5, 1)]);
}
#[test]
fn test_legendre_symbol() {
assert_eq!(legendre_symbol(2, 7), 1);
assert_eq!(legendre_symbol(3, 7), -1);
assert_eq!(legendre_symbol(7, 7), 0);
}
#[test]
fn test_sqrt_mod_prime() {
let r = sqrt_mod_prime(2, 7).expect("failed to create r");
assert_eq!((r * r) % 7, 2);
assert!(sqrt_mod_prime(3, 7).is_none());
let r2 = sqrt_mod_prime(4, 7).expect("failed to create r2");
assert_eq!((r2 * r2) % 7, 4);
}
#[test]
fn test_ntt_roundtrip() {
let modulus: i64 = 998_244_353;
let g: i64 = 3;
let original = vec![1i64, 2, 3, 4];
let mut a = original.clone();
ntt(&mut a, false, modulus, g);
ntt(&mut a, true, modulus, g);
assert_eq!(a, original);
}
#[test]
fn test_ntt_multiply() {
let modulus: i64 = 998_244_353;
let g: i64 = 3;
let prod = ntt_multiply(&[1, 2, 3], &[4, 5, 6], modulus, g);
assert_eq!(prod, vec![4, 13, 28, 27, 18]);
}
#[test]
fn test_lll_reduce() {
let basis = vec![
vec![1.0_f64, 1.0, 1.0],
vec![-1.0, 0.0, 2.0],
vec![3.0, 5.0, 6.0],
];
let reduced = lll_reduce(&basis, 0.75);
assert_eq!(reduced.len(), 3);
let orig_len: f64 = basis[2].iter().map(|x| x * x).sum::<f64>().sqrt();
let red0_len: f64 = reduced[0].iter().map(|x| x * x).sum::<f64>().sqrt();
assert!(red0_len <= orig_len + 1e-9);
}
}