use clear_on_drop::clear::Clear;
use num_bigint::traits::ModInverse;
use num_bigint::Sign::Plus;
use num_bigint::{BigInt, BigUint, RandBigInt};
use num_traits::{FromPrimitive, One, Signed, Zero};
use rand::{Rng, ThreadRng};
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
use algorithms::generate_multi_prime_key;
use errors::{Error, Result};
use hash::Hash;
use padding::PaddingScheme;
use pkcs1v15;
lazy_static! {
static ref MIN_PUB_EXPONENT: BigUint = BigUint::from_u64(2).unwrap();
static ref MAX_PUB_EXPONENT: BigUint = BigUint::from_u64(1 << (31 - 1)).unwrap();
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct RSAPublicKey {
n: BigUint,
e: BigUint,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct RSAPrivateKey {
n: BigUint,
e: BigUint,
d: BigUint,
primes: Vec<BigUint>,
#[cfg_attr(feature = "serde1", serde(skip))]
precomputed: Option<PrecomputedValues>,
}
impl PartialEq for RSAPrivateKey {
#[inline]
fn eq(&self, other: &RSAPrivateKey) -> bool {
self.n == other.n && self.e == other.e && self.d == other.d && self.primes == other.primes
}
}
impl Eq for RSAPrivateKey {}
impl Drop for RSAPrivateKey {
#[inline]
fn drop(&mut self) {
self.d.clear();
self.primes.clear();
self.precomputed.clear();
}
}
#[derive(Debug, Clone)]
struct PrecomputedValues {
dp: BigUint,
dq: BigUint,
qinv: BigInt,
crt_values: Vec<CRTValue>,
}
impl Drop for PrecomputedValues {
#[inline]
fn drop(&mut self) {
self.dp.clear();
self.dq.clear();
self.qinv.clear();
self.crt_values.clear();
}
}
#[derive(Debug, Clone)]
struct CRTValue {
exp: BigInt,
coeff: BigInt,
r: BigInt,
}
impl From<RSAPrivateKey> for RSAPublicKey {
fn from(private_key: RSAPrivateKey) -> Self {
let n = private_key.n.clone();
let e = private_key.e.clone();
RSAPublicKey { n, e }
}
}
pub trait PublicKey {
fn n(&self) -> &BigUint;
fn e(&self) -> &BigUint;
fn size(&self) -> usize {
(self.n().bits() + 7) / 8
}
fn encrypt<R: Rng>(&self, rng: &mut R, padding: PaddingScheme, msg: &[u8]) -> Result<Vec<u8>>;
fn verify<H: Hash>(
&self,
padding: PaddingScheme,
hash: Option<&H>,
hashed: &[u8],
sig: &[u8],
) -> Result<()>;
}
impl PublicKey for RSAPublicKey {
fn n(&self) -> &BigUint {
&self.n
}
fn e(&self) -> &BigUint {
&self.e
}
fn encrypt<R: Rng>(&self, rng: &mut R, padding: PaddingScheme, msg: &[u8]) -> Result<Vec<u8>> {
match padding {
PaddingScheme::PKCS1v15 => pkcs1v15::encrypt(rng, self, msg),
PaddingScheme::OAEP => unimplemented!("not yet implemented"),
_ => Err(Error::InvalidPaddingScheme),
}
}
fn verify<H: Hash>(
&self,
padding: PaddingScheme,
hash: Option<&H>,
hashed: &[u8],
sig: &[u8],
) -> Result<()> {
match padding {
PaddingScheme::PKCS1v15 => pkcs1v15::verify(self, hash, hashed, sig),
PaddingScheme::PSS => unimplemented!("not yet implemented"),
_ => Err(Error::InvalidPaddingScheme),
}
}
}
impl RSAPublicKey {
pub fn new(n: BigUint, e: BigUint) -> Result<Self> {
let k = RSAPublicKey { n, e };
check_public(&k)?;
Ok(k)
}
}
impl<'a> PublicKey for &'a RSAPublicKey {
fn n(&self) -> &BigUint {
&self.n
}
fn e(&self) -> &BigUint {
&self.e
}
fn encrypt<R: Rng>(&self, rng: &mut R, padding: PaddingScheme, msg: &[u8]) -> Result<Vec<u8>> {
(*self).encrypt(rng, padding, msg)
}
fn verify<H: Hash>(
&self,
padding: PaddingScheme,
hash: Option<&H>,
hashed: &[u8],
sig: &[u8],
) -> Result<()> {
(*self).verify(padding, hash, hashed, sig)
}
}
impl PublicKey for RSAPrivateKey {
fn n(&self) -> &BigUint {
&self.n
}
fn e(&self) -> &BigUint {
&self.e
}
fn encrypt<R: Rng>(&self, rng: &mut R, padding: PaddingScheme, msg: &[u8]) -> Result<Vec<u8>> {
match padding {
PaddingScheme::PKCS1v15 => pkcs1v15::encrypt(rng, self, msg),
PaddingScheme::OAEP => unimplemented!("not yet implemented"),
_ => Err(Error::InvalidPaddingScheme),
}
}
fn verify<H: Hash>(
&self,
padding: PaddingScheme,
hash: Option<&H>,
hashed: &[u8],
sig: &[u8],
) -> Result<()> {
match padding {
PaddingScheme::PKCS1v15 => pkcs1v15::verify(self, hash, hashed, sig),
PaddingScheme::PSS => unimplemented!("not yet implemented"),
_ => Err(Error::InvalidPaddingScheme),
}
}
}
impl<'a> PublicKey for &'a RSAPrivateKey {
fn n(&self) -> &BigUint {
&self.n
}
fn e(&self) -> &BigUint {
&self.e
}
fn encrypt<R: Rng>(&self, rng: &mut R, padding: PaddingScheme, msg: &[u8]) -> Result<Vec<u8>> {
(*self).encrypt(rng, padding, msg)
}
fn verify<H: Hash>(
&self,
padding: PaddingScheme,
hash: Option<&H>,
hashed: &[u8],
sig: &[u8],
) -> Result<()> {
(*self).verify(padding, hash, hashed, sig)
}
}
impl RSAPrivateKey {
pub fn new<R: Rng>(rng: &mut R, bit_size: usize) -> Result<RSAPrivateKey> {
generate_multi_prime_key(rng, 2, bit_size)
}
pub fn from_components(
n: BigUint,
e: BigUint,
d: BigUint,
primes: Vec<BigUint>,
) -> RSAPrivateKey {
let mut k = RSAPrivateKey {
n,
e,
d,
primes,
precomputed: None,
};
k.precompute();
k
}
pub fn precompute(&mut self) {
if self.precomputed.is_some() {
return;
}
let dp = &self.d % (&self.primes[0] - BigUint::one());
let dq = &self.d % (&self.primes[1] - BigUint::one());
let qinv = self.primes[1]
.clone()
.mod_inverse(&self.primes[0])
.expect("invalid prime");
let mut r: BigUint = &self.primes[0] * &self.primes[1];
let crt_values: Vec<CRTValue> = self
.primes
.iter()
.skip(2)
.map(|prime| {
let res = CRTValue {
exp: BigInt::from_biguint(Plus, &self.d % (prime - BigUint::one())),
r: BigInt::from_biguint(Plus, r.clone()),
coeff: BigInt::from_biguint(
Plus,
r.clone()
.mod_inverse(prime)
.expect("invalid coeff")
.to_biguint()
.unwrap(),
),
};
r *= prime;
res
})
.collect();
self.precomputed = Some(PrecomputedValues {
dp,
dq,
qinv,
crt_values,
});
}
pub fn d(&self) -> &BigUint {
&self.d
}
pub fn primes(&self) -> &[BigUint] {
&self.primes
}
pub fn validate(&self) -> Result<()> {
check_public(self)?;
let mut m = BigUint::one();
for prime in &self.primes {
if *prime < BigUint::one() {
return Err(Error::InvalidPrime);
}
m *= prime;
}
if m != self.n {
return Err(Error::InvalidModulus);
}
let mut de = self.e.clone();
de *= self.d.clone();
for prime in &self.primes {
let congruence: BigUint = &de % (prime - BigUint::one());
if !congruence.is_one() {
return Err(Error::InvalidExponent);
}
}
Ok(())
}
pub fn decrypt(&self, padding: PaddingScheme, ciphertext: &[u8]) -> Result<Vec<u8>> {
match padding {
PaddingScheme::PKCS1v15 => pkcs1v15::decrypt::<ThreadRng>(None, self, ciphertext),
PaddingScheme::OAEP => unimplemented!("not yet implemented"),
_ => Err(Error::InvalidPaddingScheme),
}
}
pub fn decrypt_blinded<R: Rng>(
&self,
rng: &mut R,
padding: PaddingScheme,
ciphertext: &[u8],
) -> Result<Vec<u8>> {
match padding {
PaddingScheme::PKCS1v15 => pkcs1v15::decrypt(Some(rng), self, ciphertext),
PaddingScheme::OAEP => unimplemented!("not yet implemented"),
_ => Err(Error::InvalidPaddingScheme),
}
}
pub fn sign<H: Hash>(
&self,
padding: PaddingScheme,
hash: Option<&H>,
digest: &[u8],
) -> Result<Vec<u8>> {
match padding {
PaddingScheme::PKCS1v15 => pkcs1v15::sign::<ThreadRng, _>(None, self, hash, digest),
PaddingScheme::PSS => unimplemented!("not yet implemented"),
_ => Err(Error::InvalidPaddingScheme),
}
}
pub fn sign_blinded<R: Rng, H: Hash>(
&self,
rng: &mut R,
padding: PaddingScheme,
hash: Option<&H>,
digest: &[u8],
) -> Result<Vec<u8>> {
match padding {
PaddingScheme::PKCS1v15 => pkcs1v15::sign(Some(rng), self, hash, digest),
PaddingScheme::PSS => unimplemented!("not yet implemented"),
_ => Err(Error::InvalidPaddingScheme),
}
}
}
#[inline]
pub fn check_public(public_key: &impl PublicKey) -> Result<()> {
if public_key.e() < &*MIN_PUB_EXPONENT {
return Err(Error::PublicExponentTooSmall);
}
if public_key.e() > &*MAX_PUB_EXPONENT {
return Err(Error::PublicExponentTooLarge);
}
Ok(())
}
#[inline]
pub fn encrypt<K: PublicKey>(key: &K, m: &BigUint) -> BigUint {
m.modpow(key.e(), key.n())
}
#[inline]
pub fn decrypt<R: Rng>(
mut rng: Option<&mut R>,
priv_key: &RSAPrivateKey,
c: &BigUint,
) -> Result<BigUint> {
if c >= priv_key.n() {
return Err(Error::Decryption);
}
if priv_key.n().is_zero() {
return Err(Error::Decryption);
}
let mut ir = None;
let c = if let Some(ref mut rng) = rng {
let mut r: BigUint;
loop {
r = rng.gen_biguint_below(priv_key.n());
if r.is_zero() {
r = BigUint::one();
}
ir = r.clone().mod_inverse(priv_key.n());
if ir.is_some() {
break;
}
}
let e = priv_key.e();
let rpowe = r.modpow(&e, priv_key.n());
(c * &rpowe) % priv_key.n()
} else {
c.clone()
};
let m = match priv_key.precomputed {
None => c.modpow(priv_key.d(), priv_key.n()),
Some(ref precomputed) => {
let p = &priv_key.primes[0];
let q = &priv_key.primes[1];
let mut m = BigInt::from_biguint(Plus, c.modpow(&precomputed.dp, p));
let mut m2 = BigInt::from_biguint(Plus, c.modpow(&precomputed.dq, q));
m -= &m2;
let primes: Vec<_> = priv_key
.primes
.iter()
.map(|v| BigInt::from_biguint(Plus, v.clone()))
.collect();
while m.is_negative() {
m += &primes[0];
}
m *= &precomputed.qinv;
m %= &primes[0];
m *= &primes[1];
m += m2;
let c = BigInt::from_biguint(Plus, c);
for (i, value) in precomputed.crt_values.iter().enumerate() {
let prime = &primes[2 + i];
m2 = c.modpow(&value.exp, prime);
m2 -= &m;
m2 *= &value.coeff;
m2 %= prime;
while m2.is_negative() {
m2 += prime;
}
m2 *= &value.r;
m += &m2;
}
m.to_biguint().expect("failed to decrypt")
}
};
match ir {
Some(ref ir) => {
Ok((m * ir.to_biguint().unwrap()) % priv_key.n())
}
None => Ok(m),
}
}
#[inline]
pub fn decrypt_and_check<R: Rng>(
rng: Option<&mut R>,
priv_key: &RSAPrivateKey,
c: &BigUint,
) -> Result<BigUint> {
let m = decrypt(rng, priv_key, c)?;
let check = encrypt(priv_key, &m);
if c != &check {
return Err(Error::Internal);
}
Ok(m)
}
#[inline]
pub fn left_pad(input: &[u8], size: usize) -> Vec<u8> {
let n = if input.len() > size {
size
} else {
input.len()
};
let mut out = vec![0u8; size];
out[size - n..].copy_from_slice(input);
out
}
#[cfg(test)]
mod tests {
use super::*;
use num_traits::{FromPrimitive, ToPrimitive};
use rand::{thread_rng, ThreadRng};
#[test]
fn test_from_into() {
let private_key = RSAPrivateKey {
n: BigUint::from_u64(100).unwrap(),
e: BigUint::from_u64(200).unwrap(),
d: BigUint::from_u64(123).unwrap(),
primes: vec![],
precomputed: None,
};
let public_key: RSAPublicKey = private_key.into();
assert_eq!(public_key.n().to_u64(), Some(100));
assert_eq!(public_key.e().to_u64(), Some(200));
}
fn test_key_basics(private_key: &RSAPrivateKey) {
private_key.validate().expect("invalid private key");
assert!(
private_key.d() < private_key.n(),
"private exponent too large"
);
let pub_key: RSAPublicKey = private_key.clone().into();
let m = BigUint::from_u64(42).expect("invalid 42");
let c = encrypt(&pub_key, &m);
let m2 = decrypt::<ThreadRng>(None, &private_key, &c)
.expect("unable to decrypt without blinding");
assert_eq!(m, m2);
let mut rng = thread_rng();
let m3 =
decrypt(Some(&mut rng), &private_key, &c).expect("unable to decrypt with blinding");
assert_eq!(m, m3);
}
macro_rules! key_generation {
($name:ident, $multi:expr, $size:expr) => {
#[test]
fn $name() {
let mut rng = thread_rng();
for _ in 0..10 {
let private_key = if $multi == 2 {
RSAPrivateKey::new(&mut rng, $size).expect("failed to generate key")
} else {
generate_multi_prime_key(&mut rng, $multi, $size).unwrap()
};
assert_eq!(private_key.n().bits(), $size);
test_key_basics(&private_key);
}
}
};
}
key_generation!(key_generation_128, 2, 128);
key_generation!(key_generation_1024, 2, 1024);
key_generation!(key_generation_multi_3_256, 3, 256);
key_generation!(key_generation_multi_4_64, 4, 64);
key_generation!(key_generation_multi_5_64, 5, 64);
key_generation!(key_generation_multi_8_576, 8, 576);
key_generation!(key_generation_multi_16_1024, 16, 1024);
#[test]
fn test_impossible_keys() {
let mut rng = thread_rng();
for i in 0..32 {
let _ = RSAPrivateKey::new(&mut rng, i).is_err();
let _ = generate_multi_prime_key(&mut rng, 3, i);
let _ = generate_multi_prime_key(&mut rng, 4, i);
let _ = generate_multi_prime_key(&mut rng, 5, i);
}
}
#[test]
fn test_negative_decryption_value() {
let private_key = RSAPrivateKey::from_components(
BigUint::from_bytes_le(&vec![
99, 192, 208, 179, 0, 220, 7, 29, 49, 151, 75, 107, 75, 73, 200, 180,
]),
BigUint::from_bytes_le(&vec![1, 0, 1]),
BigUint::from_bytes_le(&vec![
81, 163, 254, 144, 171, 159, 144, 42, 244, 133, 51, 249, 28, 12, 63, 65,
]),
vec![
BigUint::from_bytes_le(&vec![105, 101, 60, 173, 19, 153, 3, 192]),
BigUint::from_bytes_le(&vec![235, 65, 160, 134, 32, 136, 6, 241]),
],
);
for _ in 0..1000 {
test_key_basics(&private_key);
}
}
#[test]
#[cfg(feature = "serde1")]
fn test_serde() {
use rand::{SeedableRng, XorShiftRng};
use serde_test::{assert_tokens, Token};
let mut rng = XorShiftRng::from_seed([1; 16]);
let priv_key = RSAPrivateKey::new(&mut rng, 64).expect("failed to generate key");
let priv_tokens = [
Token::Struct {
name: "RSAPrivateKey",
len: 4,
},
Token::Str("n"),
Token::Seq { len: Some(2) },
Token::U32(1296829443),
Token::U32(2444363981),
Token::SeqEnd,
Token::Str("e"),
Token::Seq { len: Some(1) },
Token::U32(65537),
Token::SeqEnd,
Token::Str("d"),
Token::Seq { len: Some(2) },
Token::U32(298985985),
Token::U32(2349628418),
Token::SeqEnd,
Token::Str("primes"),
Token::Seq { len: Some(2) },
Token::Seq { len: Some(1) },
Token::U32(3238068481),
Token::SeqEnd,
Token::Seq { len: Some(1) },
Token::U32(3242199299),
Token::SeqEnd,
Token::SeqEnd,
Token::StructEnd,
];
assert_tokens(&priv_key, &priv_tokens);
let priv_tokens = [
Token::Struct {
name: "RSAPublicKey",
len: 2,
},
Token::Str("n"),
Token::Seq { len: Some(2) },
Token::U32(1296829443),
Token::U32(2444363981),
Token::SeqEnd,
Token::Str("e"),
Token::Seq { len: Some(1) },
Token::U32(65537),
Token::SeqEnd,
Token::StructEnd,
];
assert_tokens(&RSAPublicKey::from(priv_key), &priv_tokens);
}
}