use num_bigint_dig::algorithms::mod_inverse;
use num_bigint_dig::prime::probably_prime;
use num_bigint_dig::traits::ModInverse;
use num_bigint_dig::{BigUint, IntoBigInt, IntoBigUint, RandPrime};
use num_integer::Integer;
use rand::prelude::*;
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use std::convert::{From, TryFrom, TryInto};
#[derive(Clone, Debug, PartialEq)]
pub struct Prime {
prime: BigUint,
}
#[repr(transparent)]
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct UncheckedPrime {
p: BigUint,
}
impl TryFrom<BigUint> for Prime {
type Error = String;
fn try_from(value: BigUint) -> Result<Self, Self::Error> {
if Prime::is_prime(&value) {
Ok(Prime { prime: value })
} else {
Err("Given number is not a prime".to_owned())
}
}
}
impl TryFrom<UncheckedPrime> for Prime {
type Error = String;
fn try_from(value: UncheckedPrime) -> Result<Self, Self::Error> {
value.p.try_into()
}
}
impl From<Prime> for UncheckedPrime {
fn from(value: Prime) -> Self {
UncheckedPrime { p: value.prime }
}
}
impl Prime {
fn is_prime(i: &BigUint) -> bool {
probably_prime(i, 256)
}
pub fn random<Rng: CryptoRng + RngCore>(num_bits: usize, rng: &mut Rng) -> Self {
Prime {
prime: rng.gen_prime(num_bits),
}
}
pub fn num_bits(&self) -> usize {
self.prime.bits()
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct UncheckedRsa {
e: BigUint,
d: BigUint,
}
impl From<Rsa> for UncheckedRsa {
fn from(value: Rsa) -> Self {
UncheckedRsa {
e: value.e,
d: value.d,
}
}
}
impl TryFrom<(UncheckedRsa, &RsaParameter)> for Rsa {
type Error = String;
fn try_from(value: (UncheckedRsa, &RsaParameter)) -> Result<Self, Self::Error> {
let (UncheckedRsa { e, d }, parameter) = value;
if e == 1u32.into() || e.gcd(¶meter.lambda_n) != 1u32.into() {
return Err("RSA encryption key not incorrect".to_owned());
}
if e.clone().mod_inverse(¶meter.lambda_n) != Some(d.clone().into_bigint().unwrap()) {
return Err("RSA decryption key incorrect".to_owned());
}
Ok(Rsa {
parameter: parameter.clone(),
e,
d,
})
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct Rsa {
parameter: RsaParameter,
e: BigUint,
d: BigUint,
}
#[derive(Clone, Debug, PartialEq)]
pub struct RsaParameter {
n: BigUint,
lambda_n: BigUint,
}
impl RsaParameter {
pub fn from_primes(primes: &[Prime]) -> RsaParameter {
let lambda_n = primes
.iter()
.map(|p| &p.prime - &BigUint::from(1u32))
.fold(1u32.into(), |acc: BigUint, n: BigUint| acc.lcm(&n));
let n = primes
.iter()
.fold(1u32.into(), |acc: BigUint, p| acc * &p.prime);
RsaParameter { n, lambda_n }
}
pub fn n(&self) -> BigUint {
self.n.clone()
}
pub fn lambda_n(&self) -> BigUint {
self.lambda_n.clone()
}
}
impl Rsa {
pub fn get_e(&self) -> BigUint {
self.e.clone()
}
pub fn get_d(&self) -> BigUint {
self.d.clone()
}
pub fn encrypt(&self, message: BigUint) -> BigUint {
message.modpow(&self.e, &self.parameter.n)
}
pub fn decrypt(&self, message: BigUint) -> BigUint {
message.modpow(&self.d, &self.parameter.n)
}
pub fn gen_with_parameter<Rng: CryptoRng + RngCore>(
parameter: RsaParameter,
rng: &mut Rng,
) -> Rsa {
let e = loop {
let num_bytes = (parameter.lambda_n.bits() + 7) / 8;
let mut number = vec![0u8; num_bytes as usize];
rng.fill_bytes(&mut number);
let number = BigUint::from_bytes_le(&number) % ¶meter.lambda_n;
if number.gcd(¶meter.lambda_n) == 1u32.into() {
break number;
}
};
let d = mod_inverse(Cow::Borrowed(&e), Cow::Borrowed(¶meter.lambda_n)).unwrap();
Rsa {
parameter,
e,
d: d.into_biguint().unwrap(),
}
}
pub fn from_e_d(e: BigUint, d: BigUint, parameter: RsaParameter) -> Result<Rsa, &'static str> {
if e >= parameter.lambda_n {
return Err("e has to be smaller than lambda_n");
}
if d >= parameter.lambda_n {
return Err("d has to be smaller than lambda_n");
}
if e.gcd(¶meter.lambda_n) != 1u32.into() {
return Err("invalid parameter e");
}
if d.clone().into_bigint()
!= mod_inverse(Cow::Borrowed(&e), Cow::Borrowed(¶meter.lambda_n))
{
return Err("invalid parameter d");
}
Ok(Self { parameter, e, d })
}
}
#[cfg(test)]
mod test {
use super::*;
use std::convert::TryInto;
#[test]
fn encrypt_decrypt() {
let rsa_parameter = RsaParameter {
n: BigUint::from(3233u32),
lambda_n: BigUint::from(780u32),
};
let key = Rsa {
parameter: rsa_parameter,
e: BigUint::from(17u32),
d: BigUint::from(413u32),
};
let m = BigUint::from(65u8);
let c = key.encrypt(m.clone());
assert_eq!(c, BigUint::from(2790u32));
let d = key.decrypt(c);
assert_eq!(d, m);
}
#[test]
fn from_e_d() {
let rsa_parameter = RsaParameter {
n: BigUint::from(3233u32),
lambda_n: BigUint::from(780u32),
};
let key =
Rsa::from_e_d(BigUint::from(17u32), BigUint::from(413u32), rsa_parameter).unwrap();
let m = BigUint::from(65u8);
let c = key.encrypt(m.clone());
assert_eq!(c, BigUint::from(2790u32));
let d = key.decrypt(c);
assert_eq!(d, m);
}
#[test]
fn generate_keys_1() {
let mut rng = rand::thread_rng();
let p = Prime::random(128, &mut rng);
let rsa_parameter = RsaParameter::from_primes(&[p]);
let key = Rsa::gen_with_parameter(rsa_parameter, &mut rng);
let m = BigUint::from_bytes_be(&[65u8, 66, 67, 68]);
let c = key.encrypt(m.clone());
let d = key.decrypt(c);
assert_eq!(d, m);
}
#[test]
fn generate_keys_2() {
let mut rng = rand::thread_rng();
let p = Prime::random(128, &mut rng);
let q = Prime::random(128, &mut rng);
let rsa_parameter = RsaParameter::from_primes(&[p, q]);
let key = Rsa::gen_with_parameter(rsa_parameter, &mut rng);
let m = BigUint::from_bytes_be(&[65u8, 66, 67, 68]);
let c = key.encrypt(m.clone());
let d = key.decrypt(c);
assert_eq!(d, m);
}
#[test]
fn serde() {
let mut rng = rand::thread_rng();
let p: UncheckedPrime = Prime::random(128, &mut rng).into();
let p_str = serde_json::to_string(&p).unwrap();
assert_eq!(p, serde_json::from_str(&p_str).unwrap())
}
#[test]
fn import() {
let mut rng = rand::thread_rng();
let ps = [Prime::random(128, &mut rng), Prime::random(128, &mut rng)];
let rsa_parameter = RsaParameter::from_primes(&ps);
let k = Rsa::gen_with_parameter(rsa_parameter.clone(), &mut rng);
let send_rsa: UncheckedRsa = k.clone().into();
assert_eq!(Ok(k), (send_rsa, &rsa_parameter).try_into())
}
}