use alloc::vec::Vec;
use super::emsa;
use super::{Error, RsaPrivateKey, RsaPublicKey};
use crate::hash::Digest;
use crate::rng::RngCore;
impl<const LIMBS: usize> RsaPrivateKey<LIMBS> {
pub fn sign_pss<D: Digest, R: RngCore>(
&self,
msg: &[u8],
rng: &mut R,
) -> Result<Vec<u8>, Error> {
emsa::sign_pss::<D, _, R>(self, msg, rng)
}
pub fn sign_pss_with_salt_len<D: Digest, R: RngCore>(
&self,
msg: &[u8],
salt_len: usize,
rng: &mut R,
) -> Result<Vec<u8>, Error> {
emsa::sign_pss_with_salt_len::<D, _, R>(self, msg, salt_len, rng)
}
}
impl<const LIMBS: usize> RsaPublicKey<LIMBS> {
pub fn verify_pss<D: Digest>(&self, msg: &[u8], sig: &[u8]) -> Result<(), Error> {
emsa::verify_pss::<D, _>(self, msg, sig)
}
pub fn verify_pss_with_salt_len<D: Digest>(
&self,
msg: &[u8],
sig: &[u8],
salt_len: usize,
) -> Result<(), Error> {
emsa::verify_pss_with_salt_len::<D, _>(self, msg, sig, salt_len)
}
pub fn verify_pss_any_salt<D: Digest>(&self, msg: &[u8], sig: &[u8]) -> Result<(), Error> {
emsa::verify_pss_any_salt::<D, _>(self, msg, sig)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hash::Sha256;
use crate::rng::HmacDrbg;
use crate::test_util::rsa_test_key_a;
#[test]
fn sign_verify_roundtrip() {
let key = rsa_test_key_a();
let pk = key.public_key();
let mut r = HmacDrbg::<Sha256>::new(b"rsa-pss", b"nonce", &[]);
let sig = key.sign_pss::<Sha256, _>(b"pss message", &mut r).unwrap();
pk.verify_pss::<Sha256>(b"pss message", &sig).unwrap();
assert_eq!(
pk.verify_pss::<Sha256>(b"other", &sig),
Err(Error::Verification)
);
let mut bad = sig.clone();
bad[20] ^= 1;
assert_eq!(
pk.verify_pss::<Sha256>(b"pss message", &bad),
Err(Error::Verification)
);
}
#[test]
fn pss_is_randomized() {
let key = rsa_test_key_a();
let pk = key.public_key();
let mut r = HmacDrbg::<Sha256>::new(b"rsa-pss-rand", b"nonce", &[]);
let a = key.sign_pss::<Sha256, _>(b"m", &mut r).unwrap();
let b = key.sign_pss::<Sha256, _>(b"m", &mut r).unwrap();
assert_ne!(a, b);
pk.verify_pss::<Sha256>(b"m", &a).unwrap();
pk.verify_pss::<Sha256>(b"m", &b).unwrap();
}
#[test]
fn explicit_salt_len_roundtrip() {
let key = rsa_test_key_a();
let pk = key.public_key();
let mut r = HmacDrbg::<Sha256>::new(b"rsa-pss-salt", b"nonce", &[]);
for &slen in &[0usize, 32, 56] {
let sig = key
.sign_pss_with_salt_len::<Sha256, _>(b"m", slen, &mut r)
.unwrap();
pk.verify_pss_any_salt::<Sha256>(b"m", &sig).unwrap();
pk.verify_pss_with_salt_len::<Sha256>(b"m", &sig, slen)
.unwrap();
assert_eq!(
pk.verify_pss_with_salt_len::<Sha256>(b"m", &sig, slen + 1),
Err(Error::Verification),
"sLen={slen}: verify with wrong length must fail"
);
}
}
#[test]
fn strict_verify_rejects_nonstandard_salt() {
let key = rsa_test_key_a();
let pk = key.public_key();
let mut r = HmacDrbg::<Sha256>::new(b"rsa-pss-strict", b"nonce", &[]);
let sig = key
.sign_pss_with_salt_len::<Sha256, _>(b"m", 16, &mut r)
.unwrap();
assert_eq!(
pk.verify_pss::<Sha256>(b"m", &sig),
Err(Error::Verification),
"strict verify must reject sLen=16 (!= 32)"
);
pk.verify_pss_any_salt::<Sha256>(b"m", &sig).unwrap();
let sig_std = key.sign_pss::<Sha256, _>(b"m", &mut r).unwrap();
pk.verify_pss::<Sha256>(b"m", &sig_std).unwrap();
pk.verify_pss_any_salt::<Sha256>(b"m", &sig_std).unwrap();
}
}