use openssl::{
pkey::{PKey, Private, Public},
pkey_ctx::PkeyCtx,
};
use zeroize::Zeroizing;
#[cfg(not(feature = "non-fips"))]
use super::FIPS_MIN_RSA_MODULUS_LENGTH;
#[cfg(not(feature = "non-fips"))]
use crate::crypto_bail;
use crate::error::CryptoError;
pub fn ckm_rsa_pkcs_key_wrap(pub_key: &PKey<Public>, dek: &[u8]) -> Result<Vec<u8>, CryptoError> {
let (mut ctx, mut ciphertext) = init_ckm_rsa_pkcs_encryption_context(pub_key)?;
ctx.encrypt_to_vec(dek, &mut ciphertext)?;
Ok(ciphertext)
}
pub fn ckm_rsa_pkcs_encrypt(
pub_key: &PKey<Public>,
plaintext: &[u8],
) -> Result<Vec<u8>, CryptoError> {
let (mut ctx, mut ciphertext) = init_ckm_rsa_pkcs_encryption_context(pub_key)?;
ctx.encrypt_to_vec(plaintext, &mut ciphertext)?;
Ok(ciphertext)
}
fn init_ckm_rsa_pkcs_encryption_context(
pub_key: &PKey<Public>,
) -> Result<(PkeyCtx<Public>, Vec<u8>), CryptoError> {
let rsa_pub_key = pub_key.rsa()?;
let encapsulation_bytes_len = usize::try_from(rsa_pub_key.size())?;
let ciphertext = Vec::with_capacity(encapsulation_bytes_len);
let mut ctx = PkeyCtx::new(pub_key)?;
ctx.encrypt_init()?;
ctx.set_rsa_padding(openssl::rsa::Padding::PKCS1)?;
Ok((ctx, ciphertext))
}
pub fn ckm_rsa_pkcs_key_unwrap(
priv_key: &PKey<Private>,
dek: &[u8],
) -> Result<Zeroizing<Vec<u8>>, CryptoError> {
let (mut ctx, mut plaintext) = init_ckm_rsa_pkcs_decryption_context(priv_key)?;
ctx.decrypt_to_vec(dek, &mut plaintext)?;
Ok(plaintext)
}
pub fn ckm_rsa_pkcs_decrypt(
priv_key: &PKey<Private>,
ciphertext: &[u8],
) -> Result<Zeroizing<Vec<u8>>, CryptoError> {
let (mut ctx, mut plaintext) = init_ckm_rsa_pkcs_decryption_context(priv_key)?;
ctx.decrypt_to_vec(ciphertext, &mut plaintext)?;
Ok(plaintext)
}
fn init_ckm_rsa_pkcs_decryption_context(
priv_key: &PKey<Private>,
) -> Result<(PkeyCtx<Private>, Zeroizing<Vec<u8>>), CryptoError> {
let rsa_priv_key = priv_key.rsa()?;
let plaintext_bytes_len = usize::try_from(rsa_priv_key.size())? - 11;
let plaintext = Zeroizing::from(Vec::with_capacity(plaintext_bytes_len));
let mut ctx = PkeyCtx::new(priv_key)?;
ctx.decrypt_init()?;
ctx.set_rsa_padding(openssl::rsa::Padding::PKCS1)?;
Ok((ctx, plaintext))
}
#[expect(clippy::panic_in_result_fn)]
#[cfg(test)]
mod tests {
use openssl::pkey::PKey;
use zeroize::Zeroizing;
use crate::{
crypto::rsa::ckm_rsa_pkcs::{ckm_rsa_pkcs_key_unwrap, ckm_rsa_pkcs_key_wrap},
error::CryptoError,
};
#[test]
fn test_ckm_rsa_pkcs_oaep() -> Result<(), CryptoError> {
let priv_key = PKey::from_rsa(openssl::rsa::Rsa::generate(2048)?)?;
let pub_key = PKey::public_key_from_pem(&priv_key.public_key_to_pem()?)?;
let dek_to_wrap = Zeroizing::from(vec![0x01; 2048 / 8 - 2 - 2 * 256 / 8]);
let wrapped_key = ckm_rsa_pkcs_key_wrap(&pub_key, &dek_to_wrap)?;
assert_eq!(wrapped_key.len(), 2048 / 8);
let unwrapped_key = ckm_rsa_pkcs_key_unwrap(&priv_key, &wrapped_key)?;
assert_eq!(unwrapped_key.len(), 2048 / 8 - 2 - 2 * 256 / 8);
assert_eq!(unwrapped_key, dek_to_wrap);
Ok(())
}
}