use num_bigint_dig::{BigInt, BigUint, RandPrime, Sign, ToBigInt};
use num_integer::Integer;
use num_traits::{One, Zero};
use rand::{rngs::StdRng, SeedableRng};
#[derive(Debug, Clone)]
pub struct RSAPublicKey {
pub n: BigUint,
pub e: BigUint,
}
#[derive(Debug)]
pub struct RSAPrivateKey {
pub n: BigUint,
pub e: BigUint,
pub d: BigUint,
pub p: BigUint,
pub q: BigUint,
}
#[derive(Debug)]
pub struct RSAKeyPair {
pub public_key: RSAPublicKey,
pub private_key: RSAPrivateKey,
}
pub struct RSAKeyGenConfig {
pub key_size: usize,
pub public_exponent: u64,
pub seed: Option<u64>,
}
impl RSAKeyPair {
pub fn generate(config: &RSAKeyGenConfig) -> Self {
let mut rng = match config.seed {
Some(s) => StdRng::seed_from_u64(s),
None => StdRng::from_entropy(),
};
let prime_bits = config.key_size / 2;
let mut p: BigUint;
let mut q: BigUint;
let mut phi: BigUint;
let e = BigUint::from(config.public_exponent);
loop {
p = RandPrime::gen_prime(&mut rng, prime_bits);
q = RandPrime::gen_prime(&mut rng, prime_bits);
if p == q {
continue;
}
phi = (p.clone() - BigUint::one()) * (q.clone() - BigUint::one());
if mod_inverse(&e, &phi).is_some() {
break;
}
}
let n = &p * &q;
let d = mod_inverse(&e, &phi).unwrap();
let public_key = RSAPublicKey {
n: n.clone(),
e: e.clone(),
};
let private_key = RSAPrivateKey {
n: n.clone(),
e: e.clone(),
d: d.clone(),
p,
q,
};
RSAKeyPair {
public_key,
private_key,
}
}
}
pub fn rsa_encrypt(public_key: &RSAPublicKey, plaintext: &BigUint) -> BigUint {
plaintext.modpow(&public_key.e, &public_key.n)
}
pub fn rsa_decrypt(private_key: &RSAPrivateKey, ciphertext: &BigUint) -> BigUint {
ciphertext.modpow(&private_key.d, &private_key.n)
}
pub fn rsa_sign(private_key: &RSAPrivateKey, message_hash: &BigUint) -> BigUint {
message_hash.modpow(&private_key.d, &private_key.n)
}
pub fn rsa_verify(public_key: &RSAPublicKey, signature: &BigUint) -> BigUint {
signature.modpow(&public_key.e, &public_key.n)
}
fn mod_inverse(a: &BigUint, m: &BigUint) -> Option<BigUint> {
let a_int = a.to_bigint().unwrap();
let m_int = m.to_bigint().unwrap();
let (g, x, _) = extended_gcd(&a_int, &m_int);
if g.is_one() {
let mut result = x % &m_int;
if result.sign() == Sign::Minus {
result += &m_int;
}
Some(result.to_biguint().unwrap())
} else {
None
}
}
fn extended_gcd(a: &BigInt, b: &BigInt) -> (BigInt, BigInt, BigInt) {
if b.is_zero() {
(a.clone(), BigInt::one(), BigInt::zero())
} else {
let (q, r) = a.div_rem(b);
let (g, x, y) = extended_gcd(b, &r);
(g, y.clone(), x - &q * y)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_toy_rsa() {
let config = RSAKeyGenConfig {
key_size: 128, public_exponent: 65537,
seed: Some(42),
};
let keypair = RSAKeyPair::generate(&config);
let msg = BigUint::from(42u64);
let enc = rsa_encrypt(&keypair.public_key, &msg);
assert!(
msg < keypair.public_key.n,
"Message must be smaller than modulus"
);
let dec = rsa_decrypt(&keypair.private_key, &enc);
assert_eq!(dec, msg, "RSA encryption/decryption mismatch");
let hash = BigUint::from(24u64);
assert!(
hash < keypair.public_key.n,
"Hash must be smaller than modulus"
);
let sig = rsa_sign(&keypair.private_key, &hash);
let recovered = rsa_verify(&keypair.public_key, &sig);
assert_eq!(recovered, hash, "RSA sign/verify mismatch");
}
}