use core::cell::UnsafeCell;
use core::ffi::c_void;
use aead_trait::generic_array::GenericArray;
use aead_trait::{AeadCore, AeadInPlace, KeyInit, KeySizeUser};
use typenum::{U0, U13, U16, U32};
use crate::error::{check, len_as_u32, WolfCryptError};
macro_rules! impl_aes_ccm {
($name:ident, $key_size:ty, $key_len:expr, $doc:expr) => {
#[doc = $doc]
pub struct $name {
// SAFETY: UnsafeCell is needed because wc_AesCcmEncrypt may mutate
aes: UnsafeCell<wolfcrypt_rs::WcAes>,
}
unsafe impl Send for $name {}
impl KeySizeUser for $name {
type KeySize = $key_size;
}
impl KeyInit for $name {
fn new(key: &GenericArray<u8, $key_size>) -> Self {
let mut aes = wolfcrypt_rs::WcAes::zeroed();
let rc = unsafe {
wolfcrypt_rs::wc_AesInit(
&mut aes as *mut wolfcrypt_rs::WcAes,
core::ptr::null_mut::<c_void>(),
wolfcrypt_rs::INVALID_DEVID,
)
};
assert_eq!(rc, 0, "wc_AesInit failed (OOM or invalid device)");
let rc = unsafe {
wolfcrypt_rs::wc_AesCcmSetKey(
&mut aes as *mut wolfcrypt_rs::WcAes,
key.as_ptr(),
$key_len,
)
};
assert_eq!(rc, 0, "wc_AesCcmSetKey failed (invalid key length)");
Self {
aes: UnsafeCell::new(aes),
}
}
}
impl AeadCore for $name {
type NonceSize = U13;
type TagSize = U16;
type CiphertextOverhead = U0;
}
impl AeadInPlace for $name {
fn encrypt_in_place_detached(
&self,
nonce: &aead_trait::Nonce<Self>,
associated_data: &[u8],
buffer: &mut [u8],
) -> aead_trait::Result<aead_trait::Tag<Self>> {
let mut tag = GenericArray::<u8, U16>::default();
let aad_ptr = if associated_data.is_empty() {
core::ptr::null()
} else {
associated_data.as_ptr()
};
let (in_ptr, out_ptr) = if buffer.is_empty() {
(core::ptr::null(), core::ptr::null_mut())
} else {
(buffer.as_ptr(), buffer.as_mut_ptr())
};
let rc = unsafe {
wolfcrypt_rs::wc_AesCcmEncrypt(
self.aes.get(),
out_ptr,
in_ptr,
len_as_u32(buffer.len()),
nonce.as_ptr(),
13,
tag.as_mut_ptr(),
16,
aad_ptr,
len_as_u32(associated_data.len()),
)
};
if rc == 0 {
Ok(tag)
} else {
Err(aead_trait::Error)
}
}
fn decrypt_in_place_detached(
&self,
nonce: &aead_trait::Nonce<Self>,
associated_data: &[u8],
buffer: &mut [u8],
tag: &aead_trait::Tag<Self>,
) -> aead_trait::Result<()> {
let aad_ptr = if associated_data.is_empty() {
core::ptr::null()
} else {
associated_data.as_ptr()
};
let (in_ptr, out_ptr) = if buffer.is_empty() {
(core::ptr::null(), core::ptr::null_mut())
} else {
(buffer.as_ptr(), buffer.as_mut_ptr())
};
let rc = unsafe {
wolfcrypt_rs::wc_AesCcmDecrypt(
self.aes.get(),
out_ptr,
in_ptr,
len_as_u32(buffer.len()),
nonce.as_ptr(),
13,
tag.as_ptr(),
16,
aad_ptr,
len_as_u32(associated_data.len()),
)
};
if rc == 0 {
Ok(())
} else {
Err(aead_trait::Error)
}
}
}
impl Drop for $name {
fn drop(&mut self) {
unsafe {
wolfcrypt_rs::wc_AesFree(self.aes.get_mut());
}
}
}
};
}
impl_aes_ccm!(
Aes128Ccm,
typenum::U16,
16,
"AES-128-CCM AEAD cipher, implementing `AeadInPlace` and `KeyInit`.\n\n\
The trait-level nonce size is fixed to 13 bytes. Use the standalone\n\
`aes_ccm_encrypt` / `aes_ccm_decrypt` functions for variable nonce lengths."
);
impl_aes_ccm!(
Aes256Ccm,
U32,
32,
"AES-256-CCM AEAD cipher, implementing `AeadInPlace` and `KeyInit`.\n\n\
The trait-level nonce size is fixed to 13 bytes. Use the standalone\n\
`aes_ccm_encrypt` / `aes_ccm_decrypt` functions for variable nonce lengths."
);
pub fn aes_ccm_encrypt(
key: &[u8],
nonce: &[u8],
aad: &[u8],
plaintext: &[u8],
ciphertext: &mut [u8],
tag: &mut [u8],
) -> Result<(), WolfCryptError> {
if ciphertext.len() < plaintext.len() {
return Err(WolfCryptError::InvalidInput);
}
if tag.is_empty() || tag.len() > 16 {
return Err(WolfCryptError::InvalidInput);
}
if nonce.len() < wolfcrypt_rs::CCM_NONCE_MIN_SZ as usize
|| nonce.len() > wolfcrypt_rs::CCM_NONCE_MAX_SZ as usize
{
return Err(WolfCryptError::InvalidInput);
}
let mut aes = wolfcrypt_rs::WcAes::zeroed();
let rc = unsafe {
wolfcrypt_rs::wc_AesInit(
&mut aes as *mut wolfcrypt_rs::WcAes,
core::ptr::null_mut::<c_void>(),
wolfcrypt_rs::INVALID_DEVID,
)
};
check(rc, "wc_AesInit")?;
let rc = unsafe {
wolfcrypt_rs::wc_AesCcmSetKey(
&mut aes as *mut wolfcrypt_rs::WcAes,
key.as_ptr(),
len_as_u32(key.len()),
)
};
if rc != 0 {
unsafe {
wolfcrypt_rs::wc_AesFree(&mut aes);
}
return Err(WolfCryptError::Ffi {
code: rc,
func: "wc_AesCcmSetKey",
});
}
let aad_ptr = if aad.is_empty() {
core::ptr::null()
} else {
aad.as_ptr()
};
let rc = unsafe {
wolfcrypt_rs::wc_AesCcmEncrypt(
&mut aes as *mut wolfcrypt_rs::WcAes,
ciphertext.as_mut_ptr(),
plaintext.as_ptr(),
len_as_u32(plaintext.len()),
nonce.as_ptr(),
len_as_u32(nonce.len()),
tag.as_mut_ptr(),
len_as_u32(tag.len()),
aad_ptr,
len_as_u32(aad.len()),
)
};
unsafe {
wolfcrypt_rs::wc_AesFree(&mut aes);
}
check(rc, "wc_AesCcmEncrypt")?;
Ok(())
}
pub fn aes_ccm_decrypt(
key: &[u8],
nonce: &[u8],
aad: &[u8],
ciphertext: &[u8],
plaintext: &mut [u8],
tag: &[u8],
) -> Result<(), WolfCryptError> {
if plaintext.len() < ciphertext.len() {
return Err(WolfCryptError::InvalidInput);
}
if tag.is_empty() || tag.len() > 16 {
return Err(WolfCryptError::InvalidInput);
}
if nonce.len() < wolfcrypt_rs::CCM_NONCE_MIN_SZ as usize
|| nonce.len() > wolfcrypt_rs::CCM_NONCE_MAX_SZ as usize
{
return Err(WolfCryptError::InvalidInput);
}
let mut aes = wolfcrypt_rs::WcAes::zeroed();
let rc = unsafe {
wolfcrypt_rs::wc_AesInit(
&mut aes as *mut wolfcrypt_rs::WcAes,
core::ptr::null_mut::<c_void>(),
wolfcrypt_rs::INVALID_DEVID,
)
};
check(rc, "wc_AesInit")?;
let rc = unsafe {
wolfcrypt_rs::wc_AesCcmSetKey(
&mut aes as *mut wolfcrypt_rs::WcAes,
key.as_ptr(),
len_as_u32(key.len()),
)
};
if rc != 0 {
unsafe {
wolfcrypt_rs::wc_AesFree(&mut aes);
}
return Err(WolfCryptError::Ffi {
code: rc,
func: "wc_AesCcmSetKey",
});
}
let aad_ptr = if aad.is_empty() {
core::ptr::null()
} else {
aad.as_ptr()
};
let rc = unsafe {
wolfcrypt_rs::wc_AesCcmDecrypt(
&mut aes as *mut wolfcrypt_rs::WcAes,
plaintext.as_mut_ptr(),
ciphertext.as_ptr(),
len_as_u32(ciphertext.len()),
nonce.as_ptr(),
len_as_u32(nonce.len()),
tag.as_ptr(),
len_as_u32(tag.len()),
aad_ptr,
len_as_u32(aad.len()),
)
};
unsafe {
wolfcrypt_rs::wc_AesFree(&mut aes);
}
check(rc, "wc_AesCcmDecrypt")?;
Ok(())
}