use crate::generators::aes_ctr::{
AesBlockCipher, AesKey, AES_CALLS_PER_BATCH, BYTES_PER_AES_CALL, BYTES_PER_BATCH,
};
use std::arch::x86_64::{
__m128i, _mm_aesenc_si128, _mm_aesenclast_si128, _mm_aeskeygenassist_si128, _mm_shuffle_epi32,
_mm_slli_si128, _mm_store_si128, _mm_xor_si128,
};
use std::mem::transmute;
#[derive(Clone)]
pub struct AesniBlockCipher {
round_keys: [__m128i; 11],
}
impl AesBlockCipher for AesniBlockCipher {
fn new(key: AesKey) -> AesniBlockCipher {
let aes_detected = is_x86_feature_detected!("aes");
let sse2_detected = is_x86_feature_detected!("sse2");
if !(aes_detected && sse2_detected) {
panic!(
"The AesniBlockCipher requires both aes and sse2 x86 CPU features.\n\
aes feature available: {aes_detected}\nsse2 feature available: {sse2_detected}\n\
Please consider enabling the SoftwareRandomGenerator with the `software-prng` feature",
)
}
let round_keys = unsafe { generate_round_keys(key) };
AesniBlockCipher { round_keys }
}
fn generate_batch(&mut self, data: [u128; AES_CALLS_PER_BATCH]) -> [u8; BYTES_PER_BATCH] {
#[target_feature(enable = "sse2,aes")]
unsafe fn implementation(
this: &AesniBlockCipher,
data: [u128; AES_CALLS_PER_BATCH],
) -> [u8; BYTES_PER_BATCH] {
si128arr_to_u8arr(aes_encrypt_many(
u128_to_si128(data[0]),
u128_to_si128(data[1]),
u128_to_si128(data[2]),
u128_to_si128(data[3]),
u128_to_si128(data[4]),
u128_to_si128(data[5]),
u128_to_si128(data[6]),
u128_to_si128(data[7]),
&this.round_keys,
))
}
unsafe { implementation(self, data) }
}
fn generate_next(&mut self, data: u128) -> [u8; BYTES_PER_AES_CALL] {
unsafe { transmute(aes_encrypt_one(u128_to_si128(data), &self.round_keys)) }
}
}
#[target_feature(enable = "sse2,aes")]
unsafe fn generate_round_keys(key: AesKey) -> [__m128i; 11] {
let key = u128_to_si128(key.0);
let mut keys: [__m128i; 11] = [u128_to_si128(0); 11];
aes_128_key_expansion(key, &mut keys);
keys
}
#[inline(always)]
fn aes_encrypt_one(message: __m128i, keys: &[__m128i; 11]) -> __m128i {
unsafe {
let mut tmp_1 = _mm_xor_si128(message, keys[0]);
for key in keys.iter().take(10).skip(1) {
tmp_1 = _mm_aesenc_si128(tmp_1, *key);
}
_mm_aesenclast_si128(tmp_1, keys[10])
}
}
#[allow(clippy::too_many_arguments)]
#[inline(always)]
fn aes_encrypt_many(
message_1: __m128i,
message_2: __m128i,
message_3: __m128i,
message_4: __m128i,
message_5: __m128i,
message_6: __m128i,
message_7: __m128i,
message_8: __m128i,
keys: &[__m128i; 11],
) -> [__m128i; 8] {
unsafe {
let mut tmp_1 = _mm_xor_si128(message_1, keys[0]);
let mut tmp_2 = _mm_xor_si128(message_2, keys[0]);
let mut tmp_3 = _mm_xor_si128(message_3, keys[0]);
let mut tmp_4 = _mm_xor_si128(message_4, keys[0]);
let mut tmp_5 = _mm_xor_si128(message_5, keys[0]);
let mut tmp_6 = _mm_xor_si128(message_6, keys[0]);
let mut tmp_7 = _mm_xor_si128(message_7, keys[0]);
let mut tmp_8 = _mm_xor_si128(message_8, keys[0]);
for key in keys.iter().take(10).skip(1) {
tmp_1 = _mm_aesenc_si128(tmp_1, *key);
tmp_2 = _mm_aesenc_si128(tmp_2, *key);
tmp_3 = _mm_aesenc_si128(tmp_3, *key);
tmp_4 = _mm_aesenc_si128(tmp_4, *key);
tmp_5 = _mm_aesenc_si128(tmp_5, *key);
tmp_6 = _mm_aesenc_si128(tmp_6, *key);
tmp_7 = _mm_aesenc_si128(tmp_7, *key);
tmp_8 = _mm_aesenc_si128(tmp_8, *key);
}
tmp_1 = _mm_aesenclast_si128(tmp_1, keys[10]);
tmp_2 = _mm_aesenclast_si128(tmp_2, keys[10]);
tmp_3 = _mm_aesenclast_si128(tmp_3, keys[10]);
tmp_4 = _mm_aesenclast_si128(tmp_4, keys[10]);
tmp_5 = _mm_aesenclast_si128(tmp_5, keys[10]);
tmp_6 = _mm_aesenclast_si128(tmp_6, keys[10]);
tmp_7 = _mm_aesenclast_si128(tmp_7, keys[10]);
tmp_8 = _mm_aesenclast_si128(tmp_8, keys[10]);
[tmp_1, tmp_2, tmp_3, tmp_4, tmp_5, tmp_6, tmp_7, tmp_8]
}
}
fn aes_128_assist(temp1: __m128i, temp2: __m128i) -> __m128i {
let mut temp3: __m128i;
let mut temp2 = temp2;
let mut temp1 = temp1;
unsafe {
temp2 = _mm_shuffle_epi32(temp2, 0xff);
temp3 = _mm_slli_si128(temp1, 0x4);
temp1 = _mm_xor_si128(temp1, temp3);
temp3 = _mm_slli_si128(temp3, 0x4);
temp1 = _mm_xor_si128(temp1, temp3);
temp3 = _mm_slli_si128(temp3, 0x4);
temp1 = _mm_xor_si128(temp1, temp3);
temp1 = _mm_xor_si128(temp1, temp2);
}
temp1
}
#[inline(always)]
fn aes_128_key_expansion(key: __m128i, keys: &mut [__m128i; 11]) {
let (mut temp1, mut temp2): (__m128i, __m128i);
temp1 = key;
unsafe {
_mm_store_si128(keys.as_mut_ptr(), temp1);
temp2 = _mm_aeskeygenassist_si128(temp1, 0x01);
temp1 = aes_128_assist(temp1, temp2);
_mm_store_si128(keys.as_mut_ptr().offset(1), temp1);
temp2 = _mm_aeskeygenassist_si128(temp1, 0x02);
temp1 = aes_128_assist(temp1, temp2);
_mm_store_si128(keys.as_mut_ptr().offset(2), temp1);
temp2 = _mm_aeskeygenassist_si128(temp1, 0x04);
temp1 = aes_128_assist(temp1, temp2);
_mm_store_si128(keys.as_mut_ptr().offset(3), temp1);
temp2 = _mm_aeskeygenassist_si128(temp1, 0x08);
temp1 = aes_128_assist(temp1, temp2);
_mm_store_si128(keys.as_mut_ptr().offset(4), temp1);
temp2 = _mm_aeskeygenassist_si128(temp1, 0x10);
temp1 = aes_128_assist(temp1, temp2);
_mm_store_si128(keys.as_mut_ptr().offset(5), temp1);
temp2 = _mm_aeskeygenassist_si128(temp1, 0x20);
temp1 = aes_128_assist(temp1, temp2);
_mm_store_si128(keys.as_mut_ptr().offset(6), temp1);
temp2 = _mm_aeskeygenassist_si128(temp1, 0x40);
temp1 = aes_128_assist(temp1, temp2);
_mm_store_si128(keys.as_mut_ptr().offset(7), temp1);
temp2 = _mm_aeskeygenassist_si128(temp1, 0x80);
temp1 = aes_128_assist(temp1, temp2);
_mm_store_si128(keys.as_mut_ptr().offset(8), temp1);
temp2 = _mm_aeskeygenassist_si128(temp1, 0x1b);
temp1 = aes_128_assist(temp1, temp2);
_mm_store_si128(keys.as_mut_ptr().offset(9), temp1);
temp2 = _mm_aeskeygenassist_si128(temp1, 0x36);
temp1 = aes_128_assist(temp1, temp2);
_mm_store_si128(keys.as_mut_ptr().offset(10), temp1);
}
}
#[inline(always)]
fn u128_to_si128(input: u128) -> __m128i {
unsafe { transmute(input) }
}
#[allow(unused)] fn si128_to_u128(input: __m128i) -> u128 {
unsafe { transmute(input) }
}
#[inline(always)]
fn si128arr_to_u8arr(input: [__m128i; 8]) -> [u8; BYTES_PER_BATCH] {
unsafe { transmute(input) }
}
#[cfg(test)]
mod test {
use super::*;
const CIPHER_KEY: u128 = u128::from_be(0x000102030405060708090a0b0c0d0e0f);
const KEY_SCHEDULE: [u128; 11] = [
u128::from_be(0x000102030405060708090a0b0c0d0e0f),
u128::from_be(0xd6aa74fdd2af72fadaa678f1d6ab76fe),
u128::from_be(0xb692cf0b643dbdf1be9bc5006830b3fe),
u128::from_be(0xb6ff744ed2c2c9bf6c590cbf0469bf41),
u128::from_be(0x47f7f7bc95353e03f96c32bcfd058dfd),
u128::from_be(0x3caaa3e8a99f9deb50f3af57adf622aa),
u128::from_be(0x5e390f7df7a69296a7553dc10aa31f6b),
u128::from_be(0x14f9701ae35fe28c440adf4d4ea9c026),
u128::from_be(0x47438735a41c65b9e016baf4aebf7ad2),
u128::from_be(0x549932d1f08557681093ed9cbe2c974e),
u128::from_be(0x13111d7fe3944a17f307a78b4d2b30c5),
];
const PLAINTEXT: u128 = u128::from_be(0x00112233445566778899aabbccddeeff);
const CIPHERTEXT: u128 = u128::from_be(0x69c4e0d86a7b0430d8cdb78070b4c55a);
#[test]
fn test_generate_key_schedule() {
let key = u128_to_si128(CIPHER_KEY);
let mut keys: [__m128i; 11] = [u128_to_si128(0); 11];
aes_128_key_expansion(key, &mut keys);
for (expected, actual) in KEY_SCHEDULE.iter().zip(keys.iter()) {
assert_eq!(*expected, si128_to_u128(*actual));
}
}
#[test]
fn test_encrypt_many_messages() {
let message = u128_to_si128(PLAINTEXT);
let key = u128_to_si128(CIPHER_KEY);
let mut keys: [__m128i; 11] = [u128_to_si128(0); 11];
aes_128_key_expansion(key, &mut keys);
let ciphertexts = aes_encrypt_many(
message, message, message, message, message, message, message, message, &keys,
);
for ct in &ciphertexts {
assert_eq!(CIPHERTEXT, si128_to_u128(*ct));
}
}
#[test]
fn test_encrypt_one_message() {
let message = u128_to_si128(PLAINTEXT);
let key = u128_to_si128(CIPHER_KEY);
let mut keys: [__m128i; 11] = [u128_to_si128(0); 11];
aes_128_key_expansion(key, &mut keys);
let ciphertext = aes_encrypt_one(message, &keys);
assert_eq!(CIPHERTEXT, si128_to_u128(ciphertext));
}
}