use core::ops::{Add, Rem};
use crate::{Error, IV_LEN, IvLen, ctx::Ctx, error::IntegrityCheckFailed};
use aes::cipher::{
Array, Block, BlockCipherDecrypt, BlockCipherEncrypt,
array::ArraySize,
common::{InnerInit, InnerUser},
typenum::{Mod, NonZero, Sum, U16, Zero},
};
const IV: [u8; IV_LEN] = [0xA6, 0xA6, 0xA6, 0xA6, 0xA6, 0xA6, 0xA6, 0xA6];
pub type KwWrappedKey<N> = Array<u8, Sum<N, IvLen>>;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct AesKw<C> {
cipher: C,
}
impl<C> InnerUser for AesKw<C> {
type Inner = C;
}
impl<C> InnerInit for AesKw<C> {
#[inline]
fn inner_init(cipher: Self::Inner) -> Self {
AesKw { cipher }
}
}
impl<C: BlockCipherEncrypt<BlockSize = U16>> AesKw<C> {
fn wrap_key_trusted(&self, key: &[u8], buf: &mut [u8]) {
let blocks_len = key.len() / IV_LEN;
let block = &mut Block::<C>::default();
block[..IV_LEN].copy_from_slice(&IV);
buf[IV_LEN..].copy_from_slice(key);
self.cipher.encrypt_with_backend(Ctx {
blocks_len,
block,
buf,
});
buf[..IV_LEN].copy_from_slice(&block[..IV_LEN]);
}
#[inline]
pub fn wrap_key<'a>(&self, key: &[u8], buf: &'a mut [u8]) -> Result<&'a [u8], Error> {
let blocks_rem = key.len() % IV_LEN;
if blocks_rem != 0 {
return Err(Error::InvalidDataSize);
}
let expected_len = key.len() + IV_LEN;
let buf = buf
.get_mut(..expected_len)
.ok_or(Error::InvalidOutputSize { expected_len })?;
self.wrap_key_trusted(key, buf);
Ok(buf)
}
#[inline]
pub fn wrap_fixed_key<N>(&self, key: &Array<u8, N>) -> KwWrappedKey<N>
where
N: ArraySize + NonZero + Add<IvLen> + Rem<IvLen>,
Sum<N, IvLen>: ArraySize,
Mod<N, IvLen>: Zero,
{
let mut buf = KwWrappedKey::<N>::default();
self.wrap_key_trusted(key, &mut buf);
buf
}
}
impl<C: BlockCipherDecrypt<BlockSize = U16>> AesKw<C> {
fn unwrap_key_trusted<'a>(
&self,
wkey: &[u8],
buf: &'a mut [u8],
) -> Result<&'a [u8], IntegrityCheckFailed> {
let blocks_len = buf.len() / IV_LEN;
let block = &mut Block::<C>::default();
block[..IV_LEN].copy_from_slice(&wkey[..IV_LEN]);
buf.copy_from_slice(&wkey[IV_LEN..]);
self.cipher.decrypt_with_backend(Ctx {
blocks_len,
block,
buf,
});
let expected_iv = u64::from_ne_bytes(IV);
let calc_iv = u64::from_ne_bytes(block[..IV_LEN].try_into().unwrap());
if calc_iv == expected_iv {
Ok(buf)
} else {
buf.fill(0);
Err(IntegrityCheckFailed)
}
}
#[inline]
pub fn unwrap_key<'a>(&self, wkey: &[u8], buf: &'a mut [u8]) -> Result<&'a [u8], Error> {
let blocks_len = wkey.len() / IV_LEN;
let blocks_rem = wkey.len() % IV_LEN;
if blocks_rem != 0 || blocks_len < 1 {
return Err(Error::InvalidDataSize);
}
let blocks_len = blocks_len - 1;
let expected_len = blocks_len * IV_LEN;
let buf = buf
.get_mut(..expected_len)
.ok_or(Error::InvalidOutputSize { expected_len })?;
self.unwrap_key_trusted(wkey, buf)
.map_err(|_| Error::IntegrityCheckFailed)?;
Ok(buf)
}
#[inline]
pub fn unwrap_fixed_key<N>(
&self,
wkey: &KwWrappedKey<N>,
) -> Result<Array<u8, N>, IntegrityCheckFailed>
where
N: ArraySize + NonZero + Add<IvLen> + Rem<IvLen>,
Sum<N, IvLen>: ArraySize,
Mod<N, IvLen>: Zero,
{
let mut buf = Array::<u8, N>::default();
self.unwrap_key_trusted(wkey, &mut buf)?;
Ok(buf)
}
}
#[cfg(feature = "zeroize")]
impl<C: zeroize::ZeroizeOnDrop> zeroize::ZeroizeOnDrop for AesKw<C> {}