use num_traits::{Zero, One, FromPrimitive, PrimInt, Signed};
use std::mem::swap;
pub fn primitive_root(prime: u64) -> Option<u64> {
let test_exponents: Vec<u64> = distinct_prime_factors(prime - 1)
.iter()
.map(|factor| (prime - 1) / factor)
.collect();
'next: for potential_root in 2..prime {
for exp in &test_exponents {
if modular_exponent(potential_root, *exp, prime) == 1 {
continue 'next;
}
}
return Some(potential_root);
}
None
}
pub fn modular_exponent<T: PrimInt>(mut base: T, mut exponent: T, modulo: T) -> T {
let one = T::one();
let mut result = one;
while exponent > Zero::zero() {
if exponent & one == one {
result = result * base % modulo;
}
exponent = exponent >> One::one();
base = (base * base) % modulo;
}
result
}
pub fn multiplicative_inverse<T: PrimInt + FromPrimitive>(a: T, n: T) -> T {
let mut t = Zero::zero();
let mut t_new = One::one();
let mut r = n;
let mut r_new = a;
while r_new > Zero::zero() {
let quotient = r / r_new;
r = r - quotient * r_new;
swap(&mut r, &mut r_new);
let t_subtract = quotient * t_new;
t = if t_subtract < t {
t - t_subtract
} else {
n - (t_subtract - t) % n
};
swap(&mut t, &mut t_new);
}
t
}
pub fn extended_euclidean_algorithm<T: PrimInt + Signed + FromPrimitive>(a: T,
b: T)
-> (T, T, T) {
let mut s = Zero::zero();
let mut s_old = One::one();
let mut t = One::one();
let mut t_old = Zero::zero();
let mut r = b;
let mut r_old = a;
while r > Zero::zero() {
let quotient = r_old / r;
r_old = r_old - quotient * r;
swap(&mut r_old, &mut r);
s_old = s_old - quotient * s;
swap(&mut s_old, &mut s);
t_old = t_old - quotient * t;
swap(&mut t_old, &mut t);
}
(r_old, s_old, t_old)
}
pub fn distinct_prime_factors(mut n: u64) -> Vec<u64> {
let mut result = Vec::new();
if n % 2 == 0 {
while n % 2 == 0 {
n /= 2;
}
result.push(2);
}
if n > 1 {
let mut divisor = 3;
let mut limit = (n as f32).sqrt() as u64 + 1;
while divisor < limit {
if n % divisor == 0 {
while n % divisor == 0 {
n /= divisor;
}
result.push(divisor);
limit = (n as f32).sqrt() as u64 + 1;
}
divisor += 2;
}
if n > 1 {
result.push(n);
}
}
result
}
pub fn prime_factors(mut n: usize) -> Vec<usize> {
let mut result = Vec::new();
while n % 2 == 0 {
n /= 2;
result.push(2);
}
if n > 1 {
let mut divisor = 3;
let mut limit = (n as f32).sqrt() as usize + 1;
while divisor < limit {
while n % divisor == 0 {
n /= divisor;
result.push(divisor);
}
limit = (n as f32).sqrt() as usize + 1;
divisor += 2;
}
if n > 1 {
result.push(n);
}
}
result
}
#[cfg(test)]
mod unit_tests {
use super::*;
#[test]
fn test_modular_exponent() {
let test_list = vec![
((2,8,300), 256),
((2,9,300), 212),
((1,9,300), 1),
((3,416788,47), 8),
];
for (input, expected) in test_list {
let (base, exponent, modulo) = input;
let result = modular_exponent(base, exponent, modulo);
assert_eq!(result, expected);
}
}
#[test]
fn test_multiplicative_inverse() {
let prime_list = vec![3, 5, 7, 11, 13, 17, 19, 23, 29];
for modulo in prime_list {
for i in 2..modulo {
let inverse = multiplicative_inverse(i, modulo);
assert_eq!(i * inverse % modulo, 1);
}
}
}
#[test]
fn test_extended_euclidean() {
let test_list = vec![
((3,5), (1, 2, -1)),
((15,12), (3, 1, -1)),
((16,21), (1, 4, -3)),
];
for (input, expected) in test_list {
let (a, b) = input;
let result = extended_euclidean_algorithm(a, b);
assert_eq!(expected, result);
let (gcd, mut a_inverse, mut b_inverse) = result;
if gcd == 1 {
if a_inverse < 0 {
a_inverse += b;
}
if b_inverse < 0 {
b_inverse += a;
}
assert_eq!(1, a * a_inverse % b);
assert_eq!(1, b * b_inverse % a);
}
}
}
#[test]
fn test_primitive_root() {
let test_list = vec![(3, 2), (7, 3), (11, 2), (13, 2), (47, 5), (7919, 7)];
for (input, expected) in test_list {
let root = primitive_root(input).unwrap();
assert_eq!(root, expected);
}
}
#[test]
fn test_prime_factors() {
let test_list = vec![
(46, vec![2,23]),
(2, vec![2]),
(3, vec![3]),
(162, vec![2, 3]),
];
for (input, expected) in test_list {
let factors = distinct_prime_factors(input);
assert_eq!(factors, expected);
}
}
}