#![cfg(all(
not(feature = "soft"),
target_arch = "aarch64",
target_feature = "neon",
))]
use core::{
arch::aarch64::{
uint8x16_t, uint8x16x4_t, vdupq_n_u8, veorq_u8, vextq_u8, vgetq_lane_u64, vld1q_u8,
vld1q_u8_x4, vmull_p64, vreinterpretq_u64_u8, vreinterpretq_u8_p128, vrev64q_u8, vst1q_u8,
},
array,
};
#[cfg(feature = "zeroize")]
use zeroize::Zeroize;
use crate::{BLOCK_SIZE, KEY_SIZE};
cpufeatures::new!(have_aes, "aes");
#[derive(Copy, Clone, Debug)]
pub(super) struct Token {
token: have_aes::InitToken,
}
impl Token {
#[inline]
pub fn new() -> (Self, bool) {
let (token, supported) = have_aes::init_get();
(Self { token }, supported)
}
#[inline]
pub fn supported(&self) -> bool {
self.token.get()
}
}
pub(super) type Big<const GHASH: bool> = Backend<GHASH, 8>;
pub(super) type Small<const GHASH: bool> = Backend<GHASH, 1>;
#[derive(Clone, Debug)]
pub(super) struct Backend<const GHASH: bool, const N: usize> {
y: uint8x16_t,
h: [uint8x16_t; N],
}
impl<const GHASH: bool, const N: usize> Backend<GHASH, N> {
#[inline]
#[target_feature(enable = "neon,aes")]
#[allow(clippy::undocumented_unsafe_blocks)]
pub unsafe fn new(key: &[u8; KEY_SIZE]) -> Self {
const { assert!(N > 0) }
let h = if GHASH {
let key = super::mulx(u128::from_be_bytes(*key)).to_le_bytes();
unsafe { vld1q_u8(key.as_ptr()) }
} else {
unsafe { vld1q_u8(key.as_ptr()) }
};
let h = {
let mut prev = h;
let mut pow: [uint8x16_t; N] = array::from_fn(|_| unsafe { vdupq_n_u8(0) });
for (i, v) in pow.iter_mut().rev().enumerate() {
*v = h;
if i > 0 {
*v = unsafe { polymul(*v, prev) };
}
prev = *v;
}
pow
};
Self {
y: unsafe { vdupq_n_u8(0) },
h,
}
}
#[inline]
#[target_feature(enable = "neon,aes")]
#[allow(
clippy::arithmetic_side_effects,
clippy::indexing_slicing,
reason = "N - 1 is constant and N > 0"
)]
pub unsafe fn update_block(&mut self, block: &[u8; BLOCK_SIZE]) {
const { assert!(N > 0) }
unsafe {
let mut x = vld1q_u8(block.as_ptr());
if GHASH {
x = swap_bytes(x);
}
self.y = polymul(veorq_u8(self.y, x), self.h[N - 1]);
}
}
#[inline]
#[target_feature(enable = "neon,aes")]
#[allow(clippy::undocumented_unsafe_blocks)]
pub unsafe fn update_blocks(&mut self, mut blocks: &[[u8; BLOCK_SIZE]]) {
const { assert!(N > 0) }
if self.h.len() == 8 {
let (head, tail) = super::as_chunks::<_, N>(blocks);
for chunk in head {
let (lhs, rhs) = chunk.split_at(chunk.len() / 2);
let uint8x16x4_t(m0, m1, m2, m3) = unsafe { vld1q_u8_x4(lhs.as_ptr().cast()) };
let uint8x16x4_t(m4, m5, m6, m7) = unsafe { vld1q_u8_x4(rhs.as_ptr().cast()) };
let mut h = unsafe { vdupq_n_u8(0) };
let mut m = unsafe { vdupq_n_u8(0) };
let mut l = unsafe { vdupq_n_u8(0) };
macro_rules! karatsuba_xor {
($m:expr, $idx:expr) => {
unsafe {
let mut x = if GHASH { swap_bytes($m) } else { $m };
if $idx == 0 {
x = veorq_u8(x, self.y);
}
let y = self.h[$idx];
let (hh, mm, ll) = karatsuba1(x, y);
h = veorq_u8(h, hh);
m = veorq_u8(m, mm);
l = veorq_u8(l, ll);
}
};
}
karatsuba_xor!(m7, 7);
karatsuba_xor!(m6, 6);
karatsuba_xor!(m5, 5);
karatsuba_xor!(m4, 4);
karatsuba_xor!(m3, 3);
karatsuba_xor!(m2, 2);
karatsuba_xor!(m1, 1);
karatsuba_xor!(m0, 0);
let (h, l) = unsafe { karatsuba2(h, m, l) };
self.y = unsafe { mont_reduce(h, l) };
}
blocks = tail;
}
for block in blocks {
unsafe { self.update_block(block) }
}
}
#[inline]
#[target_feature(enable = "neon")]
pub unsafe fn tag(&self) -> [u8; 16] {
let y = if GHASH {
unsafe { swap_bytes(self.y) }
} else {
self.y
};
let mut tag = [0u8; 16];
unsafe { vst1q_u8(tag.as_mut_ptr(), y) }
tag
}
#[inline]
#[cfg(feature = "experimental")]
pub fn export(&self) -> FieldElement {
FieldElement(self.y)
}
#[inline]
#[cfg(feature = "experimental")]
pub fn reset(&mut self, y: FieldElement) {
self.y = y.0;
}
}
#[inline]
#[target_feature(enable = "neon")]
unsafe fn swap_bytes(x: uint8x16_t) -> uint8x16_t {
unsafe {
let x = vrev64q_u8(x);
vextq_u8(x, x, 8)
}
}
#[derive(Copy, Clone, Debug)]
#[repr(transparent)]
pub struct FieldElement(uint8x16_t);
impl FieldElement {
#[inline]
pub fn from_le_bytes(data: &[u8; BLOCK_SIZE]) -> Self {
let fe = unsafe { vld1q_u8(data.as_ptr()) };
Self(fe)
}
#[inline]
pub fn to_le_bytes(self) -> [u8; BLOCK_SIZE] {
let mut out = [0u8; BLOCK_SIZE];
unsafe { vst1q_u8(out.as_mut_ptr(), self.0) }
out
}
}
impl Default for FieldElement {
#[inline]
fn default() -> Self {
let fe = unsafe { vdupq_n_u8(0) };
Self(fe)
}
}
#[cfg(feature = "zeroize")]
impl Zeroize for FieldElement {
fn zeroize(&mut self) {
self.0.zeroize();
}
}
#[cfg(test)]
impl Eq for FieldElement {}
#[cfg(test)]
impl PartialEq for FieldElement {
fn eq(&self, other: &Self) -> bool {
self.to_le_bytes() == other.to_le_bytes()
}
}
#[inline]
#[target_feature(enable = "neon,aes")]
#[allow(clippy::undocumented_unsafe_blocks, reason = "Too many unsafe blocks.")]
unsafe fn polymul(x: uint8x16_t, y: uint8x16_t) -> uint8x16_t {
let (h, m, l) = unsafe { karatsuba1(x, y) };
let (h, l) = unsafe { karatsuba2(h, m, l) };
unsafe {
mont_reduce(h, l) }
}
#[inline]
#[target_feature(enable = "neon,aes")]
#[allow(clippy::undocumented_unsafe_blocks, reason = "Too many unsafe blocks.")]
unsafe fn karatsuba1(x: uint8x16_t, y: uint8x16_t) -> (uint8x16_t, uint8x16_t, uint8x16_t) {
let m = unsafe {
pmull(
veorq_u8(x, vextq_u8(x, x, 8)), veorq_u8(y, vextq_u8(y, y, 8)), )
};
let h = unsafe { pmull2(x, y) }; let l = unsafe { pmull(x, y) }; (h, m, l)
}
#[inline]
#[target_feature(enable = "neon")]
#[allow(clippy::undocumented_unsafe_blocks, reason = "Too many unsafe blocks.")]
unsafe fn karatsuba2(h: uint8x16_t, m: uint8x16_t, l: uint8x16_t) -> (uint8x16_t, uint8x16_t) {
let t = unsafe {
let t0 = veorq_u8(m, vextq_u8(l, h, 8));
let t1 = veorq_u8(h, l);
veorq_u8(t0, t1)
};
let x01 = unsafe {
vextq_u8(
vextq_u8(l, l, 8), t,
8,
)
};
let x23 = unsafe {
vextq_u8(
t,
vextq_u8(h, h, 8), 8,
)
};
(x23, x01)
}
#[inline]
#[target_feature(enable = "neon,aes")]
#[allow(clippy::undocumented_unsafe_blocks, reason = "Too many unsafe blocks.")]
unsafe fn mont_reduce(x23: uint8x16_t, x01: uint8x16_t) -> uint8x16_t {
let poly = unsafe {
vreinterpretq_u8_p128(1 << 127 | 1 << 126 | 1 << 121 | 1 << 63 | 1 << 62 | 1 << 57)
};
let a = unsafe { pmull(x01, poly) };
let b = unsafe { veorq_u8(x01, vextq_u8(a, a, 8)) };
let c = unsafe { pmull2(b, poly) };
unsafe { veorq_u8(x23, veorq_u8(c, b)) }
}
#[inline]
#[target_feature(enable = "neon,aes")]
unsafe fn pmull(a: uint8x16_t, b: uint8x16_t) -> uint8x16_t {
unsafe {
let p = vmull_p64(
vgetq_lane_u64(vreinterpretq_u64_u8(a), 0),
vgetq_lane_u64(vreinterpretq_u64_u8(b), 0),
);
vreinterpretq_u8_p128(p)
}
}
#[inline]
#[target_feature(enable = "neon,aes")]
unsafe fn pmull2(a: uint8x16_t, b: uint8x16_t) -> uint8x16_t {
unsafe {
let p = vmull_p64(
vgetq_lane_u64(vreinterpretq_u64_u8(a), 1),
vgetq_lane_u64(vreinterpretq_u64_u8(b), 1),
);
vreinterpretq_u8_p128(p)
}
}
#[cfg(test)]
#[allow(clippy::undocumented_unsafe_blocks)]
mod tests {
use core::ops::BitXor;
use hex_literal::hex;
use super::*;
macro_rules! fe {
($s:expr) => {{
FieldElement::from_le_bytes(&hex!($s))
}};
}
impl FieldElement {
#[inline]
#[must_use]
#[target_feature(enable = "neon,aes")]
unsafe fn polymul(self, rhs: Self) -> Self {
let fe = unsafe { polymul(self.0, rhs.0) };
Self(fe)
}
}
impl BitXor for FieldElement {
type Output = Self;
fn bitxor(self, rhs: Self) -> Self::Output {
let fe = unsafe { veorq_u8(self.0, rhs.0) };
Self(fe)
}
}
#[test]
fn test_fe_ops() {
let a = fe!("66e94bd4ef8a2c3b884cfa59ca342b2e");
let b = fe!("ff000000000000000000000000000000");
let want = fe!("99e94bd4ef8a2c3b884cfa59ca342b2e");
assert_eq!(a ^ b, want);
assert_eq!(b ^ a, want);
if have_aes::get() {
let want = fe!("ebe563401e7e91ea3ad6426b8140c394");
assert_eq!(unsafe { a.polymul(b) }, want);
assert_eq!(unsafe { b.polymul(a) }, want);
}
}
}