use core::arch::x86_64::{
__m256i, _mm256_add_epi64, _mm256_andnot_si256, _mm256_cmpgt_epi64, _mm256_loadu_si256, _mm256_set1_epi64x,
_mm256_storeu_si256, _mm256_sub_epi64, _mm256_xor_si256,
};
use poulpy_hal::reference::ntt120::{
NttAdd, NttAddInplace, NttCFromB, NttCopy, NttDFTExecute, NttExtract1BlkContiguous, NttFromZnx64, NttMulBbb, NttMulBbc,
NttMulBbc1ColX2, NttMulBbc2ColsX2, NttNegate, NttNegateInplace, NttSub, NttSubInplace, NttSubNegateInplace, NttToZnx128,
NttZero,
mat_vec::{BbbMeta, BbcMeta, extract_1blk_from_contiguous_q120b_ref},
ntt::{NttTable, NttTableInv},
primes::Primes30,
types::Q_SHIFTED,
};
use super::arithmetic_avx::{b_from_znx64_avx2, b_to_znx128_avx2, c_from_b_avx2, vec_mat1col_product_bbb_avx2};
use super::mat_vec_avx::{vec_mat1col_product_bbc_avx2, vec_mat1col_product_x2_bbc_avx2, vec_mat2cols_product_x2_bbc_avx2};
use super::ntt::{intt_avx2, ntt_avx2};
use super::NTT120Avx;
#[inline(always)]
unsafe fn lazy_reduce(x: __m256i, q_s: __m256i, msb: __m256i) -> __m256i {
unsafe {
let x_xor = _mm256_xor_si256(x, msb);
let q_xor = _mm256_xor_si256(q_s, msb);
let lt = _mm256_cmpgt_epi64(q_xor, x_xor);
_mm256_sub_epi64(x, _mm256_andnot_si256(lt, q_s))
}
}
#[target_feature(enable = "avx2")]
unsafe fn ntt_add_avx2(n: usize, res: &mut [u64], a: &[u64], b: &[u64]) {
unsafe {
let q_s = _mm256_loadu_si256(Q_SHIFTED.as_ptr() as *const __m256i);
let msb = _mm256_set1_epi64x(i64::MIN);
let mut a_ptr = a.as_ptr() as *const __m256i;
let mut b_ptr = b.as_ptr() as *const __m256i;
let mut r_ptr = res.as_mut_ptr() as *mut __m256i;
for _ in 0..n {
let av = lazy_reduce(_mm256_loadu_si256(a_ptr), q_s, msb);
let bv = lazy_reduce(_mm256_loadu_si256(b_ptr), q_s, msb);
_mm256_storeu_si256(r_ptr, _mm256_add_epi64(av, bv));
a_ptr = a_ptr.add(1);
b_ptr = b_ptr.add(1);
r_ptr = r_ptr.add(1);
}
}
}
#[target_feature(enable = "avx2")]
unsafe fn ntt_add_inplace_avx2(n: usize, res: &mut [u64], a: &[u64]) {
unsafe {
let q_s = _mm256_loadu_si256(Q_SHIFTED.as_ptr() as *const __m256i);
let msb = _mm256_set1_epi64x(i64::MIN);
let mut a_ptr = a.as_ptr() as *const __m256i;
let mut r_ptr = res.as_mut_ptr() as *mut __m256i;
for _ in 0..n {
let rv = lazy_reduce(_mm256_loadu_si256(r_ptr as *const __m256i), q_s, msb);
let av = lazy_reduce(_mm256_loadu_si256(a_ptr), q_s, msb);
_mm256_storeu_si256(r_ptr, _mm256_add_epi64(rv, av));
a_ptr = a_ptr.add(1);
r_ptr = r_ptr.add(1);
}
}
}
#[target_feature(enable = "avx2")]
unsafe fn ntt_sub_avx2(n: usize, res: &mut [u64], a: &[u64], b: &[u64]) {
unsafe {
let q_s = _mm256_loadu_si256(Q_SHIFTED.as_ptr() as *const __m256i);
let msb = _mm256_set1_epi64x(i64::MIN);
let mut a_ptr = a.as_ptr() as *const __m256i;
let mut b_ptr = b.as_ptr() as *const __m256i;
let mut r_ptr = res.as_mut_ptr() as *mut __m256i;
for _ in 0..n {
let av = lazy_reduce(_mm256_loadu_si256(a_ptr), q_s, msb);
let bv = lazy_reduce(_mm256_loadu_si256(b_ptr), q_s, msb);
_mm256_storeu_si256(r_ptr, _mm256_add_epi64(av, _mm256_sub_epi64(q_s, bv)));
a_ptr = a_ptr.add(1);
b_ptr = b_ptr.add(1);
r_ptr = r_ptr.add(1);
}
}
}
#[target_feature(enable = "avx2")]
unsafe fn ntt_sub_inplace_avx2(n: usize, res: &mut [u64], a: &[u64]) {
unsafe {
let q_s = _mm256_loadu_si256(Q_SHIFTED.as_ptr() as *const __m256i);
let msb = _mm256_set1_epi64x(i64::MIN);
let mut a_ptr = a.as_ptr() as *const __m256i;
let mut r_ptr = res.as_mut_ptr() as *mut __m256i;
for _ in 0..n {
let rv = lazy_reduce(_mm256_loadu_si256(r_ptr as *const __m256i), q_s, msb);
let av = lazy_reduce(_mm256_loadu_si256(a_ptr), q_s, msb);
_mm256_storeu_si256(r_ptr, _mm256_add_epi64(rv, _mm256_sub_epi64(q_s, av)));
a_ptr = a_ptr.add(1);
r_ptr = r_ptr.add(1);
}
}
}
#[target_feature(enable = "avx2")]
unsafe fn ntt_sub_negate_inplace_avx2(n: usize, res: &mut [u64], a: &[u64]) {
unsafe {
let q_s = _mm256_loadu_si256(Q_SHIFTED.as_ptr() as *const __m256i);
let msb = _mm256_set1_epi64x(i64::MIN);
let mut a_ptr = a.as_ptr() as *const __m256i;
let mut r_ptr = res.as_mut_ptr() as *mut __m256i;
for _ in 0..n {
let rv = lazy_reduce(_mm256_loadu_si256(r_ptr as *const __m256i), q_s, msb);
let av = lazy_reduce(_mm256_loadu_si256(a_ptr), q_s, msb);
_mm256_storeu_si256(r_ptr, _mm256_add_epi64(av, _mm256_sub_epi64(q_s, rv)));
a_ptr = a_ptr.add(1);
r_ptr = r_ptr.add(1);
}
}
}
#[target_feature(enable = "avx2")]
unsafe fn ntt_negate_avx2(n: usize, res: &mut [u64], a: &[u64]) {
unsafe {
let q_s = _mm256_loadu_si256(Q_SHIFTED.as_ptr() as *const __m256i);
let msb = _mm256_set1_epi64x(i64::MIN);
let mut a_ptr = a.as_ptr() as *const __m256i;
let mut r_ptr = res.as_mut_ptr() as *mut __m256i;
for _ in 0..n {
let av = lazy_reduce(_mm256_loadu_si256(a_ptr), q_s, msb);
_mm256_storeu_si256(r_ptr, _mm256_sub_epi64(q_s, av));
a_ptr = a_ptr.add(1);
r_ptr = r_ptr.add(1);
}
}
}
#[target_feature(enable = "avx2")]
unsafe fn ntt_negate_inplace_avx2(n: usize, res: &mut [u64]) {
unsafe {
let q_s = _mm256_loadu_si256(Q_SHIFTED.as_ptr() as *const __m256i);
let msb = _mm256_set1_epi64x(i64::MIN);
let mut r_ptr = res.as_mut_ptr() as *mut __m256i;
for _ in 0..n {
let rv = lazy_reduce(_mm256_loadu_si256(r_ptr as *const __m256i), q_s, msb);
_mm256_storeu_si256(r_ptr, _mm256_sub_epi64(q_s, rv));
r_ptr = r_ptr.add(1);
}
}
}
impl NttDFTExecute<NttTable<Primes30>> for NTT120Avx {
#[inline(always)]
fn ntt_dft_execute(table: &NttTable<Primes30>, data: &mut [u64]) {
unsafe { ntt_avx2::<Primes30>(table, data) }
}
}
impl NttDFTExecute<NttTableInv<Primes30>> for NTT120Avx {
#[inline(always)]
fn ntt_dft_execute(table: &NttTableInv<Primes30>, data: &mut [u64]) {
unsafe { intt_avx2::<Primes30>(table, data) }
}
}
impl NttFromZnx64 for NTT120Avx {
#[inline(always)]
fn ntt_from_znx64(res: &mut [u64], a: &[i64]) {
unsafe { b_from_znx64_avx2(a.len(), res, a) }
}
}
impl NttToZnx128 for NTT120Avx {
#[inline(always)]
fn ntt_to_znx128(res: &mut [i128], divisor_is_n: usize, a: &[u64]) {
unsafe { b_to_znx128_avx2(divisor_is_n, res, a) }
}
}
impl NttAdd for NTT120Avx {
#[inline(always)]
fn ntt_add(res: &mut [u64], a: &[u64], b: &[u64]) {
unsafe { ntt_add_avx2(res.len() / 4, res, a, b) }
}
}
impl NttAddInplace for NTT120Avx {
#[inline(always)]
fn ntt_add_inplace(res: &mut [u64], a: &[u64]) {
unsafe { ntt_add_inplace_avx2(res.len() / 4, res, a) }
}
}
impl NttSub for NTT120Avx {
#[inline(always)]
fn ntt_sub(res: &mut [u64], a: &[u64], b: &[u64]) {
unsafe { ntt_sub_avx2(res.len() / 4, res, a, b) }
}
}
impl NttSubInplace for NTT120Avx {
#[inline(always)]
fn ntt_sub_inplace(res: &mut [u64], a: &[u64]) {
unsafe { ntt_sub_inplace_avx2(res.len() / 4, res, a) }
}
}
impl NttSubNegateInplace for NTT120Avx {
#[inline(always)]
fn ntt_sub_negate_inplace(res: &mut [u64], a: &[u64]) {
unsafe { ntt_sub_negate_inplace_avx2(res.len() / 4, res, a) }
}
}
impl NttNegate for NTT120Avx {
#[inline(always)]
fn ntt_negate(res: &mut [u64], a: &[u64]) {
unsafe { ntt_negate_avx2(res.len() / 4, res, a) }
}
}
impl NttNegateInplace for NTT120Avx {
#[inline(always)]
fn ntt_negate_inplace(res: &mut [u64]) {
unsafe { ntt_negate_inplace_avx2(res.len() / 4, res) }
}
}
impl NttZero for NTT120Avx {
#[inline(always)]
fn ntt_zero(res: &mut [u64]) {
res.fill(0);
}
}
impl NttCopy for NTT120Avx {
#[inline(always)]
fn ntt_copy(res: &mut [u64], a: &[u64]) {
res.copy_from_slice(a);
}
}
impl NttMulBbb for NTT120Avx {
#[inline(always)]
fn ntt_mul_bbb(meta: &BbbMeta<Primes30>, ell: usize, res: &mut [u64], a: &[u64], b: &[u64]) {
unsafe { vec_mat1col_product_bbb_avx2(meta, ell, res, a, b) }
}
}
impl NttMulBbc for NTT120Avx {
#[inline(always)]
fn ntt_mul_bbc(meta: &BbcMeta<Primes30>, ell: usize, res: &mut [u64], ntt_coeff: &[u32], prepared: &[u32]) {
unsafe { vec_mat1col_product_bbc_avx2(meta, ell, res, ntt_coeff, prepared) }
}
}
impl NttCFromB for NTT120Avx {
#[inline(always)]
fn ntt_c_from_b(n: usize, res: &mut [u32], a: &[u64]) {
unsafe { c_from_b_avx2(n, res, a) }
}
}
impl NttMulBbc1ColX2 for NTT120Avx {
#[inline(always)]
fn ntt_mul_bbc_1col_x2(meta: &BbcMeta<Primes30>, ell: usize, res: &mut [u64], a: &[u32], b: &[u32]) {
unsafe { vec_mat1col_product_x2_bbc_avx2(meta, ell, res, a, b) }
}
}
impl NttMulBbc2ColsX2 for NTT120Avx {
#[inline(always)]
fn ntt_mul_bbc_2cols_x2(meta: &BbcMeta<Primes30>, ell: usize, res: &mut [u64], a: &[u32], b: &[u32]) {
unsafe { vec_mat2cols_product_x2_bbc_avx2(meta, ell, res, a, b) }
}
}
impl NttExtract1BlkContiguous for NTT120Avx {
#[inline(always)]
fn ntt_extract_1blk_contiguous(n: usize, row_max: usize, blk: usize, dst: &mut [u64], src: &[u64]) {
extract_1blk_from_contiguous_q120b_ref(n, row_max, blk, dst, src);
}
}