#![allow(unsafe_code)]
use super::blake3::{CHUNK_END, CHUNK_LEN, CHUNK_START, IV, MSG_PERMUTATION};
pub(super) const DEGREE: usize = 8;
#[cfg(target_arch = "x86_64")]
pub(super) fn supported() -> bool {
std::is_x86_feature_detected!("avx2")
}
#[cfg(target_arch = "x86_64")]
pub(super) fn hash_chunks8(
input: &[u8],
key: &[u32; 8],
counter_base: u64,
flags: u32,
) -> [[u32; 8]; 8] {
debug_assert_eq!(input.len(), DEGREE * CHUNK_LEN);
unsafe { avx2::hash_chunks8(input, key, counter_base, flags) }
}
#[cfg(target_arch = "x86_64")]
mod avx2 {
use super::*;
use core::arch::x86_64::*;
const BLOCK_LEN: u32 = 64;
#[inline(always)]
unsafe fn rotr16(x: __m256i) -> __m256i {
unsafe { _mm256_or_si256(_mm256_srli_epi32::<16>(x), _mm256_slli_epi32::<16>(x)) }
}
#[inline(always)]
unsafe fn rotr12(x: __m256i) -> __m256i {
unsafe { _mm256_or_si256(_mm256_srli_epi32::<12>(x), _mm256_slli_epi32::<20>(x)) }
}
#[inline(always)]
unsafe fn rotr8(x: __m256i) -> __m256i {
unsafe { _mm256_or_si256(_mm256_srli_epi32::<8>(x), _mm256_slli_epi32::<24>(x)) }
}
#[inline(always)]
unsafe fn rotr7(x: __m256i) -> __m256i {
unsafe { _mm256_or_si256(_mm256_srli_epi32::<7>(x), _mm256_slli_epi32::<25>(x)) }
}
#[inline(always)]
unsafe fn add(a: __m256i, b: __m256i) -> __m256i {
unsafe { _mm256_add_epi32(a, b) }
}
#[inline(always)]
unsafe fn xor(a: __m256i, b: __m256i) -> __m256i {
unsafe { _mm256_xor_si256(a, b) }
}
#[inline(always)]
#[allow(clippy::too_many_arguments)]
unsafe fn g(
v: &mut [__m256i; 16],
a: usize,
b: usize,
c: usize,
d: usize,
mx: __m256i,
my: __m256i,
) {
unsafe {
v[a] = add(add(v[a], v[b]), mx);
v[d] = rotr16(xor(v[d], v[a]));
v[c] = add(v[c], v[d]);
v[b] = rotr12(xor(v[b], v[c]));
v[a] = add(add(v[a], v[b]), my);
v[d] = rotr8(xor(v[d], v[a]));
v[c] = add(v[c], v[d]);
v[b] = rotr7(xor(v[b], v[c]));
}
}
#[inline(always)]
unsafe fn round(v: &mut [__m256i; 16], m: &[__m256i; 16]) {
unsafe {
g(v, 0, 4, 8, 12, m[0], m[1]);
g(v, 1, 5, 9, 13, m[2], m[3]);
g(v, 2, 6, 10, 14, m[4], m[5]);
g(v, 3, 7, 11, 15, m[6], m[7]);
g(v, 0, 5, 10, 15, m[8], m[9]);
g(v, 1, 6, 11, 12, m[10], m[11]);
g(v, 2, 7, 8, 13, m[12], m[13]);
g(v, 3, 4, 9, 14, m[14], m[15]);
}
}
#[inline(always)]
unsafe fn transpose8(rows: &mut [__m256i; 8]) {
unsafe {
let t0 = _mm256_unpacklo_epi32(rows[0], rows[1]);
let t1 = _mm256_unpackhi_epi32(rows[0], rows[1]);
let t2 = _mm256_unpacklo_epi32(rows[2], rows[3]);
let t3 = _mm256_unpackhi_epi32(rows[2], rows[3]);
let t4 = _mm256_unpacklo_epi32(rows[4], rows[5]);
let t5 = _mm256_unpackhi_epi32(rows[4], rows[5]);
let t6 = _mm256_unpacklo_epi32(rows[6], rows[7]);
let t7 = _mm256_unpackhi_epi32(rows[6], rows[7]);
let s0 = _mm256_unpacklo_epi64(t0, t2);
let s1 = _mm256_unpackhi_epi64(t0, t2);
let s2 = _mm256_unpacklo_epi64(t1, t3);
let s3 = _mm256_unpackhi_epi64(t1, t3);
let s4 = _mm256_unpacklo_epi64(t4, t6);
let s5 = _mm256_unpackhi_epi64(t4, t6);
let s6 = _mm256_unpacklo_epi64(t5, t7);
let s7 = _mm256_unpackhi_epi64(t5, t7);
rows[0] = _mm256_permute2x128_si256(s0, s4, 0x20);
rows[1] = _mm256_permute2x128_si256(s1, s5, 0x20);
rows[2] = _mm256_permute2x128_si256(s2, s6, 0x20);
rows[3] = _mm256_permute2x128_si256(s3, s7, 0x20);
rows[4] = _mm256_permute2x128_si256(s0, s4, 0x31);
rows[5] = _mm256_permute2x128_si256(s1, s5, 0x31);
rows[6] = _mm256_permute2x128_si256(s2, s6, 0x31);
rows[7] = _mm256_permute2x128_si256(s3, s7, 0x31);
}
}
#[inline(always)]
unsafe fn load_msg(input: &[u8], b: usize) -> [__m256i; 16] {
unsafe {
let mut lo = [_mm256_setzero_si256(); 8];
let mut hi = [_mm256_setzero_si256(); 8];
for (lane, (l, h)) in lo.iter_mut().zip(hi.iter_mut()).enumerate() {
let p = input.as_ptr().add(lane * CHUNK_LEN + b * 64);
*l = _mm256_loadu_si256(p as *const __m256i); *h = _mm256_loadu_si256(p.add(32) as *const __m256i); }
transpose8(&mut lo);
transpose8(&mut hi);
let mut m = [_mm256_setzero_si256(); 16];
m[..8].copy_from_slice(&lo);
m[8..].copy_from_slice(&hi);
m
}
}
#[inline(always)]
unsafe fn permute(m: &[__m256i; 16]) -> [__m256i; 16] {
let mut out = [unsafe { _mm256_setzero_si256() }; 16];
for (i, &p) in MSG_PERMUTATION.iter().enumerate() {
out[i] = m[p];
}
out
}
#[target_feature(enable = "avx2")]
pub(super) unsafe fn hash_chunks8(
input: &[u8],
key: &[u32; 8],
counter_base: u64,
flags: u32,
) -> [[u32; 8]; 8] {
unsafe {
let mut clo = [0u32; 8];
let mut chi = [0u32; 8];
for (k, (lo, hi)) in clo.iter_mut().zip(chi.iter_mut()).enumerate() {
let c = counter_base.wrapping_add(k as u64);
*lo = c as u32;
*hi = (c >> 32) as u32;
}
let counter_lo = _mm256_loadu_si256(clo.as_ptr() as *const __m256i);
let counter_hi = _mm256_loadu_si256(chi.as_ptr() as *const __m256i);
let block_len = _mm256_set1_epi32(BLOCK_LEN as i32);
let mut h = [_mm256_setzero_si256(); 8];
for (hi, &kw) in h.iter_mut().zip(key.iter()) {
*hi = _mm256_set1_epi32(kw as i32);
}
for b in 0..16usize {
let block_flags = {
let mut f = flags;
if b == 0 {
f |= CHUNK_START;
}
if b == 15 {
f |= CHUNK_END;
}
_mm256_set1_epi32(f as i32)
};
let msg = load_msg(input, b);
let mut v = [
h[0],
h[1],
h[2],
h[3],
h[4],
h[5],
h[6],
h[7],
_mm256_set1_epi32(IV[0] as i32),
_mm256_set1_epi32(IV[1] as i32),
_mm256_set1_epi32(IV[2] as i32),
_mm256_set1_epi32(IV[3] as i32),
counter_lo,
counter_hi,
block_len,
block_flags,
];
let mut m = msg;
for r in 0..7 {
round(&mut v, &m);
if r < 6 {
m = permute(&m);
}
}
for i in 0..8 {
h[i] = xor(v[i], v[i + 8]);
}
}
transpose8(&mut h);
let mut out = [[0u32; 8]; 8];
for (lane, hi) in h.iter().enumerate() {
_mm256_storeu_si256(out[lane].as_mut_ptr() as *mut __m256i, *hi);
}
out
}
}
}