use crate::simd::avx2::{encoding, rejection_sample::shuffle_table::SHUFFLE_TABLE, Eta};
use libcrux_intrinsics::avx2::*;
#[inline(always)]
fn shift_interval<const ETA: usize>(coefficients: Vec256) -> Vec256 {
match ETA as u8 {
2 => {
let quotient = mm256_mullo_epi32(coefficients, mm256_set1_epi32(26));
let quotient = mm256_srai_epi32::<7>(quotient);
let quotient = mm256_mullo_epi32(quotient, mm256_set1_epi32(5));
let coefficients_mod_5 = mm256_sub_epi32(coefficients, quotient);
mm256_sub_epi32(mm256_set1_epi32(ETA as i32), coefficients_mod_5)
}
4 => mm256_sub_epi32(mm256_set1_epi32(ETA as i32), coefficients),
_ => unreachable!(),
}
}
#[inline(always)]
pub(crate) fn sample<const ETA: usize>(input: &[u8], output: &mut [i32]) -> usize {
let potential_coefficients = encoding::error::deserialize_to_unsigned(Eta::Four, input);
let interval_boundary: i32 = match ETA as u8 {
2 => 15,
4 => 9,
_ => unreachable!(),
};
let compare_with_interval_boundary =
mm256_cmpgt_epi32(mm256_set1_epi32(interval_boundary), potential_coefficients);
let good = mm256_movemask_ps(mm256_castsi256_ps(compare_with_interval_boundary));
let good_lower_half = good & 0x0F;
let good_upper_half = good >> 4;
let shifted = shift_interval::<ETA>(potential_coefficients);
let lower_shuffles = SHUFFLE_TABLE[good_lower_half as usize];
let lower_shuffles = mm_loadu_si128(&lower_shuffles);
let lower_coefficients = mm256_castsi256_si128(shifted);
let lower_coefficients = mm_shuffle_epi8(lower_coefficients, lower_shuffles);
mm_storeu_si128_i32(&mut output[0..4], lower_coefficients);
let sampled_count = good_lower_half.count_ones() as usize;
let upper_shuffles = SHUFFLE_TABLE[good_upper_half as usize];
let upper_shuffles = mm_loadu_si128(&upper_shuffles);
let upper_coefficients = mm256_extracti128_si256::<1>(shifted);
let upper_coefficients = mm_shuffle_epi8(upper_coefficients, upper_shuffles);
mm_storeu_si128_i32(
&mut output[sampled_count..sampled_count + 4],
upper_coefficients,
);
sampled_count + (good_upper_half.count_ones() as usize)
}