use core::arch::x86_64::*;
use crate::{
params::*,
consts::*,
symmetric::*
};
pub(crate) const REJ_UNIFORM_AVX_NBLOCKS: usize =
(12*KYBER_N/8*(1 << 12)/KYBER_Q + XOF_BLOCKBYTES)/XOF_BLOCKBYTES;
const REJ_UNIFORM_AVX_BUFLEN: usize = REJ_UNIFORM_AVX_NBLOCKS*XOF_BLOCKBYTES;
pub unsafe fn _mm256_cmpge_epu16(a: __m256i, b: __m256i) -> __m256i {
_mm256_cmpeq_epi16(_mm256_max_epu16(a, b), a)
}
pub unsafe fn _mm_cmpge_epu16(a: __m128i, b: __m128i) -> __m128i {
_mm_cmpeq_epi16(_mm_max_epu16(a, b), a)
}
pub unsafe fn rej_uniform_avx(r: &mut[i16], buf: &[u8]) -> usize {
let mut ctr = 0;
let mut pos = 0;
let mut good: usize;
let (mut val0, mut val1);
let (mut f0, mut f1, mut g0, mut g1, mut g2, mut g3);
let (mut f, mut t, mut pilo, mut pihi);
let qdata_ptr = QDATA.coeffs[_16XQ..].as_ptr();
let bound = _mm256_load_si256(qdata_ptr as *const __m256i);
let ones = _mm256_set1_epi8(1);
let mask = _mm256_set1_epi16(0xFFF);
let idx8 = _mm256_set_epi8(
15,14,14,13,12,11,11,10,
9, 8, 8, 7, 6, 5, 5, 4,
11,10,10, 9, 8, 7, 7, 6,
5, 4, 4, 3, 2, 1, 1, 0
);
while ctr <= KYBER_N - 32 && pos <= REJ_UNIFORM_AVX_BUFLEN - 48 {
f0 = _mm256_loadu_si256(buf[pos..].as_ptr() as *const __m256i);
f1 = _mm256_loadu_si256(buf[pos+24..].as_ptr() as *const __m256i);
f0 = _mm256_permute4x64_epi64(f0, 0x94);
f1 = _mm256_permute4x64_epi64(f1, 0x94);
f0 = _mm256_shuffle_epi8(f0, idx8);
f1 = _mm256_shuffle_epi8(f1, idx8);
g0 = _mm256_srli_epi16(f0, 4);
g1 = _mm256_srli_epi16(f1, 4);
f0 = _mm256_blend_epi16(f0, g0, 0xAA);
f1 = _mm256_blend_epi16(f1, g1, 0xAA);
f0 = _mm256_and_si256(f0, mask);
f1 = _mm256_and_si256(f1, mask);
pos += 48;
g0 = _mm256_cmpgt_epi16(bound, f0);
g1 = _mm256_cmpgt_epi16(bound, f1);
g0 = _mm256_packs_epi16(g0, g1);
good = _mm256_movemask_epi8(g0) as usize;
let mut l0 = _mm_loadl_epi64(IDX[(good >> 0) & 0xFF].as_ptr() as * const __m128i);
g0 = _mm256_castsi128_si256(l0);
let mut l1 = _mm_loadl_epi64(IDX[(good >> 8) & 0xFF].as_ptr() as *const __m128i);
g1 = _mm256_castsi128_si256(l1);
l0 = _mm_loadl_epi64(IDX[(good >> 16) & 0xFF].as_ptr() as *const __m128i);
g0 = _mm256_inserti128_si256(g0, l0, 1);
l1 = _mm_loadl_epi64(IDX[(good >> 24) & 0xFF].as_ptr() as *const __m128i);
g1 = _mm256_inserti128_si256(g1, l1, 1);
g2 = _mm256_add_epi8(g0, ones);
g3 = _mm256_add_epi8(g1, ones);
g0 = _mm256_unpacklo_epi8(g0, g2);
g1 = _mm256_unpacklo_epi8(g1, g3);
f0 = _mm256_shuffle_epi8(f0, g0);
f1 = _mm256_shuffle_epi8(f1, g1);
_mm_storeu_si128(r[ctr..].as_mut_ptr() as *mut __m128i, _mm256_castsi256_si128(f0));
ctr += _popcnt32(((good >> 0) & 0xFF) as i32) as usize;
_mm_storeu_si128(r[ctr..].as_mut_ptr() as *mut __m128i, _mm256_extracti128_si256(f0, 1));
ctr += _popcnt32(((good >> 16) & 0xFF) as i32) as usize;
_mm_storeu_si128(r[ctr..].as_mut_ptr() as *mut __m128i, _mm256_castsi256_si128(f1));
ctr += _popcnt32(((good >> 8) & 0xFF) as i32) as usize;
_mm_storeu_si128(r[ctr..].as_mut_ptr() as *mut __m128i, _mm256_extracti128_si256(f1, 1));
ctr += _popcnt32(((good >> 24) & 0xFF) as i32) as usize;
}
while ctr <= KYBER_N - 8 && pos <= REJ_UNIFORM_AVX_BUFLEN - 12 {
f = _mm_loadu_si128(buf[pos..].as_ptr() as *const __m128i);
f = _mm_shuffle_epi8(f, _mm256_castsi256_si128(idx8));
t = _mm_srli_epi16(f, 4);
f = _mm_blend_epi16(f, t, 0xAA);
f = _mm_and_si128(f, _mm256_castsi256_si128(mask));
pos += 12;
t = _mm_cmpgt_epi16(_mm256_castsi256_si128(bound), f);
good = _mm_movemask_epi8(t) as usize;
let good = _pext_u32(good as u32, 0x5555) as usize;
pilo = _mm_loadl_epi64(IDX[good][..].as_ptr() as *const __m128i);
pihi = _mm_add_epi8(pilo, _mm256_castsi256_si128(ones));
pilo = _mm_unpacklo_epi8(pilo, pihi);
f = _mm_shuffle_epi8(f, pilo);
_mm_storeu_si128(r[ctr..].as_mut_ptr() as *mut __m128i, f);
ctr += _popcnt32(good as i32) as usize;
}
while ctr < KYBER_N && pos <= REJ_UNIFORM_AVX_BUFLEN - 3 {
val0 = (buf[pos+0] >> 0) as u16 | ((buf[pos+1] as u16) << 8) & 0xFFF;
val1 = (buf[pos+1] >> 4) as u16 | ((buf[pos+2] as u16) << 4);
pos += 3;
if (val0 as usize) < KYBER_Q {
r[ctr] = val0 as i16;
ctr += 1;
}
if (val1 as usize) < KYBER_Q && ctr < KYBER_N {
r[ctr] = val1 as i16;
ctr += 1;
}
}
ctr
}
const IDX: [[i8; 8]; 256] = [
[-1, -1, -1, -1, -1, -1, -1, -1],
[ 0, -1, -1, -1, -1, -1, -1, -1],
[ 2, -1, -1, -1, -1, -1, -1, -1],
[ 0, 2, -1, -1, -1, -1, -1, -1],
[ 4, -1, -1, -1, -1, -1, -1, -1],
[ 0, 4, -1, -1, -1, -1, -1, -1],
[ 2, 4, -1, -1, -1, -1, -1, -1],
[ 0, 2, 4, -1, -1, -1, -1, -1],
[ 6, -1, -1, -1, -1, -1, -1, -1],
[ 0, 6, -1, -1, -1, -1, -1, -1],
[ 2, 6, -1, -1, -1, -1, -1, -1],
[ 0, 2, 6, -1, -1, -1, -1, -1],
[ 4, 6, -1, -1, -1, -1, -1, -1],
[ 0, 4, 6, -1, -1, -1, -1, -1],
[ 2, 4, 6, -1, -1, -1, -1, -1],
[ 0, 2, 4, 6, -1, -1, -1, -1],
[ 8, -1, -1, -1, -1, -1, -1, -1],
[ 0, 8, -1, -1, -1, -1, -1, -1],
[ 2, 8, -1, -1, -1, -1, -1, -1],
[ 0, 2, 8, -1, -1, -1, -1, -1],
[ 4, 8, -1, -1, -1, -1, -1, -1],
[ 0, 4, 8, -1, -1, -1, -1, -1],
[ 2, 4, 8, -1, -1, -1, -1, -1],
[ 0, 2, 4, 8, -1, -1, -1, -1],
[ 6, 8, -1, -1, -1, -1, -1, -1],
[ 0, 6, 8, -1, -1, -1, -1, -1],
[ 2, 6, 8, -1, -1, -1, -1, -1],
[ 0, 2, 6, 8, -1, -1, -1, -1],
[ 4, 6, 8, -1, -1, -1, -1, -1],
[ 0, 4, 6, 8, -1, -1, -1, -1],
[ 2, 4, 6, 8, -1, -1, -1, -1],
[ 0, 2, 4, 6, 8, -1, -1, -1],
[10, -1, -1, -1, -1, -1, -1, -1],
[ 0, 10, -1, -1, -1, -1, -1, -1],
[ 2, 10, -1, -1, -1, -1, -1, -1],
[ 0, 2, 10, -1, -1, -1, -1, -1],
[ 4, 10, -1, -1, -1, -1, -1, -1],
[ 0, 4, 10, -1, -1, -1, -1, -1],
[ 2, 4, 10, -1, -1, -1, -1, -1],
[ 0, 2, 4, 10, -1, -1, -1, -1],
[ 6, 10, -1, -1, -1, -1, -1, -1],
[ 0, 6, 10, -1, -1, -1, -1, -1],
[ 2, 6, 10, -1, -1, -1, -1, -1],
[ 0, 2, 6, 10, -1, -1, -1, -1],
[ 4, 6, 10, -1, -1, -1, -1, -1],
[ 0, 4, 6, 10, -1, -1, -1, -1],
[ 2, 4, 6, 10, -1, -1, -1, -1],
[ 0, 2, 4, 6, 10, -1, -1, -1],
[ 8, 10, -1, -1, -1, -1, -1, -1],
[ 0, 8, 10, -1, -1, -1, -1, -1],
[ 2, 8, 10, -1, -1, -1, -1, -1],
[ 0, 2, 8, 10, -1, -1, -1, -1],
[ 4, 8, 10, -1, -1, -1, -1, -1],
[ 0, 4, 8, 10, -1, -1, -1, -1],
[ 2, 4, 8, 10, -1, -1, -1, -1],
[ 0, 2, 4, 8, 10, -1, -1, -1],
[ 6, 8, 10, -1, -1, -1, -1, -1],
[ 0, 6, 8, 10, -1, -1, -1, -1],
[ 2, 6, 8, 10, -1, -1, -1, -1],
[ 0, 2, 6, 8, 10, -1, -1, -1],
[ 4, 6, 8, 10, -1, -1, -1, -1],
[ 0, 4, 6, 8, 10, -1, -1, -1],
[ 2, 4, 6, 8, 10, -1, -1, -1],
[ 0, 2, 4, 6, 8, 10, -1, -1],
[12, -1, -1, -1, -1, -1, -1, -1],
[ 0, 12, -1, -1, -1, -1, -1, -1],
[ 2, 12, -1, -1, -1, -1, -1, -1],
[ 0, 2, 12, -1, -1, -1, -1, -1],
[ 4, 12, -1, -1, -1, -1, -1, -1],
[ 0, 4, 12, -1, -1, -1, -1, -1],
[ 2, 4, 12, -1, -1, -1, -1, -1],
[ 0, 2, 4, 12, -1, -1, -1, -1],
[ 6, 12, -1, -1, -1, -1, -1, -1],
[ 0, 6, 12, -1, -1, -1, -1, -1],
[ 2, 6, 12, -1, -1, -1, -1, -1],
[ 0, 2, 6, 12, -1, -1, -1, -1],
[ 4, 6, 12, -1, -1, -1, -1, -1],
[ 0, 4, 6, 12, -1, -1, -1, -1],
[ 2, 4, 6, 12, -1, -1, -1, -1],
[ 0, 2, 4, 6, 12, -1, -1, -1],
[ 8, 12, -1, -1, -1, -1, -1, -1],
[ 0, 8, 12, -1, -1, -1, -1, -1],
[ 2, 8, 12, -1, -1, -1, -1, -1],
[ 0, 2, 8, 12, -1, -1, -1, -1],
[ 4, 8, 12, -1, -1, -1, -1, -1],
[ 0, 4, 8, 12, -1, -1, -1, -1],
[ 2, 4, 8, 12, -1, -1, -1, -1],
[ 0, 2, 4, 8, 12, -1, -1, -1],
[ 6, 8, 12, -1, -1, -1, -1, -1],
[ 0, 6, 8, 12, -1, -1, -1, -1],
[ 2, 6, 8, 12, -1, -1, -1, -1],
[ 0, 2, 6, 8, 12, -1, -1, -1],
[ 4, 6, 8, 12, -1, -1, -1, -1],
[ 0, 4, 6, 8, 12, -1, -1, -1],
[ 2, 4, 6, 8, 12, -1, -1, -1],
[ 0, 2, 4, 6, 8, 12, -1, -1],
[10, 12, -1, -1, -1, -1, -1, -1],
[ 0, 10, 12, -1, -1, -1, -1, -1],
[ 2, 10, 12, -1, -1, -1, -1, -1],
[ 0, 2, 10, 12, -1, -1, -1, -1],
[ 4, 10, 12, -1, -1, -1, -1, -1],
[ 0, 4, 10, 12, -1, -1, -1, -1],
[ 2, 4, 10, 12, -1, -1, -1, -1],
[ 0, 2, 4, 10, 12, -1, -1, -1],
[ 6, 10, 12, -1, -1, -1, -1, -1],
[ 0, 6, 10, 12, -1, -1, -1, -1],
[ 2, 6, 10, 12, -1, -1, -1, -1],
[ 0, 2, 6, 10, 12, -1, -1, -1],
[ 4, 6, 10, 12, -1, -1, -1, -1],
[ 0, 4, 6, 10, 12, -1, -1, -1],
[ 2, 4, 6, 10, 12, -1, -1, -1],
[ 0, 2, 4, 6, 10, 12, -1, -1],
[ 8, 10, 12, -1, -1, -1, -1, -1],
[ 0, 8, 10, 12, -1, -1, -1, -1],
[ 2, 8, 10, 12, -1, -1, -1, -1],
[ 0, 2, 8, 10, 12, -1, -1, -1],
[ 4, 8, 10, 12, -1, -1, -1, -1],
[ 0, 4, 8, 10, 12, -1, -1, -1],
[ 2, 4, 8, 10, 12, -1, -1, -1],
[ 0, 2, 4, 8, 10, 12, -1, -1],
[ 6, 8, 10, 12, -1, -1, -1, -1],
[ 0, 6, 8, 10, 12, -1, -1, -1],
[ 2, 6, 8, 10, 12, -1, -1, -1],
[ 0, 2, 6, 8, 10, 12, -1, -1],
[ 4, 6, 8, 10, 12, -1, -1, -1],
[ 0, 4, 6, 8, 10, 12, -1, -1],
[ 2, 4, 6, 8, 10, 12, -1, -1],
[ 0, 2, 4, 6, 8, 10, 12, -1],
[14, -1, -1, -1, -1, -1, -1, -1],
[ 0, 14, -1, -1, -1, -1, -1, -1],
[ 2, 14, -1, -1, -1, -1, -1, -1],
[ 0, 2, 14, -1, -1, -1, -1, -1],
[ 4, 14, -1, -1, -1, -1, -1, -1],
[ 0, 4, 14, -1, -1, -1, -1, -1],
[ 2, 4, 14, -1, -1, -1, -1, -1],
[ 0, 2, 4, 14, -1, -1, -1, -1],
[ 6, 14, -1, -1, -1, -1, -1, -1],
[ 0, 6, 14, -1, -1, -1, -1, -1],
[ 2, 6, 14, -1, -1, -1, -1, -1],
[ 0, 2, 6, 14, -1, -1, -1, -1],
[ 4, 6, 14, -1, -1, -1, -1, -1],
[ 0, 4, 6, 14, -1, -1, -1, -1],
[ 2, 4, 6, 14, -1, -1, -1, -1],
[ 0, 2, 4, 6, 14, -1, -1, -1],
[ 8, 14, -1, -1, -1, -1, -1, -1],
[ 0, 8, 14, -1, -1, -1, -1, -1],
[ 2, 8, 14, -1, -1, -1, -1, -1],
[ 0, 2, 8, 14, -1, -1, -1, -1],
[ 4, 8, 14, -1, -1, -1, -1, -1],
[ 0, 4, 8, 14, -1, -1, -1, -1],
[ 2, 4, 8, 14, -1, -1, -1, -1],
[ 0, 2, 4, 8, 14, -1, -1, -1],
[ 6, 8, 14, -1, -1, -1, -1, -1],
[ 0, 6, 8, 14, -1, -1, -1, -1],
[ 2, 6, 8, 14, -1, -1, -1, -1],
[ 0, 2, 6, 8, 14, -1, -1, -1],
[ 4, 6, 8, 14, -1, -1, -1, -1],
[ 0, 4, 6, 8, 14, -1, -1, -1],
[ 2, 4, 6, 8, 14, -1, -1, -1],
[ 0, 2, 4, 6, 8, 14, -1, -1],
[10, 14, -1, -1, -1, -1, -1, -1],
[ 0, 10, 14, -1, -1, -1, -1, -1],
[ 2, 10, 14, -1, -1, -1, -1, -1],
[ 0, 2, 10, 14, -1, -1, -1, -1],
[ 4, 10, 14, -1, -1, -1, -1, -1],
[ 0, 4, 10, 14, -1, -1, -1, -1],
[ 2, 4, 10, 14, -1, -1, -1, -1],
[ 0, 2, 4, 10, 14, -1, -1, -1],
[ 6, 10, 14, -1, -1, -1, -1, -1],
[ 0, 6, 10, 14, -1, -1, -1, -1],
[ 2, 6, 10, 14, -1, -1, -1, -1],
[ 0, 2, 6, 10, 14, -1, -1, -1],
[ 4, 6, 10, 14, -1, -1, -1, -1],
[ 0, 4, 6, 10, 14, -1, -1, -1],
[ 2, 4, 6, 10, 14, -1, -1, -1],
[ 0, 2, 4, 6, 10, 14, -1, -1],
[ 8, 10, 14, -1, -1, -1, -1, -1],
[ 0, 8, 10, 14, -1, -1, -1, -1],
[ 2, 8, 10, 14, -1, -1, -1, -1],
[ 0, 2, 8, 10, 14, -1, -1, -1],
[ 4, 8, 10, 14, -1, -1, -1, -1],
[ 0, 4, 8, 10, 14, -1, -1, -1],
[ 2, 4, 8, 10, 14, -1, -1, -1],
[ 0, 2, 4, 8, 10, 14, -1, -1],
[ 6, 8, 10, 14, -1, -1, -1, -1],
[ 0, 6, 8, 10, 14, -1, -1, -1],
[ 2, 6, 8, 10, 14, -1, -1, -1],
[ 0, 2, 6, 8, 10, 14, -1, -1],
[ 4, 6, 8, 10, 14, -1, -1, -1],
[ 0, 4, 6, 8, 10, 14, -1, -1],
[ 2, 4, 6, 8, 10, 14, -1, -1],
[ 0, 2, 4, 6, 8, 10, 14, -1],
[12, 14, -1, -1, -1, -1, -1, -1],
[ 0, 12, 14, -1, -1, -1, -1, -1],
[ 2, 12, 14, -1, -1, -1, -1, -1],
[ 0, 2, 12, 14, -1, -1, -1, -1],
[ 4, 12, 14, -1, -1, -1, -1, -1],
[ 0, 4, 12, 14, -1, -1, -1, -1],
[ 2, 4, 12, 14, -1, -1, -1, -1],
[ 0, 2, 4, 12, 14, -1, -1, -1],
[ 6, 12, 14, -1, -1, -1, -1, -1],
[ 0, 6, 12, 14, -1, -1, -1, -1],
[ 2, 6, 12, 14, -1, -1, -1, -1],
[ 0, 2, 6, 12, 14, -1, -1, -1],
[ 4, 6, 12, 14, -1, -1, -1, -1],
[ 0, 4, 6, 12, 14, -1, -1, -1],
[ 2, 4, 6, 12, 14, -1, -1, -1],
[ 0, 2, 4, 6, 12, 14, -1, -1],
[ 8, 12, 14, -1, -1, -1, -1, -1],
[ 0, 8, 12, 14, -1, -1, -1, -1],
[ 2, 8, 12, 14, -1, -1, -1, -1],
[ 0, 2, 8, 12, 14, -1, -1, -1],
[ 4, 8, 12, 14, -1, -1, -1, -1],
[ 0, 4, 8, 12, 14, -1, -1, -1],
[ 2, 4, 8, 12, 14, -1, -1, -1],
[ 0, 2, 4, 8, 12, 14, -1, -1],
[ 6, 8, 12, 14, -1, -1, -1, -1],
[ 0, 6, 8, 12, 14, -1, -1, -1],
[ 2, 6, 8, 12, 14, -1, -1, -1],
[ 0, 2, 6, 8, 12, 14, -1, -1],
[ 4, 6, 8, 12, 14, -1, -1, -1],
[ 0, 4, 6, 8, 12, 14, -1, -1],
[ 2, 4, 6, 8, 12, 14, -1, -1],
[ 0, 2, 4, 6, 8, 12, 14, -1],
[10, 12, 14, -1, -1, -1, -1, -1],
[ 0, 10, 12, 14, -1, -1, -1, -1],
[ 2, 10, 12, 14, -1, -1, -1, -1],
[ 0, 2, 10, 12, 14, -1, -1, -1],
[ 4, 10, 12, 14, -1, -1, -1, -1],
[ 0, 4, 10, 12, 14, -1, -1, -1],
[ 2, 4, 10, 12, 14, -1, -1, -1],
[ 0, 2, 4, 10, 12, 14, -1, -1],
[ 6, 10, 12, 14, -1, -1, -1, -1],
[ 0, 6, 10, 12, 14, -1, -1, -1],
[ 2, 6, 10, 12, 14, -1, -1, -1],
[ 0, 2, 6, 10, 12, 14, -1, -1],
[ 4, 6, 10, 12, 14, -1, -1, -1],
[ 0, 4, 6, 10, 12, 14, -1, -1],
[ 2, 4, 6, 10, 12, 14, -1, -1],
[ 0, 2, 4, 6, 10, 12, 14, -1],
[ 8, 10, 12, 14, -1, -1, -1, -1],
[ 0, 8, 10, 12, 14, -1, -1, -1],
[ 2, 8, 10, 12, 14, -1, -1, -1],
[ 0, 2, 8, 10, 12, 14, -1, -1],
[ 4, 8, 10, 12, 14, -1, -1, -1],
[ 0, 4, 8, 10, 12, 14, -1, -1],
[ 2, 4, 8, 10, 12, 14, -1, -1],
[ 0, 2, 4, 8, 10, 12, 14, -1],
[ 6, 8, 10, 12, 14, -1, -1, -1],
[ 0, 6, 8, 10, 12, 14, -1, -1],
[ 2, 6, 8, 10, 12, 14, -1, -1],
[ 0, 2, 6, 8, 10, 12, 14, -1],
[ 4, 6, 8, 10, 12, 14, -1, -1],
[ 0, 4, 6, 8, 10, 12, 14, -1],
[ 2, 4, 6, 8, 10, 12, 14, -1],
[ 0, 2, 4, 6, 8, 10, 12, 14]
];