use crate::error::Unspecified;
use crate::fips::indicator_check;
use crate::sealed::Sealed;
use crate::wolfcrypt_rs::{
AES_set_decrypt_key, AES_set_encrypt_key, AES_unwrap_key, AES_unwrap_key_padded, AES_wrap_key,
AES_wrap_key_padded, AES_KEY,
};
use core::fmt::Debug;
use core::mem::MaybeUninit;
use core::ptr::null;
#[cfg(not(feature = "std"))]
use crate::prelude::*;
mod tests;
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
#[non_exhaustive]
pub enum BlockCipherId {
Aes128,
Aes256,
}
pub trait BlockCipher: 'static + Debug + Sealed {
fn id(&self) -> BlockCipherId;
fn key_len(&self) -> usize;
}
pub struct AesBlockCipher {
id: BlockCipherId,
key_len: usize,
}
impl BlockCipher for AesBlockCipher {
#[inline]
fn id(&self) -> BlockCipherId {
self.id
}
#[inline]
fn key_len(&self) -> usize {
self.key_len
}
}
impl Sealed for AesBlockCipher {}
impl Debug for AesBlockCipher {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
Debug::fmt(&self.id, f)
}
}
pub const AES_128: AesBlockCipher = AesBlockCipher {
id: BlockCipherId::Aes128,
key_len: 16,
};
pub const AES_256: AesBlockCipher = AesBlockCipher {
id: BlockCipherId::Aes256,
key_len: 32,
};
#[allow(clippy::module_name_repetitions)]
pub trait KeyWrap: Sealed {
fn wrap<'output>(
self,
plaintext: &[u8],
output: &'output mut [u8],
) -> Result<&'output mut [u8], Unspecified>;
fn unwrap<'output>(
self,
ciphertext: &[u8],
output: &'output mut [u8],
) -> Result<&'output mut [u8], Unspecified>;
}
#[allow(clippy::module_name_repetitions)]
pub trait KeyWrapPadded: Sealed {
fn wrap_with_padding<'output>(
self,
plaintext: &[u8],
output: &'output mut [u8],
) -> Result<&'output mut [u8], Unspecified>;
fn unwrap_with_padding<'output>(
self,
ciphertext: &[u8],
output: &'output mut [u8],
) -> Result<&'output mut [u8], Unspecified>;
}
pub type AesKek = KeyEncryptionKey<AesBlockCipher>;
pub struct KeyEncryptionKey<Cipher: BlockCipher> {
cipher: &'static Cipher,
key: Box<[u8]>,
}
impl<Cipher: BlockCipher> KeyEncryptionKey<Cipher> {
pub fn new(cipher: &'static Cipher, key: &[u8]) -> Result<Self, Unspecified> {
if key.len() != cipher.key_len() {
return Err(Unspecified);
}
let key = Vec::from(key).into_boxed_slice();
Ok(Self { cipher, key })
}
#[must_use]
pub fn block_cipher_id(&self) -> BlockCipherId {
self.cipher.id()
}
}
impl<Cipher: BlockCipher> Sealed for KeyEncryptionKey<Cipher> {}
impl KeyWrap for KeyEncryptionKey<AesBlockCipher> {
fn wrap<'output>(
self,
plaintext: &[u8],
output: &'output mut [u8],
) -> Result<&'output mut [u8], Unspecified> {
if output.len() < plaintext.len() + 8 {
return Err(Unspecified);
}
let mut aes_key = MaybeUninit::<AES_KEY>::uninit();
let key_bits: u32 = (self.key.len() * 8).try_into().map_err(|_| Unspecified)?;
if 0 != unsafe { AES_set_encrypt_key(self.key.as_ptr(), key_bits, aes_key.as_mut_ptr()) } {
return Err(Unspecified);
}
let aes_key = unsafe { aes_key.assume_init() };
let out_len = indicator_check!(unsafe {
AES_wrap_key(
&aes_key,
null(),
output.as_mut_ptr(),
plaintext.as_ptr(),
plaintext.len(),
)
});
if out_len <= 0 {
return Err(Unspecified);
}
let out_len: usize = out_len.try_into().map_err(|_| Unspecified)?;
debug_assert_eq!(out_len, plaintext.len() + 8);
Ok(&mut output[..out_len])
}
fn unwrap<'output>(
self,
ciphertext: &[u8],
output: &'output mut [u8],
) -> Result<&'output mut [u8], Unspecified> {
if output.len() < ciphertext.len() - 8 {
return Err(Unspecified);
}
let mut aes_key = MaybeUninit::<AES_KEY>::uninit();
if 0 != unsafe {
AES_set_decrypt_key(
self.key.as_ptr(),
(self.key.len() * 8).try_into().map_err(|_| Unspecified)?,
aes_key.as_mut_ptr(),
)
} {
return Err(Unspecified);
}
let aes_key = unsafe { aes_key.assume_init() };
let out_len = indicator_check!(unsafe {
AES_unwrap_key(
&aes_key,
null(),
output.as_mut_ptr(),
ciphertext.as_ptr(),
ciphertext.len(),
)
});
if out_len <= 0 {
return Err(Unspecified);
}
let out_len: usize = out_len.try_into().map_err(|_| Unspecified)?;
debug_assert_eq!(out_len, ciphertext.len() - 8);
Ok(&mut output[..out_len])
}
}
impl KeyWrapPadded for KeyEncryptionKey<AesBlockCipher> {
fn wrap_with_padding<'output>(
self,
plaintext: &[u8],
output: &'output mut [u8],
) -> Result<&'output mut [u8], Unspecified> {
let mut aes_key = MaybeUninit::<AES_KEY>::uninit();
let key_bits: u32 = (self.key.len() * 8).try_into().map_err(|_| Unspecified)?;
if 0 != unsafe { AES_set_encrypt_key(self.key.as_ptr(), key_bits, aes_key.as_mut_ptr()) } {
return Err(Unspecified);
}
let aes_key = unsafe { aes_key.assume_init() };
let mut out_len: usize = 0;
let padded_len = (plaintext.len() + 7) & !7;
if padded_len > 256 {
return Err(Unspecified);
}
if 1 != indicator_check!(unsafe {
AES_wrap_key_padded(
&aes_key,
output.as_mut_ptr(),
&mut out_len,
output.len(),
plaintext.as_ptr(),
plaintext.len(),
)
}) {
return Err(Unspecified);
}
Ok(&mut output[..out_len])
}
fn unwrap_with_padding<'output>(
self,
ciphertext: &[u8],
output: &'output mut [u8],
) -> Result<&'output mut [u8], Unspecified> {
let mut aes_key = MaybeUninit::<AES_KEY>::uninit();
if 0 != unsafe {
AES_set_decrypt_key(
self.key.as_ptr(),
(self.key.len() * 8).try_into().map_err(|_| Unspecified)?,
aes_key.as_mut_ptr(),
)
} {
return Err(Unspecified);
}
let aes_key = unsafe { aes_key.assume_init() };
let mut out_len: usize = 0;
if 1 != indicator_check!(unsafe {
AES_unwrap_key_padded(
&aes_key,
output.as_mut_ptr(),
&mut out_len,
output.len(),
ciphertext.as_ptr(),
ciphertext.len(),
)
}) {
return Err(Unspecified);
}
Ok(&mut output[..out_len])
}
}
impl<Cipher: BlockCipher> Debug for KeyEncryptionKey<Cipher> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("KeyEncryptionKey")
.field("cipher", &self.cipher)
.finish_non_exhaustive()
}
}