use {bits, der, digest, error, pkcs8};
use rand;
use std;
use super::{blinding, bigint, bigint::Prime, N};
use arithmetic::montgomery::{R, RR, RRR};
use untrusted;
pub struct RSAKeyPair {
n: bigint::Modulus<N>,
e: bigint::PublicExponent,
p: PrivatePrime<P>,
q: PrivatePrime<Q>,
qInv: bigint::Elem<P, R>,
oneRR_mod_n: bigint::One<N, RR>,
qq: bigint::Modulus<QQ>,
q_mod_n: bigint::Elem<N, R>,
n_bits: bits::BitLength,
}
impl RSAKeyPair {
pub fn from_pkcs8(input: untrusted::Input)
-> Result<RSAKeyPair, error::Unspecified> {
const RSA_ENCRYPTION: &'static [u8] =
include_bytes!("../data/alg-rsa-encryption.der");
let (der, _) = pkcs8::unwrap_key_(&RSA_ENCRYPTION,
pkcs8::Version::V1Only, input)?;
Self::from_der(der)
}
pub fn from_der(input: untrusted::Input)
-> Result<RSAKeyPair, error::Unspecified> {
input.read_all(error::Unspecified, |input| {
der::nested(input, der::Tag::Sequence, error::Unspecified, |input| {
let version = der::small_nonnegative_integer(input)?;
if version != 0 {
return Err(error::Unspecified);
}
let n = der::positive_integer(input)?;
let e = der::positive_integer(input)?;
let d = der::positive_integer(input)?;
let p = der::positive_integer(input)?;
let q = der::positive_integer(input)?;
let dP = der::positive_integer(input)?;
let dQ = der::positive_integer(input)?;
let qInv = der::positive_integer(input)?;
let (p, p_bits) =
bigint::Nonnegative::from_be_bytes_with_bit_length(p)?;
let (q, q_bits) =
bigint::Nonnegative::from_be_bytes_with_bit_length(q)?;
let ((p, p_bits, dP), (q, q_bits, dQ, qInv)) =
match q.verify_less_than(&p) {
Ok(_) => ((p, p_bits, dP), (q, q_bits, dQ, Some(qInv))),
Err(_) => {
((q, q_bits, dQ), (p, p_bits, dP, None))
},
};
let (n, n_bits, e) = super::check_public_modulus_and_exponent(
n, e, bits::BitLength::from_usize_bits(2048),
super::PRIVATE_KEY_PUBLIC_MODULUS_MAX_BITS, 65537)?;
let half_n_bits = n_bits.half_rounded_up();
if p_bits != half_n_bits {
return Err(error::Unspecified);
}
if p_bits != q_bits {
return Err(error::Unspecified);
}
let oneRR_mod_n = bigint::One::newRR(&n);
let q_mod_n_decoded = q.to_elem(&n)?;
let q_mod_n = bigint::elem_mul(oneRR_mod_n.as_ref(),
q_mod_n_decoded.clone(), &n);
let p_mod_n = p.to_elem(&n)?;
let pq_mod_n = bigint::elem_mul(&q_mod_n, p_mod_n, &n);
if !pq_mod_n.is_zero() {
return Err(error::Unspecified);
}
let (d, d_bits) =
bigint::Nonnegative::from_be_bytes_with_bit_length(d)?;
if !(half_n_bits < d_bits) {
return Err(error::Unspecified);
}
d.verify_less_than_modulus(&n)?;
if !d.is_odd() {
return Err(error::Unspecified);
}
let p = PrivatePrime::new(p, dP)?;
let q = PrivatePrime::new(q, dQ)?;
let q_mod_p = q.modulus.to_elem(&p.modulus);
let qInv = if let Some(qInv) = qInv {
bigint::Elem::from_be_bytes_padded(qInv, &p.modulus)?
} else {
let q_mod_p = bigint::elem_mul(p.oneRR.as_ref(),
q_mod_p.clone(),
&p.modulus);
bigint::elem_inverse_consttime(q_mod_p, &p.modulus, &p.oneR)?
};
let qInv =
bigint::elem_mul(p.oneRR.as_ref(), qInv, &p.modulus);
bigint::verify_inverses_consttime(&qInv, q_mod_p, &p.modulus)?;
let qq = bigint::elem_mul(&q_mod_n, q_mod_n_decoded, &n)
.into_modulus::<QQ>()?;
Ok(RSAKeyPair {
n,
e,
p,
q,
qInv,
oneRR_mod_n,
q_mod_n,
qq,
n_bits
})
})
})
}
pub fn public_modulus_len(&self) -> usize {
self.n_bits.as_usize_bytes_rounded_up()
}
}
struct PrivatePrime<M: Prime> {
modulus: bigint::Modulus<M>,
exponent: bigint::PrivateExponent<M>,
oneR: bigint::One<M, R>,
oneRR: bigint::One<M, RR>,
oneRRR: bigint::One<M, RRR>,
}
impl<M: Prime + Clone> PrivatePrime<M> {
fn new(p: bigint::Nonnegative, dP: untrusted::Input)
-> Result<Self, error::Unspecified> {
let p = bigint::Modulus::from(p)?;
let dP = bigint::PrivateExponent::from_be_bytes_padded(dP, &p)?;
let oneRR = bigint::One::newRR(&p);
let oneR = bigint::One::newR(&oneRR, &p);
let oneRRR = bigint::One::newRRR(oneRR.clone(), &p);
Ok(PrivatePrime {
modulus: p,
exponent: dP,
oneR: oneR,
oneRR: oneRR,
oneRRR: oneRRR,
})
}
}
fn elem_exp_consttime<M, MM>(c: &bigint::Elem<MM>, p: &PrivatePrime<M>)
-> Result<bigint::Elem<M>, error::Unspecified>
where M: bigint::NotMuchSmallerModulus<MM>,
M: Prime {
let c_mod_m = bigint::elem_reduced(c, &p.modulus)?;
let c_mod_m = bigint::elem_mul(p.oneRRR.as_ref(), c_mod_m, &p.modulus);
bigint::elem_exp_consttime(c_mod_m, &p.exponent, &p.oneR, &p.modulus)
}
#[derive(Copy, Clone)]
enum P {}
unsafe impl Prime for P {}
unsafe impl bigint::SmallerModulus<N> for P {}
unsafe impl bigint::NotMuchSmallerModulus<N> for P {}
#[derive(Copy, Clone)]
enum QQ {}
unsafe impl bigint::SmallerModulus<N> for QQ {}
unsafe impl bigint::NotMuchSmallerModulus<N> for QQ {}
unsafe impl bigint::SlightlySmallerModulus<N> for QQ {}
#[derive(Copy, Clone)]
enum Q {}
unsafe impl Prime for Q {}
unsafe impl bigint::SmallerModulus<N> for Q {}
unsafe impl bigint::SmallerModulus<P> for Q {}
unsafe impl bigint::SlightlySmallerModulus<P> for Q {}
unsafe impl bigint::SmallerModulus<QQ> for Q {}
unsafe impl bigint::NotMuchSmallerModulus<QQ> for Q {}
pub struct RSASigningState {
key_pair: std::sync::Arc<RSAKeyPair>,
blinding: blinding::Blinding,
}
impl RSASigningState {
pub fn new(key_pair: std::sync::Arc<RSAKeyPair>)
-> Result<Self, error::Unspecified> {
Ok(RSASigningState {
key_pair: key_pair,
blinding: blinding::Blinding::new(),
})
}
pub fn key_pair(&self) -> &RSAKeyPair { self.key_pair.as_ref() }
#[allow(non_shorthand_field_patterns)] pub fn sign(&mut self, padding_alg: &'static ::signature::RSAEncoding,
rng: &rand::SecureRandom, msg: &[u8], signature: &mut [u8])
-> Result<(), error::Unspecified> {
let mod_bits = self.key_pair.n_bits;
if signature.len() != mod_bits.as_usize_bytes_rounded_up() {
return Err(error::Unspecified);
}
let RSASigningState { key_pair: key, blinding } = self;
let m_hash = digest::digest(padding_alg.digest_alg(), msg);
padding_alg.encode(&m_hash, signature, mod_bits, rng)?;
let base = bigint::Elem::from_be_bytes_padded(
untrusted::Input::from(signature), &key.n)?;
let result = blinding.blind(base, key.e, &key.oneRR_mod_n, &key.n, rng,
|c| {
let m_1 = elem_exp_consttime(&c, &key.p)?;
let c_mod_qq = bigint::elem_reduced_once(&c, &key.qq);
let m_2 = elem_exp_consttime(&c_mod_qq, &key.q)?;
let p = &key.p.modulus;
let m_2 = bigint::elem_widen(m_2, &p);
let m_1_minus_m_2 = bigint::elem_sub(m_1, &m_2, p);
let h = bigint::elem_mul(&key.qInv, m_1_minus_m_2, p);
let h = bigint::elem_widen(h, &key.n);
let q_times_h = bigint::elem_mul(&key.q_mod_n, h, &key.n);
let m_2 = bigint::elem_widen(m_2, &key.n);
let m = bigint::elem_add(m_2, q_times_h, &key.n);
let computed =
bigint::elem_mul(&key.oneRR_mod_n.as_ref(), m.clone(), &key.n);
let verify = bigint::elem_exp_vartime(computed, key.e, &key.n);
let verify = verify.into_unencoded(&key.n);
bigint::elem_verify_equal_consttime(&verify, &c)?;
Ok(m)
})?;
result.fill_be_bytes(signature);
Ok(())
}
}
#[cfg(test)]
mod tests {
use core;
use {rand, signature, test};
use std;
use super::super::blinding;
use untrusted;
#[test]
fn test_signature_rsa_pkcs1_sign_output_buffer_len() {
const MESSAGE: &'static [u8] = b"hello, world";
let rng = rand::SystemRandom::new();
const PRIVATE_KEY_DER: &'static [u8] =
include_bytes!("signature_rsa_example_private_key.der");
let key_bytes_der = untrusted::Input::from(PRIVATE_KEY_DER);
let key_pair = signature::RSAKeyPair::from_der(key_bytes_der).unwrap();
let key_pair = std::sync::Arc::new(key_pair);
let mut signing_state =
signature::RSASigningState::new(key_pair).unwrap();
let mut signature =
vec![0; signing_state.key_pair().public_modulus_len() - 1];
assert!(signing_state.sign(&signature::RSA_PKCS1_SHA256, &rng, MESSAGE,
&mut signature).is_err());
signature.push(0);
assert!(signing_state.sign(&signature::RSA_PKCS1_SHA256, &rng, MESSAGE,
&mut signature).is_ok());
signature.push(0);
assert!(signing_state.sign(&signature::RSA_PKCS1_SHA256, &rng, MESSAGE,
&mut signature).is_err());
}
#[test]
fn test_signature_rsa_pkcs1_sign_blinding_reuse() {
const MESSAGE: &'static [u8] = b"hello, world";
let rng = rand::SystemRandom::new();
const PRIVATE_KEY_DER: &'static [u8] =
include_bytes!("signature_rsa_example_private_key.der");
let key_bytes_der = untrusted::Input::from(PRIVATE_KEY_DER);
let key_pair = signature::RSAKeyPair::from_der(key_bytes_der).unwrap();
let key_pair = std::sync::Arc::new(key_pair);
let mut signature = vec![0; key_pair.public_modulus_len()];
let mut signing_state =
signature::RSASigningState::new(key_pair).unwrap();
for _ in 0..(blinding::REMAINING_MAX + 1) {
let prev_remaining = signing_state.blinding.remaining();
let _ = signing_state.sign(&signature::RSA_PKCS1_SHA256, &rng,
MESSAGE, &mut signature);
let remaining = signing_state.blinding.remaining();
assert_eq!((remaining + 1) % blinding::REMAINING_MAX,
prev_remaining);
}
}
#[test]
fn test_signature_rsa_pkcs1_sign_blinding_creation_failure() {
const MESSAGE: &'static [u8] = b"hello, world";
const PRIVATE_KEY_DER: &'static [u8] =
include_bytes!("signature_rsa_example_private_key.der");
let key_bytes_der = untrusted::Input::from(PRIVATE_KEY_DER);
let key_pair = signature::RSAKeyPair::from_der(key_bytes_der).unwrap();
let mut inverse_blinding_factor =
vec![0u8; key_pair.public_modulus_len()];
inverse_blinding_factor[0] = 1;
let zero = vec![0u8; key_pair.public_modulus_len()];
let mut bytes = std::vec::Vec::new();
bytes.push(&inverse_blinding_factor[..]);
for _ in 0..100 {
bytes.push(&zero[..]);
}
let rng = test::rand::FixedSliceSequenceRandom {
bytes: &bytes[..],
current: core::cell::UnsafeCell::new(0),
};
let key_pair = std::sync::Arc::new(key_pair);
let mut signing_state =
signature::RSASigningState::new(key_pair).unwrap();
let mut signature =
vec![0; signing_state.key_pair().public_modulus_len()];
let result = signing_state.sign(&signature::RSA_PKCS1_SHA256, &rng,
MESSAGE, &mut signature);
assert!(result.is_err());
}
}