use super::{BlockCipher, TagMismatch};
use crate::ct::{Choice, ConditionallySelectable, ConstantTimeEq, ConstantTimeGreater};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum KwError {
InvalidLength,
IntegrityCheck,
}
impl core::fmt::Display for KwError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
KwError::InvalidLength => f.write_str("key wrap: invalid length"),
KwError::IntegrityCheck => f.write_str("key wrap: integrity check failed"),
}
}
}
impl core::error::Error for KwError {}
impl From<TagMismatch> for KwError {
fn from(_: TagMismatch) -> Self {
KwError::IntegrityCheck
}
}
const RFC3394_IV: u64 = 0xA6A6_A6A6_A6A6_A6A6;
const RFC5649_AIV_TAG: u32 = 0xA659_59A6;
#[inline]
pub fn kw_ciphertext_len(plaintext_len: usize) -> usize {
plaintext_len + 8
}
#[inline]
pub fn kwp_ciphertext_len(plaintext_len: usize) -> usize {
plaintext_len.div_ceil(8) * 8 + 8
}
#[derive(Clone)]
pub struct AesKw<C: BlockCipher> {
cipher: C,
}
impl<C: BlockCipher> AesKw<C> {
pub fn new(cipher: C) -> Self {
AesKw { cipher }
}
pub fn wrap(&self, plaintext: &[u8], out: &mut [u8]) -> Result<(), KwError> {
if plaintext.len() < 16 {
return Err(KwError::InvalidLength);
}
wrap_w(&self.cipher, RFC3394_IV, plaintext, out)
}
pub fn unwrap(&self, ciphertext: &[u8], out: &mut [u8]) -> Result<(), KwError> {
if ciphertext.len() < 24 {
return Err(KwError::InvalidLength);
}
let recovered = unwrap_w(&self.cipher, ciphertext, out)?;
if bool::from(recovered.to_be_bytes().ct_eq(&RFC3394_IV.to_be_bytes())) {
Ok(())
} else {
for b in out.iter_mut() {
*b = 0;
}
Err(KwError::IntegrityCheck)
}
}
}
fn wrap_w<C: BlockCipher>(
cipher: &C,
iv: u64,
plaintext: &[u8],
out: &mut [u8],
) -> Result<(), KwError> {
if plaintext.is_empty() || !plaintext.len().is_multiple_of(8) {
return Err(KwError::InvalidLength);
}
if out.len() != plaintext.len() + 8 {
return Err(KwError::InvalidLength);
}
let n = plaintext.len() / 8;
out[..8].copy_from_slice(&iv.to_be_bytes());
out[8..].copy_from_slice(plaintext);
let mut block = [0u8; 16];
for j in 0..6u64 {
for i in 1..=n as u64 {
block[..8].copy_from_slice(&out[..8]);
let r_off = i as usize * 8;
block[8..].copy_from_slice(&out[r_off..r_off + 8]);
cipher.encrypt_block(&mut block);
let a_new = u64::from_be_bytes(block[..8].try_into().unwrap()) ^ (n as u64 * j + i);
out[..8].copy_from_slice(&a_new.to_be_bytes());
out[r_off..r_off + 8].copy_from_slice(&block[8..]);
}
}
Ok(())
}
fn unwrap_w<C: BlockCipher>(cipher: &C, ciphertext: &[u8], out: &mut [u8]) -> Result<u64, KwError> {
if ciphertext.len() < 16 || !ciphertext.len().is_multiple_of(8) {
return Err(KwError::InvalidLength);
}
if out.len() + 8 != ciphertext.len() {
return Err(KwError::InvalidLength);
}
let n = ciphertext.len() / 8 - 1;
let mut a = u64::from_be_bytes(ciphertext[..8].try_into().unwrap());
out.copy_from_slice(&ciphertext[8..]);
let mut block = [0u8; 16];
for j in (0..6i64).rev() {
for i in (1..=n as i64).rev() {
let t = (n as u64) * (j as u64) + i as u64;
block[..8].copy_from_slice(&(a ^ t).to_be_bytes());
let r_off = (i as usize - 1) * 8;
block[8..].copy_from_slice(&out[r_off..r_off + 8]);
cipher.decrypt_block(&mut block);
a = u64::from_be_bytes(block[..8].try_into().unwrap());
out[r_off..r_off + 8].copy_from_slice(&block[8..]);
}
}
Ok(a)
}
#[derive(Clone)]
pub struct AesKwp<C: BlockCipher> {
cipher: C,
}
impl<C: BlockCipher> AesKwp<C> {
pub fn new(cipher: C) -> Self {
AesKwp { cipher }
}
pub fn wrap(&self, plaintext: &[u8], out: &mut [u8]) -> Result<(), KwError> {
if plaintext.is_empty() || plaintext.len() > u32::MAX as usize {
return Err(KwError::InvalidLength);
}
let padded_len = plaintext.len().div_ceil(8) * 8;
if out.len() != padded_len + 8 {
return Err(KwError::InvalidLength);
}
let aiv =
(u64::from(RFC5649_AIV_TAG) << 32) | u64::from(u32::try_from(plaintext.len()).unwrap());
if padded_len == 8 {
let mut block = [0u8; 16];
block[..8].copy_from_slice(&aiv.to_be_bytes());
block[8..8 + plaintext.len()].copy_from_slice(plaintext);
self.cipher.encrypt_block(&mut block);
out.copy_from_slice(&block);
Ok(())
} else {
let mut padded = [0u8; 4096]; if padded_len > padded.len() {
return Err(KwError::InvalidLength);
}
padded[..plaintext.len()].copy_from_slice(plaintext);
for b in &mut padded[plaintext.len()..padded_len] {
*b = 0;
}
let result = wrap_w(&self.cipher, aiv, &padded[..padded_len], out);
for b in &mut padded[..padded_len] {
*b = 0;
}
let _ = core::hint::black_box(&padded);
result
}
}
pub fn unwrap(&self, ciphertext: &[u8], out: &mut [u8]) -> Result<usize, KwError> {
if ciphertext.len() < 16
|| !ciphertext.len().is_multiple_of(8)
|| out.len() + 8 < ciphertext.len()
{
return Err(KwError::InvalidLength);
}
let padded_len = ciphertext.len() - 8;
let mut scratch = [0u8; 4096];
if padded_len > scratch.len() {
return Err(KwError::InvalidLength);
}
let (aiv, padded) = if ciphertext.len() == 16 {
let mut block = [0u8; 16];
block.copy_from_slice(ciphertext);
self.cipher.decrypt_block(&mut block);
scratch[..8].copy_from_slice(&block[8..]);
(
u64::from_be_bytes(block[..8].try_into().unwrap()),
&scratch[..8],
)
} else {
let recovered = unwrap_w(&self.cipher, ciphertext, &mut scratch[..padded_len])?;
(recovered, &scratch[..padded_len])
};
let high = (aiv >> 32) as u32;
let mli_u32 = aiv as u32;
let padded_len_u32 = padded_len as u32;
let tag_ok = high.ct_eq(&RFC5649_AIV_TAG);
let mli_nonzero = !mli_u32.ct_eq(&0u32);
let mli_in_range = !mli_u32.ct_gt(&padded_len_u32);
let diff = padded_len_u32.wrapping_sub(mli_u32);
let pad_short = 8u32.ct_gt(&diff);
let mli_clamped =
u32::conditional_select(&padded_len_u32, &mli_u32, mli_u32.ct_gt(&padded_len_u32));
let mut pad_acc = 0u8;
for (i, &b) in padded.iter().enumerate() {
let in_pad: Choice = !mli_clamped.ct_gt(&(i as u32));
let mask = 0u8.wrapping_sub(in_pad.unwrap_u8());
pad_acc |= b & mask;
}
let pad_ok = pad_acc.ct_eq(&0u8);
let ok: Choice = tag_ok & mli_nonzero & mli_in_range & pad_short & pad_ok;
if !bool::from(ok) {
for b in scratch.iter_mut() {
*b = 0;
}
let _ = core::hint::black_box(&scratch);
return Err(KwError::IntegrityCheck);
}
let mli = mli_clamped as usize;
out[..mli].copy_from_slice(&padded[..mli]);
for b in &mut out[mli..] {
*b = 0;
}
for b in scratch.iter_mut() {
*b = 0;
}
let _ = core::hint::black_box(&scratch);
Ok(mli)
}
}
pub type Aes128Kw = AesKw<super::Aes128>;
pub type Aes192Kw = AesKw<super::Aes192>;
pub type Aes256Kw = AesKw<super::Aes256>;
pub type Aes128Kwp = AesKwp<super::Aes128>;
pub type Aes192Kwp = AesKwp<super::Aes192>;
pub type Aes256Kwp = AesKwp<super::Aes256>;
#[cfg(test)]
mod tests {
use super::*;
use crate::cipher::{Aes128, Aes192, Aes256};
use crate::test_util::from_hex;
#[test]
fn rfc3394_128_kek_128_data() {
let kek = from_hex::<16>("000102030405060708090A0B0C0D0E0F");
let pt = from_hex::<16>("00112233445566778899AABBCCDDEEFF");
let expected = from_hex::<24>("1FA68B0A8112B447AEF34BD8FB5A7B829D3E862371D2CFE5");
let kw = Aes128Kw::new(Aes128::new(&kek));
let mut ct = [0u8; 24];
kw.wrap(&pt, &mut ct).unwrap();
assert_eq!(ct, expected);
let mut recovered = [0u8; 16];
kw.unwrap(&ct, &mut recovered).unwrap();
assert_eq!(recovered, pt);
}
#[test]
fn rfc3394_192_kek_128_data() {
let kek = from_hex::<24>("000102030405060708090A0B0C0D0E0F1011121314151617");
let pt = from_hex::<16>("00112233445566778899AABBCCDDEEFF");
let expected = from_hex::<24>("96778B25AE6CA435F92B5B97C050AED2468AB8A17AD84E5D");
let kw = Aes192Kw::new(Aes192::new(&kek));
let mut ct = [0u8; 24];
kw.wrap(&pt, &mut ct).unwrap();
assert_eq!(ct, expected);
let mut rec = [0u8; 16];
kw.unwrap(&ct, &mut rec).unwrap();
assert_eq!(rec, pt);
}
#[test]
fn rfc3394_256_kek_128_data() {
let kek =
from_hex::<32>("000102030405060708090A0B0C0D0E0F101112131415161718191A1B1C1D1E1F");
let pt = from_hex::<16>("00112233445566778899AABBCCDDEEFF");
let expected = from_hex::<24>("64E8C3F9CE0F5BA263E9777905818A2A93C8191E7D6E8AE7");
let kw = Aes256Kw::new(Aes256::new(&kek));
let mut ct = [0u8; 24];
kw.wrap(&pt, &mut ct).unwrap();
assert_eq!(ct, expected);
let mut rec = [0u8; 16];
kw.unwrap(&ct, &mut rec).unwrap();
assert_eq!(rec, pt);
}
#[test]
fn rfc3394_192_kek_192_data() {
let kek = from_hex::<24>("000102030405060708090A0B0C0D0E0F1011121314151617");
let pt = from_hex::<24>("00112233445566778899AABBCCDDEEFF0001020304050607");
let expected =
from_hex::<32>("031D33264E15D33268F24EC260743EDCE1C6C7DDEE725A936BA814915C6762D2");
let kw = Aes192Kw::new(Aes192::new(&kek));
let mut ct = [0u8; 32];
kw.wrap(&pt, &mut ct).unwrap();
assert_eq!(ct, expected);
let mut rec = [0u8; 24];
kw.unwrap(&ct, &mut rec).unwrap();
assert_eq!(rec, pt);
}
#[test]
fn rfc3394_256_kek_192_data() {
let kek =
from_hex::<32>("000102030405060708090A0B0C0D0E0F101112131415161718191A1B1C1D1E1F");
let pt = from_hex::<24>("00112233445566778899AABBCCDDEEFF0001020304050607");
let expected =
from_hex::<32>("A8F9BC1612C68B3FF6E6F4FBE30E71E4769C8B80A32CB8958CD5D17D6B254DA1");
let kw = Aes256Kw::new(Aes256::new(&kek));
let mut ct = [0u8; 32];
kw.wrap(&pt, &mut ct).unwrap();
assert_eq!(ct, expected);
let mut rec = [0u8; 24];
kw.unwrap(&ct, &mut rec).unwrap();
assert_eq!(rec, pt);
}
#[test]
fn rfc3394_256_kek_256_data() {
let kek =
from_hex::<32>("000102030405060708090A0B0C0D0E0F101112131415161718191A1B1C1D1E1F");
let pt = from_hex::<32>("00112233445566778899AABBCCDDEEFF000102030405060708090A0B0C0D0E0F");
let expected = from_hex::<40>(
"28C9F404C4B810F4CBCCB35CFB87F8263F5786E2D80ED326CBC7F0E71A99F43BFB988B9B7A02DD21",
);
let kw = Aes256Kw::new(Aes256::new(&kek));
let mut ct = [0u8; 40];
kw.wrap(&pt, &mut ct).unwrap();
assert_eq!(ct, expected);
let mut rec = [0u8; 32];
kw.unwrap(&ct, &mut rec).unwrap();
assert_eq!(rec, pt);
}
#[test]
fn rfc3394_tamper_rejected() {
let kek = from_hex::<16>("000102030405060708090A0B0C0D0E0F");
let pt = from_hex::<16>("00112233445566778899AABBCCDDEEFF");
let kw = Aes128Kw::new(Aes128::new(&kek));
let mut ct = [0u8; 24];
kw.wrap(&pt, &mut ct).unwrap();
ct[0] ^= 1;
let mut rec = [0u8; 16];
assert_eq!(kw.unwrap(&ct, &mut rec), Err(KwError::IntegrityCheck));
assert_eq!(rec, [0u8; 16]);
}
#[test]
fn rfc5649_20_byte() {
let kek = from_hex::<24>("5840df6e29b02af1ab493b705bf16ea1ae8338f4dcc176a8");
let pt = from_hex::<20>("c37b7e6492584340bed12207808941155068f738");
let expected =
from_hex::<32>("138bdeaa9b8fa7fc61f97742e72248ee5ae6ae5360d1ae6a5f54f373fa543b6a");
let kwp = Aes192Kwp::new(Aes192::new(&kek));
let mut ct = [0u8; 32];
kwp.wrap(&pt, &mut ct).unwrap();
assert_eq!(ct, expected);
let mut rec = [0u8; 24];
let n = kwp.unwrap(&ct, &mut rec).unwrap();
assert_eq!(n, 20);
assert_eq!(&rec[..20], &pt[..]);
}
#[test]
fn rfc5649_7_byte_single_block() {
let kek = from_hex::<24>("5840df6e29b02af1ab493b705bf16ea1ae8338f4dcc176a8");
let pt = from_hex::<7>("466f7250617369");
let expected = from_hex::<16>("afbeb0f07dfbf5419200f2ccb50bb24f");
let kwp = Aes192Kwp::new(Aes192::new(&kek));
let mut ct = [0u8; 16];
kwp.wrap(&pt, &mut ct).unwrap();
assert_eq!(ct, expected);
let mut rec = [0u8; 8];
let n = kwp.unwrap(&ct, &mut rec).unwrap();
assert_eq!(n, 7);
assert_eq!(&rec[..7], &pt[..]);
assert_eq!(rec[7], 0);
}
#[test]
fn rfc5649_tamper_rejected() {
let kek = from_hex::<24>("5840df6e29b02af1ab493b705bf16ea1ae8338f4dcc176a8");
let pt = from_hex::<20>("c37b7e6492584340bed12207808941155068f738");
let kwp = Aes192Kwp::new(Aes192::new(&kek));
let mut ct = [0u8; 32];
kwp.wrap(&pt, &mut ct).unwrap();
ct[5] ^= 1;
let mut rec = [0u8; 24];
assert_eq!(kwp.unwrap(&ct, &mut rec), Err(KwError::IntegrityCheck));
}
#[test]
fn rfc5649_unwrap_validation_branches() {
let kek = from_hex::<24>("5840df6e29b02af1ab493b705bf16ea1ae8338f4dcc176a8");
let aes = Aes192::new(&kek);
let kwp = Aes192Kwp::new(aes.clone());
let pt = from_hex::<20>("c37b7e6492584340bed12207808941155068f738");
let good_mli: u32 = 20;
let good_aiv: u64 = (u64::from(RFC5649_AIV_TAG) << 32) | u64::from(good_mli);
let build = |aiv: u64, padded: &[u8]| -> [u8; 32] {
let mut out = [0u8; 32];
wrap_w(&aes, aiv, padded, &mut out).unwrap();
out
};
let mut padded_ok = [0u8; 24];
padded_ok[..pt.len()].copy_from_slice(&pt);
let ct_valid = build(good_aiv, &padded_ok);
let mut rec = [0u8; 24];
let n = kwp
.unwrap(&ct_valid, &mut rec)
.expect("valid input unwraps");
assert_eq!(n, good_mli as usize);
assert_eq!(&rec[..n], &pt[..]);
let bad_prefix_aiv: u64 = (0xDEAD_BEEFu64 << 32) | u64::from(good_mli);
let ct_bad_prefix = build(bad_prefix_aiv, &padded_ok);
let bad_mli_zero_aiv: u64 = u64::from(RFC5649_AIV_TAG) << 32;
let ct_mli_zero = build(bad_mli_zero_aiv, &padded_ok);
let bad_mli_big_aiv: u64 = (u64::from(RFC5649_AIV_TAG) << 32) | 0x0000_0100u64;
let ct_mli_big = build(bad_mli_big_aiv, &padded_ok);
let bad_mli_small_aiv: u64 = (u64::from(RFC5649_AIV_TAG) << 32) | u64::from(12u32);
let ct_mli_small = build(bad_mli_small_aiv, &padded_ok);
let mut padded_badpad = [0u8; 24];
padded_badpad[..pt.len()].copy_from_slice(&pt);
padded_badpad[20] = 0xFF;
let ct_bad_pad = build(good_aiv, &padded_badpad);
for (name, ct) in [
("bad_prefix", &ct_bad_prefix),
("mli_zero", &ct_mli_zero),
("mli_big", &ct_mli_big),
("mli_small", &ct_mli_small),
("bad_pad", &ct_bad_pad),
] {
let mut buf = [0u8; 24];
let err = kwp.unwrap(ct, &mut buf);
assert_eq!(
err,
Err(KwError::IntegrityCheck),
"case {name}: expected IntegrityCheck, got {err:?}",
);
}
}
}