use crate::errors::{CrabError, CrabResult};
use aes_kw::{KekAes128, KekAes192, KekAes256};
pub struct Kw128 {
kek: KekAes128,
}
impl Kw128 {
pub const KEK_SIZE: usize = 16;
pub const MIN_KEY_SIZE: usize = 16;
pub fn new(kek: &[u8]) -> CrabResult<Self> {
if kek.len() != Self::KEK_SIZE {
return Err(CrabError::invalid_input(format!(
"AES-128-KW requires 16-byte KEK, got {}",
kek.len()
)));
}
let kek_array: [u8; 16] = kek
.try_into()
.map_err(|_| CrabError::invalid_input("Failed to convert KEK to array"))?;
Ok(Self {
kek: KekAes128::from(kek_array),
})
}
pub fn generate_kek() -> CrabResult<Vec<u8>> {
crate::rand::secure_bytes(Self::KEK_SIZE)
}
pub fn wrap_key(&self, key: &[u8]) -> CrabResult<Vec<u8>> {
if key.len() < Self::MIN_KEY_SIZE {
return Err(CrabError::invalid_input(format!(
"Key must be at least {} bytes, got {}",
Self::MIN_KEY_SIZE,
key.len()
)));
}
if key.len() % 8 != 0 {
return Err(CrabError::invalid_input(format!(
"Key length must be multiple of 8 bytes, got {}. Consider using AES-KWP for arbitrary lengths.",
key.len()
)));
}
let mut output = vec![0u8; key.len() + 8];
self.kek
.wrap(key, &mut output)
.map_err(|e| CrabError::crypto_error(format!("Key wrap failed: {:?}", e)))?;
Ok(output)
}
pub fn unwrap_key(&self, wrapped_key: &[u8]) -> CrabResult<Vec<u8>> {
if wrapped_key.len() < 24 {
return Err(CrabError::invalid_input(format!(
"Wrapped key must be at least 24 bytes, got {}",
wrapped_key.len()
)));
}
if wrapped_key.len() % 8 != 0 {
return Err(CrabError::invalid_input(format!(
"Wrapped key length must be multiple of 8 bytes, got {}",
wrapped_key.len()
)));
}
let mut output = vec![0u8; wrapped_key.len() - 8];
self.kek.unwrap(wrapped_key, &mut output).map_err(|e| {
CrabError::crypto_error(format!(
"Key unwrap failed: {:?}. Wrong KEK or tampered data.",
e
))
})?;
Ok(output)
}
}
pub struct Kw256 {
kek: KekAes256,
}
impl Kw256 {
pub const KEK_SIZE: usize = 32;
pub const MIN_KEY_SIZE: usize = 16;
pub fn new(kek: &[u8]) -> CrabResult<Self> {
if kek.len() != Self::KEK_SIZE {
return Err(CrabError::invalid_input(format!(
"AES-256-KW requires 32-byte KEK, got {}",
kek.len()
)));
}
let kek_array: [u8; 32] = kek
.try_into()
.map_err(|_| CrabError::invalid_input("Failed to convert KEK to array"))?;
Ok(Self {
kek: KekAes256::from(kek_array),
})
}
pub fn generate_kek() -> CrabResult<Vec<u8>> {
crate::rand::secure_bytes(Self::KEK_SIZE)
}
pub fn wrap_key(&self, key: &[u8]) -> CrabResult<Vec<u8>> {
if key.len() < Self::MIN_KEY_SIZE {
return Err(CrabError::invalid_input(format!(
"Key must be at least {} bytes, got {}",
Self::MIN_KEY_SIZE,
key.len()
)));
}
if key.len() % 8 != 0 {
return Err(CrabError::invalid_input(format!(
"Key length must be multiple of 8 bytes, got {}. Consider using AES-KWP for arbitrary lengths.",
key.len()
)));
}
let mut output = vec![0u8; key.len() + 8];
self.kek
.wrap(key, &mut output)
.map_err(|e| CrabError::crypto_error(format!("Key wrap failed: {:?}", e)))?;
Ok(output)
}
pub fn unwrap_key(&self, wrapped_key: &[u8]) -> CrabResult<Vec<u8>> {
if wrapped_key.len() < 24 {
return Err(CrabError::invalid_input(format!(
"Wrapped key must be at least 24 bytes, got {}",
wrapped_key.len()
)));
}
if wrapped_key.len() % 8 != 0 {
return Err(CrabError::invalid_input(format!(
"Wrapped key length must be multiple of 8 bytes, got {}",
wrapped_key.len()
)));
}
let mut output = vec![0u8; wrapped_key.len() - 8];
self.kek.unwrap(wrapped_key, &mut output).map_err(|e| {
CrabError::crypto_error(format!(
"Key unwrap failed: {:?}. Wrong KEK or tampered data.",
e
))
})?;
Ok(output)
}
}
pub struct Kw192 {
kek: KekAes192,
}
impl Kw192 {
pub const KEK_SIZE: usize = 24;
pub const MIN_KEY_SIZE: usize = 16;
pub fn new(kek: &[u8]) -> CrabResult<Self> {
if kek.len() != Self::KEK_SIZE {
return Err(CrabError::invalid_input(format!(
"AES-192-KW requires 24-byte KEK, got {}",
kek.len()
)));
}
let kek_array: [u8; 24] = kek
.try_into()
.map_err(|_| CrabError::invalid_input("Failed to convert KEK to array"))?;
Ok(Self {
kek: KekAes192::from(kek_array),
})
}
pub fn generate_kek() -> CrabResult<Vec<u8>> {
crate::rand::secure_bytes(Self::KEK_SIZE)
}
pub fn wrap_key(&self, key: &[u8]) -> CrabResult<Vec<u8>> {
if key.len() < Self::MIN_KEY_SIZE {
return Err(CrabError::invalid_input(format!(
"Key must be at least {} bytes, got {}",
Self::MIN_KEY_SIZE,
key.len()
)));
}
if key.len() % 8 != 0 {
return Err(CrabError::invalid_input(format!(
"Key length must be multiple of 8 bytes, got {}",
key.len()
)));
}
let mut output = vec![0u8; key.len() + 8];
self.kek
.wrap(key, &mut output)
.map_err(|e| CrabError::crypto_error(format!("Key wrap failed: {:?}", e)))?;
Ok(output)
}
pub fn unwrap_key(&self, wrapped_key: &[u8]) -> CrabResult<Vec<u8>> {
if wrapped_key.len() < 24 {
return Err(CrabError::invalid_input(format!(
"Wrapped key must be at least 24 bytes, got {}",
wrapped_key.len()
)));
}
if wrapped_key.len() % 8 != 0 {
return Err(CrabError::invalid_input(format!(
"Wrapped key length must be multiple of 8 bytes, got {}",
wrapped_key.len()
)));
}
let mut output = vec![0u8; wrapped_key.len() - 8];
self.kek
.unwrap(wrapped_key, &mut output)
.map_err(|e| CrabError::crypto_error(format!("Key unwrap failed: {:?}", e)))?;
Ok(output)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kw128_rfc_vector() {
let kek = hex::decode("000102030405060708090A0B0C0D0E0F").unwrap();
let key_data = hex::decode("00112233445566778899AABBCCDDEEFF").unwrap();
let expected = hex::decode("1FA68B0A8112B447AEF34BD8FB5A7B829D3E862371D2CFE5").unwrap();
let wrapper = Kw128::new(&kek).unwrap();
let wrapped = wrapper.wrap_key(&key_data).unwrap();
assert_eq!(wrapped, expected);
let unwrapped = wrapper.unwrap_key(&wrapped).unwrap();
assert_eq!(unwrapped, key_data);
}
#[test]
fn test_kw256_rfc_vector() {
let kek = hex::decode("000102030405060708090A0B0C0D0E0F101112131415161718191A1B1C1D1E1F")
.unwrap();
let key_data =
hex::decode("00112233445566778899AABBCCDDEEFF000102030405060708090A0B0C0D0E0F")
.unwrap();
let expected = hex::decode(
"28C9F404C4B810F4CBCCB35CFB87F8263F5786E2D80ED326CBC7F0E71A99F43BFB988B9B7A02DD21",
)
.unwrap();
let wrapper = Kw256::new(&kek).unwrap();
let wrapped = wrapper.wrap_key(&key_data).unwrap();
assert_eq!(wrapped, expected);
let unwrapped = wrapper.unwrap_key(&wrapped).unwrap();
assert_eq!(unwrapped, key_data);
}
#[test]
fn test_kw128_generate_kek() {
let kek = Kw128::generate_kek().unwrap();
assert_eq!(kek.len(), 16);
}
#[test]
fn test_kw256_generate_kek() {
let kek = Kw256::generate_kek().unwrap();
assert_eq!(kek.len(), 32);
}
#[test]
fn test_kw128_roundtrip() {
let kek = Kw128::generate_kek().unwrap();
let wrapper = Kw128::new(&kek).unwrap();
let key = [0x42u8; 32]; let wrapped = wrapper.wrap_key(&key).unwrap();
let unwrapped = wrapper.unwrap_key(&wrapped).unwrap();
assert_eq!(unwrapped, key);
assert_eq!(wrapped.len(), key.len() + 8);
}
#[test]
fn test_kw256_roundtrip() {
let kek = Kw256::generate_kek().unwrap();
let wrapper = Kw256::new(&kek).unwrap();
let key = [0xAAu8; 24]; let wrapped = wrapper.wrap_key(&key).unwrap();
let unwrapped = wrapper.unwrap_key(&wrapped).unwrap();
assert_eq!(unwrapped, key);
assert_eq!(wrapped.len(), key.len() + 8);
}
#[test]
fn test_kw192_roundtrip() {
let kek = Kw192::generate_kek().unwrap();
let wrapper = Kw192::new(&kek).unwrap();
let key = [0x77u8; 16]; let wrapped = wrapper.wrap_key(&key).unwrap();
let unwrapped = wrapper.unwrap_key(&wrapped).unwrap();
assert_eq!(unwrapped, key);
}
#[test]
fn test_kw256_wrong_kek_fails() {
let kek1 = Kw256::generate_kek().unwrap();
let kek2 = Kw256::generate_kek().unwrap();
let wrapper1 = Kw256::new(&kek1).unwrap();
let wrapper2 = Kw256::new(&kek2).unwrap();
let key = [0x55u8; 32];
let wrapped = wrapper1.wrap_key(&key).unwrap();
let result = wrapper2.unwrap_key(&wrapped);
assert!(result.is_err());
}
#[test]
fn test_kw256_invalid_key_size() {
let kek = Kw256::generate_kek().unwrap();
let wrapper = Kw256::new(&kek).unwrap();
let result = wrapper.wrap_key(&[0u8; 8]);
assert!(result.is_err());
let result = wrapper.wrap_key(&[0u8; 17]);
assert!(result.is_err());
}
#[test]
fn test_kw128_invalid_kek_size() {
let result = Kw128::new(&[0u8; 15]);
assert!(result.is_err());
let result = Kw128::new(&[0u8; 17]);
assert!(result.is_err());
}
#[test]
fn test_kw256_invalid_kek_size() {
let result = Kw256::new(&[0u8; 31]);
assert!(result.is_err());
let result = Kw256::new(&[0u8; 33]);
assert!(result.is_err());
}
#[test]
fn test_kw256_tampered_data_fails() {
let kek = Kw256::generate_kek().unwrap();
let wrapper = Kw256::new(&kek).unwrap();
let key = [0x42u8; 32];
let mut wrapped = wrapper.wrap_key(&key).unwrap();
wrapped[10] ^= 0xFF;
let result = wrapper.unwrap_key(&wrapped);
assert!(result.is_err());
}
#[test]
fn test_kw256_different_sizes() {
let kek = Kw256::generate_kek().unwrap();
let wrapper = Kw256::new(&kek).unwrap();
for size in [16, 24, 32, 40, 48, 56, 64] {
let key = vec![0x99u8; size];
let wrapped = wrapper.wrap_key(&key).unwrap();
let unwrapped = wrapper.unwrap_key(&wrapped).unwrap();
assert_eq!(unwrapped, key);
assert_eq!(wrapped.len(), size + 8);
}
}
#[test]
fn test_kw128_deterministic() {
let kek = Kw128::generate_kek().unwrap();
let wrapper = Kw128::new(&kek).unwrap();
let key = [0x33u8; 16];
let wrapped1 = wrapper.wrap_key(&key).unwrap();
let wrapped2 = wrapper.wrap_key(&key).unwrap();
assert_eq!(wrapped1, wrapped2);
}
#[test]
fn test_kw256_multiple_keys() {
let kek = Kw256::generate_kek().unwrap();
let wrapper = Kw256::new(&kek).unwrap();
let key1 = [0x11u8; 32];
let key2 = [0x22u8; 32];
let key3 = [0x33u8; 32];
let wrapped1 = wrapper.wrap_key(&key1).unwrap();
let wrapped2 = wrapper.wrap_key(&key2).unwrap();
let wrapped3 = wrapper.wrap_key(&key3).unwrap();
assert_eq!(wrapper.unwrap_key(&wrapped1).unwrap(), key1);
assert_eq!(wrapper.unwrap_key(&wrapped2).unwrap(), key2);
assert_eq!(wrapper.unwrap_key(&wrapped3).unwrap(), key3);
assert_ne!(wrapped1, wrapped2);
assert_ne!(wrapped2, wrapped3);
}
}