use super::bigint::BigInt;
#[derive(Clone, Debug)]
pub struct RsaPublicKey {
pub n: BigInt,
pub e: BigInt,
}
#[derive(Clone, Debug)]
pub struct RsaSecretKey {
pub n: BigInt,
pub d: BigInt,
pub p: BigInt,
pub q: BigInt,
pub dp: BigInt,
pub dq: BigInt,
pub qinv: BigInt,
}
impl RsaPublicKey {
pub fn modulus_byte_len(&self) -> usize {
self.n.byte_len()
}
}
impl RsaSecretKey {
pub fn modulus_byte_len(&self) -> usize {
self.n.byte_len()
}
}
pub fn rsa_keygen(bits: usize, rng: &mut dyn FnMut(&mut [u8])) -> (RsaPublicKey, RsaSecretKey) {
let half = bits / 2;
let e = BigInt::from_u64(65537);
loop {
let p = BigInt::random_prime(half, rng);
let q = BigInt::random_prime(half, rng);
if p == q {
continue;
}
let n = p.mul(&q);
if n.bit_len() != bits {
continue;
}
let one = BigInt::from_u64(1);
let p1 = p.sub(&one);
let q1 = q.sub(&one);
let phi = p1.mul(&q1);
let g = gcd(&p1, &q1);
let lambda = phi.div_rem(&g).0;
let d = match e.mod_inv(&lambda) {
Some(d) => d,
None => continue,
};
let dp = d.rem(&p1);
let dq = d.rem(&q1);
let qinv = match q.mod_inv(&p) {
Some(qi) => qi,
None => continue,
};
let pk = RsaPublicKey { n: n.clone(), e };
let sk = RsaSecretKey {
n,
d,
p,
q,
dp,
dq,
qinv,
};
return (pk, sk);
}
}
pub fn rsa_encrypt_raw(pk: &RsaPublicKey, m: &BigInt) -> BigInt {
m.mod_exp(&pk.e, &pk.n)
}
pub fn rsa_decrypt_raw(sk: &RsaSecretKey, c: &BigInt) -> BigInt {
let cp = c.rem(&sk.p);
let cq = c.rem(&sk.q);
let m1 = cp.mod_exp(&sk.dp, &sk.p);
let m2 = cq.mod_exp(&sk.dq, &sk.q);
let m2_mod_p = m2.rem(&sk.p);
let diff = if m1 >= m2_mod_p {
m1.sub(&m2_mod_p)
} else {
m1.add(&sk.p).sub(&m2_mod_p)
};
let h = sk.qinv.mul(&diff).rem(&sk.p);
m2.add(&h.mul(&sk.q))
}
fn gcd(a: &BigInt, b: &BigInt) -> BigInt {
let mut a = a.clone();
let mut b = b.clone();
while !b.is_zero() {
let r = a.rem(&b);
a = b;
b = r;
}
a
}
#[cfg(test)]
mod tests {
use super::*;
fn test_rng() -> impl FnMut(&mut [u8]) {
let mut state: u64 = 0xdeadbeefcafebabe;
move |buf: &mut [u8]| {
for b in buf.iter_mut() {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
*b = (state >> 33) as u8;
}
}
}
#[test]
fn test_raw_encrypt_decrypt_small() {
let p = BigInt::from_u64(61);
let q = BigInt::from_u64(53);
let n = p.mul(&q); let e = BigInt::from_u64(17);
let one = BigInt::from_u64(1);
let p1 = p.sub(&one);
let q1 = q.sub(&one);
let lambda = p1.mul(&q1); let d = e.mod_inv(&lambda).unwrap();
let dp = d.rem(&p1);
let dq = d.rem(&q1);
let qinv = q.mod_inv(&p).unwrap();
let pk = RsaPublicKey { n: n.clone(), e };
let sk = RsaSecretKey {
n,
d,
p,
q,
dp,
dq,
qinv,
};
let m = BigInt::from_u64(42);
let c = rsa_encrypt_raw(&pk, &m);
let m2 = rsa_decrypt_raw(&sk, &c);
assert_eq!(m, m2);
}
#[test]
fn test_keygen_roundtrip() {
let mut rng = test_rng();
let (pk, sk) = rsa_keygen(512, &mut rng);
let m = BigInt::from_u64(12345678);
let c = rsa_encrypt_raw(&pk, &m);
let m2 = rsa_decrypt_raw(&sk, &c);
assert_eq!(m, m2);
}
}