use crate::error::{Error, Result};
use aes::Aes256;
use aes_kw::Kek;
pub fn wrap_key_aes_kw(kek: &[u8; 32], plaintext_key: &[u8]) -> Result<Vec<u8>> {
if plaintext_key.len() < 16 {
return Err(Error::Cryptography(
"Key to wrap must be at least 16 bytes".to_string(),
));
}
if !plaintext_key.len().is_multiple_of(8) {
return Err(Error::Cryptography(
"Key to wrap must be multiple of 8 bytes".to_string(),
));
}
let kek = Kek::<Aes256>::from(*kek);
let mut output = vec![0u8; plaintext_key.len() + 8];
kek.wrap(plaintext_key, &mut output)
.map_err(|e| Error::Cryptography(format!("Key wrap failed: {:?}", e)))?;
Ok(output)
}
pub fn unwrap_key_aes_kw(kek: &[u8; 32], wrapped_key: &[u8]) -> Result<Vec<u8>> {
if wrapped_key.len() < 24 {
return Err(Error::Cryptography(
"Wrapped key must be at least 24 bytes".to_string(),
));
}
if !wrapped_key.len().is_multiple_of(8) {
return Err(Error::Cryptography(
"Wrapped key must be multiple of 8 bytes".to_string(),
));
}
let kek = Kek::<Aes256>::from(*kek);
let mut output = vec![0u8; wrapped_key.len() - 8];
kek.unwrap(wrapped_key, &mut output)
.map_err(|e| Error::Cryptography(format!("Key unwrap failed: {:?}", e)))?;
Ok(output)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_wrap_unwrap_roundtrip() {
let kek = [0x42u8; 32];
let plaintext = [0xABu8; 32];
let wrapped = wrap_key_aes_kw(&kek, &plaintext).unwrap();
let unwrapped = unwrap_key_aes_kw(&kek, &wrapped).unwrap();
assert_eq!(&unwrapped[..], &plaintext[..]);
}
#[test]
fn test_wrap_produces_longer_output() {
let kek = [0x42u8; 32];
let plaintext = [0xABu8; 32];
let wrapped = wrap_key_aes_kw(&kek, &plaintext).unwrap();
assert_eq!(wrapped.len(), plaintext.len() + 8);
}
#[test]
fn test_wrong_kek_fails() {
let kek1 = [0x42u8; 32];
let kek2 = [0x43u8; 32];
let plaintext = [0xABu8; 32];
let wrapped = wrap_key_aes_kw(&kek1, &plaintext).unwrap();
assert!(unwrap_key_aes_kw(&kek2, &wrapped).is_err());
}
#[test]
fn test_tampering_detected() {
let kek = [0x42u8; 32];
let plaintext = [0xABu8; 32];
let mut wrapped = wrap_key_aes_kw(&kek, &plaintext).unwrap();
wrapped[0] ^= 0xFF;
assert!(unwrap_key_aes_kw(&kek, &wrapped).is_err());
}
#[test]
fn test_short_key_rejected() {
let kek = [0x42u8; 32];
let plaintext = [0xABu8; 8];
assert!(wrap_key_aes_kw(&kek, &plaintext).is_err());
}
#[test]
fn test_non_aligned_key_rejected() {
let kek = [0x42u8; 32];
let plaintext = [0xABu8; 17];
assert!(wrap_key_aes_kw(&kek, &plaintext).is_err());
}
}