#![allow(unsafe_code)]
#![allow(clippy::inline_always)]
#![allow(clippy::indexing_slicing)]
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
use super::{BLOCK_LEN, K, big_sigma0, big_sigma1, ch, maj};
#[cfg(target_arch = "x86_64")]
#[inline(always)]
unsafe fn small_sigma0_256(x: __m256i) -> __m256i {
unsafe {
let rotr1 = _mm256_or_si256(_mm256_srli_epi64(x, 1), _mm256_slli_epi64(x, 63));
let rotr8 = _mm256_or_si256(_mm256_srli_epi64(x, 8), _mm256_slli_epi64(x, 56));
_mm256_xor_si256(_mm256_xor_si256(rotr1, rotr8), _mm256_srli_epi64(x, 7))
}
}
#[cfg(target_arch = "x86_64")]
#[inline(always)]
unsafe fn small_sigma1_256(x: __m256i) -> __m256i {
unsafe {
let rotr19 = _mm256_or_si256(_mm256_srli_epi64(x, 19), _mm256_slli_epi64(x, 45));
let rotr61 = _mm256_or_si256(_mm256_srli_epi64(x, 61), _mm256_slli_epi64(x, 3));
_mm256_xor_si256(_mm256_xor_si256(rotr19, rotr61), _mm256_srli_epi64(x, 6))
}
}
#[cfg(target_arch = "x86_64")]
#[inline(always)]
unsafe fn small_sigma0_128(x: __m128i) -> __m128i {
unsafe {
let rotr1 = _mm_or_si128(_mm_srli_epi64(x, 1), _mm_slli_epi64(x, 63));
let rotr8 = _mm_or_si128(_mm_srli_epi64(x, 8), _mm_slli_epi64(x, 56));
_mm_xor_si128(_mm_xor_si128(rotr1, rotr8), _mm_srli_epi64(x, 7))
}
}
#[cfg(target_arch = "x86_64")]
#[inline(always)]
unsafe fn small_sigma1_128(x: __m128i) -> __m128i {
unsafe {
let rotr19 = _mm_or_si128(_mm_srli_epi64(x, 19), _mm_slli_epi64(x, 45));
let rotr61 = _mm_or_si128(_mm_srli_epi64(x, 61), _mm_slli_epi64(x, 3));
_mm_xor_si128(_mm_xor_si128(rotr19, rotr61), _mm_srli_epi64(x, 6))
}
}
#[cfg(target_arch = "x86_64")]
static BSWAP64_128: [u8; 16] = [7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8];
#[cfg(target_arch = "x86_64")]
#[inline(always)]
unsafe fn load_two_blocks(blk1: *const u8, blk2: *const u8, offset: usize, bswap: __m128i) -> __m256i {
unsafe {
let lo = _mm_shuffle_epi8(_mm_loadu_si128(blk1.add(offset).cast()), bswap);
let hi = _mm_shuffle_epi8(_mm_loadu_si128(blk2.add(offset).cast()), bswap);
_mm256_inserti128_si256(_mm256_castsi128_si256(lo), hi, 1)
}
}
#[cfg(target_arch = "x86_64")]
#[inline(always)]
unsafe fn extract_lo(v: __m256i) -> (u64, u64) {
unsafe {
let lo128 = _mm256_castsi256_si128(v);
(_mm_extract_epi64(lo128, 0) as u64, _mm_extract_epi64(lo128, 1) as u64)
}
}
#[cfg(target_arch = "x86_64")]
#[inline(always)]
unsafe fn extract_hi(v: __m256i) -> (u64, u64) {
unsafe {
let hi128 = _mm256_extracti128_si256(v, 1);
(_mm_extract_epi64(hi128, 0) as u64, _mm_extract_epi64(hi128, 1) as u64)
}
}
#[cfg(target_arch = "x86_64")]
#[inline(always)]
unsafe fn extract_128(v: __m128i) -> (u64, u64) {
unsafe { (_mm_extract_epi64(v, 0) as u64, _mm_extract_epi64(v, 1) as u64) }
}
#[cfg(target_arch = "x86_64")]
#[inline(always)]
unsafe fn schedule_rotate_128(x: &mut [__m128i; 8], k: __m128i) -> __m128i {
unsafe {
let w_tm15 = _mm_alignr_epi8(x[1], x[0], 8);
let w_tm7 = _mm_alignr_epi8(x[5], x[4], 8);
x[0] = _mm_add_epi64(
_mm_add_epi64(x[0], w_tm7),
_mm_add_epi64(small_sigma0_128(w_tm15), small_sigma1_128(x[7])),
);
let new_val = x[0];
x[0] = x[1];
x[1] = x[2];
x[2] = x[3];
x[3] = x[4];
x[4] = x[5];
x[5] = x[6];
x[6] = x[7];
x[7] = new_val;
_mm_add_epi64(x[7], k)
}
}
#[cfg(target_arch = "x86_64")]
#[inline(always)]
unsafe fn schedule_rotate_256(x: &mut [__m256i; 8], k: __m256i) -> __m256i {
unsafe {
let w_tm15 = _mm256_alignr_epi8(x[1], x[0], 8);
let w_tm7 = _mm256_alignr_epi8(x[5], x[4], 8);
x[0] = _mm256_add_epi64(
_mm256_add_epi64(x[0], w_tm7),
_mm256_add_epi64(small_sigma0_256(w_tm15), small_sigma1_256(x[7])),
);
let new_val = x[0];
x[0] = x[1];
x[1] = x[2];
x[2] = x[3];
x[3] = x[4];
x[4] = x[5];
x[5] = x[6];
x[6] = x[7];
x[7] = new_val;
_mm256_add_epi64(x[7], k)
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2,bmi2")]
pub(crate) unsafe fn compress_blocks_avx2_decoupled(state: &mut [u64; 8], blocks: &[u8]) {
debug_assert_eq!(blocks.len() % BLOCK_LEN, 0);
let num_blocks = blocks.len() / BLOCK_LEN;
if num_blocks == 0 {
return;
}
unsafe {
let mut ptr = blocks.as_ptr();
let mut remaining = num_blocks;
let bswap = _mm_loadu_si128(BSWAP64_128.as_ptr().cast());
let mut rv = [0u64; 16];
let mut t2_buf = [0u64; 80];
while remaining >= 2 {
if remaining >= 4 {
_mm_prefetch::<_MM_HINT_T0>(ptr.add(BLOCK_LEN.strict_mul(2)).cast());
_mm_prefetch::<_MM_HINT_T0>(ptr.add(BLOCK_LEN.strict_mul(2).strict_add(64)).cast());
_mm_prefetch::<_MM_HINT_T0>(ptr.add(BLOCK_LEN.strict_mul(3)).cast());
_mm_prefetch::<_MM_HINT_T0>(ptr.add(BLOCK_LEN.strict_mul(3).strict_add(64)).cast());
}
let blk1 = ptr;
let blk2 = ptr.add(BLOCK_LEN);
let mut w: [__m256i; 8] = [
load_two_blocks(blk1, blk2, 0, bswap),
load_two_blocks(blk1, blk2, 16, bswap),
load_two_blocks(blk1, blk2, 32, bswap),
load_two_blocks(blk1, blk2, 48, bswap),
load_two_blocks(blk1, blk2, 64, bswap),
load_two_blocks(blk1, blk2, 80, bswap),
load_two_blocks(blk1, blk2, 96, bswap),
load_two_blocks(blk1, blk2, 112, bswap),
];
for (i, &wv) in w.iter().enumerate() {
let r = i.strict_mul(2);
let k_pair: __m128i = _mm_loadu_si128(K.as_ptr().add(r).cast());
let wk = _mm256_add_epi64(wv, _mm256_set_m128i(k_pair, k_pair));
let (lo0, lo1) = extract_lo(wk);
let (hi0, hi1) = extract_hi(wk);
rv[r] = lo0;
rv[r.strict_add(1)] = lo1;
t2_buf[r] = hi0;
t2_buf[r.strict_add(1)] = hi1;
}
let mut a = state[0];
let mut b = state[1];
let mut c = state[2];
let mut d = state[3];
let mut e = state[4];
let mut f = state[5];
let mut g = state[6];
let mut h = state[7];
macro_rules! round {
($wk:expr) => {{
let t1 = h
.wrapping_add(big_sigma1(e))
.wrapping_add(ch(e, f, g))
.wrapping_add($wk);
let t2 = big_sigma0(a).wrapping_add(maj(a, b, c));
h = g;
g = f;
f = e;
e = d.wrapping_add(t1);
d = c;
c = b;
b = a;
a = t1.wrapping_add(t2);
}};
}
for outer in 0..4usize {
for j in 0..8usize {
let r = j.strict_mul(2);
round!(rv[r]);
round!(rv[r.strict_add(1)]);
let kr = 8usize.strict_add(outer.strict_mul(8)).strict_add(j).strict_mul(2);
let k_pair: __m128i = _mm_loadu_si128(K.as_ptr().add(kr).cast());
let wk = schedule_rotate_256(&mut w, _mm256_set_m128i(k_pair, k_pair));
let (lo0, lo1) = extract_lo(wk);
let (hi0, hi1) = extract_hi(wk);
rv[r] = lo0;
rv[r.strict_add(1)] = lo1;
t2_buf[kr] = hi0;
t2_buf[kr.strict_add(1)] = hi1;
}
}
for &wk in &rv {
round!(wk);
}
state[0] = state[0].wrapping_add(a);
state[1] = state[1].wrapping_add(b);
state[2] = state[2].wrapping_add(c);
state[3] = state[3].wrapping_add(d);
state[4] = state[4].wrapping_add(e);
state[5] = state[5].wrapping_add(f);
state[6] = state[6].wrapping_add(g);
state[7] = state[7].wrapping_add(h);
a = state[0];
b = state[1];
c = state[2];
d = state[3];
e = state[4];
f = state[5];
g = state[6];
h = state[7];
for &wk in &t2_buf {
round!(wk);
}
state[0] = state[0].wrapping_add(a);
state[1] = state[1].wrapping_add(b);
state[2] = state[2].wrapping_add(c);
state[3] = state[3].wrapping_add(d);
state[4] = state[4].wrapping_add(e);
state[5] = state[5].wrapping_add(f);
state[6] = state[6].wrapping_add(g);
state[7] = state[7].wrapping_add(h);
ptr = ptr.add(BLOCK_LEN.strict_mul(2));
remaining = remaining.strict_sub(2);
}
if remaining == 1 {
let mut wv: [__m128i; 8] = [
_mm_shuffle_epi8(_mm_loadu_si128(ptr.cast()), bswap),
_mm_shuffle_epi8(_mm_loadu_si128(ptr.add(16).cast()), bswap),
_mm_shuffle_epi8(_mm_loadu_si128(ptr.add(32).cast()), bswap),
_mm_shuffle_epi8(_mm_loadu_si128(ptr.add(48).cast()), bswap),
_mm_shuffle_epi8(_mm_loadu_si128(ptr.add(64).cast()), bswap),
_mm_shuffle_epi8(_mm_loadu_si128(ptr.add(80).cast()), bswap),
_mm_shuffle_epi8(_mm_loadu_si128(ptr.add(96).cast()), bswap),
_mm_shuffle_epi8(_mm_loadu_si128(ptr.add(112).cast()), bswap),
];
for (i, &v) in wv.iter().enumerate() {
let r = i.strict_mul(2);
let k_pair: __m128i = _mm_loadu_si128(K.as_ptr().add(r).cast());
let wk = _mm_add_epi64(v, k_pair);
let (w0, w1) = extract_128(wk);
rv[r] = w0;
rv[r.strict_add(1)] = w1;
}
let mut a = state[0];
let mut b = state[1];
let mut c = state[2];
let mut d = state[3];
let mut e = state[4];
let mut f = state[5];
let mut g = state[6];
let mut h = state[7];
macro_rules! round1 {
($wk:expr) => {{
let t1 = h
.wrapping_add(big_sigma1(e))
.wrapping_add(ch(e, f, g))
.wrapping_add($wk);
let t2 = big_sigma0(a).wrapping_add(maj(a, b, c));
h = g;
g = f;
f = e;
e = d.wrapping_add(t1);
d = c;
c = b;
b = a;
a = t1.wrapping_add(t2);
}};
}
for outer in 0..4usize {
for j in 0..8usize {
let r = j.strict_mul(2);
round1!(rv[r]);
round1!(rv[r.strict_add(1)]);
let kr = 8usize.strict_add(outer.strict_mul(8)).strict_add(j).strict_mul(2);
let k_pair: __m128i = _mm_loadu_si128(K.as_ptr().add(kr).cast());
let wk = schedule_rotate_128(&mut wv, k_pair);
let (w0, w1) = extract_128(wk);
rv[r] = w0;
rv[r.strict_add(1)] = w1;
}
}
for &wk in &rv {
round1!(wk);
}
state[0] = state[0].wrapping_add(a);
state[1] = state[1].wrapping_add(b);
state[2] = state[2].wrapping_add(c);
state[3] = state[3].wrapping_add(d);
state[4] = state[4].wrapping_add(e);
state[5] = state[5].wrapping_add(f);
state[6] = state[6].wrapping_add(g);
state[7] = state[7].wrapping_add(h);
}
} }