use core::ffi::c_uint;
use core::ptr;
use alloc::vec;
use alloc::vec::Vec;
use crate::error::{check, WolfCryptError};
use wolfcrypt_rs::{AES_KEY, AES_set_encrypt_key, AES_set_decrypt_key, AES_wrap_key, AES_unwrap_key};
struct AesKeyGuard(AES_KEY);
impl Drop for AesKeyGuard {
fn drop(&mut self) {
use zeroize::Zeroize;
let bytes = unsafe {
core::slice::from_raw_parts_mut(
&mut self.0 as *mut AES_KEY as *mut u8,
core::mem::size_of::<AES_KEY>(),
)
};
bytes.zeroize();
}
}
pub fn aes_wrap_key(kek: &[u8], plaintext: &[u8]) -> Result<Vec<u8>, WolfCryptError> {
if plaintext.len() < 16 || plaintext.len() % 8 != 0 {
return Err(WolfCryptError::INVALID_INPUT);
}
match kek.len() {
16 | 24 | 32 => {}
_ => return Err(WolfCryptError::INVALID_INPUT),
}
unsafe {
let mut guard = AesKeyGuard(AES_KEY::zeroed());
let rc = AES_set_encrypt_key(kek.as_ptr(), (kek.len() * 8) as c_uint, &mut guard.0);
check(rc, "AES_set_encrypt_key")?;
let mut out = vec![0u8; plaintext.len() + 8];
let rc = AES_wrap_key(
&guard.0,
ptr::null(), out.as_mut_ptr(),
plaintext.as_ptr(),
plaintext.len(),
);
if rc <= 0 {
return Err(WolfCryptError::Ffi { code: rc, func: "AES_wrap_key" });
}
let out_len = rc as usize;
if out_len > out.len() {
return Err(WolfCryptError::Ffi { code: -1, func: "AES_wrap_key (output length)" });
}
out.truncate(out_len);
Ok(out)
}
}
pub fn aes_unwrap_key(kek: &[u8], ciphertext: &[u8]) -> Result<Vec<u8>, WolfCryptError> {
if ciphertext.len() < 24 || ciphertext.len() % 8 != 0 {
return Err(WolfCryptError::INVALID_INPUT);
}
match kek.len() {
16 | 24 | 32 => {}
_ => return Err(WolfCryptError::INVALID_INPUT),
}
unsafe {
let mut guard = AesKeyGuard(AES_KEY::zeroed());
let rc = AES_set_decrypt_key(kek.as_ptr(), (kek.len() * 8) as c_uint, &mut guard.0);
check(rc, "AES_set_decrypt_key")?;
let mut out = vec![0u8; ciphertext.len()];
let rc = AES_unwrap_key(
&guard.0,
ptr::null(), out.as_mut_ptr(),
ciphertext.as_ptr(),
ciphertext.len(),
);
if rc <= 0 {
return Err(WolfCryptError::Ffi { code: rc, func: "AES_unwrap_key" });
}
let out_len = rc as usize;
if out_len > out.len() {
return Err(WolfCryptError::Ffi { code: -1, func: "AES_unwrap_key (output length)" });
}
out.truncate(out_len);
Ok(out)
}
}