use core::{convert::TryInto as _, marker::PhantomData, mem};
use aes::{
cipher::{
generic_array::typenum::Unsigned as _,
{BlockDecrypt, BlockEncrypt, KeyInit, KeySizeUser},
},
Aes128, Aes192, Aes256,
};
use crate::{Error, Result};
pub type Aes128Kw<'a> = AesKeyWrap<'a, Aes128>;
pub type Aes192Kw<'a> = AesKeyWrap<'a, Aes192>;
pub type Aes256Kw<'a> = AesKeyWrap<'a, Aes256>;
pub const BLOCK: usize = mem::size_of::<u64>();
pub const DIV: u64 = 0xA6A6A6A6A6A6A6A6;
#[derive(Clone, Copy, Debug)]
pub struct AesKeyWrap<'a, T> {
key: &'a [u8],
cipher: PhantomData<T>,
}
impl<'a, T> AesKeyWrap<'a, T> {
pub const BLOCK: usize = BLOCK;
pub fn new(key: &'a [u8]) -> Self {
Self {
key,
cipher: PhantomData,
}
}
}
impl<'a, T> AesKeyWrap<'a, T>
where
T: KeyInit,
{
pub const KEY_LENGTH: usize = <T as KeySizeUser>::KeySize::USIZE;
}
impl<'a, T> AesKeyWrap<'a, T>
where
T: BlockEncrypt + BlockDecrypt + KeyInit,
{
#[allow(non_snake_case)]
pub fn wrap_key(&self, plaintext: &[u8], ciphertext: &mut [u8]) -> Result<()> {
assert_buffer_gte!(ciphertext.len(), plaintext.len() + BLOCK, "ciphertext");
if plaintext.len() % BLOCK != 0 {
return Err(Error::CipherError { alg: "AES Key Wrap" });
}
let cipher: T = T::new_from_slice(self.key).unwrap();
let N: usize = plaintext.len() / BLOCK;
let R: &mut [u8] = &mut ciphertext[BLOCK..];
let mut A: u64 = DIV;
R.copy_from_slice(plaintext);
let mut B: [u8; BLOCK << 1] = [0; BLOCK << 1];
for j in 0..=5 {
for i in 1..=N {
B[..BLOCK].copy_from_slice(&A.to_be_bytes());
B[BLOCK..].copy_from_slice(&R[BLOCK * (i - 1)..BLOCK * i]);
cipher.encrypt_block((&mut B[..]).into());
A = Self::__read_u64(&B[..BLOCK]) ^ ((N * j) + i) as u64;
R[BLOCK * (i - 1)..BLOCK * i].copy_from_slice(&B[BLOCK..]);
}
}
ciphertext[..BLOCK].copy_from_slice(&A.to_be_bytes());
Ok(())
}
#[allow(non_snake_case)]
pub fn unwrap_key(&self, ciphertext: &[u8], plaintext: &mut [u8]) -> Result<()> {
assert_buffer_gte!(ciphertext.len(), BLOCK, "ciphertext");
assert_buffer_gte!(plaintext.len(), ciphertext.len() - BLOCK, "plaintext");
if ciphertext.len() % BLOCK != 0 {
return Err(Error::CipherError { alg: "AES Key Wrap" });
}
let cipher: T = T::new_from_slice(self.key).unwrap();
let N: usize = (ciphertext.len() / BLOCK) - 1;
let R: &mut [u8] = plaintext;
let mut A: u64 = Self::__read_u64(&ciphertext[..BLOCK]);
R.copy_from_slice(&ciphertext[BLOCK..]);
let mut B: [u8; BLOCK << 1] = [0; BLOCK << 1];
for j in (0..=5).rev() {
for i in (1..=N).rev() {
B[..BLOCK].copy_from_slice(&(A ^ ((N * j) + i) as u64).to_be_bytes());
B[BLOCK..].copy_from_slice(&R[BLOCK * (i - 1)..BLOCK * i]);
cipher.decrypt_block((&mut B[..]).into());
A = Self::__read_u64(&B[..BLOCK]);
R[BLOCK * (i - 1)..BLOCK * i].copy_from_slice(&B[BLOCK..]);
}
}
if A == DIV {
Ok(())
} else {
Err(Error::CipherError { alg: "AES Key Wrap" })
}
}
fn __read_u64(slice: &[u8]) -> u64 {
assert_eq!(slice.len(), BLOCK);
u64::from_be_bytes(slice.try_into().unwrap())
}
}