#![cfg(target_arch = "aarch64")]
use core::arch::aarch64::{
veorq_u8, vld1q_u8, vmull_p64, vreinterpretq_u8_p128, vreinterpretq_u64_u8, vst1q_u8,
};
#[inline]
const fn reverse_byte(b: u8) -> u8 {
let b = ((b & 0xF0) >> 4) | ((b & 0x0F) << 4);
let b = ((b & 0xCC) >> 2) | ((b & 0x33) << 2);
((b & 0xAA) >> 1) | ((b & 0x55) << 1)
}
#[inline]
const fn natural_bytes(b: &[u8; 16]) -> [u8; 16] {
let mut buf = [0u8; 16];
let mut i = 0;
while i < 16 {
buf[i] = reverse_byte(b[i]);
i += 1;
}
buf
}
#[target_feature(enable = "neon,aes")]
#[allow(unsafe_op_in_unsafe_fn)]
pub unsafe fn ghash_mul_pmull(h: &[u8; 16], x: &[u8; 16]) -> [u8; 16] {
let h_n = natural_bytes(h);
let x_n = natural_bytes(x);
let h_vec_u8 = vld1q_u8(h_n.as_ptr());
let x_vec_u8 = vld1q_u8(x_n.as_ptr());
let h_u64 = vreinterpretq_u64_u8(h_vec_u8);
let x_u64 = vreinterpretq_u64_u8(x_vec_u8);
let h_lo: u64 = core::arch::aarch64::vgetq_lane_u64(h_u64, 0);
let h_hi: u64 = core::arch::aarch64::vgetq_lane_u64(h_u64, 1);
let x_lo: u64 = core::arch::aarch64::vgetq_lane_u64(x_u64, 0);
let x_hi: u64 = core::arch::aarch64::vgetq_lane_u64(x_u64, 1);
let t00 = vreinterpretq_u8_p128(vmull_p64(h_lo, x_lo));
let t01 = vreinterpretq_u8_p128(vmull_p64(h_lo, x_hi));
let t10 = vreinterpretq_u8_p128(vmull_p64(h_hi, x_lo));
let t11 = vreinterpretq_u8_p128(vmull_p64(h_hi, x_hi));
let middle = veorq_u8(t01, t10);
let mut p_low_bytes = [0u8; 16];
let mut p_high_bytes = [0u8; 16];
let mut middle_bytes = [0u8; 16];
vst1q_u8(p_low_bytes.as_mut_ptr(), t00);
vst1q_u8(p_high_bytes.as_mut_ptr(), t11);
vst1q_u8(middle_bytes.as_mut_ptr(), middle);
let mut scratch = [0u8; 32];
scratch[..16].copy_from_slice(&p_low_bytes);
for i in 0..16 {
scratch[8 + i] ^= middle_bytes[i];
}
for i in 0..16 {
scratch[16 + i] ^= p_high_bytes[i];
}
let mut low = u128::from_le_bytes(scratch[..16].try_into().unwrap_or([0; 16]));
let mut high = u128::from_le_bytes(scratch[16..].try_into().unwrap_or([0; 16]));
let mut idx: i32 = 127;
while idx >= 0 {
#[allow(clippy::cast_sign_loss)]
let i = idx as u32;
let bit = (high >> i) & 1;
let mask = 0u128.wrapping_sub(bit);
let positions = [i, i + 1, i + 2, i + 7];
for &p in &positions {
if p < 128 {
low ^= (1u128 << p) & mask;
} else {
high ^= (1u128 << (p - 128)) & mask;
}
}
high &= !(1u128 << i);
idx -= 1;
}
let _ = high;
let out_natural = low.to_le_bytes();
let mut out = [0u8; 16];
let mut i = 0;
while i < 16 {
out[i] = reverse_byte(out_natural[i]);
i += 1;
}
out
}