use core::arch::x86_64::{
__m256i, _mm_add_epi64, _mm_cvtsi64_si128, _mm_cvtsi128_si64, _mm_unpackhi_epi64, _mm256_add_epi32, _mm256_add_epi64,
_mm256_and_si256, _mm256_andnot_si256, _mm256_castsi256_si128, _mm256_cmpgt_epi64, _mm256_extracti128_si256,
_mm256_loadu_si256, _mm256_mul_epu32, _mm256_or_si256, _mm256_set1_epi64x, _mm256_setzero_si256, _mm256_slli_epi64,
_mm256_srl_epi64, _mm256_srli_epi64, _mm256_storeu_si256, _mm256_sub_epi64,
};
use poulpy_cpu_ref::reference::ntt120::{
mat_vec::BbbMeta,
primes::{PrimeSet, Primes30},
};
pub(crate) const Q_VEC: [u64; 4] = [
Primes30::Q[0] as u64,
Primes30::Q[1] as u64,
Primes30::Q[2] as u64,
Primes30::Q[3] as u64,
];
pub(crate) const OQ: [u64; 4] = {
let mut oq = [0u64; 4];
let mut k = 0usize;
while k < 4 {
let q = Q_VEC[k];
oq[k] = q - (i64::MIN as u64 % q); k += 1;
}
oq
};
pub(crate) const BARRETT_MU: [u64; 4] = {
let mut mu = [0u64; 4];
let mut k = 0usize;
while k < 4 {
mu[k] = (1u64 << 61) / Q_VEC[k];
k += 1;
}
mu
};
pub(crate) const POW32: [u64; 4] = {
let mut p = [0u64; 4];
let mut k = 0usize;
while k < 4 {
p[k] = ((1u128 << 32) % Q_VEC[k] as u128) as u64;
k += 1;
}
p
};
pub(crate) const CRT_VEC: [u64; 4] = [
Primes30::CRT_CST[0] as u64,
Primes30::CRT_CST[1] as u64,
Primes30::CRT_CST[2] as u64,
Primes30::CRT_CST[3] as u64,
];
pub(crate) const POW32_CRT: [u64; 4] = {
let mut r = [0u64; 4];
let mut k = 0usize;
while k < 4 {
r[k] = (POW32[k] * CRT_VEC[k]) % Q_VEC[k];
k += 1;
}
r
};
pub(crate) const POW16_CRT: [u64; 4] = {
let mut r = [0u64; 4];
let mut k = 0usize;
while k < 4 {
r[k] = ((1u64 << 16) * CRT_VEC[k]) % Q_VEC[k];
k += 1;
}
r
};
const QM: [u128; 4] = {
let q0 = Primes30::Q[0] as u128;
let q1 = Primes30::Q[1] as u128;
let q2 = Primes30::Q[2] as u128;
let q3 = Primes30::Q[3] as u128;
[q1 * q2 * q3, q0 * q2 * q3, q0 * q1 * q3, q0 * q1 * q2]
};
pub(crate) const QM_HI: [u64; 4] = [
(QM[0] >> 64) as u64,
(QM[1] >> 64) as u64,
(QM[2] >> 64) as u64,
(QM[3] >> 64) as u64,
];
pub(crate) const QM_MID: [u64; 4] = [
((QM[0] >> 32) & 0xFFFF_FFFF) as u64,
((QM[1] >> 32) & 0xFFFF_FFFF) as u64,
((QM[2] >> 32) & 0xFFFF_FFFF) as u64,
((QM[3] >> 32) & 0xFFFF_FFFF) as u64,
];
pub(crate) const QM_LO: [u64; 4] = [
(QM[0] & 0xFFFF_FFFF) as u64,
(QM[1] & 0xFFFF_FFFF) as u64,
(QM[2] & 0xFFFF_FFFF) as u64,
(QM[3] & 0xFFFF_FFFF) as u64,
];
pub(crate) const TOTAL_Q: u128 = {
let q0 = Primes30::Q[0] as u128;
let q1 = Primes30::Q[1] as u128;
let q2 = Primes30::Q[2] as u128;
let q3 = Primes30::Q[3] as u128;
q0 * q1 * q2 * q3
};
pub(crate) const TOTAL_Q_MULT: [u128; 4] = [0, TOTAL_Q, TOTAL_Q * 2, TOTAL_Q * 3];
#[inline(always)]
pub(crate) unsafe fn cond_sub(x: __m256i, q: __m256i) -> __m256i {
unsafe {
let lt = _mm256_cmpgt_epi64(q, x);
_mm256_sub_epi64(x, _mm256_andnot_si256(lt, q))
}
}
#[inline(always)]
pub(crate) unsafe fn barrett_reduce(tmp: __m256i, q: __m256i, mu: __m256i) -> __m256i {
unsafe {
let mask32 = _mm256_set1_epi64x(u32::MAX as i64);
let tmp_hi = _mm256_srli_epi64::<32>(tmp);
let tmp_lo = _mm256_and_si256(tmp, mask32);
let q_hi = _mm256_srli_epi64::<29>(_mm256_mul_epu32(tmp_hi, mu));
let q_lo = _mm256_srli_epi64::<61>(_mm256_mul_epu32(tmp_lo, mu));
let q_approx = _mm256_add_epi64(q_hi, q_lo);
let r = _mm256_sub_epi64(tmp, _mm256_mul_epu32(q_approx, q));
let r = cond_sub(r, q);
cond_sub(r, q)
}
}
#[inline(always)]
unsafe fn hadd64(v: __m256i) -> u64 {
unsafe {
let lo128 = _mm256_castsi256_si128(v);
let hi128 = _mm256_extracti128_si256::<1>(v);
let sum2 = _mm_add_epi64(lo128, hi128); let sum2h = _mm_unpackhi_epi64(sum2, sum2); let sum1 = _mm_add_epi64(sum2, sum2h); _mm_cvtsi128_si64(sum1) as u64
}
}
#[inline(always)]
pub(crate) unsafe fn crt_accumulate_avx2(t: __m256i, qm_hi: __m256i, qm_mid: __m256i, qm_lo: __m256i) -> u128 {
unsafe {
let p_hi = _mm256_mul_epu32(t, qm_hi); let p_mid = _mm256_mul_epu32(t, qm_mid); let p_lo = _mm256_mul_epu32(t, qm_lo);
let s_hi = hadd64(p_hi); let s_mid = hadd64(p_mid); let s_lo = hadd64(p_lo);
((s_hi as u128) << 64) + ((s_mid as u128) << 32) + (s_lo as u128)
}
}
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn b_from_znx64_avx2(nn: usize, res: &mut [u64], x: &[i64]) {
assert!(
res.len() >= 4 * nn,
"b_from_znx64_avx2: res.len()={} < 4*nn={}",
res.len(),
4 * nn
);
assert!(x.len() >= nn, "b_from_znx64_avx2: x.len()={} < nn={}", x.len(), nn);
unsafe {
let oq_vec = _mm256_loadu_si256(OQ.as_ptr() as *const __m256i);
let i64_max = _mm256_set1_epi64x(i64::MAX);
let zero = _mm256_setzero_si256();
let mut r_ptr = res.as_mut_ptr() as *mut __m256i;
for &xval in &x[..nn] {
let xv = _mm256_set1_epi64x(xval);
let xl = _mm256_and_si256(xv, i64_max);
let sign = _mm256_cmpgt_epi64(zero, xv);
let add = _mm256_and_si256(sign, oq_vec);
_mm256_storeu_si256(r_ptr, _mm256_add_epi64(xl, add));
r_ptr = r_ptr.add(1);
}
}
}
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn b_from_znx64_masked_avx2(nn: usize, res: &mut [u64], x: &[i64], mask: i64) {
assert!(
res.len() >= 4 * nn,
"b_from_znx64_masked_avx2: res.len()={} < 4*nn={}",
res.len(),
4 * nn
);
assert!(x.len() >= nn, "b_from_znx64_masked_avx2: x.len()={} < nn={}", x.len(), nn);
unsafe {
let oq_vec = _mm256_loadu_si256(OQ.as_ptr() as *const __m256i);
let i64_max = _mm256_set1_epi64x(i64::MAX);
let zero = _mm256_setzero_si256();
let mut r_ptr = res.as_mut_ptr() as *mut __m256i;
for &xval in &x[..nn] {
let xv = _mm256_set1_epi64x(xval & mask);
let xl = _mm256_and_si256(xv, i64_max);
let sign = _mm256_cmpgt_epi64(zero, xv);
let add = _mm256_and_si256(sign, oq_vec);
_mm256_storeu_si256(r_ptr, _mm256_add_epi64(xl, add));
r_ptr = r_ptr.add(1);
}
}
}
#[inline(always)]
pub(crate) unsafe fn reduce_b_to_canonical(x: __m256i, q: __m256i, mu: __m256i, pow32: __m256i) -> __m256i {
unsafe {
let mask32 = _mm256_set1_epi64x(u32::MAX as i64);
let x_hi = _mm256_srli_epi64::<32>(x);
let x_lo = _mm256_and_si256(x, mask32);
let x_hi_r = cond_sub(x_hi, q);
let tmp = _mm256_add_epi64(_mm256_mul_epu32(x_hi_r, pow32), x_lo);
barrett_reduce(tmp, q, mu)
}
}
#[inline(always)]
pub(crate) unsafe fn reduce_b_and_apply_crt(
x: __m256i,
q: __m256i,
mu: __m256i,
pow32_crt: __m256i,
pow16_crt: __m256i,
crt: __m256i,
) -> __m256i {
unsafe {
let mask32 = _mm256_set1_epi64x(u32::MAX as i64);
let mask16 = _mm256_set1_epi64x(0xFFFF_i64);
let x_hi = _mm256_srli_epi64::<32>(x);
let x_hi_r = cond_sub(x_hi, q);
let x_lo = _mm256_and_si256(x, mask32);
let x_lo_hi = _mm256_srli_epi64::<16>(x_lo);
let x_lo_lo = _mm256_and_si256(x_lo, mask16);
let p1 = _mm256_mul_epu32(x_hi_r, pow32_crt);
let p2 = _mm256_mul_epu32(x_lo_hi, pow16_crt);
let p3 = _mm256_mul_epu32(x_lo_lo, crt);
let tmp = _mm256_add_epi64(_mm256_add_epi64(p1, p2), p3);
barrett_reduce(tmp, q, mu)
}
}
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn c_from_b_avx2(nn: usize, res: &mut [u32], a: &[u64]) {
assert!(
res.len() >= 8 * nn,
"c_from_b_avx2: res.len()={} < 8*nn={}",
res.len(),
8 * nn
);
assert!(a.len() >= 4 * nn, "c_from_b_avx2: a.len()={} < 4*nn={}", a.len(), 4 * nn);
unsafe {
let q = _mm256_loadu_si256(Q_VEC.as_ptr() as *const __m256i);
let mu = _mm256_loadu_si256(BARRETT_MU.as_ptr() as *const __m256i);
let pow32 = _mm256_loadu_si256(POW32.as_ptr() as *const __m256i);
let mut a_ptr = a.as_ptr() as *const __m256i;
let mut r_ptr = res.as_mut_ptr() as *mut __m256i;
for _ in 0..nn {
let xv = _mm256_loadu_si256(a_ptr);
let r = reduce_b_to_canonical(xv, q, mu, pow32);
let r_shift = barrett_reduce(_mm256_mul_epu32(r, pow32), q, mu);
let packed = _mm256_or_si256(r, _mm256_slli_epi64::<32>(r_shift));
_mm256_storeu_si256(r_ptr, packed);
a_ptr = a_ptr.add(1);
r_ptr = r_ptr.add(1);
}
}
}
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn pack_left_1blk_x2_avx2(dst: &mut [u32], a: &[u64], row_count: usize, row_stride: usize, blk: usize) {
debug_assert!(dst.len() >= 16 * row_count);
debug_assert!(a.len() >= row_stride.saturating_mul(row_count.saturating_sub(1)) + 8 * blk + 8);
unsafe {
let q = _mm256_loadu_si256(Q_VEC.as_ptr() as *const __m256i);
let mu = _mm256_loadu_si256(BARRETT_MU.as_ptr() as *const __m256i);
let pow32 = _mm256_loadu_si256(POW32.as_ptr() as *const __m256i);
let mut dst_ptr = dst.as_mut_ptr() as *mut __m256i;
let mut a_ptr = a.as_ptr().add(8 * blk) as *const __m256i;
for _ in 0..row_count {
let a0 = _mm256_loadu_si256(a_ptr);
let r0 = reduce_b_to_canonical(a0, q, mu, pow32);
_mm256_storeu_si256(dst_ptr, r0);
let a1 = _mm256_loadu_si256(a_ptr.add(1));
let r1 = reduce_b_to_canonical(a1, q, mu, pow32);
_mm256_storeu_si256(dst_ptr.add(1), r1);
a_ptr = (a_ptr as *const u64).add(row_stride) as *const __m256i;
dst_ptr = dst_ptr.add(2);
}
}
}
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn pack_right_1blk_x2_avx2(dst: &mut [u32], a: &[u32], row_count: usize, row_stride: usize, blk: usize) {
debug_assert!(dst.len() >= 16 * row_count);
debug_assert!(a.len() >= row_stride.saturating_mul(row_count.saturating_sub(1)) + 16 * blk + 16);
unsafe {
let mut dst_ptr = dst.as_mut_ptr() as *mut __m256i;
let mut a_ptr = a.as_ptr().add(row_stride * row_count.saturating_sub(1) + 16 * blk) as *const __m256i;
for _ in 0..row_count {
_mm256_storeu_si256(dst_ptr, _mm256_loadu_si256(a_ptr));
_mm256_storeu_si256(dst_ptr.add(1), _mm256_loadu_si256(a_ptr.add(1)));
a_ptr = (a_ptr as *const u32).sub(row_stride) as *const __m256i;
dst_ptr = dst_ptr.add(2);
}
}
}
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn pairwise_pack_left_1blk_x2_avx2(
dst: &mut [u32],
a: &[u64],
b: &[u64],
row_count: usize,
row_stride: usize,
blk: usize,
) {
debug_assert!(dst.len() >= 16 * row_count);
debug_assert!(a.len() >= row_stride.saturating_mul(row_count.saturating_sub(1)) + 8 * blk + 8);
debug_assert!(b.len() >= row_stride.saturating_mul(row_count.saturating_sub(1)) + 8 * blk + 8);
unsafe {
let q = _mm256_loadu_si256(Q_VEC.as_ptr() as *const __m256i);
let mu = _mm256_loadu_si256(BARRETT_MU.as_ptr() as *const __m256i);
let pow32 = _mm256_loadu_si256(POW32.as_ptr() as *const __m256i);
let mut dst_ptr = dst.as_mut_ptr() as *mut __m256i;
let mut a_ptr = a.as_ptr().add(8 * blk) as *const __m256i;
let mut b_ptr = b.as_ptr().add(8 * blk) as *const __m256i;
for _ in 0..row_count {
let a0 = _mm256_loadu_si256(a_ptr);
let b0 = _mm256_loadu_si256(b_ptr);
let r0 = reduce_b_to_canonical(a0, q, mu, pow32);
let s0 = reduce_b_to_canonical(b0, q, mu, pow32);
_mm256_storeu_si256(dst_ptr, cond_sub(_mm256_add_epi64(r0, s0), q));
let a1 = _mm256_loadu_si256(a_ptr.add(1));
let b1 = _mm256_loadu_si256(b_ptr.add(1));
let r1 = reduce_b_to_canonical(a1, q, mu, pow32);
let s1 = reduce_b_to_canonical(b1, q, mu, pow32);
_mm256_storeu_si256(dst_ptr.add(1), cond_sub(_mm256_add_epi64(r1, s1), q));
a_ptr = (a_ptr as *const u64).add(row_stride) as *const __m256i;
b_ptr = (b_ptr as *const u64).add(row_stride) as *const __m256i;
dst_ptr = dst_ptr.add(2);
}
}
}
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn pairwise_pack_right_1blk_x2_avx2(
dst: &mut [u32],
a: &[u32],
b: &[u32],
row_count: usize,
row_stride: usize,
blk: usize,
) {
debug_assert!(dst.len() >= 16 * row_count);
debug_assert!(a.len() >= row_stride.saturating_mul(row_count.saturating_sub(1)) + 16 * blk + 16);
debug_assert!(b.len() >= row_stride.saturating_mul(row_count.saturating_sub(1)) + 16 * blk + 16);
unsafe {
let mut dst_ptr = dst.as_mut_ptr() as *mut __m256i;
let mut a_ptr = a.as_ptr().add(row_stride * row_count.saturating_sub(1) + 16 * blk) as *const __m256i;
let mut b_ptr = b.as_ptr().add(row_stride * row_count.saturating_sub(1) + 16 * blk) as *const __m256i;
for _ in 0..row_count {
_mm256_storeu_si256(
dst_ptr,
_mm256_add_epi32(_mm256_loadu_si256(a_ptr), _mm256_loadu_si256(b_ptr)),
);
_mm256_storeu_si256(
dst_ptr.add(1),
_mm256_add_epi32(_mm256_loadu_si256(a_ptr.add(1)), _mm256_loadu_si256(b_ptr.add(1))),
);
a_ptr = (a_ptr as *const u32).sub(row_stride) as *const __m256i;
b_ptr = (b_ptr as *const u32).sub(row_stride) as *const __m256i;
dst_ptr = dst_ptr.add(2);
}
}
}
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn vec_mat1col_product_bbb_avx2(meta: &BbbMeta<Primes30>, ell: usize, res: &mut [u64], x: &[u64], y: &[u64]) {
assert!(res.len() >= 4, "vec_mat1col_product_bbb_avx2: res.len()={} < 4", res.len());
assert!(
x.len() >= 4 * ell,
"vec_mat1col_product_bbb_avx2: x.len()={} < 4*ell={}",
x.len(),
4 * ell
);
assert!(
y.len() >= 4 * ell,
"vec_mat1col_product_bbb_avx2: y.len()={} < 4*ell={}",
y.len(),
4 * ell
);
unsafe {
let mask32 = _mm256_set1_epi64x(u32::MAX as i64);
let mut s1 = _mm256_setzero_si256();
let mut s2 = _mm256_setzero_si256();
let mut s3 = _mm256_setzero_si256();
let mut s4 = _mm256_setzero_si256();
let mut x_ptr = x.as_ptr() as *const __m256i;
let mut y_ptr = y.as_ptr() as *const __m256i;
for _ in 0..ell {
let xv = _mm256_loadu_si256(x_ptr);
let xl = _mm256_and_si256(xv, mask32);
let xh = _mm256_srli_epi64::<32>(xv);
let yv = _mm256_loadu_si256(y_ptr);
let yl = _mm256_and_si256(yv, mask32);
let yh = _mm256_srli_epi64::<32>(yv);
let a = _mm256_mul_epu32(xl, yl); let b = _mm256_mul_epu32(xl, yh); let c = _mm256_mul_epu32(xh, yl); let d = _mm256_mul_epu32(xh, yh);
s1 = _mm256_add_epi64(s1, _mm256_and_si256(a, mask32));
s2 = _mm256_add_epi64(s2, _mm256_srli_epi64::<32>(a));
s2 = _mm256_add_epi64(s2, _mm256_and_si256(b, mask32));
s2 = _mm256_add_epi64(s2, _mm256_and_si256(c, mask32));
s3 = _mm256_add_epi64(s3, _mm256_srli_epi64::<32>(b));
s3 = _mm256_add_epi64(s3, _mm256_srli_epi64::<32>(c));
s3 = _mm256_add_epi64(s3, _mm256_and_si256(d, mask32));
s4 = _mm256_add_epi64(s4, _mm256_srli_epi64::<32>(d));
x_ptr = x_ptr.add(1);
y_ptr = y_ptr.add(1);
}
let h2 = meta.h;
let mask_h2 = _mm256_set1_epi64x(((1u64 << h2) - 1) as i64);
let h2_cnt = _mm_cvtsi64_si128(h2 as i64);
let s1h_pow = _mm256_set1_epi64x(meta.s1h_pow_red as i64); let s2l_pow = _mm256_loadu_si256(meta.s2l_pow_red.as_ptr() as *const __m256i);
let s2h_pow = _mm256_loadu_si256(meta.s2h_pow_red.as_ptr() as *const __m256i);
let s3l_pow = _mm256_loadu_si256(meta.s3l_pow_red.as_ptr() as *const __m256i);
let s3h_pow = _mm256_loadu_si256(meta.s3h_pow_red.as_ptr() as *const __m256i);
let s4l_pow = _mm256_loadu_si256(meta.s4l_pow_red.as_ptr() as *const __m256i);
let s4h_pow = _mm256_loadu_si256(meta.s4h_pow_red.as_ptr() as *const __m256i);
let s1l = _mm256_and_si256(s1, mask_h2);
let s1h = _mm256_srl_epi64(s1, h2_cnt);
let s2l = _mm256_and_si256(s2, mask_h2);
let s2h = _mm256_srl_epi64(s2, h2_cnt);
let s3l = _mm256_and_si256(s3, mask_h2);
let s3h = _mm256_srl_epi64(s3, h2_cnt);
let s4l = _mm256_and_si256(s4, mask_h2);
let s4h = _mm256_srl_epi64(s4, h2_cnt);
let mut t = s1l;
t = _mm256_add_epi64(t, _mm256_mul_epu32(s1h, s1h_pow));
t = _mm256_add_epi64(t, _mm256_mul_epu32(s2l, s2l_pow));
t = _mm256_add_epi64(t, _mm256_mul_epu32(s2h, s2h_pow));
t = _mm256_add_epi64(t, _mm256_mul_epu32(s3l, s3l_pow));
t = _mm256_add_epi64(t, _mm256_mul_epu32(s3h, s3h_pow));
t = _mm256_add_epi64(t, _mm256_mul_epu32(s4l, s4l_pow));
t = _mm256_add_epi64(t, _mm256_mul_epu32(s4h, s4h_pow));
_mm256_storeu_si256(res.as_mut_ptr() as *mut __m256i, t);
}
}
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn b_to_znx128_avx2(nn: usize, res: &mut [i128], a: &[u64]) {
assert!(res.len() >= nn, "b_to_znx128_avx2: res.len()={} < nn={}", res.len(), nn);
assert!(a.len() >= 4 * nn, "b_to_znx128_avx2: a.len()={} < 4*nn={}", a.len(), 4 * nn);
let half_q: u128 = TOTAL_Q.div_ceil(2);
unsafe {
let q_vec = _mm256_loadu_si256(Q_VEC.as_ptr() as *const __m256i);
let mu_vec = _mm256_loadu_si256(BARRETT_MU.as_ptr() as *const __m256i);
let pow32_crt_vec = _mm256_loadu_si256(POW32_CRT.as_ptr() as *const __m256i);
let pow16_crt_vec = _mm256_loadu_si256(POW16_CRT.as_ptr() as *const __m256i);
let crt_vec = _mm256_loadu_si256(CRT_VEC.as_ptr() as *const __m256i);
let qm_hi_vec = _mm256_loadu_si256(QM_HI.as_ptr() as *const __m256i);
let qm_mid_vec = _mm256_loadu_si256(QM_MID.as_ptr() as *const __m256i);
let qm_lo_vec = _mm256_loadu_si256(QM_LO.as_ptr() as *const __m256i);
let mut a_ptr = a.as_ptr() as *const __m256i;
for r in &mut res[..nn] {
let xv = _mm256_loadu_si256(a_ptr);
let t = reduce_b_and_apply_crt(xv, q_vec, mu_vec, pow32_crt_vec, pow16_crt_vec, crt_vec);
let mut v = crt_accumulate_avx2(t, qm_hi_vec, qm_mid_vec, qm_lo_vec);
let q_approx = (v >> 120) as usize;
v -= TOTAL_Q_MULT[q_approx]; if v >= TOTAL_Q {
v -= TOTAL_Q; }
*r = if v >= half_q { v as i128 - TOTAL_Q as i128 } else { v as i128 };
a_ptr = a_ptr.add(1);
}
}
}
#[cfg(all(test, target_feature = "avx2"))]
mod tests {
use super::*;
use poulpy_cpu_ref::reference::ntt120::{
arithmetic::{b_from_znx64_ref, b_to_znx128_ref, c_from_b_ref},
mat_vec::{BbbMeta, vec_mat1col_product_bbb_ref},
primes::Primes30,
};
#[test]
fn b_from_znx64_avx2_vs_ref() {
let n = 64usize;
let coeffs: Vec<i64> = (0..n as i64).map(|i| i * 17 - 500).collect();
let mut res_avx = vec![0u64; 4 * n];
let mut res_ref = vec![0u64; 4 * n];
unsafe { b_from_znx64_avx2(n, &mut res_avx, &coeffs) };
b_from_znx64_ref::<Primes30>(n, &mut res_ref, &coeffs);
assert_eq!(res_avx, res_ref, "b_from_znx64: AVX2 vs ref mismatch");
}
#[test]
fn c_from_b_avx2_vs_ref() {
let n = 64usize;
let coeffs: Vec<i64> = (0..n as i64).map(|i| i * 11 + 3).collect();
let mut b = vec![0u64; 4 * n];
b_from_znx64_ref::<Primes30>(n, &mut b, &coeffs);
let mut res_avx = vec![0u32; 8 * n];
let mut res_ref = vec![0u32; 8 * n];
unsafe { c_from_b_avx2(n, &mut res_avx, &b) };
c_from_b_ref::<Primes30>(n, &mut res_ref, &b);
assert_eq!(res_avx, res_ref, "c_from_b: AVX2 vs ref mismatch");
}
#[test]
fn vec_mat1col_product_bbb_avx2_vs_ref() {
let ell = 16usize;
let n = 64usize;
let meta = BbbMeta::<Primes30>::new();
let x_i64: Vec<i64> = (0..ell * n).map(|i| (i as i64 * 7 + 1) % 100).collect();
let y_i64: Vec<i64> = (0..ell * n).map(|i| (i as i64 * 13 + 2) % 100).collect();
let mut x = vec![0u64; 4 * ell * n];
let mut y = vec![0u64; 4 * ell * n];
b_from_znx64_ref::<Primes30>(ell * n, &mut x, &x_i64);
b_from_znx64_ref::<Primes30>(ell * n, &mut y, &y_i64);
let mut res_avx = vec![0u64; 4 * n];
let mut res_ref = vec![0u64; 4 * n];
unsafe { vec_mat1col_product_bbb_avx2(&meta, ell, &mut res_avx, &x, &y) };
vec_mat1col_product_bbb_ref::<Primes30>(&meta, ell, &mut res_ref, &x, &y);
assert_eq!(res_avx, res_ref, "vec_mat1col_product_bbb: AVX2 vs ref mismatch");
}
#[test]
fn reduce_b_and_apply_crt_vs_two_step() {
use poulpy_cpu_ref::reference::ntt120::arithmetic::b_from_znx64_ref;
let n = 64usize;
let coeffs: Vec<i64> = (0..n as i64).map(|i| i * 5 - 160).collect();
let mut b = vec![0u64; 4 * n];
b_from_znx64_ref::<Primes30>(n, &mut b, &coeffs);
let q = unsafe { _mm256_loadu_si256(Q_VEC.as_ptr() as *const __m256i) };
let mu = unsafe { _mm256_loadu_si256(BARRETT_MU.as_ptr() as *const __m256i) };
let pow32 = unsafe { _mm256_loadu_si256(POW32.as_ptr() as *const __m256i) };
let crt = unsafe { _mm256_loadu_si256(CRT_VEC.as_ptr() as *const __m256i) };
let pow32_crt = unsafe { _mm256_loadu_si256(POW32_CRT.as_ptr() as *const __m256i) };
let pow16_crt = unsafe { _mm256_loadu_si256(POW16_CRT.as_ptr() as *const __m256i) };
for j in 0..n {
let xv = unsafe { _mm256_loadu_si256(b[4 * j..].as_ptr() as *const __m256i) };
let mut two_step = [0u64; 4];
let mut fused = [0u64; 4];
unsafe {
let xk = reduce_b_to_canonical(xv, q, mu, pow32);
let t = barrett_reduce(_mm256_mul_epu32(xk, crt), q, mu);
_mm256_storeu_si256(two_step.as_mut_ptr() as *mut __m256i, t);
let t2 = reduce_b_and_apply_crt(xv, q, mu, pow32_crt, pow16_crt, crt);
_mm256_storeu_si256(fused.as_mut_ptr() as *mut __m256i, t2);
}
assert_eq!(fused, two_step, "reduce_b_and_apply_crt mismatch at j={j}");
}
}
#[test]
fn b_to_znx128_avx2_vs_ref() {
let n = 64usize;
let coeffs: Vec<i64> = (0..n as i64).map(|i| i * 5 - 160).collect();
let mut b = vec![0u64; 4 * n];
b_from_znx64_ref::<Primes30>(n, &mut b, &coeffs);
let mut res_avx = vec![0i128; n];
let mut res_ref = vec![0i128; n];
unsafe { b_to_znx128_avx2(n, &mut res_avx, &b) };
b_to_znx128_ref::<Primes30>(n, &mut res_ref, &b);
assert_eq!(res_avx, res_ref, "b_to_znx128: AVX2 vs ref mismatch");
}
}