use crate::{
encoding::{AsDer, Pkcs8V1Der, PublicKeyX509Der},
error::{KeyRejected, Unspecified},
fips::indicator_check,
ptr::{DetachableLcPtr, LcPtr},
};
use aws_lc::{
EVP_PKEY_CTX_new, EVP_PKEY_CTX_set0_rsa_oaep_label, EVP_PKEY_CTX_set_rsa_mgf1_md,
EVP_PKEY_CTX_set_rsa_oaep_md, EVP_PKEY_CTX_set_rsa_padding, EVP_PKEY_decrypt,
EVP_PKEY_decrypt_init, EVP_PKEY_encrypt, EVP_PKEY_encrypt_init, EVP_sha1, EVP_sha256,
EVP_sha384, EVP_sha512, OPENSSL_malloc, EVP_MD, EVP_PKEY, EVP_PKEY_CTX, RSA_PKCS1_OAEP_PADDING,
};
use core::{fmt::Debug, mem::size_of_val, ptr::null_mut};
use mirai_annotations::verify_unreachable;
use super::{
encoding,
key::{generate_rsa_key, is_rsa_key, key_size_bits, key_size_bytes},
KeySize,
};
pub const OAEP_SHA1_MGF1SHA1: OaepAlgorithm = OaepAlgorithm {
id: EncryptionAlgorithmId::OaepSha1Mgf1sha1,
oaep_hash_fn: EVP_sha1,
mgf1_hash_fn: EVP_sha1,
};
pub const OAEP_SHA256_MGF1SHA256: OaepAlgorithm = OaepAlgorithm {
id: EncryptionAlgorithmId::OaepSha256Mgf1sha256,
oaep_hash_fn: EVP_sha256,
mgf1_hash_fn: EVP_sha256,
};
pub const OAEP_SHA384_MGF1SHA384: OaepAlgorithm = OaepAlgorithm {
id: EncryptionAlgorithmId::OaepSha384Mgf1sha384,
oaep_hash_fn: EVP_sha384,
mgf1_hash_fn: EVP_sha384,
};
pub const OAEP_SHA512_MGF1SHA512: OaepAlgorithm = OaepAlgorithm {
id: EncryptionAlgorithmId::OaepSha512Mgf1sha512,
oaep_hash_fn: EVP_sha512,
mgf1_hash_fn: EVP_sha512,
};
#[allow(clippy::module_name_repetitions)]
#[derive(Debug, Clone, Copy, PartialEq)]
#[non_exhaustive]
pub enum EncryptionAlgorithmId {
OaepSha1Mgf1sha1,
OaepSha256Mgf1sha256,
OaepSha384Mgf1sha384,
OaepSha512Mgf1sha512,
}
type OaepHashFn = unsafe extern "C" fn() -> *const EVP_MD;
type Mgf1HashFn = unsafe extern "C" fn() -> *const EVP_MD;
pub struct OaepAlgorithm {
id: EncryptionAlgorithmId,
oaep_hash_fn: OaepHashFn,
mgf1_hash_fn: Mgf1HashFn,
}
impl OaepAlgorithm {
#[must_use]
pub fn id(&self) -> EncryptionAlgorithmId {
self.id
}
#[inline]
fn oaep_hash_fn(&self) -> OaepHashFn {
self.oaep_hash_fn
}
#[inline]
fn mgf1_hash_fn(&self) -> Mgf1HashFn {
self.mgf1_hash_fn
}
}
impl Debug for OaepAlgorithm {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
Debug::fmt(&self.id, f)
}
}
pub struct PrivateDecryptingKey(LcPtr<EVP_PKEY>);
impl PrivateDecryptingKey {
fn new(evp_pkey: LcPtr<EVP_PKEY>) -> Result<Self, Unspecified> {
Self::validate_key(&evp_pkey)?;
Ok(Self(evp_pkey))
}
fn validate_key(key: &LcPtr<EVP_PKEY>) -> Result<(), Unspecified> {
if !is_rsa_key(key) {
return Err(Unspecified);
};
match key_size_bits(key) {
2048..=8192 => Ok(()),
_ => Err(Unspecified),
}
}
pub fn generate(size: KeySize) -> Result<Self, Unspecified> {
let key = generate_rsa_key(size.bits(), false)?;
Self::new(key)
}
#[cfg(feature = "fips")]
pub fn generate_fips(size: KeySize) -> Result<Self, Unspecified> {
let key = generate_rsa_key(size.bits(), true)?;
Self::new(key)
}
pub fn from_pkcs8(pkcs8: &[u8]) -> Result<Self, KeyRejected> {
let key = encoding::pkcs8::decode_der(pkcs8)?;
Ok(Self::new(key)?)
}
#[cfg(feature = "fips")]
#[must_use]
pub fn is_valid_fips_key(&self) -> bool {
super::key::is_valid_fips_key(&self.0)
}
#[must_use]
pub fn key_size_bytes(&self) -> usize {
key_size_bytes(&self.0)
}
#[must_use]
pub fn key_size_bits(&self) -> usize {
key_size_bits(&self.0)
}
#[must_use]
#[allow(clippy::missing_panics_doc)]
pub fn public_key(&self) -> PublicEncryptingKey {
PublicEncryptingKey::new(self.0.clone()).expect(
"PublicEncryptingKey key size to be supported by PrivateDecryptingKey key sizes",
)
}
}
impl Debug for PrivateDecryptingKey {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_tuple("PrivateDecryptingKey").finish()
}
}
impl AsDer<Pkcs8V1Der<'static>> for PrivateDecryptingKey {
fn as_der(&self) -> Result<Pkcs8V1Der<'static>, Unspecified> {
Ok(Pkcs8V1Der::new(encoding::pkcs8::encode_v1_der(&self.0)?))
}
}
impl Clone for PrivateDecryptingKey {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
pub struct PublicEncryptingKey(LcPtr<EVP_PKEY>);
impl PublicEncryptingKey {
fn new(evp_pkey: LcPtr<EVP_PKEY>) -> Result<Self, Unspecified> {
Self::validate_key(&evp_pkey)?;
Ok(Self(evp_pkey))
}
fn validate_key(key: &LcPtr<EVP_PKEY>) -> Result<(), Unspecified> {
if !is_rsa_key(key) {
return Err(Unspecified);
};
match key_size_bits(key) {
2048..=8192 => Ok(()),
_ => Err(Unspecified),
}
}
pub fn from_der(value: &[u8]) -> Result<Self, KeyRejected> {
Ok(Self(encoding::rfc5280::decode_public_key_der(value)?))
}
#[must_use]
pub fn key_size_bytes(&self) -> usize {
key_size_bytes(&self.0)
}
#[must_use]
pub fn key_size_bits(&self) -> usize {
key_size_bits(&self.0)
}
}
impl Debug for PublicEncryptingKey {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_tuple("PublicEncryptingKey").finish()
}
}
impl Clone for PublicEncryptingKey {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
pub struct OaepPublicEncryptingKey {
public_key: PublicEncryptingKey,
}
impl OaepPublicEncryptingKey {
pub fn new(public_key: PublicEncryptingKey) -> Result<Self, Unspecified> {
Ok(Self { public_key })
}
pub fn encrypt<'ciphertext>(
&self,
algorithm: &'static OaepAlgorithm,
plaintext: &[u8],
ciphertext: &'ciphertext mut [u8],
label: Option<&[u8]>,
) -> Result<&'ciphertext mut [u8], Unspecified> {
let pkey_ctx = LcPtr::new(unsafe { EVP_PKEY_CTX_new(*self.public_key.0, null_mut()) })?;
if 1 != unsafe { EVP_PKEY_encrypt_init(*pkey_ctx) } {
return Err(Unspecified);
}
configure_oaep_crypto_operation(
&pkey_ctx,
algorithm.oaep_hash_fn(),
algorithm.mgf1_hash_fn(),
label,
)?;
let mut out_len = ciphertext.len();
if 1 != indicator_check!(unsafe {
EVP_PKEY_encrypt(
*pkey_ctx,
ciphertext.as_mut_ptr(),
&mut out_len,
plaintext.as_ptr(),
plaintext.len(),
)
}) {
return Err(Unspecified);
};
Ok(&mut ciphertext[..out_len])
}
#[must_use]
pub fn key_size_bytes(&self) -> usize {
self.public_key.key_size_bytes()
}
#[must_use]
pub fn key_size_bits(&self) -> usize {
self.public_key.key_size_bits()
}
#[must_use]
pub fn max_plaintext_size(&self, algorithm: &'static OaepAlgorithm) -> usize {
#[allow(unreachable_patterns)]
let hash_len: usize = match algorithm.id() {
EncryptionAlgorithmId::OaepSha1Mgf1sha1 => 20,
EncryptionAlgorithmId::OaepSha256Mgf1sha256 => 32,
EncryptionAlgorithmId::OaepSha384Mgf1sha384 => 48,
EncryptionAlgorithmId::OaepSha512Mgf1sha512 => 64,
_ => verify_unreachable!(),
};
self.key_size_bytes() - 2 * hash_len - 2
}
#[must_use]
pub fn ciphertext_size(&self) -> usize {
self.key_size_bytes()
}
}
impl Debug for OaepPublicEncryptingKey {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("OaepPublicEncryptingKey")
.finish_non_exhaustive()
}
}
pub struct OaepPrivateDecryptingKey {
private_key: PrivateDecryptingKey,
}
impl OaepPrivateDecryptingKey {
pub fn new(private_key: PrivateDecryptingKey) -> Result<Self, Unspecified> {
Ok(Self { private_key })
}
pub fn decrypt<'plaintext>(
&self,
algorithm: &'static OaepAlgorithm,
ciphertext: &[u8],
plaintext: &'plaintext mut [u8],
label: Option<&[u8]>,
) -> Result<&'plaintext mut [u8], Unspecified> {
let pkey_ctx = LcPtr::new(unsafe { EVP_PKEY_CTX_new(*self.private_key.0, null_mut()) })?;
if 1 != unsafe { EVP_PKEY_decrypt_init(*pkey_ctx) } {
return Err(Unspecified);
}
configure_oaep_crypto_operation(
&pkey_ctx,
algorithm.oaep_hash_fn(),
algorithm.mgf1_hash_fn(),
label,
)?;
let mut out_len = plaintext.len();
if 1 != indicator_check!(unsafe {
EVP_PKEY_decrypt(
*pkey_ctx,
plaintext.as_mut_ptr(),
&mut out_len,
ciphertext.as_ptr(),
ciphertext.len(),
)
}) {
return Err(Unspecified);
};
Ok(&mut plaintext[..out_len])
}
#[must_use]
pub fn key_size_bytes(&self) -> usize {
self.private_key.key_size_bytes()
}
#[must_use]
pub fn key_size_bits(&self) -> usize {
self.private_key.key_size_bits()
}
#[must_use]
pub fn min_output_size(&self) -> usize {
self.key_size_bytes()
}
}
impl Debug for OaepPrivateDecryptingKey {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("OaepPrivateDecryptingKey")
.finish_non_exhaustive()
}
}
fn configure_oaep_crypto_operation(
evp_pkey_ctx: &LcPtr<EVP_PKEY_CTX>,
oaep_hash_fn: OaepHashFn,
mgf1_hash_fn: Mgf1HashFn,
label: Option<&[u8]>,
) -> Result<(), Unspecified> {
if 1 != unsafe { EVP_PKEY_CTX_set_rsa_padding(**evp_pkey_ctx, RSA_PKCS1_OAEP_PADDING) } {
return Err(Unspecified);
};
if 1 != unsafe { EVP_PKEY_CTX_set_rsa_oaep_md(**evp_pkey_ctx, oaep_hash_fn()) } {
return Err(Unspecified);
};
if 1 != unsafe { EVP_PKEY_CTX_set_rsa_mgf1_md(**evp_pkey_ctx, mgf1_hash_fn()) } {
return Err(Unspecified);
};
let label = if let Some(label) = label {
label
} else {
&[0u8; 0]
};
if label.is_empty() {
if 1 != unsafe { EVP_PKEY_CTX_set0_rsa_oaep_label(**evp_pkey_ctx, null_mut(), 0) } {
return Err(Unspecified);
}
return Ok(());
}
let label_ptr =
DetachableLcPtr::<u8>::new(unsafe { OPENSSL_malloc(size_of_val(label)) }.cast())?;
{
let label_ptr = unsafe { core::slice::from_raw_parts_mut(*label_ptr, label.len()) };
label_ptr.copy_from_slice(label);
}
if 1 != unsafe { EVP_PKEY_CTX_set0_rsa_oaep_label(**evp_pkey_ctx, *label_ptr, label.len()) } {
return Err(Unspecified);
};
label_ptr.detach();
Ok(())
}
impl AsDer<PublicKeyX509Der<'static>> for PublicEncryptingKey {
fn as_der(&self) -> Result<PublicKeyX509Der<'static>, Unspecified> {
encoding::rfc5280::encode_public_key_der(&self.0)
}
}