use crate::generators::aes_ctr::{
AesBlockCipher, AesKey, AES_CALLS_PER_BATCH, BYTES_PER_AES_CALL, BYTES_PER_BATCH,
};
use core::arch::aarch64::{
uint8x16_t, vaeseq_u8, vaesmcq_u8, vdupq_n_u32, vdupq_n_u8, veorq_u8, vgetq_lane_u32,
vreinterpretq_u32_u8, vreinterpretq_u8_u32,
};
use std::arch::is_aarch64_feature_detected;
use std::mem::transmute;
const RCONS: [u32; 10] = [0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36];
const NUM_WORDS_IN_KEY: usize = 4;
const NUM_ROUNDS: usize = 10;
const NUM_ROUND_KEYS: usize = NUM_ROUNDS + 1;
#[derive(Clone)]
pub struct ArmAesBlockCipher {
round_keys: [uint8x16_t; NUM_ROUND_KEYS],
}
impl AesBlockCipher for ArmAesBlockCipher {
fn new(key: AesKey) -> ArmAesBlockCipher {
let aes_detected = is_aarch64_feature_detected!("aes");
let neon_detected = is_aarch64_feature_detected!("neon");
if !(aes_detected && neon_detected) {
panic!(
"The ArmAesBlockCipher requires both aes and neon aarch64 CPU features.\n\
aes feature available: {aes_detected}\nneon feature available: {neon_detected}\n\
Please consider enabling the SoftwareRandomGenerator with the `software-prng` feature",
)
}
let round_keys = unsafe { generate_round_keys(key) };
ArmAesBlockCipher { round_keys }
}
fn generate_batch(&mut self, data: [u128; AES_CALLS_PER_BATCH]) -> [u8; BYTES_PER_BATCH] {
#[target_feature(enable = "aes,neon")]
unsafe fn implementation(
this: &ArmAesBlockCipher,
data: [u128; AES_CALLS_PER_BATCH],
) -> [u8; BYTES_PER_BATCH] {
let mut output = [0u8; BYTES_PER_BATCH];
for (input, out) in data.iter().copied().zip(output.chunks_exact_mut(16)) {
let encrypted = encrypt(input, &this.round_keys);
out.copy_from_slice(&encrypted.to_ne_bytes());
}
output
}
unsafe { implementation(self, data) }
}
fn generate_next(&mut self, data: u128) -> [u8; BYTES_PER_AES_CALL] {
unsafe { encrypt(data, &self.round_keys) }.to_ne_bytes()
}
}
#[inline(always)]
unsafe fn sub_word(word: u32) -> u32 {
let data = vreinterpretq_u8_u32(vdupq_n_u32(word));
let zero_key = vdupq_n_u8(0u8);
let temp = vaeseq_u8(data, zero_key);
vgetq_lane_u32::<0>(vreinterpretq_u32_u8(temp))
}
#[inline(always)]
fn uint8x16_t_to_u128(input: uint8x16_t) -> u128 {
unsafe { transmute(input) }
}
#[inline(always)]
fn u128_to_uint8x16_t(input: u128) -> uint8x16_t {
unsafe { transmute(input) }
}
#[target_feature(enable = "aes,neon")]
unsafe fn generate_round_keys(key: AesKey) -> [uint8x16_t; NUM_ROUND_KEYS] {
let mut round_keys: [uint8x16_t; NUM_ROUND_KEYS] = std::mem::zeroed();
round_keys[0] = u128_to_uint8x16_t(key.0);
let words = std::slice::from_raw_parts_mut(
round_keys.as_mut_ptr() as *mut u32,
NUM_ROUND_KEYS * NUM_WORDS_IN_KEY,
);
debug_assert_eq!(words.len(), 44);
for i in NUM_WORDS_IN_KEY..words.len() {
if (i % NUM_WORDS_IN_KEY) == 0 {
words[i] = words[i - NUM_WORDS_IN_KEY]
^ sub_word(words[i - 1]).rotate_right(8)
^ RCONS[(i / NUM_WORDS_IN_KEY) - 1];
} else {
words[i] = words[i - NUM_WORDS_IN_KEY] ^ words[i - 1];
}
}
round_keys
}
#[inline(always)]
unsafe fn encrypt(message: u128, keys: &[uint8x16_t; NUM_ROUND_KEYS]) -> u128 {
let mut data: uint8x16_t = u128_to_uint8x16_t(message);
for &key in keys.iter().take(NUM_ROUNDS - 1) {
data = vaesmcq_u8(vaeseq_u8(data, key));
}
data = vaeseq_u8(data, keys[NUM_ROUNDS - 1]);
data = veorq_u8(data, keys[NUM_ROUND_KEYS - 1]);
uint8x16_t_to_u128(data)
}
#[cfg(test)]
mod test {
use super::*;
const CIPHER_KEY: u128 = u128::from_be(0x000102030405060708090a0b0c0d0e0f);
const KEY_SCHEDULE: [u128; 11] = [
u128::from_be(0x000102030405060708090a0b0c0d0e0f),
u128::from_be(0xd6aa74fdd2af72fadaa678f1d6ab76fe),
u128::from_be(0xb692cf0b643dbdf1be9bc5006830b3fe),
u128::from_be(0xb6ff744ed2c2c9bf6c590cbf0469bf41),
u128::from_be(0x47f7f7bc95353e03f96c32bcfd058dfd),
u128::from_be(0x3caaa3e8a99f9deb50f3af57adf622aa),
u128::from_be(0x5e390f7df7a69296a7553dc10aa31f6b),
u128::from_be(0x14f9701ae35fe28c440adf4d4ea9c026),
u128::from_be(0x47438735a41c65b9e016baf4aebf7ad2),
u128::from_be(0x549932d1f08557681093ed9cbe2c974e),
u128::from_be(0x13111d7fe3944a17f307a78b4d2b30c5),
];
const PLAINTEXT: u128 = u128::from_be(0x00112233445566778899aabbccddeeff);
const CIPHERTEXT: u128 = u128::from_be(0x69c4e0d86a7b0430d8cdb78070b4c55a);
#[test]
fn test_generate_key_schedule() {
let key = AesKey(CIPHER_KEY);
let keys = unsafe { generate_round_keys(key) };
for (expected, actual) in KEY_SCHEDULE.iter().zip(keys.iter()) {
assert_eq!(*expected, uint8x16_t_to_u128(*actual));
}
}
#[test]
fn test_encrypt_message() {
let message = PLAINTEXT;
let key = AesKey(CIPHER_KEY);
let keys = unsafe { generate_round_keys(key) };
let ciphertext = unsafe { encrypt(message, &keys) };
assert_eq!(CIPHERTEXT, ciphertext);
}
}