use alloc::vec;
use alloc::vec::Vec;
use super::{Error, Pkcs1Digest};
use crate::ct::{ConstantTimeEq, ConstantTimeLess};
use crate::hash::Digest;
use crate::rng::RngCore;
pub(crate) trait RawPublic {
fn key_size(&self) -> usize;
fn modulus_bits(&self) -> usize;
fn raw_public(&self, m: &[u8]) -> Vec<u8>;
}
pub(crate) trait PublicModulus {
fn modulus_be_bytes(&self) -> Vec<u8>;
}
fn ct_lt_be(a: &[u8], b: &[u8]) -> bool {
debug_assert_eq!(a.len(), b.len());
let mut lt: u8 = 0;
let mut gt: u8 = 0;
for (&x, &y) in a.iter().zip(b.iter()) {
let undecided = !(lt | gt);
lt |= undecided & ct_lt_u8(x, y);
gt |= undecided & ct_lt_u8(y, x);
}
(lt & 1) == 1
}
#[inline]
fn ct_lt_u8(a: u8, b: u8) -> u8 {
0u8.wrapping_sub(a.ct_lt(&b).unwrap_u8())
}
pub(crate) trait RawPrivate {
fn key_size(&self) -> usize;
fn modulus_bits(&self) -> usize;
fn raw_private(&self, c: &[u8]) -> Vec<u8>;
fn secret_seed(&self) -> [u8; 32];
}
pub(crate) fn encrypt_pkcs1v15<K: RawPublic, R: RngCore>(
key: &K,
msg: &[u8],
rng: &mut R,
) -> Result<Vec<u8>, Error> {
let k = key.key_size();
if msg.len() + 11 > k {
return Err(Error::MessageTooLong);
}
let ps_len = k - msg.len() - 3;
let mut em = vec![0u8; k];
em[1] = 0x02;
fill_nonzero(&mut em[2..2 + ps_len], rng);
em[k - msg.len()..].copy_from_slice(msg);
Ok(key.raw_public(&em))
}
pub(crate) fn decrypt_pkcs1v15<K: RawPrivate>(key: &K, ct: &[u8]) -> Result<Vec<u8>, Error> {
let k = key.key_size();
if ct.len() != k {
return Err(Error::InvalidLength);
}
if k < 11 {
return Err(Error::InvalidLength);
}
let em = key.raw_private(ct);
let mut bad: u8 = em[0]; bad |= em[1] ^ 0x02;
let mut found: u8 = 0;
let mut sep_idx: u32 = 0;
for (i, &b) in em.iter().enumerate().skip(2) {
let is_zero = ct_eq_u8(b, 0x00) & !found;
let mask = 0u32.wrapping_sub((is_zero & 1) as u32);
sep_idx |= (i as u32) & mask;
found |= is_zero;
}
bad |= !found;
let too_small = sep_idx.ct_lt(&10u32).unwrap_u8();
bad |= 0u8.wrapping_sub(too_small);
if bad != 0 {
return Err(Error::Decryption);
}
Ok(em[(sep_idx as usize) + 1..].to_vec())
}
pub(crate) fn decrypt_pkcs1v15_session<K: RawPrivate>(
key: &K,
ct: &[u8],
expected_len: usize,
) -> Result<Vec<u8>, Error> {
use crate::ct::ConditionallySelectable;
use crate::hash::HmacSha256;
let k = key.key_size();
if ct.len() != k {
return Err(Error::InvalidLength);
}
if k < 11 {
return Err(Error::InvalidLength);
}
let em = key.raw_private(ct);
let mut bad: u8 = em[0];
bad |= em[1] ^ 0x02;
let mut found: u8 = 0;
let mut sep_idx: u32 = 0;
for (i, &b) in em.iter().enumerate().skip(2) {
let is_zero = ct_eq_u8(b, 0x00) & !found;
let mask = 0u32.wrapping_sub((is_zero & 1) as u32);
sep_idx |= (i as u32) & mask;
found |= is_zero;
}
bad |= !found;
let too_small = sep_idx.ct_lt(&10u32).unwrap_u8();
bad |= 0u8.wrapping_sub(too_small);
let key_secret = key.secret_seed();
let mut fallback = Vec::with_capacity(expected_len);
let mut counter: u32 = 0;
while fallback.len() < expected_len {
let mut h = HmacSha256::new(&key_secret);
h.update(b"purecrypto-rsa-pkcs1v15-implicit-reject-v1");
h.update(ct);
h.update(&counter.to_be_bytes());
let tag = h.finalize();
fallback.extend_from_slice(tag.as_ref());
counter += 1;
}
fallback.truncate(expected_len);
let real_start = (sep_idx as usize).saturating_add(1);
let mut fold = bad;
fold |= fold >> 4;
fold |= fold >> 2;
fold |= fold >> 1;
let bad_choice = crate::ct::Choice::from(fold & 1);
let mut out = Vec::with_capacity(expected_len);
for (i, &fallback_byte) in fallback.iter().enumerate() {
let idx = real_start.saturating_add(i).min(em.len().saturating_sub(1));
let real_byte = em[idx];
out.push(u8::conditional_select(
&fallback_byte,
&real_byte,
bad_choice,
));
}
Ok(out)
}
pub(crate) fn sign_pkcs1v15<D: Pkcs1Digest, K: RawPrivate>(
key: &K,
msg: &[u8],
) -> Result<Vec<u8>, Error> {
let em = encode_pkcs1v15::<D>(msg, key.key_size())?;
Ok(key.raw_private(&em))
}
pub(crate) fn verify_pkcs1v15<D: Pkcs1Digest, K: RawPublic + PublicModulus>(
key: &K,
msg: &[u8],
sig: &[u8],
) -> Result<(), Error> {
let k = key.key_size();
if sig.len() != k {
return Err(Error::InvalidLength);
}
if !ct_lt_be(sig, &key.modulus_be_bytes()) {
return Err(Error::Verification);
}
let em = key.raw_public(sig);
let expected = encode_pkcs1v15::<D>(msg, k)?;
if bool::from(em.as_slice().ct_eq(expected.as_slice())) {
Ok(())
} else {
Err(Error::Verification)
}
}
#[cfg(feature = "tls-legacy")]
fn encode_pkcs1v15_prehashed(t: &[u8], k: usize) -> Result<Vec<u8>, Error> {
if t.len() + 11 > k {
return Err(Error::MessageTooLong);
}
let ps_len = k - t.len() - 3;
let mut em = vec![0u8; k];
em[1] = 0x01;
for b in &mut em[2..2 + ps_len] {
*b = 0xff;
}
let t_start = 2 + ps_len + 1;
em[t_start..].copy_from_slice(t);
Ok(em)
}
#[cfg(feature = "tls-legacy")]
pub(crate) fn sign_pkcs1v15_raw<K: RawPrivate>(key: &K, t: &[u8]) -> Result<Vec<u8>, Error> {
let em = encode_pkcs1v15_prehashed(t, key.key_size())?;
Ok(key.raw_private(&em))
}
#[cfg(feature = "tls-legacy")]
pub(crate) fn verify_pkcs1v15_raw<K: RawPublic + PublicModulus>(
key: &K,
t: &[u8],
sig: &[u8],
) -> Result<(), Error> {
let k = key.key_size();
if sig.len() != k {
return Err(Error::InvalidLength);
}
if !ct_lt_be(sig, &key.modulus_be_bytes()) {
return Err(Error::Verification);
}
let em = key.raw_public(sig);
let expected = encode_pkcs1v15_prehashed(t, k)?;
if bool::from(em.as_slice().ct_eq(expected.as_slice())) {
Ok(())
} else {
Err(Error::Verification)
}
}
fn encode_pkcs1v15<D: Pkcs1Digest>(msg: &[u8], k: usize) -> Result<Vec<u8>, Error> {
let digest = D::digest(msg);
let prefix = D::DIGEST_INFO_PREFIX;
let t_len = prefix.len() + digest.as_ref().len();
if t_len + 11 > k {
return Err(Error::MessageTooLong);
}
let ps_len = k - t_len - 3;
let mut em = vec![0u8; k];
em[1] = 0x01;
for b in &mut em[2..2 + ps_len] {
*b = 0xff;
}
let t_start = 2 + ps_len + 1;
em[t_start..t_start + prefix.len()].copy_from_slice(prefix);
em[t_start + prefix.len()..].copy_from_slice(digest.as_ref());
Ok(em)
}
fn fill_nonzero<R: RngCore>(dst: &mut [u8], rng: &mut R) {
for slot in dst.iter_mut() {
loop {
let mut b = [0u8; 1];
rng.fill_bytes(&mut b);
if b[0] != 0 {
*slot = b[0];
break;
}
}
}
}
pub(crate) fn sign_pss<D: Digest, K: RawPrivate, R: RngCore>(
key: &K,
msg: &[u8],
rng: &mut R,
) -> Result<Vec<u8>, Error> {
let em = emsa_pss_encode::<D, R>(msg, key.modulus_bits() - 1, rng)?;
Ok(key.raw_private(&em))
}
pub(crate) fn verify_pss<D: Digest, K: RawPublic + PublicModulus>(
key: &K,
msg: &[u8],
sig: &[u8],
) -> Result<(), Error> {
let k = key.key_size();
if sig.len() != k {
return Err(Error::InvalidLength);
}
if !ct_lt_be(sig, &key.modulus_be_bytes()) {
return Err(Error::Verification);
}
let m = key.raw_public(sig);
let em_bits = key.modulus_bits() - 1;
let em_len = em_bits.div_ceil(8);
if m[..k - em_len].iter().any(|&b| b != 0) {
return Err(Error::Verification);
}
emsa_pss_verify::<D>(msg, &m[k - em_len..], em_bits)
}
pub(crate) fn encrypt_oaep<D: Digest, K: RawPublic, R: RngCore>(
key: &K,
msg: &[u8],
label: &[u8],
rng: &mut R,
) -> Result<Vec<u8>, Error> {
let k = key.key_size();
let h_len = D::OUTPUT_LEN;
if k < 2 * h_len + 2 || msg.len() > k - 2 * h_len - 2 {
return Err(Error::MessageTooLong);
}
let mut db = vec![0u8; k - h_len - 1];
db[..h_len].copy_from_slice(D::digest(label).as_ref());
let one_off = k - msg.len() - h_len - 2; db[one_off] = 0x01;
db[one_off + 1..].copy_from_slice(msg);
let mut seed = vec![0u8; h_len];
rng.fill_bytes(&mut seed);
let db_mask = mgf1::<D>(&seed, k - h_len - 1);
for (b, m) in db.iter_mut().zip(db_mask.iter()) {
*b ^= m;
}
let seed_mask = mgf1::<D>(&db, h_len);
for (s, m) in seed.iter_mut().zip(seed_mask.iter()) {
*s ^= m;
}
let mut em = vec![0u8; k];
em[1..1 + h_len].copy_from_slice(&seed);
em[1 + h_len..].copy_from_slice(&db);
Ok(key.raw_public(&em))
}
pub(crate) fn decrypt_oaep<D: Digest, K: RawPrivate>(
key: &K,
ciphertext: &[u8],
label: &[u8],
) -> Result<Vec<u8>, Error> {
let k = key.key_size();
let h_len = D::OUTPUT_LEN;
if ciphertext.len() != k || k < 2 * h_len + 2 {
return Err(Error::Decryption);
}
let em = key.raw_private(ciphertext);
let y = em[0];
let masked_seed = &em[1..1 + h_len];
let masked_db = &em[1 + h_len..];
let seed_mask = mgf1::<D>(masked_db, h_len);
let mut seed = vec![0u8; h_len];
for i in 0..h_len {
seed[i] = masked_seed[i] ^ seed_mask[i];
}
let db_mask = mgf1::<D>(&seed, k - h_len - 1);
let mut db = vec![0u8; k - h_len - 1];
for i in 0..db.len() {
db[i] = masked_db[i] ^ db_mask[i];
}
let l_hash = D::digest(label);
let mut bad: u8 = y;
let mut diff: u8 = 0;
for (b, h) in db.iter().take(h_len).zip(l_hash.as_ref().iter()) {
diff |= b ^ h;
}
bad |= diff;
let ps_region = &db[h_len..];
let mut found: u8 = 0;
let mut sep_idx: usize = 0;
let mut pre_bad: u8 = 0;
for (i, &b) in ps_region.iter().enumerate() {
let is_one = ct_eq_u8(b, 0x01) & !found;
let mask = 0usize.wrapping_sub((is_one & 1) as usize);
sep_idx |= i & mask; found |= is_one;
pre_bad |= b & !found;
}
bad |= !found;
bad |= pre_bad;
if bad != 0 {
return Err(Error::Decryption);
}
Ok(ps_region[sep_idx + 1..].to_vec())
}
#[inline]
fn ct_eq_u8(a: u8, b: u8) -> u8 {
0u8.wrapping_sub(a.ct_eq(&b).unwrap_u8())
}
pub(crate) fn mgf1<D: Digest>(seed: &[u8], mask_len: usize) -> Vec<u8> {
let mut mask = Vec::with_capacity(mask_len);
let mut counter: u32 = 0;
while mask.len() < mask_len {
let mut h = D::new();
h.update(seed);
h.update(&counter.to_be_bytes());
mask.extend_from_slice(h.finalize().as_ref());
counter += 1;
}
mask.truncate(mask_len);
mask
}
fn emsa_pss_encode<D: Digest, R: RngCore>(
msg: &[u8],
em_bits: usize,
rng: &mut R,
) -> Result<Vec<u8>, Error> {
let h_len = D::OUTPUT_LEN;
let s_len = h_len;
let em_len = em_bits.div_ceil(8);
if em_len < h_len + s_len + 2 {
return Err(Error::MessageTooLong);
}
let m_hash = D::digest(msg);
let mut salt = vec![0u8; s_len];
rng.fill_bytes(&mut salt);
let mut m_prime = vec![0u8; 8];
m_prime.extend_from_slice(m_hash.as_ref());
m_prime.extend_from_slice(&salt);
let h = D::digest(&m_prime);
let db_len = em_len - h_len - 1;
let mut db = vec![0u8; db_len];
db[db_len - s_len - 1] = 0x01;
db[db_len - s_len..].copy_from_slice(&salt);
let db_mask = mgf1::<D>(h.as_ref(), db_len);
for (b, m) in db.iter_mut().zip(db_mask.iter()) {
*b ^= *m;
}
let clear = 8 * em_len - em_bits;
if clear > 0 {
db[0] &= 0xff >> clear;
}
let mut em = db;
em.extend_from_slice(h.as_ref());
em.push(0xbc);
Ok(em)
}
fn emsa_pss_verify<D: Digest>(msg: &[u8], em: &[u8], em_bits: usize) -> Result<(), Error> {
let h_len = D::OUTPUT_LEN;
let s_len = h_len;
let em_len = em.len();
if em_len < h_len + s_len + 2 || em[em_len - 1] != 0xbc {
return Err(Error::Verification);
}
let db_len = em_len - h_len - 1;
let masked_db = &em[..db_len];
let h = &em[db_len..db_len + h_len];
let clear = 8 * em_len - em_bits;
if clear > 0 && masked_db[0] & (0xffu8 << (8 - clear)) != 0 {
return Err(Error::Verification);
}
let db_mask = mgf1::<D>(h, db_len);
let mut db = vec![0u8; db_len];
for i in 0..db_len {
db[i] = masked_db[i] ^ db_mask[i];
}
if clear > 0 {
db[0] &= 0xff >> clear;
}
let ps_len = db_len - s_len - 1;
if db[..ps_len].iter().any(|&b| b != 0) || db[ps_len] != 0x01 {
return Err(Error::Verification);
}
let salt = &db[ps_len + 1..];
let m_hash = D::digest(msg);
let mut m_prime = vec![0u8; 8];
m_prime.extend_from_slice(m_hash.as_ref());
m_prime.extend_from_slice(salt);
let h_prime = D::digest(&m_prime);
if bool::from(h_prime.as_ref().ct_eq(h)) {
Ok(())
} else {
Err(Error::Verification)
}
}
#[cfg(test)]
mod tests {
use crate::bignum::BoxedUint;
use crate::hash::Sha256;
use crate::rng::HmacDrbg;
use crate::rsa::BoxedRsaPrivateKey;
fn sk_3072() -> BoxedRsaPrivateKey {
let mut rng = HmacDrbg::<Sha256>::new(b"emsa-sep-idx-regression", b"nonce", &[]);
BoxedRsaPrivateKey::generate(3072, BoxedUint::from_u64(65537), &mut rng, 16)
}
#[test]
#[ignore = "slow in debug; run with --release --ignored"]
fn pkcs1v15_roundtrip_separator_index_above_255() {
let key = sk_3072();
let pk = key.public_key();
let mut rng = HmacDrbg::<Sha256>::new(b"emsa-pkcs1-ct", b"nonce", &[]);
let msg = b"hello"; let ct = pk.encrypt_pkcs1v15(msg, &mut rng).unwrap();
let pt = key.decrypt_pkcs1v15(&ct).unwrap();
assert_eq!(pt.as_slice(), msg);
}
#[test]
#[ignore = "slow in debug; run with --release --ignored"]
fn oaep_roundtrip_separator_index_above_255() {
let key = sk_3072();
let pk = key.public_key();
let mut rng = HmacDrbg::<Sha256>::new(b"emsa-oaep-ct", b"nonce", &[]);
let label = b"";
let msg = b"hi";
let ct = pk.encrypt_oaep::<Sha256, _>(msg, label, &mut rng).unwrap();
let pt = key.decrypt_oaep::<Sha256>(&ct, label).unwrap();
assert_eq!(pt.as_slice(), msg);
}
use crate::bignum::Uint;
use crate::rsa::{Error, RsaPrivateKey};
use crate::test_util::rsa_test_key_a;
fn widened_test_key_33() -> RsaPrivateKey<33> {
let key = rsa_test_key_a();
let widen = |v: &Uint<32>| -> Uint<33> {
let mut be = [0u8; 33 * 8];
v.write_be_bytes(&mut be[8..]);
Uint::<33>::from_be_bytes(&be)
};
RsaPrivateKey::<33>::from_components(
widen(key.modulus()),
widen(key.exponent()),
widen(key.private_exponent()),
)
}
fn sig_plus_modulus(sig: &[u8], n: &Uint<33>) -> alloc::vec::Vec<u8> {
let k = sig.len();
let s = BoxedUint::from_be_bytes(sig);
let n_boxed = BoxedUint::from_be_bytes(&{
let mut b = alloc::vec![0u8; k];
n.write_be_bytes(&mut b);
b
});
let sum = s.add(&n_boxed);
assert!(
sum.bit_len() <= 8 * k,
"s + n overflowed k octets — test needs a non-full-width modulus"
);
sum.to_be_bytes(k)
}
#[test]
fn pkcs1v15_rejects_s_plus_n() {
let sk = rsa_test_key_a();
let pk = sk.public_key();
let n = {
let mut nb = [0u8; 256];
sk.modulus().write_be_bytes(&mut nb);
BoxedUint::from_be_bytes(&nb)
};
let k = 256;
for i in 0u32..256 {
let mut msg = *b"f4-cg-s-plus-n-0000";
msg[15..].copy_from_slice(&i.to_be_bytes());
let sig = sk.sign_pkcs1v15::<Sha256>(&msg).unwrap();
pk.verify_pkcs1v15::<Sha256>(&msg, &sig).unwrap();
let s = BoxedUint::from_be_bytes(&sig);
let sum = s.add(&n);
if sum.bit_len() > 8 * k {
continue; }
let mal = sum.to_be_bytes(k);
assert_ne!(mal, sig, "s + n must differ from s");
assert_eq!(
pk.verify_pkcs1v15::<Sha256>(&msg, &mal),
Err(Error::Verification),
"const-generic: s + n must be rejected (RSAVP1 step 1)"
);
return;
}
panic!("no message yielded a representable s + n in 256 tries");
}
#[test]
fn pss_rejects_s_plus_n_and_exercises_leading_octet() {
let sk = widened_test_key_33();
let pk = sk.public_key();
let k = 33 * 8;
let em_len = (pk.modulus().bit_len() - 1).div_ceil(8);
assert!(
k - em_len > 0,
"expected a non-full-width modulus (k - em_len = {})",
k - em_len
);
let mut rng = HmacDrbg::<Sha256>::new(b"f4-pss", b"nonce", &[]);
let msg = b"f4 pss";
let sig = sk.sign_pss::<Sha256, _>(msg, &mut rng).unwrap();
pk.verify_pss::<Sha256>(msg, &sig).unwrap();
let mal = sig_plus_modulus(&sig, pk.modulus());
assert_ne!(mal, sig, "s + n must differ from s");
assert_eq!(
pk.verify_pss::<Sha256>(msg, &mal),
Err(Error::Verification),
"PSS s + n must be rejected (RSAVP1 step 1)"
);
}
#[test]
fn boxed_pkcs1v15_rejects_s_plus_n() {
let key = rsa_test_key_a();
let mut nb = [0u8; 256];
key.modulus().write_be_bytes(&mut nb);
let mut eb = [0u8; 256];
key.exponent().write_be_bytes(&mut eb);
let boxed_pk = crate::rsa::BoxedRsaPublicKey::new(
BoxedUint::from_be_bytes(&nb),
BoxedUint::from_be_bytes(&eb),
);
let n = BoxedUint::from_be_bytes(&nb);
let k = 256;
for i in 0u32..256 {
let mut msg = *b"f4-boxed-s-plus-n-0000";
msg[18..].copy_from_slice(&i.to_be_bytes());
let sig = key.sign_pkcs1v15::<Sha256>(&msg).unwrap();
boxed_pk.verify_pkcs1v15::<Sha256>(&msg, &sig).unwrap();
let s = BoxedUint::from_be_bytes(&sig);
let sum = s.add(&n);
if sum.bit_len() > 8 * k {
continue; }
let mal = sum.to_be_bytes(k);
assert_ne!(mal, sig);
assert_eq!(
boxed_pk.verify_pkcs1v15::<Sha256>(&msg, &mal),
Err(Error::Verification),
"boxed: s + n must be rejected (RSAVP1 step 1)"
);
return; }
panic!("no message yielded a representable s + n in 256 tries");
}
#[test]
fn ct_lt_be_matches_integer_order() {
let cases: &[(&[u8], &[u8], bool)] = &[
(&[0, 0, 0], &[0, 0, 0], false), (&[0, 0, 1], &[0, 0, 2], true), (&[0, 1, 0], &[0, 2, 0], true), (&[1, 0, 0], &[2, 0, 0], true), (&[2, 0, 0], &[1, 0, 0], false), (&[0, 2, 0], &[1, 0, 0], true), (&[0xff, 0xff], &[0xff, 0xff], false),
(&[0x7f, 0xff], &[0x80, 0x00], true),
];
for (a, b, want) in cases {
assert_eq!(super::ct_lt_be(a, b), *want, "ct_lt_be({a:?}, {b:?})");
}
}
}