use crate::{
aes,
arch::*,
block::{BatchMut, Block, Zeroed},
};
use zeroize::Zeroize;
#[cfg(any(test, feature = "testing"))]
pub mod testing;
#[derive(Zeroize)]
pub struct Key<const ROUNDS: usize> {
pub encrypt: EncryptionKey<ROUNDS>,
pub decrypt: DecryptionKey<ROUNDS>,
}
impl<const N: usize> super::aes128::EncryptionKey for Key<N> {
type Block = __m128i;
type KeyRound = KeyRound;
#[inline(always)]
fn keyround(&self, index: usize) -> &Self::KeyRound {
self.encrypt.keyround(index)
}
}
impl<const N: usize> super::aes128::DecryptionKey for Key<N> {
type Block = __m128i;
type KeyRound = KeyRound;
#[inline(always)]
fn keyround(&self, index: usize) -> &Self::KeyRound {
self.decrypt.keyround(index)
}
}
impl<const N: usize> super::aes256::EncryptionKey for Key<N> {
type Block = __m128i;
type KeyRound = KeyRound;
#[inline(always)]
fn keyround(&self, index: usize) -> &Self::KeyRound {
self.encrypt.keyround(index)
}
}
impl<const N: usize> super::aes256::DecryptionKey for Key<N> {
type Block = __m128i;
type KeyRound = KeyRound;
#[inline(always)]
fn keyround(&self, index: usize) -> &Self::KeyRound {
self.decrypt.keyround(index)
}
}
#[derive(Zeroize)]
pub struct EncryptionKey<const ROUNDS: usize>([KeyRound; ROUNDS]);
impl<const N: usize> super::aes128::EncryptionKey for EncryptionKey<N> {
type Block = __m128i;
type KeyRound = KeyRound;
#[inline(always)]
fn keyround(&self, index: usize) -> &Self::KeyRound {
unsafe {
unsafe_assert!(index < N);
self.0.get_unchecked(index)
}
}
}
impl<const N: usize> super::aes256::EncryptionKey for EncryptionKey<N> {
type Block = __m128i;
type KeyRound = KeyRound;
#[inline(always)]
fn keyround(&self, index: usize) -> &Self::KeyRound {
unsafe {
unsafe_assert!(index < N);
self.0.get_unchecked(index)
}
}
}
#[derive(Zeroize)]
pub struct DecryptionKey<const ROUNDS: usize>([KeyRound; ROUNDS]);
impl<const N: usize> super::aes128::DecryptionKey for DecryptionKey<N> {
type Block = __m128i;
type KeyRound = KeyRound;
#[inline(always)]
fn keyround(&self, index: usize) -> &Self::KeyRound {
unsafe {
unsafe_assert!(index < N);
self.0.get_unchecked(index)
}
}
}
impl<const N: usize> super::aes256::DecryptionKey for DecryptionKey<N> {
type Block = __m128i;
type KeyRound = KeyRound;
#[inline(always)]
fn keyround(&self, index: usize) -> &Self::KeyRound {
unsafe {
unsafe_assert!(index < N);
self.0.get_unchecked(index)
}
}
}
macro_rules! aeskeygenassist {
($imm8:expr, $src:expr, $dest:ident) => {
$dest = _mm_aeskeygenassist_si128($src, $imm8);
};
}
macro_rules! key_shuffle {
($dest:ident) => {{
let mut xmm4 = _mm_slli_si128($dest, 4);
$dest = xmm4.xor($dest);
xmm4 = _mm_slli_si128($dest, 8);
$dest = xmm4.xor($dest);
}}
}
pub mod aes128 {
use super::*;
use crate::aes::aes128::KEY_LEN;
const ROUNDS: usize = aes::aes128::ROUNDS + 1;
pub type Key = super::Key<ROUNDS>;
pub type EncryptionKey = super::EncryptionKey<ROUNDS>;
impl Key {
#[inline(always)]
#[allow(unknown_lints, clippy::needless_late_init)]
pub fn new(key: [u8; KEY_LEN]) -> Self {
let mut enc = [KeyRound(__m128i::zeroed()); ROUNDS];
let mut dec = [KeyRound(__m128i::zeroed()); ROUNDS];
unsafe {
debug_assert!(Avx2::is_supported());
let mut xmm0;
let mut xmm1;
macro_rules! key_expansion_128 {
($idx:expr) => {{
key_shuffle!(xmm0);
xmm1 = _mm_shuffle_epi32(xmm1, 0b11111111);
xmm0 = xmm1.xor(xmm0);
enc[$idx] = KeyRound(xmm0);
}};
}
xmm0 = __m128i::from_array(key);
enc[0] = KeyRound(xmm0);
aeskeygenassist!(0x1, xmm0, xmm1);
key_expansion_128!(1);
aeskeygenassist!(0x2, xmm0, xmm1);
key_expansion_128!(2);
aeskeygenassist!(0x4, xmm0, xmm1);
key_expansion_128!(3);
aeskeygenassist!(0x8, xmm0, xmm1);
key_expansion_128!(4);
aeskeygenassist!(0x10, xmm0, xmm1);
key_expansion_128!(5);
aeskeygenassist!(0x20, xmm0, xmm1);
key_expansion_128!(6);
aeskeygenassist!(0x40, xmm0, xmm1);
key_expansion_128!(7);
aeskeygenassist!(0x80, xmm0, xmm1);
key_expansion_128!(8);
aeskeygenassist!(0x1b, xmm0, xmm1);
key_expansion_128!(9);
aeskeygenassist!(0x36, xmm0, xmm1);
key_expansion_128!(10);
dec[10] = enc[10];
dec[9] = enc[9].inv_mix_columns();
dec[8] = enc[8].inv_mix_columns();
dec[7] = enc[7].inv_mix_columns();
dec[6] = enc[6].inv_mix_columns();
dec[5] = enc[5].inv_mix_columns();
dec[4] = enc[4].inv_mix_columns();
dec[3] = enc[3].inv_mix_columns();
dec[2] = enc[2].inv_mix_columns();
dec[1] = enc[1].inv_mix_columns();
dec[0] = enc[0];
}
Self {
encrypt: EncryptionKey(enc),
decrypt: DecryptionKey(dec),
}
}
}
}
pub mod aes256 {
use super::*;
use crate::aes::aes256::KEY_LEN;
const ROUNDS: usize = aes::aes256::ROUNDS + 1;
pub type Key = super::Key<ROUNDS>;
pub type EncryptionKey = super::EncryptionKey<ROUNDS>;
impl Key {
#[inline(always)]
#[allow(unknown_lints, clippy::needless_late_init)]
pub fn new(key: [u8; KEY_LEN]) -> Self {
let mut enc = [KeyRound(__m128i::zeroed()); ROUNDS];
let mut dec = [KeyRound(__m128i::zeroed()); ROUNDS];
unsafe {
debug_assert!(Avx2::is_supported());
let mut xmm0;
let mut xmm1;
let mut xmm2;
macro_rules! key_expansion_256a {
($idx:expr) => {{
key_shuffle!(xmm0);
xmm1 = _mm_shuffle_epi32(xmm1, 0b11111111);
xmm0 = xmm0.xor(xmm1);
enc[$idx] = KeyRound(xmm0);
}};
}
macro_rules! key_expansion_256b {
($idx:expr) => {{
key_shuffle!(xmm2);
xmm1 = _mm_shuffle_epi32(xmm1, 0b10101010);
xmm2 = xmm2.xor(xmm1);
enc[$idx] = KeyRound(xmm2);
}};
}
let key = key.as_ptr() as *const __m128i;
xmm0 = _mm_loadu_si128(key);
enc[0] = KeyRound(xmm0);
xmm2 = _mm_loadu_si128(key.offset(1));
enc[1] = KeyRound(xmm2);
aeskeygenassist!(0x1, xmm2, xmm1);
key_expansion_256a!(2);
aeskeygenassist!(0x1, xmm0, xmm1);
key_expansion_256b!(3);
aeskeygenassist!(0x2, xmm2, xmm1);
key_expansion_256a!(4);
aeskeygenassist!(0x2, xmm0, xmm1);
key_expansion_256b!(5);
aeskeygenassist!(0x4, xmm2, xmm1);
key_expansion_256a!(6);
aeskeygenassist!(0x4, xmm0, xmm1);
key_expansion_256b!(7);
aeskeygenassist!(0x8, xmm2, xmm1);
key_expansion_256a!(8);
aeskeygenassist!(0x8, xmm0, xmm1);
key_expansion_256b!(9);
aeskeygenassist!(0x10, xmm2, xmm1);
key_expansion_256a!(10);
aeskeygenassist!(0x10, xmm0, xmm1);
key_expansion_256b!(11);
aeskeygenassist!(0x20, xmm2, xmm1);
key_expansion_256a!(12);
aeskeygenassist!(0x20, xmm0, xmm1);
key_expansion_256b!(13);
aeskeygenassist!(0x40, xmm2, xmm1);
key_expansion_256a!(14);
dec[14] = enc[14];
dec[13] = enc[13].inv_mix_columns();
dec[12] = enc[12].inv_mix_columns();
dec[11] = enc[11].inv_mix_columns();
dec[10] = enc[10].inv_mix_columns();
dec[9] = enc[9].inv_mix_columns();
dec[8] = enc[8].inv_mix_columns();
dec[7] = enc[7].inv_mix_columns();
dec[6] = enc[6].inv_mix_columns();
dec[5] = enc[5].inv_mix_columns();
dec[4] = enc[4].inv_mix_columns();
dec[3] = enc[3].inv_mix_columns();
dec[2] = enc[2].inv_mix_columns();
dec[1] = enc[1].inv_mix_columns();
dec[0] = enc[0];
}
Self {
encrypt: EncryptionKey(enc),
decrypt: DecryptionKey(dec),
}
}
}
}
#[derive(Clone, Copy)]
#[repr(transparent)]
pub struct KeyRound(__m128i);
impl KeyRound {
#[inline(always)]
fn inv_mix_columns(self) -> Self {
unsafe {
debug_assert!(Avx2::is_supported());
Self(_mm_aesimc_si128(self.0))
}
}
}
impl Default for KeyRound {
#[inline(always)]
fn default() -> Self {
Self(__m128i::zeroed())
}
}
impl zeroize::DefaultIsZeroes for KeyRound {}
impl super::KeyRound for KeyRound {
type Block = __m128i;
#[inline(always)]
fn xor<B: BatchMut<Block = __m128i>>(&self, block: &mut B) {
block.update(|_idx, b| *b = b.xor(self.0));
}
#[inline(always)]
fn encrypt<B: BatchMut<Block = __m128i>>(&self, block: &mut B) {
unsafe {
debug_assert!(Avx2::is_supported());
block.update(|_idx, b| *b = _mm_aesenc_si128(*b, self.0));
}
}
#[inline(always)]
fn encrypt_finish<B: BatchMut<Block = __m128i>>(&self, block: &mut B) {
unsafe {
debug_assert!(Avx2::is_supported());
block.update(|_idx, b| *b = _mm_aesenclast_si128(*b, self.0));
}
}
#[inline(always)]
fn decrypt<B: BatchMut<Block = __m128i>>(&self, block: &mut B) {
unsafe {
debug_assert!(Avx2::is_supported());
block.update(|_idx, b| *b = _mm_aesdec_si128(*b, self.0));
}
}
#[inline(always)]
fn decrypt_finish<B: BatchMut<Block = __m128i>>(&self, block: &mut B) {
unsafe {
debug_assert!(Avx2::is_supported());
block.update(|_idx, b| *b = _mm_aesdeclast_si128(*b, self.0));
}
}
}