#![allow(unsafe_op_in_unsafe_fn)]
use crate::Block;
use cipher::{
array::{Array, ArraySize},
inout::InOut,
};
use core::{arch::aarch64::*, mem};
#[target_feature(enable = "aes")]
pub(super) unsafe fn encrypt<const KEYS: usize>(
keys: &[uint8x16_t; KEYS],
block: InOut<'_, '_, Block>,
) {
assert!(KEYS == 11 || KEYS == 13 || KEYS == 15);
let (in_ptr, out_ptr) = block.into_raw();
let mut block = vld1q_u8(in_ptr.cast());
for &key in &keys[..KEYS - 2] {
block = vaeseq_u8(block, key);
block = vaesmcq_u8(block);
}
block = vaeseq_u8(block, keys[KEYS - 2]);
block = veorq_u8(block, keys[KEYS - 1]);
vst1q_u8(out_ptr.cast(), block);
}
#[target_feature(enable = "aes")]
pub(super) unsafe fn decrypt<const KEYS: usize>(
keys: &[uint8x16_t; KEYS],
block: InOut<'_, '_, Block>,
) {
assert!(KEYS == 11 || KEYS == 13 || KEYS == 15);
let (in_ptr, out_ptr) = block.into_raw();
let mut block = vld1q_u8(in_ptr.cast());
for &key in &keys[..KEYS - 2] {
block = vaesdq_u8(block, key);
block = vaesimcq_u8(block);
}
block = vaesdq_u8(block, keys[KEYS - 2]);
block = veorq_u8(block, keys[KEYS - 1]);
vst1q_u8(out_ptr.cast(), block);
}
#[target_feature(enable = "aes")]
pub(super) unsafe fn encrypt_par<const KEYS: usize, ParBlocks: ArraySize>(
keys: &[uint8x16_t; KEYS],
blocks: InOut<'_, '_, Array<Block, ParBlocks>>,
) {
#[inline(always)]
unsafe fn par_round<ParBlocks: ArraySize>(
key: uint8x16_t,
blocks: &mut Array<uint8x16_t, ParBlocks>,
) {
for block in blocks {
*block = vaesmcq_u8(vaeseq_u8(*block, key));
}
}
assert!(KEYS == 11 || KEYS == 13 || KEYS == 15);
let (in_ptr, out_ptr) = blocks.into_raw();
let in_ptr: *const Block = in_ptr.cast();
let out_ptr: *mut Block = out_ptr.cast();
let mut blocks: Array<uint8x16_t, ParBlocks> = mem::zeroed();
for i in 0..ParBlocks::USIZE {
blocks[i] = vld1q_u8(in_ptr.add(i).cast());
}
par_round(keys[0], &mut blocks);
par_round(keys[1], &mut blocks);
par_round(keys[2], &mut blocks);
par_round(keys[3], &mut blocks);
par_round(keys[4], &mut blocks);
par_round(keys[5], &mut blocks);
par_round(keys[6], &mut blocks);
par_round(keys[7], &mut blocks);
par_round(keys[8], &mut blocks);
if KEYS >= 13 {
par_round(keys[9], &mut blocks);
par_round(keys[10], &mut blocks);
}
if KEYS == 15 {
par_round(keys[11], &mut blocks);
par_round(keys[12], &mut blocks);
}
for i in 0..ParBlocks::USIZE {
blocks[i] = vaeseq_u8(blocks[i], keys[KEYS - 2]);
blocks[i] = veorq_u8(blocks[i], keys[KEYS - 1]);
vst1q_u8(out_ptr.add(i).cast(), blocks[i]);
}
}
#[target_feature(enable = "aes")]
pub(super) unsafe fn decrypt_par<const KEYS: usize, ParBlocks: ArraySize>(
keys: &[uint8x16_t; KEYS],
blocks: InOut<'_, '_, Array<Block, ParBlocks>>,
) {
#[inline(always)]
unsafe fn par_round<ParBlocks: ArraySize>(
key: uint8x16_t,
blocks: &mut Array<uint8x16_t, ParBlocks>,
) {
for block in blocks {
*block = vaesimcq_u8(vaesdq_u8(*block, key));
}
}
assert!(KEYS == 11 || KEYS == 13 || KEYS == 15);
let (in_ptr, out_ptr) = blocks.into_raw();
let in_ptr: *const Block = in_ptr.cast();
let out_ptr: *mut Block = out_ptr.cast();
let mut blocks: Array<uint8x16_t, ParBlocks> = mem::zeroed();
for i in 0..ParBlocks::USIZE {
blocks[i] = vld1q_u8(in_ptr.add(i).cast());
}
par_round(keys[0], &mut blocks);
par_round(keys[1], &mut blocks);
par_round(keys[2], &mut blocks);
par_round(keys[3], &mut blocks);
par_round(keys[4], &mut blocks);
par_round(keys[5], &mut blocks);
par_round(keys[6], &mut blocks);
par_round(keys[7], &mut blocks);
par_round(keys[8], &mut blocks);
if KEYS >= 13 {
par_round(keys[9], &mut blocks);
par_round(keys[10], &mut blocks);
}
if KEYS == 15 {
par_round(keys[11], &mut blocks);
par_round(keys[12], &mut blocks);
}
for i in 0..ParBlocks::USIZE {
blocks[i] = vaesdq_u8(blocks[i], keys[KEYS - 2]);
blocks[i] = veorq_u8(blocks[i], keys[KEYS - 1]);
vst1q_u8(out_ptr.add(i) as *mut u8, blocks[i]);
}
}