use openssl::{
cipher::{Cipher, CipherRef},
cipher_ctx::CipherCtx,
};
use zeroize::Zeroizing;
use crate::error::{CryptoError, result::CryptoResult};
const AES_BLOCK_SIZE: usize = 16; const AES_WRAP_BLOCK_SIZE: usize = 8;
fn select_cipher(kek: &[u8]) -> CryptoResult<&CipherRef> {
Ok(match kek.len() {
16 => Cipher::aes_128_wrap(),
24 => Cipher::aes_192_wrap(),
32 => Cipher::aes_256_wrap(),
_ => {
return Err(CryptoError::InvalidSize(
"The KEK size should be 16, 24 or 32 bytes".to_owned(),
));
}
})
}
pub fn rfc3394_wrap(plaintext: &[u8], kek: &[u8]) -> CryptoResult<Vec<u8>> {
let n_bytes = plaintext.len();
if !n_bytes.is_multiple_of(AES_WRAP_BLOCK_SIZE) || n_bytes < 2 * AES_WRAP_BLOCK_SIZE {
return Err(CryptoError::InvalidSize(
"The plaintext size should be >= 16 and a multiple of 8".to_owned(),
));
}
let cipher = select_cipher(kek)?;
let mut ctx = CipherCtx::new()?;
ctx.encrypt_init(Some(cipher), Some(kek), None)?;
let mut ciphertext = vec![0_u8; n_bytes + AES_WRAP_BLOCK_SIZE + (AES_BLOCK_SIZE * 2)];
let mut written = ctx.cipher_update(plaintext, Some(&mut ciphertext))?;
written += ctx.cipher_final(ciphertext.get_mut(written..).ok_or_else(|| {
CryptoError::IndexingSlicing("Buffer too small for cipher_final".to_owned())
})?)?;
ciphertext.truncate(written);
Ok(ciphertext)
}
pub fn rfc3394_unwrap(ciphertext: &[u8], kek: &[u8]) -> CryptoResult<Zeroizing<Vec<u8>>> {
let n_bytes = ciphertext.len();
if !n_bytes.is_multiple_of(AES_WRAP_BLOCK_SIZE) || n_bytes < 3 * AES_WRAP_BLOCK_SIZE {
return Err(CryptoError::InvalidSize(
"The ciphertext size should be >= 24 and a multiple of 8".to_owned(),
));
}
let cipher = select_cipher(kek)?;
let mut ctx = CipherCtx::new()?;
ctx.decrypt_init(Some(cipher), Some(kek), None)?;
let mut plaintext = Zeroizing::new(vec![
0_u8;
n_bytes - AES_WRAP_BLOCK_SIZE + (AES_BLOCK_SIZE * 2)
]);
let mut written = ctx.cipher_update(ciphertext, Some(&mut plaintext))?;
written += ctx.cipher_final(plaintext.get_mut(written..).ok_or_else(|| {
CryptoError::IndexingSlicing("Buffer too small for cipher_final".to_owned())
})?)?;
plaintext.truncate(written);
Ok(plaintext)
}
#[allow(clippy::unwrap_used, clippy::expect_used)]
#[cfg(test)]
mod tests {
use zeroize::Zeroizing;
use super::*;
fn test_wrap_unwrap(kek_hex: &str, plaintext_hex: &str, expected_ciphertext_hex: &str) {
#[cfg(not(feature = "non-fips"))]
openssl::provider::Provider::load(None, "fips").unwrap();
let kek = hex::decode(kek_hex).unwrap();
let p = hex::decode(plaintext_hex).unwrap();
let c_expected = hex::decode(expected_ciphertext_hex).unwrap();
let c = rfc3394_wrap(&p, &kek).unwrap();
assert_eq!(c, c_expected, "Wrap output mismatch");
let p_unwrapped = rfc3394_unwrap(&c, &kek).unwrap();
assert_eq!(p_unwrapped, Zeroizing::from(p), "Unwrap output mismatch");
}
#[test]
fn test_rfc3394_aes128_kek() {
test_wrap_unwrap(
"000102030405060708090A0B0C0D0E0F",
"00112233445566778899AABBCCDDEEFF",
"1FA68B0A8112B447AEF34BD8FB5A7B829D3E862371D2CFE5",
);
}
#[test]
fn test_rfc3394_aes192_kek() {
let kek = "000102030405060708090A0B0C0D0E0F1011121314151617";
test_wrap_unwrap(
kek,
"00112233445566778899AABBCCDDEEFF",
"96778B25AE6CA435F92B5B97C050AED2468AB8A17AD84E5D",
);
test_wrap_unwrap(
kek,
"00112233445566778899AABBCCDDEEFF0001020304050607",
"031D33264E15D33268F24EC260743EDCE1C6C7DDEE725A936BA814915C6762D2",
);
}
#[test]
fn test_rfc3394_aes256_kek() {
let kek = "000102030405060708090A0B0C0D0E0F101112131415161718191A1B1C1D1E1F";
test_wrap_unwrap(
kek,
"00112233445566778899AABBCCDDEEFF",
"64E8C3F9CE0F5BA263E9777905818A2A93C8191E7D6E8AE7",
);
test_wrap_unwrap(
kek,
"00112233445566778899AABBCCDDEEFF0001020304050607",
"A8F9BC1612C68B3FF6E6F4FBE30E71E4769C8B80A32CB8958CD5D17D6B254DA1",
);
test_wrap_unwrap(
kek,
"00112233445566778899AABBCCDDEEFF000102030405060708090A0B0C0D0E0F",
"28C9F404C4B810F4CBCCB35CFB87F8263F5786E2D80ED326CBC7F0E71A99F43BFB988B9B7A02DD21",
);
}
#[test]
fn test_errors() {
#[cfg(not(feature = "non-fips"))]
openssl::provider::Provider::load(None, "fips").unwrap();
let kek_bad = [0x00_u8; 1];
let p16 = [0x11_u8; 16];
rfc3394_wrap(&p16, &kek_bad).unwrap_err();
let c24 = [0x22_u8; 24];
rfc3394_unwrap(&c24, &kek_bad).unwrap_err();
let kek16 = [0x01_u8; 16];
let p15 = [0x33_u8; 15];
rfc3394_wrap(&p15, &kek16).unwrap_err();
let p8 = [0x44_u8; 8];
rfc3394_wrap(&p8, &kek16).unwrap_err();
let c16 = [0x55_u8; 16];
rfc3394_unwrap(&c16, &kek16).unwrap_err();
let c23 = [0x66_u8; 23];
rfc3394_unwrap(&c23, &kek16).unwrap_err();
}
}