use crate::util::{blum_prime, get_inv_mod_p_minus_1_and_q_minus_1};
use crypto_bigint::{
modular::{MontyForm, MontyParams, SafeGcdInverter},
subtle::ConstantTimeGreater,
BitOps, Concat, Odd, PrecomputeInverter, Split, Uint,
};
use crypto_primes::generate_prime_with_rng;
use rand_core::CryptoRngCore;
use zeroize::{Zeroize, ZeroizeOnDrop};
#[derive(Debug, Clone, PartialEq, Eq, Zeroize, ZeroizeOnDrop)]
pub struct Primes<const LIMBS: usize> {
pub p: Odd<Uint<LIMBS>>,
pub q: Odd<Uint<LIMBS>>,
}
#[derive(Debug, Clone, PartialEq, Eq, Zeroize, ZeroizeOnDrop)]
pub struct PrimesWithPrecomp<const PRIME_LIMBS: usize, const MODULUS_LIMBS: usize> {
pub p_inv: MontyForm<PRIME_LIMBS>,
pub n_inv_p_minus_1: Uint<PRIME_LIMBS>,
pub n_inv_q_minus_1: Uint<PRIME_LIMBS>,
#[zeroize(skip)]
pub p_mtg: MontyParams<PRIME_LIMBS>,
#[zeroize(skip)]
pub q_mtg: MontyParams<PRIME_LIMBS>,
#[zeroize(skip)]
pub p_mtg_1: MontyParams<MODULUS_LIMBS>,
#[zeroize(skip)]
pub q_mtg_1: MontyParams<MODULUS_LIMBS>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Modulus<const LIMBS: usize>(pub Odd<Uint<LIMBS>>);
impl<const PRIME_LIMBS: usize> Primes<PRIME_LIMBS> {
pub fn new<R: CryptoRngCore>(rng: &mut R) -> Self {
let p: Uint<PRIME_LIMBS> = generate_prime_with_rng(rng, Uint::<PRIME_LIMBS>::BITS);
let q: Uint<PRIME_LIMBS> = generate_prime_with_rng(rng, Uint::<PRIME_LIMBS>::BITS);
Self {
p: p.to_odd().unwrap(),
q: q.to_odd().unwrap(),
}
}
pub fn new_with_blum_primes<R: CryptoRngCore>(rng: &mut R) -> Self {
let p = blum_prime::<R, PRIME_LIMBS>(rng);
let q = blum_prime::<R, PRIME_LIMBS>(rng);
Self {
p: p.to_odd().unwrap(),
q: q.to_odd().unwrap(),
}
}
pub fn from_primes(p: Odd<Uint<PRIME_LIMBS>>, q: Odd<Uint<PRIME_LIMBS>>) -> Self {
Self { p, q }
}
}
impl<const MODULUS_LIMBS: usize> Modulus<MODULUS_LIMBS> {
pub fn new<const PRIME_LIMBS: usize>(dk: &Primes<PRIME_LIMBS>) -> Self
where
Uint<PRIME_LIMBS>: Concat<Output = Uint<MODULUS_LIMBS>>,
{
const { assert!(2 * PRIME_LIMBS == MODULUS_LIMBS) };
let n: Uint<MODULUS_LIMBS> = dk.p.widening_mul(&dk.q).into();
Self(n.to_odd().unwrap())
}
pub fn size_for_hashing<'a>(&self) -> [u8; 2] {
let mut out = [0; 2];
let b = self.0.bytes_precision().to_le_bytes();
out[0] = b[0];
let mut start = 1;
if b.len() > 1 {
out[1] = b[1];
start = 2;
}
for i in start..b.len() {
assert_eq!(b[i], 0);
}
out
}
pub fn is_greater_than(&self, rhs: &Uint<MODULUS_LIMBS>) -> bool {
self.0.ct_gt(&rhs).into()
}
}
impl<
const PRIME_LIMBS: usize,
const MODULUS_LIMBS: usize,
const WIDE_MODULUS_LIMBS: usize,
const PRIME_UNSAT_LIMBS: usize,
> From<Primes<PRIME_LIMBS>> for PrimesWithPrecomp<PRIME_LIMBS, MODULUS_LIMBS>
where
Uint<PRIME_LIMBS>: Concat<Output = Uint<MODULUS_LIMBS>>,
Uint<MODULUS_LIMBS>: Split<Output = Uint<PRIME_LIMBS>>,
Uint<MODULUS_LIMBS>: Concat<Output = Uint<WIDE_MODULUS_LIMBS>>,
Uint<WIDE_MODULUS_LIMBS>: Split<Output = Uint<MODULUS_LIMBS>>,
Odd<Uint<PRIME_LIMBS>>: PrecomputeInverter<
Inverter = SafeGcdInverter<PRIME_LIMBS, PRIME_UNSAT_LIMBS>,
Output = Uint<PRIME_LIMBS>,
>,
{
fn from(primes: Primes<PRIME_LIMBS>) -> Self {
let p_mtg = MontyParams::new(primes.p);
let q_mtg = MontyParams::new(primes.q);
let p_mtg_1 = MontyParams::new(primes.p.resize::<MODULUS_LIMBS>().to_odd().unwrap());
let q_mtg_1 = MontyParams::new(primes.q.resize::<MODULUS_LIMBS>().to_odd().unwrap());
let p_inv = MontyForm::new(&primes.p, q_mtg).inv().unwrap();
let n: Uint<MODULUS_LIMBS> = primes.p.widening_mul(&primes.q).into();
let (n_inv_p_minus_1, n_inv_q_minus_1) =
get_inv_mod_p_minus_1_and_q_minus_1(&n, &primes.p, &primes.q).unwrap();
Self {
p_inv,
n_inv_p_minus_1,
n_inv_q_minus_1,
p_mtg,
q_mtg,
p_mtg_1,
q_mtg_1,
}
}
}