use core::cell::Cell;
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
use crate::constants::{AES128_KEY_COUNT, AES128_KEY_SIZE, AES256_KEY_COUNT, AES256_KEY_SIZE};
const _: () = assert!(size_of::<__m128i>() == size_of::<u128>());
const _: () = assert!(align_of::<__m128i>() == align_of::<u128>());
#[derive(Clone)]
pub struct Aes128Ctr64 {
counter: Cell<__m128i>,
round_keys: Cell<[__m128i; AES128_KEY_COUNT]>,
}
impl Drop for Aes128Ctr64 {
fn drop(&mut self) {
self.counter.set(unsafe { core::mem::zeroed() });
self.round_keys.set(unsafe { core::mem::zeroed() });
core::sync::atomic::compiler_fence(core::sync::atomic::Ordering::SeqCst);
}
}
impl Aes128Ctr64 {
#[cfg(all(
feature = "tls",
not(any(
feature = "tls_aes128_ctr128",
feature = "tls_aes256_ctr64",
feature = "tls_aes256_ctr128"
))
))]
pub(crate) const fn zeroed() -> Self {
Self {
counter: Cell::new(unsafe { core::mem::zeroed() }),
round_keys: Cell::new(unsafe { core::mem::zeroed() }),
}
}
#[target_feature(enable = "sse2", enable = "aes")]
pub(crate) fn from_seed_impl(key: [u8; 16], nonce: [u8; 8], counter: [u8; 8]) -> Self {
let counter =
((u64::from_le_bytes(nonce) as u128) << 64) + u64::from_le_bytes(counter) as u128;
let counter = unsafe { _mm_loadu_si128(counter.to_le_bytes().as_ptr().cast()) };
let round_keys: [__m128i; AES128_KEY_COUNT] = aes128_key_expansion(key);
Self {
counter: Cell::new(counter),
round_keys: Cell::new(round_keys),
}
}
#[target_feature(enable = "sse2", enable = "aes")]
pub(crate) fn seed_impl(&self, key: [u8; 16], nonce: [u8; 8], counter: [u8; 8]) {
let counter =
((u64::from_le_bytes(nonce) as u128) << 64) + u64::from_le_bytes(counter) as u128;
let counter = unsafe { _mm_loadu_si128(counter.to_le_bytes().as_ptr().cast()) };
let round_keys: [__m128i; AES128_KEY_COUNT] = aes128_key_expansion(key);
self.counter.set(counter);
self.round_keys.set(round_keys)
}
pub(crate) fn is_hardware_accelerated_impl(&self) -> bool {
true
}
pub(crate) fn counter_impl(&self) -> u64 {
let bytes: [u8; 16] = unsafe { *(&self.counter.get() as *const __m128i as *const _) };
u128::from_le_bytes(bytes) as u64
}
#[target_feature(enable = "sse2", enable = "aes")]
pub(crate) fn next_impl(&self) -> u128 {
let counter = self.counter.get();
let increment = _mm_set_epi64x(0, 1);
let new_counter = _mm_add_epi64(counter, increment);
self.counter.set(new_counter);
let rks = self.round_keys.as_array_of_cells();
let mut state = _mm_xor_si128(counter, rks[0].get());
state = _mm_aesenc_si128(state, rks[1].get());
state = _mm_aesenc_si128(state, rks[2].get());
state = _mm_aesenc_si128(state, rks[3].get());
state = _mm_aesenc_si128(state, rks[4].get());
state = _mm_aesenc_si128(state, rks[5].get());
state = _mm_aesenc_si128(state, rks[6].get());
state = _mm_aesenc_si128(state, rks[7].get());
state = _mm_aesenc_si128(state, rks[8].get());
state = _mm_aesenc_si128(state, rks[9].get());
state = _mm_aesenclast_si128(state, rks[10].get());
u128::from_le_bytes(unsafe { *(&state as *const __m128i as *const _) })
}
}
#[derive(Clone)]
pub struct Aes128Ctr128 {
counter: Cell<u128>,
round_keys: Cell<[__m128i; AES128_KEY_COUNT]>,
}
impl Drop for Aes128Ctr128 {
fn drop(&mut self) {
self.counter.set(0);
self.round_keys.set(unsafe { core::mem::zeroed() });
core::sync::atomic::compiler_fence(core::sync::atomic::Ordering::SeqCst);
}
}
impl Aes128Ctr128 {
#[cfg(all(feature = "tls", feature = "tls_aes128_ctr128"))]
pub(crate) const fn zeroed() -> Self {
Self {
counter: Cell::new(0),
round_keys: Cell::new(unsafe { core::mem::zeroed() }),
}
}
pub(crate) fn jump_impl(&self) -> Self {
let clone = self.clone();
self.counter.set(self.counter.get() + (1 << 64));
clone
}
pub(crate) fn long_jump_impl(&self) -> Self {
let clone = self.clone();
self.counter.set(self.counter.get() + (1 << 96));
clone
}
#[target_feature(enable = "sse2", enable = "aes")]
pub(crate) fn from_seed_impl(key: [u8; 16], counter: [u8; 16]) -> Self {
let counter = u128::from_le_bytes(counter);
let round_keys: [__m128i; AES128_KEY_COUNT] = aes128_key_expansion(key);
Self {
counter: Cell::new(counter),
round_keys: Cell::new(round_keys),
}
}
#[target_feature(enable = "sse2", enable = "aes")]
pub(crate) fn seed_impl(&self, key: [u8; 16], counter: [u8; 16]) {
let counter = u128::from_le_bytes(counter);
let round_keys: [__m128i; AES128_KEY_COUNT] = aes128_key_expansion(key);
self.counter.set(counter);
self.round_keys.set(round_keys)
}
pub(crate) fn is_hardware_accelerated_impl(&self) -> bool {
true
}
pub(crate) fn counter_impl(&self) -> u128 {
self.counter.get()
}
#[target_feature(enable = "sse2", enable = "aes")]
pub(crate) fn next_impl(&self) -> u128 {
let counter = self.counter.get();
self.counter.set(counter.wrapping_add(1));
let rks = self.round_keys.as_array_of_cells();
let counter = unsafe { _mm_loadu_si128(counter.to_le_bytes().as_ptr().cast()) };
let mut state = _mm_xor_si128(counter, rks[0].get());
state = _mm_aesenc_si128(state, rks[1].get());
state = _mm_aesenc_si128(state, rks[2].get());
state = _mm_aesenc_si128(state, rks[3].get());
state = _mm_aesenc_si128(state, rks[4].get());
state = _mm_aesenc_si128(state, rks[5].get());
state = _mm_aesenc_si128(state, rks[6].get());
state = _mm_aesenc_si128(state, rks[7].get());
state = _mm_aesenc_si128(state, rks[8].get());
state = _mm_aesenc_si128(state, rks[9].get());
state = _mm_aesenclast_si128(state, rks[10].get());
u128::from_le_bytes(unsafe { *(&state as *const __m128i as *const _) })
}
}
#[derive(Clone)]
pub struct Aes256Ctr64 {
counter: Cell<__m128i>,
round_keys: Cell<[__m128i; AES256_KEY_COUNT]>,
}
impl Drop for Aes256Ctr64 {
fn drop(&mut self) {
self.counter.set(unsafe { core::mem::zeroed() });
self.round_keys.set(unsafe { core::mem::zeroed() });
core::sync::atomic::compiler_fence(core::sync::atomic::Ordering::SeqCst);
}
}
impl Aes256Ctr64 {
#[cfg(all(feature = "tls", feature = "tls_aes256_ctr64"))]
pub(crate) const fn zeroed() -> Self {
Self {
counter: Cell::new(unsafe { core::mem::zeroed() }),
round_keys: Cell::new(unsafe { core::mem::zeroed() }),
}
}
#[target_feature(enable = "sse2", enable = "aes")]
pub(crate) fn from_seed_impl(key: [u8; 32], nonce: [u8; 8], counter: [u8; 8]) -> Self {
let counter =
((u64::from_le_bytes(nonce) as u128) << 64) + u64::from_le_bytes(counter) as u128;
let counter = unsafe { _mm_loadu_si128(counter.to_le_bytes().as_ptr().cast()) };
let round_keys: [__m128i; AES256_KEY_COUNT] = aes256_key_expansion(key);
Self {
counter: Cell::new(counter),
round_keys: Cell::new(round_keys),
}
}
#[target_feature(enable = "sse2", enable = "aes")]
pub(crate) fn seed_impl(&self, key: [u8; 32], nonce: [u8; 8], counter: [u8; 8]) {
let counter =
((u64::from_le_bytes(nonce) as u128) << 64) + u64::from_le_bytes(counter) as u128;
let counter = unsafe { _mm_loadu_si128(counter.to_le_bytes().as_ptr().cast()) };
let round_keys: [__m128i; AES256_KEY_COUNT] = aes256_key_expansion(key);
self.counter.set(counter);
self.round_keys.set(round_keys)
}
pub(crate) fn is_hardware_accelerated_impl(&self) -> bool {
true
}
pub(crate) fn counter_impl(&self) -> u64 {
let bytes: [u8; 16] = unsafe { *(&self.counter.get() as *const __m128i as *const _) };
u128::from_le_bytes(bytes) as u64
}
#[target_feature(enable = "sse2", enable = "aes")]
pub(crate) fn next_impl(&self) -> u128 {
let counter = self.counter.get();
let increment = _mm_set_epi64x(0, 1);
let new_counter = _mm_add_epi64(counter, increment);
self.counter.set(new_counter);
let rks = self.round_keys.as_array_of_cells();
let mut state = _mm_xor_si128(counter, rks[0].get());
state = _mm_aesenc_si128(state, rks[1].get());
state = _mm_aesenc_si128(state, rks[2].get());
state = _mm_aesenc_si128(state, rks[3].get());
state = _mm_aesenc_si128(state, rks[4].get());
state = _mm_aesenc_si128(state, rks[5].get());
state = _mm_aesenc_si128(state, rks[6].get());
state = _mm_aesenc_si128(state, rks[7].get());
state = _mm_aesenc_si128(state, rks[8].get());
state = _mm_aesenc_si128(state, rks[9].get());
state = _mm_aesenc_si128(state, rks[10].get());
state = _mm_aesenc_si128(state, rks[11].get());
state = _mm_aesenc_si128(state, rks[12].get());
state = _mm_aesenc_si128(state, rks[13].get());
state = _mm_aesenclast_si128(state, rks[14].get());
u128::from_le_bytes(unsafe { *(&state as *const __m128i as *const _) })
}
}
#[derive(Clone)]
pub struct Aes256Ctr128 {
counter: Cell<u128>,
round_keys: Cell<[__m128i; AES256_KEY_COUNT]>,
}
impl Drop for Aes256Ctr128 {
fn drop(&mut self) {
self.counter.set(0);
self.round_keys.set(unsafe { core::mem::zeroed() });
core::sync::atomic::compiler_fence(core::sync::atomic::Ordering::SeqCst);
}
}
impl Aes256Ctr128 {
#[cfg(all(feature = "tls", feature = "tls_aes256_ctr128"))]
pub(crate) const fn zeroed() -> Self {
Self {
counter: Cell::new(0),
round_keys: Cell::new(unsafe { core::mem::zeroed() }),
}
}
pub(crate) fn jump_impl(&self) -> Self {
let clone = self.clone();
self.counter.set(self.counter.get() + (1 << 64));
clone
}
pub(crate) fn long_jump_impl(&self) -> Self {
let clone = self.clone();
self.counter.set(self.counter.get() + (1 << 96));
clone
}
#[target_feature(enable = "sse2", enable = "aes")]
pub(crate) fn from_seed_impl(key: [u8; 32], counter: [u8; 16]) -> Self {
let counter = u128::from_le_bytes(counter);
let round_keys: [__m128i; 15] = aes256_key_expansion(key);
Self {
counter: Cell::new(counter),
round_keys: Cell::new(round_keys),
}
}
#[target_feature(enable = "sse2", enable = "aes")]
pub(crate) fn seed_impl(&self, key: [u8; 32], counter: [u8; 16]) {
let counter = u128::from_le_bytes(counter);
let round_keys: [__m128i; 15] = aes256_key_expansion(key);
self.counter.set(counter);
self.round_keys.set(round_keys)
}
pub(crate) fn is_hardware_accelerated_impl(&self) -> bool {
true
}
pub(crate) fn counter_impl(&self) -> u128 {
self.counter.get()
}
#[target_feature(enable = "sse2", enable = "aes")]
pub(crate) fn next_impl(&self) -> u128 {
let counter = self.counter.get();
self.counter.set(counter.wrapping_add(1));
let rks = self.round_keys.as_array_of_cells();
let counter = unsafe { _mm_loadu_si128(counter.to_le_bytes().as_ptr().cast()) };
let mut state = _mm_xor_si128(counter, rks[0].get());
state = _mm_aesenc_si128(state, rks[1].get());
state = _mm_aesenc_si128(state, rks[2].get());
state = _mm_aesenc_si128(state, rks[3].get());
state = _mm_aesenc_si128(state, rks[4].get());
state = _mm_aesenc_si128(state, rks[5].get());
state = _mm_aesenc_si128(state, rks[6].get());
state = _mm_aesenc_si128(state, rks[7].get());
state = _mm_aesenc_si128(state, rks[8].get());
state = _mm_aesenc_si128(state, rks[9].get());
state = _mm_aesenc_si128(state, rks[10].get());
state = _mm_aesenc_si128(state, rks[11].get());
state = _mm_aesenc_si128(state, rks[12].get());
state = _mm_aesenc_si128(state, rks[13].get());
state = _mm_aesenclast_si128(state, rks[14].get());
u128::from_le_bytes(unsafe { *(&state as *const __m128i as *const _) })
}
}
#[target_feature(enable = "sse2", enable = "aes")]
pub fn aes128_key_expansion(key: [u8; AES128_KEY_SIZE]) -> [__m128i; AES128_KEY_COUNT] {
#[target_feature(enable = "sse2", enable = "aes")]
fn generate_round_key<const RCON: i32, const ROUND: usize>(
expanded_keys: &mut [__m128i; AES128_KEY_COUNT],
) {
let prev_key = expanded_keys[ROUND - 1];
let mut temp = _mm_aeskeygenassist_si128::<RCON>(prev_key);
temp = _mm_shuffle_epi32::<0xFF>(temp);
let mut key = _mm_xor_si128(prev_key, _mm_slli_si128::<0x4>(prev_key));
key = _mm_xor_si128(key, _mm_slli_si128::<0x4>(key));
key = _mm_xor_si128(key, _mm_slli_si128::<0x4>(key));
expanded_keys[ROUND] = _mm_xor_si128(key, temp);
}
let mut expanded_keys: [__m128i; AES128_KEY_COUNT] = unsafe { core::mem::zeroed() };
expanded_keys[0] = unsafe { _mm_loadu_si128(key.as_ptr().cast()) };
generate_round_key::<0x01, 1>(&mut expanded_keys);
generate_round_key::<0x02, 2>(&mut expanded_keys);
generate_round_key::<0x04, 3>(&mut expanded_keys);
generate_round_key::<0x08, 4>(&mut expanded_keys);
generate_round_key::<0x10, 5>(&mut expanded_keys);
generate_round_key::<0x20, 6>(&mut expanded_keys);
generate_round_key::<0x40, 7>(&mut expanded_keys);
generate_round_key::<0x80, 8>(&mut expanded_keys);
generate_round_key::<0x1B, 9>(&mut expanded_keys);
generate_round_key::<0x36, 10>(&mut expanded_keys);
expanded_keys
}
#[target_feature(enable = "sse2", enable = "aes")]
pub fn aes256_key_expansion(key: [u8; AES256_KEY_SIZE]) -> [__m128i; AES256_KEY_COUNT] {
#[target_feature(enable = "sse2", enable = "aes")]
fn generate_round_keys<const RCON: i32, const RNUM: usize>(
expanded_keys: &mut [__m128i; AES256_KEY_COUNT],
) {
let prev_key_0 = expanded_keys[RNUM * 2];
let prev_key_1 = expanded_keys[(RNUM * 2) + 1];
let mut temp = _mm_aeskeygenassist_si128::<RCON>(prev_key_1);
temp = _mm_shuffle_epi32::<0xFF>(temp);
let mut key = _mm_xor_si128(prev_key_0, _mm_slli_si128::<0x4>(prev_key_0));
key = _mm_xor_si128(key, _mm_slli_si128::<0x4>(key));
key = _mm_xor_si128(key, _mm_slli_si128::<0x4>(key));
key = _mm_xor_si128(temp, key);
expanded_keys[(RNUM * 2) + 2] = key;
if RNUM < 6 {
let mut temp = _mm_aeskeygenassist_si128::<0x00>(key);
temp = _mm_shuffle_epi32::<0xAA>(temp);
let mut key = _mm_xor_si128(prev_key_1, _mm_slli_si128::<4>(prev_key_1));
key = _mm_xor_si128(key, _mm_slli_si128::<0x4>(key));
key = _mm_xor_si128(key, _mm_slli_si128::<0x4>(key));
key = _mm_xor_si128(temp, key);
expanded_keys[(RNUM * 2) + 3] = key;
}
}
let mut expanded_keys: [__m128i; AES256_KEY_COUNT] = unsafe { core::mem::zeroed() };
expanded_keys[0] = unsafe { _mm_loadu_si128(key.as_ptr().cast()) };
expanded_keys[1] = unsafe { _mm_loadu_si128(key[16..].as_ptr().cast()) };
generate_round_keys::<0x01, 0>(&mut expanded_keys);
generate_round_keys::<0x02, 1>(&mut expanded_keys);
generate_round_keys::<0x04, 2>(&mut expanded_keys);
generate_round_keys::<0x08, 3>(&mut expanded_keys);
generate_round_keys::<0x10, 4>(&mut expanded_keys);
generate_round_keys::<0x20, 5>(&mut expanded_keys);
generate_round_keys::<0x40, 6>(&mut expanded_keys);
expanded_keys
}
#[cfg(all(
test,
target_arch = "x86_64",
target_feature = "sse2",
target_feature = "aes",
not(feature = "verification")
))]
mod tests {
use super::*;
use crate::constants::AES_BLOCK_SIZE;
use crate::tests::{aes128_key_expansion_test, aes256_key_expansion_test};
#[test]
fn test_aes128_key_expansion() {
aes128_key_expansion_test(|key| {
let expanded = unsafe { aes128_key_expansion(key) };
let expanded: [[u8; AES_BLOCK_SIZE]; AES128_KEY_COUNT] = unsafe {
core::mem::transmute::<
[__m128i; AES128_KEY_COUNT],
[[u8; AES_BLOCK_SIZE]; AES128_KEY_COUNT],
>(expanded)
};
expanded
});
}
#[test]
fn test_aes256_key_expansion() {
aes256_key_expansion_test(|key| {
let expanded = unsafe { aes256_key_expansion(key) };
let expanded: [[u8; AES_BLOCK_SIZE]; AES256_KEY_COUNT] = unsafe {
core::mem::transmute::<
[__m128i; AES256_KEY_COUNT],
[[u8; AES_BLOCK_SIZE]; AES256_KEY_COUNT],
>(expanded)
};
expanded
});
}
}