use std::fmt;
use rand_core::RngCore;
use rug::{Assign, Complete, Integer};
mod small_primes;
pub fn external_rand(rng: &mut impl RngCore) -> rug::rand::ThreadRandState {
use bytemuck::TransparentWrapper;
#[derive(TransparentWrapper)]
#[repr(transparent)]
pub struct ExternalRand<R>(R);
impl<R: RngCore> rug::rand::ThreadRandGen for ExternalRand<R> {
fn gen(&mut self) -> u32 {
self.0.next_u32()
}
}
rug::rand::ThreadRandState::new_custom(ExternalRand::wrap_mut(rng))
}
#[inline(always)]
pub fn in_mult_group(x: &Integer, n: &Integer) -> bool {
x.cmp0().is_ge() && in_mult_group_abs(x, n)
}
#[inline(always)]
pub fn in_mult_group_abs(x: &Integer, n: &Integer) -> bool {
x.gcd_ref(n).complete() == *Integer::ONE
}
pub fn sample_in_mult_group(rng: &mut impl RngCore, n: &Integer) -> Integer {
let mut rng = external_rand(rng);
let mut x = Integer::new();
loop {
x.assign(n.random_below_ref(&mut rng));
if in_mult_group(&x, n) {
return x;
}
}
}
pub fn generate_safe_prime(rng: &mut impl RngCore, bits: u32) -> Integer {
sieve_generate_safe_primes(rng, bits, 135)
}
pub fn sieve_generate_safe_primes(rng: &mut impl RngCore, bits: u32, amount: usize) -> Integer {
use rug::integer::IsPrime;
let amount = amount.min(small_primes::SMALL_PRIMES.len());
let mut rng = external_rand(rng);
let mut x = Integer::new();
'trial: loop {
x.assign(Integer::random_bits(bits - 1, &mut rng));
x.set_bit(bits - 2, true);
x |= 1u32;
for &small_prime in &small_primes::SMALL_PRIMES[0..amount] {
let mod_result = x.mod_u(small_prime);
if mod_result == (small_prime - 1) / 2 {
continue 'trial;
}
}
if let IsPrime::Yes | IsPrime::Probably = x.is_probably_prime(25) {
x <<= 1;
x += 1;
if let IsPrime::Yes | IsPrime::Probably = x.is_probably_prime(25) {
return x;
}
}
}
}
#[derive(Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct CrtExp {
n: Integer,
n1: Integer,
phi_n1: Integer,
n2: Integer,
phi_n2: Integer,
beta: Integer,
}
#[derive(Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Exponent {
e_mod_phi_pp: Integer,
e_mod_phi_qq: Integer,
is_negative: bool,
}
impl CrtExp {
pub fn build(n1: Integer, phi_n1: Integer, n2: Integer, phi_n2: Integer) -> Option<Self> {
if n1.cmp0().is_le()
|| n2.cmp0().is_le()
|| phi_n1.cmp0().is_le()
|| phi_n2.cmp0().is_le()
|| phi_n1 >= n1
|| phi_n2 >= n2
{
return None;
}
let beta = n1.invert_ref(&n2)?.into();
Some(Self {
n: (&n1 * &n2).complete(),
n1,
phi_n1,
n2,
phi_n2,
beta,
})
}
pub fn build_n(p: &Integer, q: &Integer) -> Option<Self> {
let phi_p = (p - 1u8).complete();
let phi_q = (q - 1u8).complete();
Self::build(p.clone(), phi_p, q.clone(), phi_q)
}
pub fn build_nn(p: &Integer, q: &Integer) -> Option<Self> {
let pp = p.square_ref().complete();
let qq = q.square_ref().complete();
let phi_pp = (&pp - p).complete();
let phi_qq = (&qq - q).complete();
Self::build(pp, phi_pp, qq, phi_qq)
}
pub fn prepare_exponent(&self, e: &Integer) -> Exponent {
let neg_e = (-e).complete();
let is_negative = e.cmp0().is_lt();
let e = if is_negative { &neg_e } else { e };
let e_mod_phi_pp = e.modulo_ref(&self.phi_n1).complete();
let e_mod_phi_qq = e.modulo_ref(&self.phi_n2).complete();
Exponent {
e_mod_phi_pp,
e_mod_phi_qq,
is_negative,
}
}
pub fn exp(&self, x: &Integer, e: &Exponent) -> Option<Integer> {
let s1 = x.modulo_ref(&self.n1).complete();
let s2 = x.modulo_ref(&self.n2).complete();
#[allow(clippy::expect_used)]
let r1 = s1
.pow_mod(&e.e_mod_phi_pp, &self.n1)
.expect("exponent is guaranteed to be non-negative");
#[allow(clippy::expect_used)]
let r2 = s2
.pow_mod(&e.e_mod_phi_qq, &self.n2)
.expect("exponent is guaranteed to be non-negative");
let result = ((r2 - &r1) * &self.beta).modulo(&self.n2) * &self.n1 + &r1;
if e.is_negative {
result.invert(&self.n).ok()
} else {
Some(result)
}
}
}
impl fmt::Debug for CrtExp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("CrtExp")
}
}
impl fmt::Debug for Exponent {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("CrtExponent")
}
}
#[cfg(test)]
mod test {
#[test]
fn safe_prime_size() {
let mut rng = rand_dev::DevRng::new();
for size in [500, 512, 513, 514] {
let mut prime = super::generate_safe_prime(&mut rng, size);
prime >>= size - 1;
assert_eq!(&prime, rug::Integer::ONE);
}
}
}