use core::arch::aarch64::*;
use super::Vector;
use crate::xxhash3::{primes::PRIME32_1, SliceBackport as _};
#[derive(Copy, Clone)]
pub struct Impl(());
impl Impl {
#[inline]
#[cfg(feature = "std")]
pub unsafe fn new_unchecked() -> Self {
Self(())
}
}
impl Vector for Impl {
#[inline]
fn round_scramble(&self, acc: &mut [u64; 8], secret_end: &[u8; 64]) {
unsafe { round_scramble_neon(acc, secret_end) }
}
#[inline]
fn accumulate(&self, acc: &mut [u64; 8], stripe: &[u8; 64], secret: &[u8; 64]) {
unsafe { accumulate_neon(acc, stripe, secret) }
}
}
#[target_feature(enable = "neon")]
#[inline]
unsafe fn round_scramble_neon(acc: &mut [u64; 8], secret_end: &[u8; 64]) {
let secret_base = secret_end.as_ptr().cast::<u64>();
let (acc, _) = acc.bp_as_chunks_mut::<2>();
for (i, acc) in acc.iter_mut().enumerate() {
unsafe {
let mut accv = vld1q_u64(acc.as_ptr());
let secret = vld1q_u64(secret_base.add(i * 2));
let shifted = vshrq_n_u64::<47>(accv);
accv = veorq_u64(accv, shifted);
accv = veorq_u64(accv, secret);
accv = xx_vmulq_u32_u64(accv, PRIME32_1 as u32);
vst1q_u64(acc.as_mut_ptr(), accv);
}
}
}
#[target_feature(enable = "neon")]
#[inline]
unsafe fn accumulate_neon(acc: &mut [u64; 8], stripe: &[u8; 64], secret: &[u8; 64]) {
let (acc2, _) = acc.bp_as_chunks_mut::<4>();
for (i, acc) in acc2.iter_mut().enumerate() {
unsafe {
let mut accv_0 = vld1q_u64(acc.as_ptr().cast::<u64>());
let mut accv_1 = vld1q_u64(acc.as_ptr().cast::<u64>().add(2));
let stripe_0 = vld1q_u64(stripe.as_ptr().cast::<u64>().add(i * 4));
let stripe_1 = vld1q_u64(stripe.as_ptr().cast::<u64>().add(i * 4 + 2));
let secret_0 = vld1q_u64(secret.as_ptr().cast::<u64>().add(i * 4));
let secret_1 = vld1q_u64(secret.as_ptr().cast::<u64>().add(i * 4 + 2));
let stripe_rot_0 = vextq_u64::<1>(stripe_0, stripe_0);
let stripe_rot_1 = vextq_u64::<1>(stripe_1, stripe_1);
let value_0 = veorq_u64(stripe_0, secret_0);
let value_1 = veorq_u64(stripe_1, secret_1);
let parts_0 = vreinterpretq_u32_u64(value_0);
let parts_1 = vreinterpretq_u32_u64(value_1);
let hi = vuzp1q_u32(parts_0, parts_1);
let lo = vuzp2q_u32(parts_0, parts_1);
let sum_0 = vmlal_u32(stripe_rot_0, vget_low_u32(hi), vget_low_u32(lo));
let sum_1 = vmlal_high_u32(stripe_rot_1, hi, lo);
reordering_barrier(sum_0);
reordering_barrier(sum_1);
accv_0 = vaddq_u64(accv_0, sum_0);
accv_1 = vaddq_u64(accv_1, sum_1);
vst1q_u64(acc.as_mut_ptr().cast::<u64>(), accv_0);
vst1q_u64(acc.as_mut_ptr().cast::<u64>().add(2), accv_1);
};
}
}
#[inline]
pub fn xx_vmulq_u32_u64(input: uint64x2_t, og_factor: u32) -> uint64x2_t {
unsafe {
let input_as_u32 = vreinterpretq_u32_u64(input);
let factor = vmov_n_u32(og_factor);
let factor_striped = vmovq_n_u64(u64::from(og_factor) << 32);
let factor_striped = vreinterpretq_u32_u64(factor_striped);
let high_shifted_as_32 = vmulq_u32(input_as_u32, factor_striped);
let high_shifted = vreinterpretq_u64_u32(high_shifted_as_32);
let input_lo = vmovn_u64(input);
vmlal_u32(high_shifted, input_lo, factor)
}
}
#[inline]
#[target_feature(enable = "neon")]
unsafe fn reordering_barrier(r: uint64x2_t) {
unsafe {
core::arch::asm!(
"/* {r:v} */",
r = in(vreg) r,
options(nomem, nostack),
)
}
}