use num_bigint::{BigUint, ModInverse, RandPrime};
use num_integer::Integer;
use rand::rngs::StdRng;
use rand::SeedableRng;
use crate::error::Error;
pub mod codec;
pub mod traits;
use traits::{Codec, RsaKey};
#[derive(Debug, Clone, PartialEq)]
pub struct Rsa<C: Codec> {
pub codec: C,
pub key_pair: KeyPair,
}
impl<C: Codec> Rsa<C> {
pub fn encrypt(&mut self, msg: &[u8]) -> Result<BigUint, Error> {
let encoded_msg = self.codec.encode(msg)?;
let encrypted_msg = self.key_pair.public_key.crypt(&encoded_msg)?;
Ok(encrypted_msg)
}
pub fn decrypt(&self, msg: &BigUint) -> Result<Vec<u8>, Error> {
let decrypted_msg = self.key_pair.private_key.crypt(msg)?;
let decoded_msg = self.codec.decode(&decrypted_msg)?;
Ok(decoded_msg)
}
pub fn decrypt_bytes(&self, msg: &[u8]) -> Result<Vec<u8>, Error> {
let msg = BigUint::from_bytes_be(msg);
self.decrypt(&msg)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct KeyPair {
pub p: BigUint,
pub q: BigUint,
pub private_key: Key,
pub public_key: Key,
}
#[derive(Debug, Clone, PartialEq)]
pub struct Key {
pub exponent: BigUint,
pub modulus: BigUint,
}
impl Key {
pub fn new(exponent: BigUint, modulus: BigUint) -> Self {
Self { exponent, modulus }
}
}
impl RsaKey for Key {
fn crypt(&self, message: &BigUint) -> Result<BigUint, Error> {
if message >= &self.modulus {
println!("message: {}", message);
return Err(Error::InvalidSize);
}
Ok(message.modpow(&self.exponent, &self.modulus))
}
}
impl Codec for Key {
fn encode(&mut self, chunk: &[u8]) -> Result<BigUint, Error> {
let msg = BigUint::from_bytes_be(chunk);
if msg > self.modulus {
return Err(Error::MessageTooLarge);
}
Ok(msg)
}
fn decode(&self, chunk: &BigUint) -> Result<Vec<u8>, Error> {
Ok(chunk.to_bytes_be())
}
}
impl KeyPair {
pub fn generate(prime_bits: usize, seed: u64) -> Result<Self, Error> {
if prime_bits < 8 {
return Err(Error::InvalidKeyParams { prime_bits });
}
let mut rng = StdRng::seed_from_u64(seed);
let p: BigUint = rng.gen_prime(prime_bits);
let mut q: BigUint = rng.gen_prime(prime_bits);
while q == p {
q = rng.gen_prime(prime_bits);
}
Self::new(&p, &q)
}
pub fn new(p: &BigUint, q: &BigUint) -> Result<Self, Error> {
let n = p * q;
let one: BigUint = 1u8.into();
let euler_p: BigUint = p - 1u8;
let euler_q: BigUint = q - 1u8;
let totient = euler_p.lcm(&euler_q);
let mut public_exponent: BigUint = 3u8.into();
while public_exponent.gcd(&totient) != one && public_exponent < totient {
public_exponent += 2u8;
}
if public_exponent >= totient {
return Err(Error::UnknownError);
}
let private_exponent = public_exponent
.clone()
.mod_inverse(&totient)
.ok_or(Error::UnknownError)?
.to_biguint()
.ok_or(Error::UnknownError)?;
let modulus = n;
let private_key = Key::new(private_exponent.clone(), modulus.clone());
let public_key = Key::new(public_exponent.clone(), modulus.clone());
Ok(Self {
p: p.clone(),
q: q.clone(),
private_key,
public_key,
})
}
}
#[cfg(test)]
mod tests {
use std::num::NonZeroU8;
use proptest::prelude::*;
use rsa::{Pkcs1v15Encrypt, RsaPrivateKey, RsaPublicKey};
use super::*;
const SEED: u64 = 1234;
#[test]
fn fails_to_generate_primes_too_small() {
let prime_bits = 7usize;
let result = KeyPair::generate(prime_bits, SEED);
assert!(result.is_err());
}
proptest! {
#[test]
fn generates_keys_with_valid_params(primebits in 8usize..128) {
let result = KeyPair::generate(primebits, SEED);
assert!(result.is_ok());
let keypair = result.unwrap();
assert!(keypair.p != keypair.q);
}
}
#[test]
fn keys_match_known_good_implementation() {
let prime_bits = 128usize;
let keypair = KeyPair::generate(prime_bits, SEED).unwrap();
let p = keypair.p.clone();
let q = keypair.q.clone();
let n = p.clone() * q.clone();
let public_exponent = keypair.public_key.exponent.clone();
let other_keypair = RsaPrivateKey::from_p_q(p, q, public_exponent.clone()).unwrap();
let other_pubkey = RsaPublicKey::new(n, public_exponent).unwrap();
assert_eq!(other_pubkey, other_keypair.to_public_key());
}
#[test]
fn can_be_decrypted_by_known_good_implementation() {
let prime_bits = 128usize;
let keypair = KeyPair::generate(prime_bits, SEED).unwrap();
let modulus_bytes = keypair.private_key.modulus.to_bytes_be().len();
let codec = codec::Pkcs1V1_5::new(
codec::BlockType::Type02,
modulus_bytes,
modulus_bytes - 11,
SEED,
);
let mut rsa = Rsa::<codec::Pkcs1V1_5> {
codec,
key_pair: keypair.clone(),
};
let kg_keypair = RsaPrivateKey::from_p_q(
keypair.p.clone(),
keypair.q.clone(),
keypair.public_key.exponent.clone(),
)
.expect("failed to generate keypair");
let plaintext = b"Hello, world!";
let ciphertext = rsa.encrypt(plaintext).unwrap();
let decrypted = kg_keypair
.decrypt(Pkcs1v15Encrypt, &ciphertext.to_bytes_be())
.expect("should decrypt");
assert_eq!(plaintext.as_slice(), &decrypted);
}
#[test]
fn can_encrypt_for_known_good_implementation() {
let mut rng = StdRng::seed_from_u64(1002);
let prime_bits = 128usize;
let keypair = KeyPair::generate(prime_bits, SEED).unwrap();
let modulus_bytes = keypair.private_key.modulus.to_bytes_be().len();
let codec = codec::Pkcs1V1_5::new(
codec::BlockType::Type02,
modulus_bytes,
modulus_bytes - 11,
SEED,
);
let kg_keypair = RsaPrivateKey::from_p_q(
keypair.p.clone(),
keypair.q.clone(),
keypair.public_key.exponent.clone(),
)
.expect("failed to generate keypair");
let plaintext = b"Hello, world!";
let ciphertext = kg_keypair
.to_public_key()
.encrypt(&mut rng, Pkcs1v15Encrypt, &plaintext[..])
.expect("failed to encrypt");
let rsa = Rsa::<codec::Pkcs1V1_5> {
codec,
key_pair: keypair.clone(),
};
let cipher_text_int = BigUint::from_bytes_be(ciphertext.as_slice());
let decrypted = rsa.decrypt(&cipher_text_int).unwrap();
assert_eq!(plaintext.as_slice(), &decrypted);
}
#[test]
fn round_trips_encryption_plain_2() {
let primebits = 33;
let plaintext = vec![1; 7];
let mut keypair = KeyPair::generate(primebits, SEED).unwrap();
let encoded_plaintext = keypair.public_key.encode(&plaintext).unwrap();
let ciphertext = keypair.public_key.crypt(&encoded_plaintext).unwrap();
let decrypted = keypair.private_key.crypt(&ciphertext).unwrap();
let decrypted_text = keypair.private_key.decode(&decrypted).unwrap();
assert_eq!(plaintext, decrypted_text);
}
proptest! {
#[test]
fn round_trips_encryption_plain(
plaintext in prop::collection::vec(any::<NonZeroU8>(), 1..10),
) {
let primebits = 64;
let mut keypair = KeyPair::generate(primebits, SEED).unwrap();
let plaintext = plaintext.into_iter().map(|nz| nz.get()).collect::<Vec<u8>>();
let encoded_plaintext = keypair.public_key.encode(&plaintext).unwrap();
let ciphertext = keypair.public_key.crypt(&encoded_plaintext).unwrap();
let decrypted = keypair.private_key.crypt(&ciphertext).unwrap();
let decrypted_text = keypair.private_key.decode(&decrypted).unwrap();
assert_eq!(plaintext, decrypted_text);
}
}
proptest! {
#[test]
fn round_trips_encryption_pkcs_type01(
plaintext in prop::collection::vec(any::<u8>(), 1..10),
) {
let primebits = 64;
let keypair = KeyPair::generate(primebits, SEED).unwrap();
let modulus_bytes = keypair.private_key.modulus.bits() / 8;
let codec = codec::Pkcs1V1_5::new(codec::BlockType::Type01, modulus_bytes, modulus_bytes - 11, SEED);
let mut rsa = Rsa::<codec::Pkcs1V1_5> {
codec,
key_pair: keypair,
};
let ciphertext = rsa.encrypt(&plaintext).unwrap();
let decrypted_text = rsa.decrypt(&ciphertext).unwrap();
assert_eq!(plaintext, decrypted_text);
}
}
proptest! {
#[test]
fn round_trips_encryption_pkcs_type02(
plaintext in prop::collection::vec(any::<u8>(), 1..10),
) {
let primebits = 64;
let keypair = KeyPair::generate(primebits, SEED).unwrap();
let modulus_bytes = keypair.private_key.modulus.bits() / 8;
let codec = codec::Pkcs1V1_5::new(codec::BlockType::Type02, modulus_bytes, modulus_bytes - 11, SEED);
let mut rsa = Rsa::<codec::Pkcs1V1_5> {
codec,
key_pair: keypair.clone(),
};
let ciphertext = rsa.encrypt(&plaintext).unwrap();
let decrypted_text = rsa.decrypt(&ciphertext).unwrap();
assert_eq!(plaintext, decrypted_text);
}
}
}