#![allow(unsafe_op_in_unsafe_fn)]
use super::ExpandedKey;
use crate::{Block, ParBlocks, field_element::FieldElement};
use core::arch::aarch64::*;
const P1: u64 = 0xC200000000000000;
cpufeatures::new!(pmull, "aes");
pub(crate) use pmull::InitToken;
type ByteArray = [u8; 16];
impl FieldElement {
#[target_feature(enable = "neon", enable = "aes")]
#[inline]
unsafe fn from_uint64x2_t(reg: uint64x2_t) -> Self {
let mut out = ByteArray::default();
vst1q_u8(out.as_mut_ptr(), vreinterpretq_u8_u64(reg));
out.into()
}
#[target_feature(enable = "neon", enable = "aes")]
#[inline]
unsafe fn to_uint64x2_t(self) -> uint64x2_t {
load_bytes(&self.into())
}
}
#[target_feature(enable = "neon", enable = "aes")]
#[inline]
unsafe fn load_bytes(bytes: &ByteArray) -> uint64x2_t {
vreinterpretq_u64_u8(vld1q_u8(bytes.as_ptr()))
}
#[target_feature(enable = "neon", enable = "aes")]
#[inline]
pub(super) unsafe fn proc_block(key: &ExpandedKey, y: FieldElement, block: &Block) -> FieldElement {
let data = load_bytes(&block.0);
let y = veorq_u64(y.to_uint64x2_t(), data);
FieldElement::from_uint64x2_t(gf128_mul_rf(
y,
key.h1.to_uint64x2_t(),
key.d1.to_uint64x2_t(),
))
}
#[target_feature(enable = "neon", enable = "aes")]
#[inline]
pub(super) unsafe fn proc_par_blocks(
key: &ExpandedKey,
acc: FieldElement,
par_blocks: &ParBlocks,
) -> FieldElement {
let m0 = load_bytes(&par_blocks[0].0);
let m1 = load_bytes(&par_blocks[1].0);
let m2 = load_bytes(&par_blocks[2].0);
let m3 = load_bytes(&par_blocks[3].0);
let y0 = veorq_u64(acc.to_uint64x2_t(), m0);
let (r0, f0) = rf_mul_unreduced(y0, key.h4.to_uint64x2_t(), key.d4.to_uint64x2_t());
let (r1, f1) = rf_mul_unreduced(m1, key.h3.to_uint64x2_t(), key.d3.to_uint64x2_t());
let (r2, f2) = rf_mul_unreduced(m2, key.h2.to_uint64x2_t(), key.d2.to_uint64x2_t());
let (r3, f3) = rf_mul_unreduced(m3, key.h1.to_uint64x2_t(), key.d1.to_uint64x2_t());
let r = veorq_u64(veorq_u64(r0, r1), veorq_u64(r2, r3));
let f = veorq_u64(veorq_u64(f0, f1), veorq_u64(f2, f3));
FieldElement::from_uint64x2_t(reduce_rf(r, f))
}
#[target_feature(enable = "neon", enable = "aes")]
pub(super) unsafe fn expand_key(h: &[u8; 16]) -> ExpandedKey {
let h1 = load_bytes(h);
let d1 = compute_d(h1);
let h2 = gf128_mul_rf(h1, h1, d1);
let d2 = compute_d(h2);
let h3 = gf128_mul_rf(h2, h1, d1);
let d3 = compute_d(h3);
let h4 = gf128_mul_rf(h2, h2, d2);
let d4 = compute_d(h4);
ExpandedKey {
h1: FieldElement::from_uint64x2_t(h1),
d1: FieldElement::from_uint64x2_t(d1),
h2: FieldElement::from_uint64x2_t(h2),
d2: FieldElement::from_uint64x2_t(d2),
h3: FieldElement::from_uint64x2_t(h3),
d3: FieldElement::from_uint64x2_t(d3),
h4: FieldElement::from_uint64x2_t(h4),
d4: FieldElement::from_uint64x2_t(d4),
}
}
#[target_feature(enable = "neon", enable = "aes")]
#[inline]
unsafe fn compute_d(h: uint64x2_t) -> uint64x2_t {
let h_swap = vextq_u64(h, h, 1);
let h0 = vgetq_lane_u64(h, 0);
let t = vreinterpretq_u64_p128(vmull_p64(h0, P1));
veorq_u64(h_swap, t)
}
#[target_feature(enable = "neon", enable = "aes")]
#[inline]
unsafe fn gf128_mul_rf(m: uint64x2_t, h: uint64x2_t, d: uint64x2_t) -> uint64x2_t {
let (r, f) = rf_mul_unreduced(m, h, d);
reduce_rf(r, f)
}
#[target_feature(enable = "neon", enable = "aes")]
#[inline]
unsafe fn rf_mul_unreduced(
m: uint64x2_t,
h: uint64x2_t,
d: uint64x2_t,
) -> (uint64x2_t, uint64x2_t) {
let m0 = vgetq_lane_u64(m, 0);
let m1 = vgetq_lane_u64(m, 1);
let h0 = vgetq_lane_u64(h, 0);
let h1 = vgetq_lane_u64(h, 1);
let d0 = vgetq_lane_u64(d, 0);
let d1 = vgetq_lane_u64(d, 1);
let r0 = vmull_p64(m0, d1);
let r1 = vmull_p64(m1, h1);
let r = veorq_u64(vreinterpretq_u64_p128(r0), vreinterpretq_u64_p128(r1));
let f0 = vmull_p64(m0, d0);
let f1 = vmull_p64(m1, h0);
let f = veorq_u64(vreinterpretq_u64_p128(f0), vreinterpretq_u64_p128(f1));
(r, f)
}
#[target_feature(enable = "neon", enable = "aes")]
#[inline]
unsafe fn reduce_rf(r: uint64x2_t, f: uint64x2_t) -> uint64x2_t {
let f1 = vgetq_lane_u64(f, 1);
let f1_vec = vcombine_u64(vcreate_u64(f1), vcreate_u64(0));
let f0 = vgetq_lane_u64(f, 0);
let f0_shifted = vcombine_u64(vcreate_u64(0), vcreate_u64(f0));
let p1_f0 = vmull_p64(f0, P1);
let result = veorq_u64(r, f1_vec);
let result = veorq_u64(result, f0_shifted);
veorq_u64(result, vreinterpretq_u64_p128(p1_f0))
}