#![cfg(target_arch = "aarch64")]
use core::arch::aarch64::{
uint8x16_t, uint64x2_t, vaddq_u64, veorq_u64, vextq_u64, vld1q_u8, vld1q_u64, vmovn_u64, vmull_u32, vqtbl1q_u8,
vreinterpretq_u8_u64, vreinterpretq_u32_u64, vreinterpretq_u64_u8, vreinterpretq_u64_u32, vrev64q_u32, vshlq_n_u64,
vsriq_n_u64, vst1q_u64,
};
use super::BLOCK_WORDS;
#[target_feature(enable = "neon")]
pub(super) unsafe fn compress_neon(
dst: &mut [u64; BLOCK_WORDS],
x: &[u64; BLOCK_WORDS],
y: &[u64; BLOCK_WORDS],
xor_into: bool,
) {
let mut r = unsafe { [core::mem::zeroed::<uint64x2_t>(); 64] };
let mut q = r;
for i in 0..64 {
let (xv, yv) = unsafe { (vld1q_u64(x.as_ptr().add(2 * i)), vld1q_u64(y.as_ptr().add(2 * i))) };
r[i] = veorq_u64(xv, yv);
q[i] = r[i];
}
for row in 0..8 {
let base = row * 8;
let (a_lo, a_hi, b_lo, b_hi, c_lo, c_hi, d_lo, d_hi) = (
q[base],
q[base + 1],
q[base + 2],
q[base + 3],
q[base + 4],
q[base + 5],
q[base + 6],
q[base + 7],
);
let (a_lo, a_hi, b_lo, b_hi, c_lo, c_hi, d_lo, d_hi) =
unsafe { p_round_neon(a_lo, a_hi, b_lo, b_hi, c_lo, c_hi, d_lo, d_hi) };
q[base] = a_lo;
q[base + 1] = a_hi;
q[base + 2] = b_lo;
q[base + 3] = b_hi;
q[base + 4] = c_lo;
q[base + 5] = c_hi;
q[base + 6] = d_lo;
q[base + 7] = d_hi;
}
for col in 0..8 {
let (a_lo, a_hi, b_lo, b_hi, c_lo, c_hi, d_lo, d_hi) = (
q[col],
q[col + 8],
q[col + 16],
q[col + 24],
q[col + 32],
q[col + 40],
q[col + 48],
q[col + 56],
);
let (a_lo, a_hi, b_lo, b_hi, c_lo, c_hi, d_lo, d_hi) =
unsafe { p_round_neon(a_lo, a_hi, b_lo, b_hi, c_lo, c_hi, d_lo, d_hi) };
q[col] = a_lo;
q[col + 8] = a_hi;
q[col + 16] = b_lo;
q[col + 24] = b_hi;
q[col + 32] = c_lo;
q[col + 40] = c_hi;
q[col + 48] = d_lo;
q[col + 56] = d_hi;
}
for i in 0..64 {
let final_v = veorq_u64(q[i], r[i]);
unsafe {
if xor_into {
let cur = vld1q_u64(dst.as_ptr().add(2 * i));
vst1q_u64(dst.as_mut_ptr().add(2 * i), veorq_u64(cur, final_v));
} else {
vst1q_u64(dst.as_mut_ptr().add(2 * i), final_v);
}
}
}
}
#[inline(always)]
#[allow(clippy::too_many_arguments)]
unsafe fn p_round_neon(
mut a_lo: uint64x2_t,
mut a_hi: uint64x2_t,
mut b_lo: uint64x2_t,
mut b_hi: uint64x2_t,
mut c_lo: uint64x2_t,
mut c_hi: uint64x2_t,
mut d_lo: uint64x2_t,
mut d_hi: uint64x2_t,
) -> (
uint64x2_t,
uint64x2_t,
uint64x2_t,
uint64x2_t,
uint64x2_t,
uint64x2_t,
uint64x2_t,
uint64x2_t,
) {
unsafe {
gb_neon(
&mut a_lo, &mut a_hi, &mut b_lo, &mut b_hi, &mut c_lo, &mut c_hi, &mut d_lo, &mut d_hi,
);
}
unsafe {
let b_lo2 = vextq_u64::<1>(b_lo, b_hi);
let b_hi2 = vextq_u64::<1>(b_hi, b_lo);
b_lo = b_lo2;
b_hi = b_hi2;
let c_lo2 = c_hi;
let c_hi2 = c_lo;
c_lo = c_lo2;
c_hi = c_hi2;
let d_lo2 = vextq_u64::<1>(d_hi, d_lo);
let d_hi2 = vextq_u64::<1>(d_lo, d_hi);
d_lo = d_lo2;
d_hi = d_hi2;
}
unsafe {
gb_neon(
&mut a_lo, &mut a_hi, &mut b_lo, &mut b_hi, &mut c_lo, &mut c_hi, &mut d_lo, &mut d_hi,
);
}
unsafe {
let b_lo2 = vextq_u64::<1>(b_hi, b_lo);
let b_hi2 = vextq_u64::<1>(b_lo, b_hi);
b_lo = b_lo2;
b_hi = b_hi2;
let c_lo2 = c_hi;
let c_hi2 = c_lo;
c_lo = c_lo2;
c_hi = c_hi2;
let d_lo2 = vextq_u64::<1>(d_lo, d_hi);
let d_hi2 = vextq_u64::<1>(d_hi, d_lo);
d_lo = d_lo2;
d_hi = d_hi2;
}
(a_lo, a_hi, b_lo, b_hi, c_lo, c_hi, d_lo, d_hi)
}
#[inline(always)]
#[allow(clippy::too_many_arguments)]
unsafe fn gb_neon(
a_lo: &mut uint64x2_t,
a_hi: &mut uint64x2_t,
b_lo: &mut uint64x2_t,
b_hi: &mut uint64x2_t,
c_lo: &mut uint64x2_t,
c_hi: &mut uint64x2_t,
d_lo: &mut uint64x2_t,
d_hi: &mut uint64x2_t,
) {
unsafe {
let p_lo = bla_mul(*a_lo, *b_lo);
let p_hi = bla_mul(*a_hi, *b_hi);
*a_lo = vaddq_u64(vaddq_u64(*a_lo, *b_lo), p_lo);
*a_hi = vaddq_u64(vaddq_u64(*a_hi, *b_hi), p_hi);
*d_lo = ror32(veorq_u64(*d_lo, *a_lo));
*d_hi = ror32(veorq_u64(*d_hi, *a_hi));
let p_lo = bla_mul(*c_lo, *d_lo);
let p_hi = bla_mul(*c_hi, *d_hi);
*c_lo = vaddq_u64(vaddq_u64(*c_lo, *d_lo), p_lo);
*c_hi = vaddq_u64(vaddq_u64(*c_hi, *d_hi), p_hi);
*b_lo = ror24(veorq_u64(*b_lo, *c_lo));
*b_hi = ror24(veorq_u64(*b_hi, *c_hi));
let p_lo = bla_mul(*a_lo, *b_lo);
let p_hi = bla_mul(*a_hi, *b_hi);
*a_lo = vaddq_u64(vaddq_u64(*a_lo, *b_lo), p_lo);
*a_hi = vaddq_u64(vaddq_u64(*a_hi, *b_hi), p_hi);
*d_lo = ror16(veorq_u64(*d_lo, *a_lo));
*d_hi = ror16(veorq_u64(*d_hi, *a_hi));
let p_lo = bla_mul(*c_lo, *d_lo);
let p_hi = bla_mul(*c_hi, *d_hi);
*c_lo = vaddq_u64(vaddq_u64(*c_lo, *d_lo), p_lo);
*c_hi = vaddq_u64(vaddq_u64(*c_hi, *d_hi), p_hi);
*b_lo = ror63(veorq_u64(*b_lo, *c_lo));
*b_hi = ror63(veorq_u64(*b_hi, *c_hi));
}
}
#[inline(always)]
unsafe fn bla_mul(a: uint64x2_t, b: uint64x2_t) -> uint64x2_t {
unsafe {
let al = vmovn_u64(a);
let bl = vmovn_u64(b);
vshlq_n_u64::<1>(vmull_u32(al, bl))
}
}
#[inline(always)]
unsafe fn ror32(v: uint64x2_t) -> uint64x2_t {
unsafe { vreinterpretq_u64_u32(vrev64q_u32(vreinterpretq_u32_u64(v))) }
}
#[inline(always)]
unsafe fn ror24(v: uint64x2_t) -> uint64x2_t {
static IDX: [u8; 16] = [3, 4, 5, 6, 7, 0, 1, 2, 11, 12, 13, 14, 15, 8, 9, 10];
unsafe {
let table: uint8x16_t = vld1q_u8(IDX.as_ptr());
vreinterpretq_u64_u8(vqtbl1q_u8(vreinterpretq_u8_u64(v), table))
}
}
#[inline(always)]
unsafe fn ror16(v: uint64x2_t) -> uint64x2_t {
static IDX: [u8; 16] = [2, 3, 4, 5, 6, 7, 0, 1, 10, 11, 12, 13, 14, 15, 8, 9];
unsafe {
let table: uint8x16_t = vld1q_u8(IDX.as_ptr());
vreinterpretq_u64_u8(vqtbl1q_u8(vreinterpretq_u8_u64(v), table))
}
}
#[inline(always)]
unsafe fn ror63(v: uint64x2_t) -> uint64x2_t {
unsafe { vsriq_n_u64::<63>(vshlq_n_u64::<1>(v), v) }
}