use core::arch::aarch64::*;
#[derive(Clone, Copy)]
#[repr(C, align(16))]
pub(in crate::aead) struct CeRoundKeys {
rk: [uint8x16_t; 15],
}
impl CeRoundKeys {
pub(super) fn zeroize(&mut self) {
let bytes = unsafe { core::slice::from_raw_parts_mut(self.rk.as_mut_ptr().cast::<u8>(), 15usize.strict_mul(16)) };
crate::traits::ct::zeroize(bytes);
}
}
#[target_feature(enable = "neon")]
#[inline]
pub(super) unsafe fn from_portable_core(rk: &[u32; 60]) -> CeRoundKeys {
unsafe {
let mut keys = [vdupq_n_u8(0); 15];
let mut i = 0usize;
while i < 15 {
let base = i.strict_mul(4);
let mut bytes = [0u8; 16];
bytes[0..4].copy_from_slice(&rk[base].to_be_bytes());
bytes[4..8].copy_from_slice(&rk[base.strict_add(1)].to_be_bytes());
bytes[8..12].copy_from_slice(&rk[base.strict_add(2)].to_be_bytes());
bytes[12..16].copy_from_slice(&rk[base.strict_add(3)].to_be_bytes());
keys[i] = vld1q_u8(bytes.as_ptr());
i = i.strict_add(1);
}
CeRoundKeys { rk: keys }
}
}
#[target_feature(enable = "aes,neon")]
#[inline]
pub(super) unsafe fn expand_key_hw(key: &[u8; 32]) -> CeRoundKeys {
unsafe {
#[target_feature(enable = "aes,neon")]
#[inline]
unsafe fn sub_word_hw(w: u32) -> u32 {
let state = vreinterpretq_u8_u32(vdupq_n_u32(w));
let zero = vdupq_n_u8(0);
let result = vaeseq_u8(state, zero);
vgetq_lane_u32(vreinterpretq_u32_u8(result), 0)
}
let mut rk = [0u32; super::EXPANDED_KEY_WORDS];
let mut i = 0usize;
while i < 8 {
let base = i.strict_mul(4);
rk[i] = u32::from_be_bytes([
key[base],
key[base.strict_add(1)],
key[base.strict_add(2)],
key[base.strict_add(3)],
]);
i = i.strict_add(1);
}
i = 8;
while i < super::EXPANDED_KEY_WORDS {
let mut temp = rk[i.strict_sub(1)];
if i.strict_rem(8) == 0 {
temp = sub_word_hw(super::rot_word(temp)) ^ super::RCON[i.strict_div(8).strict_sub(1)];
} else if i.strict_rem(8) == 4 {
temp = sub_word_hw(temp);
}
rk[i] = rk[i.strict_sub(8)] ^ temp;
i = i.strict_add(1);
}
from_portable_core(&rk)
}
}
#[target_feature(enable = "aes,neon")]
pub(super) unsafe fn expand_key(key: &[u8; 32]) -> CeRoundKeys {
unsafe { expand_key_hw(key) }
}
#[target_feature(enable = "aes,neon")]
#[inline]
pub(super) unsafe fn encrypt_block_core(keys: &CeRoundKeys, block: &mut [u8; 16]) {
unsafe {
let k = &keys.rk;
let mut state = vld1q_u8(block.as_ptr());
state = vaesmcq_u8(vaeseq_u8(state, k[0]));
state = vaesmcq_u8(vaeseq_u8(state, k[1]));
state = vaesmcq_u8(vaeseq_u8(state, k[2]));
state = vaesmcq_u8(vaeseq_u8(state, k[3]));
state = vaesmcq_u8(vaeseq_u8(state, k[4]));
state = vaesmcq_u8(vaeseq_u8(state, k[5]));
state = vaesmcq_u8(vaeseq_u8(state, k[6]));
state = vaesmcq_u8(vaeseq_u8(state, k[7]));
state = vaesmcq_u8(vaeseq_u8(state, k[8]));
state = vaesmcq_u8(vaeseq_u8(state, k[9]));
state = vaesmcq_u8(vaeseq_u8(state, k[10]));
state = vaesmcq_u8(vaeseq_u8(state, k[11]));
state = vaesmcq_u8(vaeseq_u8(state, k[12]));
state = vaeseq_u8(state, k[13]);
state = veorq_u8(state, k[14]);
vst1q_u8(block.as_mut_ptr(), state);
}
}
#[target_feature(enable = "aes,neon")]
pub(super) unsafe fn encrypt_block(keys: &CeRoundKeys, block: &mut [u8; 16]) {
unsafe { encrypt_block_core(keys, block) }
}