use {bits, der, digest, error};
use rand;
use std;
use super::{blinding, bigint, N};
use super::bigint::{R, RR, RRR};
use untrusted;
pub struct RSAKeyPair {
n: bigint::Modulus<N>,
e: bigint::PublicExponent,
p: PrivatePrime<P>,
q: PrivatePrime<Q>,
qInv: bigint::Elem<P, R>,
oneRR_mod_n: bigint::One<N, RR>,
qq: bigint::Modulus<QQ>,
q_mod_n: bigint::Elem<N, R>,
n_bits: bits::BitLength,
}
unsafe impl Sync for RSAKeyPair {}
impl RSAKeyPair {
pub fn from_der(input: untrusted::Input)
-> Result<RSAKeyPair, error::Unspecified> {
input.read_all(error::Unspecified, |input| {
der::nested(input, der::Tag::Sequence, error::Unspecified, |input| {
let version = try!(der::small_nonnegative_integer(input));
if version != 0 {
return Err(error::Unspecified);
}
let n = try!(bigint::Positive::from_der(input));
let e = try!(bigint::Positive::from_der(input));
let d = try!(bigint::Positive::from_der(input));
let p = try!(bigint::Positive::from_der(input));
let q = try!(bigint::Positive::from_der(input));
let dP = try!(bigint::Positive::from_der(input));
let dQ = try!(bigint::Positive::from_der(input));
let qInv = try!(bigint::Positive::from_der(input));
let n_bits = n.bit_length();
let (n, e) = try!(super::check_public_modulus_and_exponent(
n, e, bits::BitLength::from_usize_bits(2048),
super::PRIVATE_KEY_PUBLIC_MODULUS_MAX_BITS,
bits::BitLength::from_usize_bits(17)));
let half_n_bits = n_bits.half_rounded_up();
if p.bit_length() != half_n_bits {
return Err(error::Unspecified);
}
let p = try!(p.into_odd_positive());
if p.bit_length() != q.bit_length() {
return Err(error::Unspecified);
}
let q = try!(q.into_odd_positive());
let n = try!(n.into_modulus::<N>());
let oneRR_mod_n = try!(bigint::One::newRR(&n));
let q_mod_n_decoded = {
let q = try!(q.try_clone());
try!(q.into_elem(&n))
};
try!(q.verify_less_than(&p));
{
let p_mod_n = {
let p = try!(p.try_clone());
try!(p.into_elem(&n))
};
let p_minus_q_bits = {
let p_minus_q =
try!(bigint::elem_sub(p_mod_n, &q_mod_n_decoded,
&n));
p_minus_q.bit_length()
};
let min_pq_bitlen_diff = try!(half_n_bits.try_sub(
bits::BitLength::from_usize_bits(100)));
if p_minus_q_bits <= min_pq_bitlen_diff {
return Err(error::Unspecified);
}
}
let q_mod_n = {
let clone = try!(q_mod_n_decoded.try_clone());
try!(bigint::elem_mul(oneRR_mod_n.as_ref(), clone, &n))
};
let p_mod_n = {
let p = try!(p.try_clone());
try!(p.into_elem(&n))
};
let pq_mod_n =
try!(bigint::elem_mul(&q_mod_n, p_mod_n, &n));
if !pq_mod_n.is_zero() {
return Err(error::Unspecified);
}
if !(half_n_bits < d.bit_length()) {
return Err(error::Unspecified);
}
let d = try!(d.into_odd_positive());
try!(d.verify_less_than(&n.value()));
let p = try!(PrivatePrime::new(p, dP));
let qInv = try!(qInv.into_elem(&p.modulus));
let qInv =
try!(bigint::elem_mul(p.oneRR.as_ref(), qInv, &p.modulus));
let q_mod_p = {
let q = try!(q.try_clone());
try!(q.into_elem(&p.modulus))
};
let qInv_times_q_mod_p =
try!(bigint::elem_mul(&qInv, q_mod_p, &p.modulus));
if !qInv_times_q_mod_p.is_one() {
return Err(error::Unspecified);
}
let q = try!(PrivatePrime::new(q, dQ));
let qq =
try!(bigint::elem_mul(&q_mod_n, q_mod_n_decoded, &n));
let qq = try!(qq.into_modulus::<QQ>());
Ok(RSAKeyPair {
n: n,
e: e,
p: p,
q: q,
qInv: qInv,
oneRR_mod_n: oneRR_mod_n,
q_mod_n: q_mod_n,
qq: qq,
n_bits: n_bits,
})
})
})
}
pub fn public_modulus_len(&self) -> usize {
self.n_bits.as_usize_bytes_rounded_up()
}
}
struct PrivatePrime<M: Prime> {
modulus: bigint::Modulus<M>,
exponent: bigint::OddPositive,
oneR: bigint::One<M, R>,
oneRR: bigint::One<M, RR>,
oneRRR: bigint::One<M, RRR>,
}
impl<M: Prime> PrivatePrime<M> {
fn new(p: bigint::OddPositive, dP: bigint::Positive)
-> Result<Self, error::Unspecified> {
let dP = try!(dP.into_odd_positive());
try!(dP.verify_less_than(&p));
let p = try!(p.into_modulus());
let oneRR = try!(bigint::One::newRR(&p));
let oneRR_clone = try!(oneRR.try_clone());
let oneR = try!(bigint::One::newR(&oneRR, &p));
let oneRRR = try!(bigint::One::newRRR(oneRR_clone, &p));
Ok(PrivatePrime {
modulus: p,
exponent: dP,
oneR: oneR,
oneRR: oneRR,
oneRRR: oneRRR,
})
}
}
fn elem_exp_consttime<M, MM>(c: &bigint::Elem<MM>, p: &PrivatePrime<M>)
-> Result<bigint::Elem<M>, error::Unspecified>
where M: bigint::NotMuchSmallerModulus<MM>,
M: Prime {
let c_mod_m = try!(bigint::elem_reduced(c, &p.modulus));
let c_mod_m = try!(bigint::elem_mul(p.oneRRR.as_ref(), c_mod_m, &p.modulus));
bigint::elem_exp_consttime(c_mod_m, &p.exponent, &p.oneR, &p.modulus)
}
unsafe trait Prime {}
enum P {}
unsafe impl Prime for P {}
unsafe impl bigint::SmallerModulus<N> for P {}
unsafe impl bigint::NotMuchSmallerModulus<N> for P {}
enum QQ {}
unsafe impl bigint::SmallerModulus<N> for QQ {}
unsafe impl bigint::NotMuchSmallerModulus<N> for QQ {}
unsafe impl bigint::SlightlySmallerModulus<N> for QQ {}
enum Q {}
unsafe impl Prime for Q {}
unsafe impl bigint::SmallerModulus<N> for Q {}
unsafe impl bigint::SmallerModulus<P> for Q {}
unsafe impl bigint::SlightlySmallerModulus<P> for Q {}
unsafe impl bigint::SmallerModulus<QQ> for Q {}
unsafe impl bigint::NotMuchSmallerModulus<QQ> for Q {}
pub struct RSASigningState {
key_pair: std::sync::Arc<RSAKeyPair>,
blinding: blinding::Blinding,
}
impl RSASigningState {
pub fn new(key_pair: std::sync::Arc<RSAKeyPair>)
-> Result<Self, error::Unspecified> {
Ok(RSASigningState {
key_pair: key_pair,
blinding: blinding::Blinding::new(),
})
}
pub fn key_pair(&self) -> &RSAKeyPair { self.key_pair.as_ref() }
#[allow(non_shorthand_field_patterns)] pub fn sign(&mut self, padding_alg: &'static ::signature::RSAEncoding,
rng: &rand::SecureRandom, msg: &[u8], signature: &mut [u8])
-> Result<(), error::Unspecified> {
let mod_bits = self.key_pair.n_bits;
if signature.len() != mod_bits.as_usize_bytes_rounded_up() {
return Err(error::Unspecified);
}
let &mut RSASigningState {
key_pair: ref key,
blinding: ref mut blinding,
} = self;
let m_hash = digest::digest(padding_alg.digest_alg(), msg);
try!(padding_alg.encode(&m_hash, signature, mod_bits, rng));
let base = try!(bigint::Positive::from_be_bytes_padded(
untrusted::Input::from(signature)));
let base = try!(base.into_elem(&key.n));
let result = try!(blinding.blind(base, key.e, &key.n, &key.oneRR_mod_n,
rng, |c| {
let m_1 = try!(elem_exp_consttime(&c, &key.p));
let c_mod_qq = try!(bigint::elem_reduced_once(&c, &key.qq));
let m_2 = try!(elem_exp_consttime(&c_mod_qq, &key.q));
let p = &key.p.modulus;
let m_2 = bigint::elem_widen(m_2);
let m_1_minus_m_2 = try!(bigint::elem_sub(m_1, &m_2, p));
let h = try!(bigint::elem_mul(&key.qInv, m_1_minus_m_2, p));
let h = bigint::elem_widen(h);
let q_times_h = try!(bigint::elem_mul(&key.q_mod_n, h, &key.n));
let m_2 = bigint::elem_widen(m_2);
let m = try!(bigint::elem_add(&m_2, q_times_h, &key.n));
let computed = try!(m.try_clone());
let computed =
try!(bigint::elem_mul(&key.oneRR_mod_n.as_ref(), computed,
&key.n));
let verify =
try!(bigint::elem_exp_vartime(computed, key.e, &key.n));
let verify = try!(verify.into_unencoded(&key.n));
try!(bigint::elem_verify_equal_consttime(&verify, &c));
Ok(m)
}));
result.fill_be_bytes(signature)
}
}
#[cfg(test)]
mod tests {
use core;
use {error, rand, signature, test};
use std;
use super::super::blinding;
use untrusted;
#[test]
fn test_signature_rsa_pkcs1_sign() {
let rng = rand::SystemRandom::new();
test::from_file("src/rsa/rsa_pkcs1_sign_tests.txt",
|section, test_case| {
assert_eq!(section, "");
let digest_name = test_case.consume_string("Digest");
let alg = match digest_name.as_ref() {
"SHA256" => &signature::RSA_PKCS1_SHA256,
"SHA384" => &signature::RSA_PKCS1_SHA384,
"SHA512" => &signature::RSA_PKCS1_SHA512,
_ => { panic!("Unsupported digest: {}", digest_name) }
};
let private_key = test_case.consume_bytes("Key");
let msg = test_case.consume_bytes("Msg");
let expected = test_case.consume_bytes("Sig");
let result = test_case.consume_string("Result");
let private_key = untrusted::Input::from(&private_key);
let key_pair = signature::RSAKeyPair::from_der(private_key);
if result == "Fail-Invalid-Key" {
assert!(key_pair.is_err());
return Ok(());
}
let key_pair = key_pair.unwrap();
let key_pair = std::sync::Arc::new(key_pair);
let mut signing_state =
signature::RSASigningState::new(key_pair).unwrap();
let mut actual: std::vec::Vec<u8> =
vec![0; signing_state.key_pair().public_modulus_len()];
signing_state.sign(alg, &rng, &msg, actual.as_mut_slice()).unwrap();
assert_eq!(actual.as_slice() == &expected[..], result == "Pass");
Ok(())
});
}
#[test]
fn test_signature_rsa_pkcs1_sign_output_buffer_len() {
const MESSAGE: &'static [u8] = b"hello, world";
let rng = rand::SystemRandom::new();
const PRIVATE_KEY_DER: &'static [u8] =
include_bytes!("signature_rsa_example_private_key.der");
let key_bytes_der = untrusted::Input::from(PRIVATE_KEY_DER);
let key_pair = signature::RSAKeyPair::from_der(key_bytes_der).unwrap();
let key_pair = std::sync::Arc::new(key_pair);
let mut signing_state =
signature::RSASigningState::new(key_pair).unwrap();
let mut signature =
vec![0; signing_state.key_pair().public_modulus_len() - 1];
assert!(signing_state.sign(&signature::RSA_PKCS1_SHA256, &rng, MESSAGE,
&mut signature).is_err());
signature.push(0);
assert!(signing_state.sign(&signature::RSA_PKCS1_SHA256, &rng, MESSAGE,
&mut signature).is_ok());
signature.push(0);
assert!(signing_state.sign(&signature::RSA_PKCS1_SHA256, &rng, MESSAGE,
&mut signature).is_err());
}
#[test]
fn test_signature_rsa_pkcs1_sign_blinding_reuse() {
const MESSAGE: &'static [u8] = b"hello, world";
let rng = rand::SystemRandom::new();
const PRIVATE_KEY_DER: &'static [u8] =
include_bytes!("signature_rsa_example_private_key.der");
let key_bytes_der = untrusted::Input::from(PRIVATE_KEY_DER);
let key_pair = signature::RSAKeyPair::from_der(key_bytes_der).unwrap();
let key_pair = std::sync::Arc::new(key_pair);
let mut signature = vec![0; key_pair.public_modulus_len()];
let mut signing_state =
signature::RSASigningState::new(key_pair).unwrap();
for _ in 0..(blinding::REMAINING_MAX + 1) {
let prev_remaining = signing_state.blinding.remaining();
let _ = signing_state.sign(&signature::RSA_PKCS1_SHA256, &rng,
MESSAGE, &mut signature);
let remaining = signing_state.blinding.remaining();
assert_eq!((remaining + 1) % blinding::REMAINING_MAX,
prev_remaining);
}
}
#[test]
fn test_signature_rsa_pkcs1_sign_blinding_creation_failure() {
const MESSAGE: &'static [u8] = b"hello, world";
const PRIVATE_KEY_DER: &'static [u8] =
include_bytes!("signature_rsa_example_private_key.der");
let key_bytes_der = untrusted::Input::from(PRIVATE_KEY_DER);
let key_pair = signature::RSAKeyPair::from_der(key_bytes_der).unwrap();
let mut inverse_blinding_factor =
vec![0u8; key_pair.public_modulus_len()];
inverse_blinding_factor[0] = 1;
let zero = vec![0u8; key_pair.public_modulus_len()];
let mut bytes = std::vec::Vec::new();
bytes.push(&inverse_blinding_factor[..]);
for _ in 0..100 {
bytes.push(&zero[..]);
}
let rng = test::rand::FixedSliceSequenceRandom {
bytes: &bytes[..],
current: core::cell::UnsafeCell::new(0),
};
let key_pair = std::sync::Arc::new(key_pair);
let mut signing_state =
signature::RSASigningState::new(key_pair).unwrap();
let mut signature =
vec![0; signing_state.key_pair().public_modulus_len()];
let result = signing_state.sign(&signature::RSA_PKCS1_SHA256, &rng,
MESSAGE, &mut signature);
assert!(result.is_err());
}
#[cfg(feature = "rsa_signing")]
#[test]
fn test_signature_rsa_pss_sign() {
struct DeterministicSalt<'a> {
salt: &'a [u8],
rng: &'a rand::SecureRandom
}
impl<'a> rand::SecureRandom for DeterministicSalt<'a> {
fn fill(&self, dest: &mut [u8]) -> Result<(), error::Unspecified> {
let dest_len = dest.len();
if dest_len != self.salt.len() {
try!(self.rng.fill(dest));
} else {
dest.copy_from_slice(&self.salt);
}
Ok(())
}
}
let rng = rand::SystemRandom::new();
test::from_file("src/rsa/rsa_pss_sign_tests.txt", |section, test_case| {
assert_eq!(section, "");
let digest_name = test_case.consume_string("Digest");
let alg = match digest_name.as_ref() {
"SHA256" => &signature::RSA_PSS_SHA256,
"SHA384" => &signature::RSA_PSS_SHA384,
"SHA512" => &signature::RSA_PSS_SHA512,
_ => { panic!("Unsupported digest: {}", digest_name) }
};
let result = test_case.consume_string("Result");
let private_key = test_case.consume_bytes("Key");
let private_key = untrusted::Input::from(&private_key);
let key_pair = signature::RSAKeyPair::from_der(private_key);
if key_pair.is_err() && result == "Fail-Invalid-Key" {
return Ok(());
}
let key_pair = key_pair.unwrap();
let key_pair = std::sync::Arc::new(key_pair);
let msg = test_case.consume_bytes("Msg");
let salt = test_case.consume_bytes("Salt");
let expected = test_case.consume_bytes("Sig");
let new_rng = DeterministicSalt { salt: &salt, rng: &rng };
let mut signing_state =
signature::RSASigningState::new(key_pair).unwrap();
let mut actual: std::vec::Vec<u8> =
vec![0; signing_state.key_pair().public_modulus_len()];
try!(signing_state.sign(alg, &new_rng, &msg, actual.as_mut_slice()));
assert_eq!(actual.as_slice() == &expected[..], result == "Pass");
Ok(())
});
}
#[test]
fn test_sync_and_send() {
const PRIVATE_KEY_DER: &'static [u8] =
include_bytes!("signature_rsa_example_private_key.der");
let key_bytes_der = untrusted::Input::from(PRIVATE_KEY_DER);
let key_pair = signature::RSAKeyPair::from_der(key_bytes_der).unwrap();
let key_pair = std::sync::Arc::new(key_pair);
let _: &Send = &key_pair;
let _: &Sync = &key_pair;
let signing_state = signature::RSASigningState::new(key_pair).unwrap();
let _: &Send = &signing_state;
}
}