#![cfg(target_arch = "x86_64")]
#![allow(clippy::cast_possible_truncation)]
use core::arch::x86_64::{
__m256i, __m512i, _mm_loadu_si128, _mm_storeu_si128, _mm256_add_epi64, _mm256_castsi128_si256,
_mm256_castsi256_si128, _mm256_extracti128_si256, _mm256_inserti128_si256, _mm256_load_si256, _mm256_loadu_si256,
_mm256_mul_epu32, _mm256_or_si256, _mm256_permute4x64_epi64, _mm256_ror_epi64, _mm256_shuffle_epi8,
_mm256_shuffle_epi32, _mm256_slli_epi64, _mm256_srli_epi64, _mm256_storeu_si256, _mm256_xor_si256, _mm512_add_epi64,
_mm512_loadu_si512, _mm512_mul_epu32, _mm512_permutex_epi64, _mm512_ror_epi64, _mm512_shuffle_i64x2,
_mm512_slli_epi64, _mm512_storeu_si512, _mm512_xor_si512,
};
use super::BLOCK_WORDS;
#[repr(align(32))]
struct Align32([u8; 32]);
static ROT24_MASK: Align32 = Align32([
3, 4, 5, 6, 7, 0, 1, 2, 11, 12, 13, 14, 15, 8, 9, 10, 3, 4, 5, 6, 7, 0, 1, 2, 11, 12, 13, 14, 15, 8, 9, 10,
]);
static ROT16_MASK: Align32 = Align32([
2, 3, 4, 5, 6, 7, 0, 1, 10, 11, 12, 13, 14, 15, 8, 9, 2, 3, 4, 5, 6, 7, 0, 1, 10, 11, 12, 13, 14, 15, 8, 9,
]);
#[target_feature(enable = "avx2")]
pub(super) unsafe fn compress_avx2(
dst: &mut [u64; BLOCK_WORDS],
x: &[u64; BLOCK_WORDS],
y: &[u64; BLOCK_WORDS],
xor_into: bool,
) {
unsafe {
let mut r = [0u64; BLOCK_WORDS];
let mut q = [0u64; BLOCK_WORDS];
let mut i = 0;
while i < BLOCK_WORDS {
let xv = _mm256_loadu_si256(x.as_ptr().add(i).cast());
let yv = _mm256_loadu_si256(y.as_ptr().add(i).cast());
let rv = _mm256_xor_si256(xv, yv);
_mm256_storeu_si256(r.as_mut_ptr().add(i).cast(), rv);
_mm256_storeu_si256(q.as_mut_ptr().add(i).cast(), rv);
i += 4;
}
let mut row = 0usize;
while row < 8 {
let base = row * 16;
let mut a = _mm256_loadu_si256(q.as_ptr().add(base).cast());
let mut b = _mm256_loadu_si256(q.as_ptr().add(base + 4).cast());
let mut c = _mm256_loadu_si256(q.as_ptr().add(base + 8).cast());
let mut d = _mm256_loadu_si256(q.as_ptr().add(base + 12).cast());
p_round_avx2(&mut a, &mut b, &mut c, &mut d);
_mm256_storeu_si256(q.as_mut_ptr().add(base).cast(), a);
_mm256_storeu_si256(q.as_mut_ptr().add(base + 4).cast(), b);
_mm256_storeu_si256(q.as_mut_ptr().add(base + 8).cast(), c);
_mm256_storeu_si256(q.as_mut_ptr().add(base + 12).cast(), d);
row += 1;
}
let mut col = 0usize;
while col < 8 {
let base = col * 2;
let mut a = load_col_pair_avx2(&q, base, base + 16);
let mut b = load_col_pair_avx2(&q, base + 32, base + 48);
let mut c = load_col_pair_avx2(&q, base + 64, base + 80);
let mut d = load_col_pair_avx2(&q, base + 96, base + 112);
p_round_avx2(&mut a, &mut b, &mut c, &mut d);
store_col_pair_avx2(&mut q, base, base + 16, a);
store_col_pair_avx2(&mut q, base + 32, base + 48, b);
store_col_pair_avx2(&mut q, base + 64, base + 80, c);
store_col_pair_avx2(&mut q, base + 96, base + 112, d);
col += 1;
}
let mut i = 0;
while i < BLOCK_WORDS {
let qv = _mm256_loadu_si256(q.as_ptr().add(i).cast());
let rv = _mm256_loadu_si256(r.as_ptr().add(i).cast());
let f = _mm256_xor_si256(qv, rv);
if xor_into {
let cur = _mm256_loadu_si256(dst.as_ptr().add(i).cast());
_mm256_storeu_si256(dst.as_mut_ptr().add(i).cast(), _mm256_xor_si256(cur, f));
} else {
_mm256_storeu_si256(dst.as_mut_ptr().add(i).cast(), f);
}
i += 4;
}
}
}
#[inline(always)]
unsafe fn load_col_pair_avx2(q: &[u64; BLOCK_WORDS], lo: usize, hi: usize) -> __m256i {
unsafe {
let lo_v = _mm_loadu_si128(q.as_ptr().add(lo).cast());
let hi_v = _mm_loadu_si128(q.as_ptr().add(hi).cast());
_mm256_inserti128_si256(_mm256_castsi128_si256(lo_v), hi_v, 1)
}
}
#[inline(always)]
unsafe fn store_col_pair_avx2(q: &mut [u64; BLOCK_WORDS], lo: usize, hi: usize, v: __m256i) {
unsafe {
_mm_storeu_si128(q.as_mut_ptr().add(lo).cast(), _mm256_castsi256_si128(v));
_mm_storeu_si128(q.as_mut_ptr().add(hi).cast(), _mm256_extracti128_si256(v, 1));
}
}
#[inline(always)]
unsafe fn p_round_avx2(a: &mut __m256i, b: &mut __m256i, c: &mut __m256i, d: &mut __m256i) {
unsafe {
gb_avx2(a, b, c, d);
*b = _mm256_permute4x64_epi64(*b, 0x39);
*c = _mm256_permute4x64_epi64(*c, 0x4E);
*d = _mm256_permute4x64_epi64(*d, 0x93);
gb_avx2(a, b, c, d);
*b = _mm256_permute4x64_epi64(*b, 0x93);
*c = _mm256_permute4x64_epi64(*c, 0x4E);
*d = _mm256_permute4x64_epi64(*d, 0x39);
}
}
#[inline(always)]
unsafe fn gb_avx2(a: &mut __m256i, b: &mut __m256i, c: &mut __m256i, d: &mut __m256i) {
unsafe {
let p = _mm256_mul_epu32(*a, *b);
*a = _mm256_add_epi64(_mm256_add_epi64(*a, *b), _mm256_slli_epi64(p, 1));
*d = ror32_avx2(_mm256_xor_si256(*d, *a));
let p = _mm256_mul_epu32(*c, *d);
*c = _mm256_add_epi64(_mm256_add_epi64(*c, *d), _mm256_slli_epi64(p, 1));
*b = ror24_avx2(_mm256_xor_si256(*b, *c));
let p = _mm256_mul_epu32(*a, *b);
*a = _mm256_add_epi64(_mm256_add_epi64(*a, *b), _mm256_slli_epi64(p, 1));
*d = ror16_avx2(_mm256_xor_si256(*d, *a));
let p = _mm256_mul_epu32(*c, *d);
*c = _mm256_add_epi64(_mm256_add_epi64(*c, *d), _mm256_slli_epi64(p, 1));
*b = ror63_avx2(_mm256_xor_si256(*b, *c));
}
}
#[inline(always)]
unsafe fn ror32_avx2(x: __m256i) -> __m256i {
unsafe { _mm256_shuffle_epi32(x, 0xB1) }
}
#[inline(always)]
unsafe fn ror24_avx2(x: __m256i) -> __m256i {
unsafe {
let mask = _mm256_load_si256(ROT24_MASK.0.as_ptr().cast());
_mm256_shuffle_epi8(x, mask)
}
}
#[inline(always)]
unsafe fn ror16_avx2(x: __m256i) -> __m256i {
unsafe {
let mask = _mm256_load_si256(ROT16_MASK.0.as_ptr().cast());
_mm256_shuffle_epi8(x, mask)
}
}
#[inline(always)]
unsafe fn ror63_avx2(x: __m256i) -> __m256i {
unsafe { _mm256_or_si256(_mm256_add_epi64(x, x), _mm256_srli_epi64(x, 63)) }
}
#[target_feature(enable = "avx512f,avx512vl")]
pub(super) unsafe fn compress_avx512(
dst: &mut [u64; BLOCK_WORDS],
x: &[u64; BLOCK_WORDS],
y: &[u64; BLOCK_WORDS],
xor_into: bool,
) {
unsafe {
let mut r = [0u64; BLOCK_WORDS];
let mut q = [0u64; BLOCK_WORDS];
let mut i = 0;
while i < BLOCK_WORDS {
let xv = _mm512_loadu_si512(x.as_ptr().add(i).cast());
let yv = _mm512_loadu_si512(y.as_ptr().add(i).cast());
let rv = _mm512_xor_si512(xv, yv);
_mm512_storeu_si512(r.as_mut_ptr().add(i).cast(), rv);
_mm512_storeu_si512(q.as_mut_ptr().add(i).cast(), rv);
i += 8;
}
let mut iter = 0;
while iter < 4 {
let off = iter * 32;
let r0_lo = _mm512_loadu_si512(q.as_ptr().add(off).cast());
let r0_hi = _mm512_loadu_si512(q.as_ptr().add(off + 8).cast());
let r1_lo = _mm512_loadu_si512(q.as_ptr().add(off + 16).cast());
let r1_hi = _mm512_loadu_si512(q.as_ptr().add(off + 24).cast());
let mut a = _mm512_shuffle_i64x2(r0_lo, r1_lo, 0x44);
let mut b = _mm512_shuffle_i64x2(r0_lo, r1_lo, 0xEE);
let mut c = _mm512_shuffle_i64x2(r0_hi, r1_hi, 0x44);
let mut d = _mm512_shuffle_i64x2(r0_hi, r1_hi, 0xEE);
p_round_avx512(&mut a, &mut b, &mut c, &mut d);
let r0_lo_out = _mm512_shuffle_i64x2(a, b, 0x44);
let r0_hi_out = _mm512_shuffle_i64x2(c, d, 0x44);
let r1_lo_out = _mm512_shuffle_i64x2(a, b, 0xEE);
let r1_hi_out = _mm512_shuffle_i64x2(c, d, 0xEE);
_mm512_storeu_si512(q.as_mut_ptr().add(off).cast(), r0_lo_out);
_mm512_storeu_si512(q.as_mut_ptr().add(off + 8).cast(), r0_hi_out);
_mm512_storeu_si512(q.as_mut_ptr().add(off + 16).cast(), r1_lo_out);
_mm512_storeu_si512(q.as_mut_ptr().add(off + 24).cast(), r1_hi_out);
iter += 1;
}
let mut col = 0usize;
while col < 8 {
let base = col * 2;
let mut a = load_col_pair_avx2(&q, base, base + 16);
let mut b = load_col_pair_avx2(&q, base + 32, base + 48);
let mut c = load_col_pair_avx2(&q, base + 64, base + 80);
let mut d = load_col_pair_avx2(&q, base + 96, base + 112);
p_round_avx512vl(&mut a, &mut b, &mut c, &mut d);
store_col_pair_avx2(&mut q, base, base + 16, a);
store_col_pair_avx2(&mut q, base + 32, base + 48, b);
store_col_pair_avx2(&mut q, base + 64, base + 80, c);
store_col_pair_avx2(&mut q, base + 96, base + 112, d);
col += 1;
}
let mut i = 0;
while i < BLOCK_WORDS {
let qv = _mm512_loadu_si512(q.as_ptr().add(i).cast());
let rv = _mm512_loadu_si512(r.as_ptr().add(i).cast());
let f = _mm512_xor_si512(qv, rv);
if xor_into {
let cur = _mm512_loadu_si512(dst.as_ptr().add(i).cast());
_mm512_storeu_si512(dst.as_mut_ptr().add(i).cast(), _mm512_xor_si512(cur, f));
} else {
_mm512_storeu_si512(dst.as_mut_ptr().add(i).cast(), f);
}
i += 8;
}
}
}
#[inline(always)]
unsafe fn p_round_avx512(a: &mut __m512i, b: &mut __m512i, c: &mut __m512i, d: &mut __m512i) {
unsafe {
gb_avx512(a, b, c, d);
*b = _mm512_permutex_epi64(*b, 0x39);
*c = _mm512_permutex_epi64(*c, 0x4E);
*d = _mm512_permutex_epi64(*d, 0x93);
gb_avx512(a, b, c, d);
*b = _mm512_permutex_epi64(*b, 0x93);
*c = _mm512_permutex_epi64(*c, 0x4E);
*d = _mm512_permutex_epi64(*d, 0x39);
}
}
#[inline(always)]
unsafe fn gb_avx512(a: &mut __m512i, b: &mut __m512i, c: &mut __m512i, d: &mut __m512i) {
unsafe {
let p = _mm512_mul_epu32(*a, *b);
*a = _mm512_add_epi64(_mm512_add_epi64(*a, *b), _mm512_slli_epi64(p, 1));
*d = _mm512_ror_epi64(_mm512_xor_si512(*d, *a), 32);
let p = _mm512_mul_epu32(*c, *d);
*c = _mm512_add_epi64(_mm512_add_epi64(*c, *d), _mm512_slli_epi64(p, 1));
*b = _mm512_ror_epi64(_mm512_xor_si512(*b, *c), 24);
let p = _mm512_mul_epu32(*a, *b);
*a = _mm512_add_epi64(_mm512_add_epi64(*a, *b), _mm512_slli_epi64(p, 1));
*d = _mm512_ror_epi64(_mm512_xor_si512(*d, *a), 16);
let p = _mm512_mul_epu32(*c, *d);
*c = _mm512_add_epi64(_mm512_add_epi64(*c, *d), _mm512_slli_epi64(p, 1));
*b = _mm512_ror_epi64(_mm512_xor_si512(*b, *c), 63);
}
}
#[inline(always)]
unsafe fn p_round_avx512vl(a: &mut __m256i, b: &mut __m256i, c: &mut __m256i, d: &mut __m256i) {
unsafe {
gb_avx512vl(a, b, c, d);
*b = _mm256_permute4x64_epi64(*b, 0x39);
*c = _mm256_permute4x64_epi64(*c, 0x4E);
*d = _mm256_permute4x64_epi64(*d, 0x93);
gb_avx512vl(a, b, c, d);
*b = _mm256_permute4x64_epi64(*b, 0x93);
*c = _mm256_permute4x64_epi64(*c, 0x4E);
*d = _mm256_permute4x64_epi64(*d, 0x39);
}
}
#[inline(always)]
unsafe fn gb_avx512vl(a: &mut __m256i, b: &mut __m256i, c: &mut __m256i, d: &mut __m256i) {
unsafe {
let p = _mm256_mul_epu32(*a, *b);
*a = _mm256_add_epi64(_mm256_add_epi64(*a, *b), _mm256_slli_epi64(p, 1));
*d = _mm256_ror_epi64(_mm256_xor_si256(*d, *a), 32);
let p = _mm256_mul_epu32(*c, *d);
*c = _mm256_add_epi64(_mm256_add_epi64(*c, *d), _mm256_slli_epi64(p, 1));
*b = _mm256_ror_epi64(_mm256_xor_si256(*b, *c), 24);
let p = _mm256_mul_epu32(*a, *b);
*a = _mm256_add_epi64(_mm256_add_epi64(*a, *b), _mm256_slli_epi64(p, 1));
*d = _mm256_ror_epi64(_mm256_xor_si256(*d, *a), 16);
let p = _mm256_mul_epu32(*c, *d);
*c = _mm256_add_epi64(_mm256_add_epi64(*c, *d), _mm256_slli_epi64(p, 1));
*b = _mm256_ror_epi64(_mm256_xor_si256(*b, *c), 63);
}
}