use super::random_prime;
use crate::bignum::{MontModulus, Uint, inv_mod};
use crate::ct::ConstantTimeEq;
use crate::hash::{Digest, Sha256};
use crate::rng::{CryptoRng, RngCore};
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct RsaPublicKey<const LIMBS: usize> {
n: Uint<LIMBS>,
e: Uint<LIMBS>,
}
#[derive(Clone)]
pub struct RsaPrivateKey<const LIMBS: usize> {
n: Uint<LIMBS>,
e: Uint<LIMBS>,
d: Uint<LIMBS>,
p: Uint<LIMBS>,
q: Uint<LIMBS>,
phi_n_minus_1: Uint<LIMBS>,
blinding_seed: [u8; 32],
}
fn derive_blinding<const LIMBS: usize>(
p: &Uint<LIMBS>,
q: &Uint<LIMBS>,
d: &Uint<LIMBS>,
) -> (Uint<LIMBS>, [u8; 32]) {
let p_is_zero = bool::from(p.ct_eq(&Uint::ZERO));
let q_is_zero = bool::from(q.ct_eq(&Uint::ZERO));
let phi_n_minus_1 = if p_is_zero || q_is_zero {
Uint::ZERO
} else {
let pm1 = p.wrapping_sub(&Uint::ONE);
let qm1 = q.wrapping_sub(&Uint::ONE);
let phi = pm1.mul_wide(&qm1).0;
phi.wrapping_sub(&Uint::ONE)
};
let mut h = Sha256::new();
h.update(b"purecrypto-rsa-blinding-seed-v1");
for i in 0..LIMBS {
let limb_bytes = d.as_limbs()[LIMBS - 1 - i].to_be_bytes();
h.update(&limb_bytes);
}
let digest = h.finalize();
let mut seed = [0u8; 32];
seed.copy_from_slice(digest.as_ref());
(phi_n_minus_1, seed)
}
fn raw_private_blinded<const LIMBS: usize>(
n: &Uint<LIMBS>,
e: &Uint<LIMBS>,
d: &Uint<LIMBS>,
phi_n_minus_1: &Uint<LIMBS>,
blinding_seed: &[u8; 32],
c: &Uint<LIMBS>,
) -> Uint<LIMBS> {
use crate::hash::HmacSha256;
let modulus = MontModulus::new(*n);
if bool::from(phi_n_minus_1.ct_eq(&Uint::ZERO)) {
return modulus.pow(c, d);
}
let mut r_limbs = [0u64; LIMBS];
let mut counter: u32 = 0;
let mut limbs_remaining = LIMBS;
while limbs_remaining > 0 {
let mut m = HmacSha256::new(blinding_seed);
m.update(b"r");
m.update(&counter.to_be_bytes());
for i in 0..LIMBS {
let limb_bytes = c.as_limbs()[LIMBS - 1 - i].to_be_bytes();
m.update(&limb_bytes);
}
let tag = m.finalize();
let tag_bytes = tag.as_ref();
for j in 0..4 {
if limbs_remaining == 0 {
break;
}
limbs_remaining -= 1;
let off = j * 8;
let bytes: [u8; 8] = tag_bytes[off..off + 8]
.try_into()
.expect("HMAC-SHA256 emits 32 bytes");
r_limbs[limbs_remaining] = u64::from_be_bytes(bytes);
}
counter += 1;
}
let r_raw = Uint::<LIMBS>::from_limbs(r_limbs);
let r = r_raw.reduce(n);
let r_is_zero = r.ct_eq(&Uint::ZERO);
let r_is_one = r.ct_eq(&Uint::ONE);
let bad = r_is_zero | r_is_one;
let r = <Uint<LIMBS> as crate::ct::ConditionallySelectable>::conditional_select(
&Uint::from_u64(2),
&r,
bad,
);
let r_e = modulus.pow(&r, e);
let r_inv = modulus.pow(&r, phi_n_minus_1);
let c_blind = modulus.mul_mod(c, &r_e);
let m_blind = modulus.pow(&c_blind, d);
modulus.mul_mod(&m_blind, &r_inv)
}
impl<const LIMBS: usize> RsaPublicKey<LIMBS> {
pub fn new(n: Uint<LIMBS>, e: Uint<LIMBS>) -> Self {
RsaPublicKey { n, e }
}
#[inline]
pub fn modulus(&self) -> &Uint<LIMBS> {
&self.n
}
#[inline]
pub fn exponent(&self) -> &Uint<LIMBS> {
&self.e
}
pub fn raw(&self, m: &Uint<LIMBS>) -> Uint<LIMBS> {
MontModulus::new(self.n).pow(m, &self.e)
}
}
impl<const LIMBS: usize> Drop for RsaPrivateKey<LIMBS> {
fn drop(&mut self) {
self.d = Uint::ZERO;
self.p = Uint::ZERO;
self.q = Uint::ZERO;
self.phi_n_minus_1 = Uint::ZERO;
for b in self.blinding_seed.iter_mut() {
*b = 0;
}
let _ = core::hint::black_box(&self.d);
let _ = core::hint::black_box(&self.p);
let _ = core::hint::black_box(&self.q);
let _ = core::hint::black_box(&self.phi_n_minus_1);
let _ = core::hint::black_box(&self.blinding_seed);
}
}
impl<const LIMBS: usize> RsaPrivateKey<LIMBS> {
pub fn generate<R: RngCore + CryptoRng>(e: Uint<LIMBS>, rng: &mut R, rounds: usize) -> Self {
let half_bits = LIMBS * 32;
loop {
let p = random_prime::<LIMBS, R>(rng, half_bits, rounds);
let q = random_prime::<LIMBS, R>(rng, half_bits, rounds);
if p == q {
continue;
}
let n = p.mul_wide(&q).0; let phi = p
.wrapping_sub(&Uint::ONE)
.mul_wide(&q.wrapping_sub(&Uint::ONE))
.0;
if let Some(d) = inv_mod(&e, &phi) {
let (phi_n_minus_1, blinding_seed) = derive_blinding(&p, &q, &d);
return RsaPrivateKey {
n,
e,
d,
p,
q,
phi_n_minus_1,
blinding_seed,
};
}
}
}
pub fn from_components(n: Uint<LIMBS>, e: Uint<LIMBS>, d: Uint<LIMBS>) -> Self {
let (phi_n_minus_1, blinding_seed) =
derive_blinding(&Uint::<LIMBS>::ZERO, &Uint::<LIMBS>::ZERO, &d);
RsaPrivateKey {
n,
e,
d,
p: Uint::ZERO,
q: Uint::ZERO,
phi_n_minus_1,
blinding_seed,
}
}
pub fn public_key(&self) -> RsaPublicKey<LIMBS> {
RsaPublicKey {
n: self.n,
e: self.e,
}
}
#[inline]
pub fn modulus(&self) -> &Uint<LIMBS> {
&self.n
}
#[inline]
pub fn primes(&self) -> (&Uint<LIMBS>, &Uint<LIMBS>) {
(&self.p, &self.q)
}
#[inline]
pub fn exponent(&self) -> &Uint<LIMBS> {
&self.e
}
#[inline]
pub fn private_exponent(&self) -> &Uint<LIMBS> {
&self.d
}
pub(crate) fn from_raw_parts(
n: Uint<LIMBS>,
e: Uint<LIMBS>,
d: Uint<LIMBS>,
p: Uint<LIMBS>,
q: Uint<LIMBS>,
) -> Self {
let (phi_n_minus_1, blinding_seed) = derive_blinding(&p, &q, &d);
RsaPrivateKey {
n,
e,
d,
p,
q,
phi_n_minus_1,
blinding_seed,
}
}
pub fn raw(&self, c: &Uint<LIMBS>) -> Uint<LIMBS> {
raw_private_blinded(
&self.n,
&self.e,
&self.d,
&self.phi_n_minus_1,
&self.blinding_seed,
c,
)
}
pub(crate) fn secret_seed_bytes(&self) -> [u8; 32] {
self.blinding_seed
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hash::Sha256;
use crate::rng::HmacDrbg;
fn rng() -> HmacDrbg<Sha256> {
HmacDrbg::new(b"rsa-keygen-test", b"nonce", &[])
}
#[cfg(all(feature = "alloc", feature = "der"))]
#[test]
fn blinding_does_not_alter_result() {
let key = crate::test_util::rsa_test_key_a();
let (n, e, d) = (*key.modulus(), *key.exponent(), *key.private_exponent());
let unblinded = RsaPrivateKey::<32>::from_components(n, e, d);
let mut c = Uint::<32>::from_be_bytes(b"some-message-bytes-for-the-rsa-private-op-1234567");
c = c.reduce(&n);
let blinded_result = key.raw(&c);
let unblinded_result = unblinded.raw(&c);
assert_eq!(
blinded_result, unblinded_result,
"base blinding must not change the mathematical result"
);
}
#[cfg(all(feature = "alloc", feature = "der"))]
#[test]
fn blinding_is_deterministic_for_same_input() {
let key = crate::test_util::rsa_test_key_a();
let n = *key.modulus();
let c = Uint::<32>::from_be_bytes(b"deterministic-c-bytes-here-xxxxx").reduce(&n);
assert_eq!(key.raw(&c), key.raw(&c));
}
#[cfg(all(feature = "alloc", feature = "der"))]
#[test]
fn blinding_seed_differs_across_keys() {
let a = crate::test_util::rsa_test_key_a();
let b = crate::test_util::rsa_test_key_b();
assert_ne!(
a.blinding_seed, b.blinding_seed,
"distinct private keys must derive distinct blinding seeds"
);
}
#[test]
#[ignore = "slow in debug; run with --release --ignored"]
fn keygen_roundtrip_rsa2048() {
let mut r = rng();
let e = Uint::<32>::from_u64(65537);
let key = RsaPrivateKey::<32>::generate(e, &mut r, 16);
let pubkey = key.public_key();
assert!(bool::from(key.modulus().is_odd()));
assert_eq!(pubkey.exponent(), &e);
assert_eq!(key.modulus().bit_len(), 2048);
let m = Uint::<32>::from_u64(0x0123_4567_89ab_cdef);
let c = pubkey.raw(&m);
assert_ne!(c, m);
assert_eq!(key.raw(&c), m);
}
}