use crypto_bigint::modular::MontyForm;
use crypto_bigint::modular::MontyParams;
use crypto_bigint::rand_core::RngCore;
use crypto_bigint::subtle::Choice;
use crypto_bigint::Concat;
use crypto_bigint::RandomBits;
use crypto_bigint::Split;
use crypto_bigint::{Integer, Odd, RandomMod, Uint};
pub fn is_composite<const LIMBS: usize, const LIMBS_DOUBLE: usize>(
p: Uint<LIMBS>,
t: u32,
rng: &mut impl RngCore,
) -> Choice
where
Uint<LIMBS>: Concat<Output = Uint<LIMBS_DOUBLE>>,
Uint<LIMBS_DOUBLE>: Split<Output = Uint<LIMBS>>,
{
assert!(t > 0);
#[allow(non_snake_case)]
let FALSE = Choice::from(0);
#[allow(non_snake_case)]
let TRUE = Choice::from(1);
let two = two::<LIMBS>();
let three = two + Uint::ONE;
if &p == &two || &p == &three {
return FALSE;
}
let p_minus_1 = p - Uint::ONE;
let p_minus_3 = p_minus_1 - Uint::ONE - Uint::ONE;
if p.is_even().into() {
return TRUE;
}
let mut s = Uint::<LIMBS>::ZERO;
let mut q = p - Uint::<LIMBS>::ONE;
while q.is_even().unwrap_u8() == 1u8 {
q = q / two;
s += Uint::ONE;
}
't: for _ in 0..t {
let mut a: Uint<LIMBS> = Uint::ZERO;
while a.is_even().unwrap_u8() == 1u8 && a < two {
a = Uint::<LIMBS>::random_mod(rng, &p_minus_3.to_nz().unwrap()) + two;
}
let mut a = MontyForm::new(&a, MontyParams::new(Odd::new(p).unwrap()));
a = a.pow(&q);
if &a.retrieve() == &Uint::ONE {
continue;
}
let mut current_s = s;
while current_s > Uint::ZERO {
a = a.square();
if &a.retrieve() == &p_minus_1 || &a.retrieve() == &Uint::<LIMBS>::ONE {
continue 't;
}
current_s = current_s - Uint::ONE;
}
return TRUE;
}
FALSE
}
#[inline(always)]
fn two<const LIMBS: usize>() -> Uint<LIMBS> {
Uint::<LIMBS>::ONE + Uint::<LIMBS>::ONE
}
pub fn generate_probable_prime<const LIMBS: usize, const LIMBS_DOUBLE: usize, R: RngCore>(
bits: u32,
t: u32,
mut rng: &mut R,
) -> Uint<LIMBS>
where
Uint<LIMBS>: Concat<Output = Uint<LIMBS_DOUBLE>>,
Uint<LIMBS_DOUBLE>: Split<Output = Uint<LIMBS>>,
{
assert!(Uint::<LIMBS>::BITS >= bits, "bits are larger than limbs");
let two = Uint::<LIMBS>::ONE + Uint::<LIMBS>::ONE;
let mut p = Uint::<LIMBS>::ZERO;
while p.is_even().unwrap_u8() == 1u8 {
p = Uint::<LIMBS>::random_bits(&mut rng, bits);
}
loop {
if is_composite(p, t, &mut rng).unwrap_u8() == 0u8 {
return p;
}
p += two;
}
}
#[cfg(test)]
mod tests;