use super::bigint::BigInt;
use super::rsa::{RsaPublicKey, RsaSecretKey, rsa_decrypt_raw, rsa_encrypt_raw};
use crate::Hasher;
fn mgf1<H: Hasher>(seed: &[u8], len: usize) -> Vec<u8> {
let h_len = H::OUTPUT_LEN;
let mut out = Vec::with_capacity(len);
let mut counter: u32 = 0;
while out.len() < len {
let mut hasher = H::new();
hasher.update(seed);
hasher.update(&counter.to_be_bytes());
let block = hasher.finalize();
let take = (len - out.len()).min(h_len);
out.extend_from_slice(&block[..take]);
counter += 1;
}
out.truncate(len);
out
}
fn emsa_pss_encode<H: Hasher>(m_hash: &[u8], em_bits: usize, salt: &[u8]) -> Option<Vec<u8>> {
let h_len = H::OUTPUT_LEN;
let s_len = salt.len();
let em_len = (em_bits + 7) / 8;
if m_hash.len() != h_len {
return None;
}
if em_len < h_len + s_len + 2 {
return None;
}
let mut m_prime = Vec::with_capacity(8 + h_len + s_len);
m_prime.extend_from_slice(&[0u8; 8]);
m_prime.extend_from_slice(m_hash);
m_prime.extend_from_slice(salt);
let h = H::hash(&m_prime);
let db_len = em_len - h_len - 1;
let mut db = vec![0u8; db_len];
let ps_len = db_len - s_len - 1;
db[ps_len] = 0x01;
db[ps_len + 1..].copy_from_slice(salt);
let db_mask = mgf1::<H>(&h, db_len);
for i in 0..db_len {
db[i] ^= db_mask[i];
}
let clear_bits = 8 * em_len - em_bits;
if clear_bits > 0 {
db[0] &= 0xff_u8 >> clear_bits;
}
let mut em = Vec::with_capacity(em_len);
em.extend_from_slice(&db);
em.extend_from_slice(&h);
em.push(0xbc);
Some(em)
}
fn emsa_pss_verify<H: Hasher>(m_hash: &[u8], em: &[u8], em_bits: usize, s_len: usize) -> bool {
let h_len = H::OUTPUT_LEN;
let em_len = (em_bits + 7) / 8;
if m_hash.len() != h_len {
return false;
}
if em.len() != em_len {
return false;
}
if em_len < h_len + s_len + 2 {
return false;
}
if em[em_len - 1] != 0xbc {
return false;
}
let db_len = em_len - h_len - 1;
let masked_db = &em[..db_len];
let h = &em[db_len..db_len + h_len];
let clear_bits = 8 * em_len - em_bits;
if clear_bits > 0 && (masked_db[0] >> (8 - clear_bits)) != 0 {
return false;
}
let db_mask = mgf1::<H>(h, db_len);
let mut db = vec![0u8; db_len];
for i in 0..db_len {
db[i] = masked_db[i] ^ db_mask[i];
}
if clear_bits > 0 {
db[0] &= 0xff_u8 >> clear_bits;
}
let ps_len = db_len - s_len - 1;
for byte in &db[..ps_len] {
if *byte != 0 {
return false;
}
}
if db[ps_len] != 0x01 {
return false;
}
let salt = &db[ps_len + 1..];
let mut m_prime = Vec::with_capacity(8 + h_len + s_len);
m_prime.extend_from_slice(&[0u8; 8]);
m_prime.extend_from_slice(m_hash);
m_prime.extend_from_slice(salt);
let h_prime = H::hash(&m_prime);
let mut diff = 0u8;
for (a, b) in h.iter().zip(h_prime.iter()) {
diff |= a ^ b;
}
diff == 0
}
pub fn pss_sign_with_salt<H: Hasher>(sk: &RsaSecretKey, m_hash: &[u8], salt: &[u8]) -> Option<Vec<u8>> {
let k = sk.modulus_byte_len();
let mod_bits = sk.n.bit_len();
let em_bits = mod_bits - 1;
let em = emsa_pss_encode::<H>(m_hash, em_bits, salt)?;
let m = BigInt::from_be_bytes(&em);
let s = rsa_decrypt_raw(sk, &m);
Some(s.to_be_bytes(k))
}
pub fn pss_sign<H: Hasher>(
sk: &RsaSecretKey,
m_hash: &[u8],
s_len: usize,
rng: &mut dyn FnMut(&mut [u8]),
) -> Option<Vec<u8>> {
let mut salt = vec![0u8; s_len];
if s_len > 0 {
rng(&mut salt);
}
pss_sign_with_salt::<H>(sk, m_hash, &salt)
}
pub fn pss_sign_msg<H: Hasher>(
sk: &RsaSecretKey,
msg: &[u8],
s_len: usize,
rng: &mut dyn FnMut(&mut [u8]),
) -> Option<Vec<u8>> {
let digest = H::hash(msg);
pss_sign::<H>(sk, &digest, s_len, rng)
}
pub fn pss_verify<H: Hasher>(pk: &RsaPublicKey, m_hash: &[u8], s_len: usize, sig: &[u8]) -> bool {
let k = pk.modulus_byte_len();
if sig.len() != k {
return false;
}
let mod_bits = pk.n.bit_len();
if mod_bits == 0 {
return false;
}
let em_bits = mod_bits - 1;
let em_len = (em_bits + 7) / 8;
let s = BigInt::from_be_bytes(sig);
let m = rsa_encrypt_raw(pk, &s);
let em = m.to_be_bytes(em_len);
if em.len() != em_len {
return false;
}
emsa_pss_verify::<H>(m_hash, &em, em_bits, s_len)
}
pub fn pss_verify_msg<H: Hasher>(pk: &RsaPublicKey, msg: &[u8], s_len: usize, sig: &[u8]) -> bool {
let digest = H::hash(msg);
pss_verify::<H>(pk, &digest, s_len, sig)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hash::sha256::Sha256;
use crate::hash::sha384::Sha384;
use crate::hash::sha512::Sha512;
fn test_rng() -> impl FnMut(&mut [u8]) {
let mut state: u64 = 0xdeadbeefcafebabe;
move |buf: &mut [u8]| {
for b in buf.iter_mut() {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
*b = (state >> 33) as u8;
}
}
}
#[test]
fn test_emsa_pss_encode_verify_roundtrip_sha256() {
let m_hash = Sha256::hash(b"hello PSS");
let salt = [0x5a; 32];
let em_bits = 2047; let em = emsa_pss_encode::<Sha256>(&m_hash, em_bits, &salt).expect("encode");
assert_eq!(em.len(), (em_bits + 7) / 8);
assert!(emsa_pss_verify::<Sha256>(&m_hash, &em, em_bits, salt.len()));
}
#[test]
fn test_emsa_pss_encode_verify_roundtrip_sha384() {
let m_hash = Sha384::hash(b"hello PSS sha384");
let salt = [0x17; 48];
let em_bits = 3071; let em = emsa_pss_encode::<Sha384>(&m_hash, em_bits, &salt).unwrap();
assert_eq!(em.len(), 384);
assert!(emsa_pss_verify::<Sha384>(&m_hash, &em, em_bits, salt.len()));
}
#[test]
fn test_emsa_pss_encode_verify_roundtrip_sha512() {
let m_hash = Sha512::hash(b"hello PSS sha512");
let salt = [0x00; 64];
let em_bits = 4095; let em = emsa_pss_encode::<Sha512>(&m_hash, em_bits, &salt).unwrap();
assert_eq!(em.len(), 512);
assert!(emsa_pss_verify::<Sha512>(&m_hash, &em, em_bits, salt.len()));
}
#[test]
fn test_emsa_pss_verify_wrong_salt_length_rejects() {
let m_hash = Sha256::hash(b"msg");
let salt = [0xAB; 32];
let em = emsa_pss_encode::<Sha256>(&m_hash, 2047, &salt).unwrap();
assert!(!emsa_pss_verify::<Sha256>(&m_hash, &em, 2047, 16));
}
#[test]
fn test_emsa_pss_verify_tampered_rejects() {
let m_hash = Sha256::hash(b"msg");
let salt = [0x12; 32];
let mut em = emsa_pss_encode::<Sha256>(&m_hash, 2047, &salt).unwrap();
em[0] ^= 0x01;
assert!(!emsa_pss_verify::<Sha256>(&m_hash, &em, 2047, salt.len()));
}
#[test]
fn test_emsa_pss_verify_missing_bc_byte_rejects() {
let m_hash = Sha256::hash(b"msg");
let salt = [0x12; 32];
let mut em = emsa_pss_encode::<Sha256>(&m_hash, 2047, &salt).unwrap();
let last = em.len() - 1;
em[last] = 0xbd;
assert!(!emsa_pss_verify::<Sha256>(&m_hash, &em, 2047, salt.len()));
}
#[test]
fn test_pss_sign_verify_roundtrip_sha256() {
let mut rng = test_rng();
let (pk, sk) = super::super::rsa::rsa_keygen(1024, &mut rng);
let msg = b"PSS end-to-end with SHA-256";
let sig = pss_sign_msg::<Sha256>(&sk, msg, 32, &mut rng).expect("sign");
assert_eq!(sig.len(), pk.modulus_byte_len());
assert!(pss_verify_msg::<Sha256>(&pk, msg, 32, &sig));
}
#[test]
fn test_pss_deterministic_signatures_agree() {
let mut rng = test_rng();
let (pk, sk) = super::super::rsa::rsa_keygen(1024, &mut rng);
let m_hash = Sha256::hash(b"determinism");
let sig1 = pss_sign_with_salt::<Sha256>(&sk, &m_hash, &[]).expect("sign 1");
let sig2 = pss_sign_with_salt::<Sha256>(&sk, &m_hash, &[]).expect("sign 2");
assert_eq!(sig1, sig2);
assert!(pss_verify::<Sha256>(&pk, &m_hash, 0, &sig1));
}
#[test]
fn test_pss_random_signatures_differ() {
let mut rng = test_rng();
let (_pk, sk) = super::super::rsa::rsa_keygen(1024, &mut rng);
let msg = b"randomisation";
let sig1 = pss_sign_msg::<Sha256>(&sk, msg, 32, &mut rng).unwrap();
let sig2 = pss_sign_msg::<Sha256>(&sk, msg, 32, &mut rng).unwrap();
assert_ne!(sig1, sig2);
}
#[test]
fn test_pss_verify_rejects_wrong_message() {
let mut rng = test_rng();
let (pk, sk) = super::super::rsa::rsa_keygen(1024, &mut rng);
let sig = pss_sign_msg::<Sha256>(&sk, b"original", 32, &mut rng).unwrap();
assert!(!pss_verify_msg::<Sha256>(&pk, b"tampered", 32, &sig));
}
#[test]
fn test_pss_verify_rejects_tampered_signature() {
let mut rng = test_rng();
let (pk, sk) = super::super::rsa::rsa_keygen(1024, &mut rng);
let msg = b"msg";
let mut sig = pss_sign_msg::<Sha256>(&sk, msg, 32, &mut rng).unwrap();
sig[0] ^= 0x01;
assert!(!pss_verify_msg::<Sha256>(&pk, msg, 32, &sig));
}
#[test]
fn test_pss_verify_rejects_wrong_key() {
let mut rng = test_rng();
let (_pk_a, sk_a) = super::super::rsa::rsa_keygen(1024, &mut rng);
let (pk_b, _sk_b) = super::super::rsa::rsa_keygen(1024, &mut rng);
let sig = pss_sign_msg::<Sha256>(&sk_a, b"msg", 32, &mut rng).unwrap();
assert!(!pss_verify_msg::<Sha256>(&pk_b, b"msg", 32, &sig));
}
#[test]
fn test_pss_verify_rejects_wrong_salt_length() {
let mut rng = test_rng();
let (pk, sk) = super::super::rsa::rsa_keygen(1024, &mut rng);
let sig = pss_sign_msg::<Sha256>(&sk, b"msg", 32, &mut rng).unwrap();
assert!(!pss_verify_msg::<Sha256>(&pk, b"msg", 16, &sig));
}
#[test]
fn test_pss_sign_rejects_too_small_modulus_for_salt() {
let mut rng = test_rng();
let (_pk, sk) = super::super::rsa::rsa_keygen(512, &mut rng);
let m_hash = Sha256::hash(b"msg");
let result = pss_sign::<Sha256>(&sk, &m_hash, 32, &mut rng);
assert!(result.is_none());
}
#[test]
fn test_pss_sign_accepts_short_salt_on_small_modulus() {
let mut rng = test_rng();
let (pk, sk) = super::super::rsa::rsa_keygen(512, &mut rng);
let msg = b"small-modulus PSS";
let sig = pss_sign_msg::<Sha256>(&sk, msg, 16, &mut rng).expect("sign");
assert!(pss_verify_msg::<Sha256>(&pk, msg, 16, &sig));
}
#[test]
fn test_pss_hash_mismatch_rejected() {
let mut rng = test_rng();
let (pk, sk) = super::super::rsa::rsa_keygen(2048, &mut rng);
let msg = b"hash flexibility matters";
let sig = pss_sign_msg::<Sha256>(&sk, msg, 32, &mut rng).unwrap();
assert!(!pss_verify_msg::<Sha384>(&pk, msg, 32, &sig));
}
#[test]
fn test_pss_sha384_3072_roundtrip() {
let mut rng = test_rng();
let (pk, sk) = super::super::rsa::rsa_keygen(3072, &mut rng);
let msg = b"PSS SHA-384 / RSA-3072";
let sig = pss_sign_msg::<Sha384>(&sk, msg, 48, &mut rng).unwrap();
assert!(pss_verify_msg::<Sha384>(&pk, msg, 48, &sig));
}
}