use std::str::FromStr;
use rand::Rng;
use num_bigint::{ToBigUint, BigUint, RandBigInt, BigInt, Sign};
use num::{Zero, One, Integer};
use crate::helpers::generics::*;
const DISCARTERS: [u8; 7] = [3, 5, 7, 11, 13, 17, 19];
pub fn gen_big_num(bit_len: &u32) -> BigUint {
let mut rng = rand::thread_rng();
let a = rng.gen_biguint(bit_len.to_owned() as usize);
a
}
pub fn gen_big_prime(size: &u32, threshold: u32) -> BigUint {
let mut proposal = gen_big_num(size);
if proposal.is_even() {
proposal = proposal + BigUint::one();
}
while !is_prime(&proposal, threshold) {
proposal = proposal + 2.to_biguint().unwrap();
}
proposal
}
pub fn is_prime(proposal: &BigUint, threshold: u32) -> bool {
if !rabin_miller(proposal, threshold) {return false}
true
}
fn rabin_miller(proposal: &BigUint, t: u32) -> bool {
let (z, o, tw) = gen_basic_biguints();
let (zero, one, two) = (&z, &o, &tw);
if proposal.clone() <= one.to_owned() {return false};
if proposal.clone() != two.to_owned() && proposal.clone() % two == zero.to_owned() {return false};
let discarts: Vec<bool> = DISCARTERS.into_iter().map(|x| (proposal % x.to_biguint().unwrap()).is_zero()).collect();
println!("{:?}", discarts);
for result in discarts {
if result == true {return false}
}
let (s,d) = refactor(proposal);
let mut counter = 0;
while counter < t {
let mut rng = rand::thread_rng();
let a = rng.gen_biguint_range(&two , &(proposal - two) );
let mut x = mod_exp_pow(&a, &d, proposal);
if x != one.to_owned() && x != proposal.to_owned() - one {
let mut i = zero.clone();
loop {
x = mod_exp_pow(&x, &two, proposal);
if x == proposal.to_owned() - one {break;}
if x == one.to_owned() || i >= s.clone()- one{return false;};
i = i.clone() + one;
}
}
counter +=2;
}
true
}
#[cfg(test)]
#[test]
fn rabin_miller_works() {
let res = rabin_miller(&179425357u32.to_biguint().unwrap(), 9);
assert_eq!(res, true);
let res2 = rabin_miller(&82589933u32.to_biguint().unwrap(), 64);
assert_eq!(res2, true);
let known_prime_str =
"118595363679537468261258276757550704318651155601593299292198496313960907653004730006758459999825003212944725610469590674020124506249770566394260832237809252494505683255861199449482385196474342481641301503121142740933186279111209376061535491003888763334916103110474472949854230628809878558752830476310536476569";
let known_prime: BigUint = FromStr::from_str(known_prime_str).unwrap();
assert!(rabin_miller(&known_prime, 64));
assert_eq!(rabin_miller(&19u32.to_biguint().unwrap(), 9), false);
}
pub fn mod_exp_pow(base: &BigUint, exp: &BigUint, md: &BigUint) -> BigUint {
let mut res = BigUint::one();
let (zero, one, _) = gen_basic_biguints();
let (mut base, mut exponent) = (base.clone(), exp.clone());
while exponent > zero {
if exponent.clone() & one.clone() > zero {
res = (res * base.clone()) % md;
}
exponent >>= 1;
base = (base.clone() * base.clone()) % md;
}
res
}
#[cfg(test)]
#[test]
fn mod_exp_works() {
let res = mod_exp_pow(&BigUint::from(4 as u32), &BigUint::from(13 as u32), &BigUint::from(497 as u32));
assert_eq!(res, BigUint::from(445 as u32));
let res2 = mod_exp_pow(&BigUint::from(5 as u32), &BigUint::from(3 as u32), &BigUint::from(13 as u32));
assert_eq!(res2, BigUint::from(8 as u32));
}
fn refactor(n: &BigUint) -> (BigUint, BigUint) {
let (mut s, one, two) = gen_basic_biguints();
let mut d = n.clone() - one.clone();
while d.is_even() {
d = d / two.clone();
s = s + one.clone();
}
(s, d)
}
pub fn egcd<'a>(a: &'a mut BigInt, b: &'a mut BigInt) -> (BigInt, BigInt, BigInt) {
if a.to_owned() == BigInt::from(0 as u32) {
(b.clone(), BigInt::from(0 as i32), BigInt::from(1 as i32))
} else {
let mut b_mod_a = b.clone() % a.clone();
let ref_b_mod_a = &mut b_mod_a;
let (g, x, y) = egcd(ref_b_mod_a, a);
let mut b_div_a = b.clone() / a.clone();
let ref_b_div_a = &mut b_div_a;
(g, (y - ref_b_div_a.clone() * x.clone()), x)
}
}
#[cfg(test)]
#[test]
fn egcd_test() {
use num_bigint::ToBigInt;
use std::str::FromStr;
let a = &mut 179425357u32.to_bigint().unwrap();
let b = &mut 97u32.to_bigint().unwrap();
let (g, x, y) = egcd(a, b);
assert_eq!(a.clone()*x + b.clone()*y, g);
let a = &mut 1024u32.to_bigint().unwrap();
let b = &mut 512u32.to_bigint().unwrap();
let (g, _x, _y) = egcd(a, b);
assert_eq!(512u32.to_bigint().unwrap(), g);
let known_prime_str = "118595363679537468261258276757550704318651155601593299292198496313960907653004730006758459999825003212944725610469590674020124506249770566394260832237809252494505683255861199449482385196474342481641301503121142740933186279111209376061535491003888763334916103110474472949854230628809878558752830476310536476569";
let known_prime_str_2 = "357111317192329313741434753596167717379838997101103107109113127131137139149151157163167173179181191193197199211223227229233239241251257263269271277281283293307311313317331337347349353359367373379383389397401409419421431433439443449457461463467479487491499503509521523541547557563569571577587593599601607613617619631641643647653659661673677683691701709719727733739743751757761769773787797809811821823827829839853857859863877881883887907911919929937941947953967971977983991997";
let mut a: BigInt = FromStr::from_str(known_prime_str).unwrap();
let mut b: BigInt = FromStr::from_str(known_prime_str_2).unwrap();
let a_r = &mut a;
let b_r = &mut b;
let (g, x, y) = egcd(a_r, b_r);
assert_eq!(a_r.clone()*x + b_r.clone()*y, g);
}
pub fn find_e(fi_n: &BigUint) -> Result<BigUint, bool> {
let mut rng = rand::thread_rng();
let sign = Sign::Plus;
let mut fi_n = BigInt::from_biguint(sign, fi_n.clone());
let (zero, one, two) = gen_basic_bigints();
let mut a = rng.gen_bigint_range(&(fi_n.clone()/two.clone()) , &((BigInt::from(3) * fi_n.clone())/BigInt::from(4) ));
if a.is_even() {a = a + one.clone()};
let mut res = zero;
while res != one.clone() && a <= fi_n.clone() - one.clone() {
let (res2, _, _) = egcd(&mut fi_n, &mut a);
res = res2;
a = a.clone() + two.clone();
}
if res == one {
a = a.clone() - two.clone();
return Ok(biguint_from_bigint(&a).unwrap());
}
Err(false)
}