use lib_q_intrinsics::*;
use crate::simd::avx2::rejection_sample::shuffle_table::SHUFFLE_TABLE;
use crate::simd::traits::FIELD_MODULUS;
#[inline(always)]
fn bytestream_to_potential_coefficients(serialized: &[u8]) -> Vec256 {
debug_assert_eq!(serialized.len(), 24);
let mut serialized_extended = [0u8; 32];
serialized_extended[..24].copy_from_slice(serialized);
const COEFFICIENT_MASK: i32 = (1 << 23) - 1;
let coefficients = mm256_loadu_si256_u8(&serialized_extended);
let coefficients =
mm256_permutevar8x32_epi32(coefficients, mm256_set_epi32(0, 5, 4, 3, 0, 2, 1, 0));
let coefficients = mm256_shuffle_epi8(
coefficients,
mm256_set_epi8(
-1, 11, 10, 9, -1, 8, 7, 6, -1, 5, 4, 3, -1, 2, 1, 0, -1, 11, 10, 9, -1, 8, 7, 6, -1,
5, 4, 3, -1, 2, 1, 0,
),
);
mm256_and_si256(coefficients, mm256_set1_epi32(COEFFICIENT_MASK))
}
#[inline(always)]
pub(crate) fn sample(input: &[u8], output: &mut [i32]) -> usize {
let field_modulus = mm256_set1_epi32(FIELD_MODULUS);
let potential_coefficients = bytestream_to_potential_coefficients(input);
let compare_with_field_modulus = mm256_cmpgt_epi32(field_modulus, potential_coefficients);
let good = mm256_movemask_ps(mm256_castsi256_ps(compare_with_field_modulus));
let good_lower_half = good & 0x0F;
let good_upper_half = good >> 4;
let lower_shuffles = SHUFFLE_TABLE[good_lower_half as usize];
let lower_shuffles = mm_loadu_si128(&lower_shuffles);
let lower_coefficients = mm256_castsi256_si128(potential_coefficients);
let lower_coefficients = mm_shuffle_epi8(lower_coefficients, lower_shuffles);
mm_storeu_si128_i32(&mut output[0..4], lower_coefficients);
let sampled_count = good_lower_half.count_ones() as usize;
let upper_shuffles = SHUFFLE_TABLE[good_upper_half as usize];
let upper_shuffles = mm_loadu_si128(&upper_shuffles);
let upper_coefficients = mm256_extracti128_si256::<1>(potential_coefficients);
let upper_coefficients = mm_shuffle_epi8(upper_coefficients, upper_shuffles);
mm_storeu_si128_i32(
&mut output[sampled_count..sampled_count + 4],
upper_coefficients,
);
sampled_count + (good_upper_half.count_ones() as usize)
}