use core::arch::x86_64::{
__m256i, _mm_cvtsi64_si128, _mm256_add_epi64, _mm256_and_si256, _mm256_andnot_si256, _mm256_cmpgt_epi64, _mm256_loadu_si256,
_mm256_mul_epu32, _mm256_or_si256, _mm256_set1_epi64x, _mm256_setzero_si256, _mm256_slli_epi64, _mm256_srl_epi64,
_mm256_srli_epi64, _mm256_storeu_si256, _mm256_sub_epi64,
};
use poulpy_hal::reference::ntt120::{
mat_vec::BbbMeta,
primes::{PrimeSet, Primes30},
};
const Q_VEC: [u64; 4] = [
Primes30::Q[0] as u64,
Primes30::Q[1] as u64,
Primes30::Q[2] as u64,
Primes30::Q[3] as u64,
];
const OQ: [u64; 4] = {
let mut oq = [0u64; 4];
let mut k = 0usize;
while k < 4 {
let q = Q_VEC[k];
oq[k] = q - (i64::MIN as u64 % q); k += 1;
}
oq
};
const BARRETT_MU: [u64; 4] = {
let mut mu = [0u64; 4];
let mut k = 0usize;
while k < 4 {
mu[k] = (1u64 << 61) / Q_VEC[k];
k += 1;
}
mu
};
const POW32: [u64; 4] = {
let mut p = [0u64; 4];
let mut k = 0usize;
while k < 4 {
p[k] = ((1u128 << 32) % Q_VEC[k] as u128) as u64;
k += 1;
}
p
};
const CRT_VEC: [u64; 4] = [
Primes30::CRT_CST[0] as u64,
Primes30::CRT_CST[1] as u64,
Primes30::CRT_CST[2] as u64,
Primes30::CRT_CST[3] as u64,
];
#[inline(always)]
unsafe fn cond_sub(x: __m256i, q: __m256i) -> __m256i {
unsafe {
let lt = _mm256_cmpgt_epi64(q, x);
_mm256_sub_epi64(x, _mm256_andnot_si256(lt, q))
}
}
#[inline(always)]
unsafe fn barrett_reduce(tmp: __m256i, q: __m256i, mu: __m256i) -> __m256i {
unsafe {
let mask32 = _mm256_set1_epi64x(u32::MAX as i64);
let tmp_hi = _mm256_srli_epi64::<32>(tmp);
let tmp_lo = _mm256_and_si256(tmp, mask32);
let q_hi = _mm256_srli_epi64::<29>(_mm256_mul_epu32(tmp_hi, mu));
let q_lo = _mm256_srli_epi64::<61>(_mm256_mul_epu32(tmp_lo, mu));
let q_approx = _mm256_add_epi64(q_hi, q_lo);
let r = _mm256_sub_epi64(tmp, _mm256_mul_epu32(q_approx, q));
let r = cond_sub(r, q);
cond_sub(r, q)
}
}
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn b_from_znx64_avx2(nn: usize, res: &mut [u64], x: &[i64]) {
assert!(
res.len() >= 4 * nn,
"b_from_znx64_avx2: res.len()={} < 4*nn={}",
res.len(),
4 * nn
);
assert!(x.len() >= nn, "b_from_znx64_avx2: x.len()={} < nn={}", x.len(), nn);
unsafe {
let oq_vec = _mm256_loadu_si256(OQ.as_ptr() as *const __m256i);
let i64_max = _mm256_set1_epi64x(i64::MAX);
let zero = _mm256_setzero_si256();
let mut r_ptr = res.as_mut_ptr() as *mut __m256i;
for &xval in &x[..nn] {
let xv = _mm256_set1_epi64x(xval);
let xl = _mm256_and_si256(xv, i64_max);
let sign = _mm256_cmpgt_epi64(zero, xv);
let add = _mm256_and_si256(sign, oq_vec);
_mm256_storeu_si256(r_ptr, _mm256_add_epi64(xl, add));
r_ptr = r_ptr.add(1);
}
}
}
#[inline(always)]
unsafe fn reduce_b_to_canonical(x: __m256i, q: __m256i, mu: __m256i, pow32: __m256i) -> __m256i {
unsafe {
let mask32 = _mm256_set1_epi64x(u32::MAX as i64);
let x_hi = _mm256_srli_epi64::<32>(x);
let x_lo = _mm256_and_si256(x, mask32);
let x_hi_r = cond_sub(x_hi, q);
let tmp = _mm256_add_epi64(_mm256_mul_epu32(x_hi_r, pow32), x_lo);
barrett_reduce(tmp, q, mu)
}
}
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn c_from_b_avx2(nn: usize, res: &mut [u32], a: &[u64]) {
assert!(
res.len() >= 8 * nn,
"c_from_b_avx2: res.len()={} < 8*nn={}",
res.len(),
8 * nn
);
assert!(a.len() >= 4 * nn, "c_from_b_avx2: a.len()={} < 4*nn={}", a.len(), 4 * nn);
unsafe {
let q = _mm256_loadu_si256(Q_VEC.as_ptr() as *const __m256i);
let mu = _mm256_loadu_si256(BARRETT_MU.as_ptr() as *const __m256i);
let pow32 = _mm256_loadu_si256(POW32.as_ptr() as *const __m256i);
let mut a_ptr = a.as_ptr() as *const __m256i;
let mut r_ptr = res.as_mut_ptr() as *mut __m256i;
for _ in 0..nn {
let xv = _mm256_loadu_si256(a_ptr);
let r = reduce_b_to_canonical(xv, q, mu, pow32);
let r_shift = barrett_reduce(_mm256_mul_epu32(r, pow32), q, mu);
let packed = _mm256_or_si256(r, _mm256_slli_epi64::<32>(r_shift));
_mm256_storeu_si256(r_ptr, packed);
a_ptr = a_ptr.add(1);
r_ptr = r_ptr.add(1);
}
}
}
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn vec_mat1col_product_bbb_avx2(meta: &BbbMeta<Primes30>, ell: usize, res: &mut [u64], x: &[u64], y: &[u64]) {
assert!(res.len() >= 4, "vec_mat1col_product_bbb_avx2: res.len()={} < 4", res.len());
assert!(
x.len() >= 4 * ell,
"vec_mat1col_product_bbb_avx2: x.len()={} < 4*ell={}",
x.len(),
4 * ell
);
assert!(
y.len() >= 4 * ell,
"vec_mat1col_product_bbb_avx2: y.len()={} < 4*ell={}",
y.len(),
4 * ell
);
unsafe {
let mask32 = _mm256_set1_epi64x(u32::MAX as i64);
let mut s1 = _mm256_setzero_si256();
let mut s2 = _mm256_setzero_si256();
let mut s3 = _mm256_setzero_si256();
let mut s4 = _mm256_setzero_si256();
let mut x_ptr = x.as_ptr() as *const __m256i;
let mut y_ptr = y.as_ptr() as *const __m256i;
for _ in 0..ell {
let xv = _mm256_loadu_si256(x_ptr);
let xl = _mm256_and_si256(xv, mask32);
let xh = _mm256_srli_epi64::<32>(xv);
let yv = _mm256_loadu_si256(y_ptr);
let yl = _mm256_and_si256(yv, mask32);
let yh = _mm256_srli_epi64::<32>(yv);
let a = _mm256_mul_epu32(xl, yl); let b = _mm256_mul_epu32(xl, yh); let c = _mm256_mul_epu32(xh, yl); let d = _mm256_mul_epu32(xh, yh);
s1 = _mm256_add_epi64(s1, _mm256_and_si256(a, mask32));
s2 = _mm256_add_epi64(s2, _mm256_srli_epi64::<32>(a));
s2 = _mm256_add_epi64(s2, _mm256_and_si256(b, mask32));
s2 = _mm256_add_epi64(s2, _mm256_and_si256(c, mask32));
s3 = _mm256_add_epi64(s3, _mm256_srli_epi64::<32>(b));
s3 = _mm256_add_epi64(s3, _mm256_srli_epi64::<32>(c));
s3 = _mm256_add_epi64(s3, _mm256_and_si256(d, mask32));
s4 = _mm256_add_epi64(s4, _mm256_srli_epi64::<32>(d));
x_ptr = x_ptr.add(1);
y_ptr = y_ptr.add(1);
}
let h2 = meta.h;
let mask_h2 = _mm256_set1_epi64x(((1u64 << h2) - 1) as i64);
let h2_cnt = _mm_cvtsi64_si128(h2 as i64);
let s1h_pow = _mm256_set1_epi64x(meta.s1h_pow_red as i64); let s2l_pow = _mm256_loadu_si256(meta.s2l_pow_red.as_ptr() as *const __m256i);
let s2h_pow = _mm256_loadu_si256(meta.s2h_pow_red.as_ptr() as *const __m256i);
let s3l_pow = _mm256_loadu_si256(meta.s3l_pow_red.as_ptr() as *const __m256i);
let s3h_pow = _mm256_loadu_si256(meta.s3h_pow_red.as_ptr() as *const __m256i);
let s4l_pow = _mm256_loadu_si256(meta.s4l_pow_red.as_ptr() as *const __m256i);
let s4h_pow = _mm256_loadu_si256(meta.s4h_pow_red.as_ptr() as *const __m256i);
let s1l = _mm256_and_si256(s1, mask_h2);
let s1h = _mm256_srl_epi64(s1, h2_cnt);
let s2l = _mm256_and_si256(s2, mask_h2);
let s2h = _mm256_srl_epi64(s2, h2_cnt);
let s3l = _mm256_and_si256(s3, mask_h2);
let s3h = _mm256_srl_epi64(s3, h2_cnt);
let s4l = _mm256_and_si256(s4, mask_h2);
let s4h = _mm256_srl_epi64(s4, h2_cnt);
let mut t = s1l;
t = _mm256_add_epi64(t, _mm256_mul_epu32(s1h, s1h_pow));
t = _mm256_add_epi64(t, _mm256_mul_epu32(s2l, s2l_pow));
t = _mm256_add_epi64(t, _mm256_mul_epu32(s2h, s2h_pow));
t = _mm256_add_epi64(t, _mm256_mul_epu32(s3l, s3l_pow));
t = _mm256_add_epi64(t, _mm256_mul_epu32(s3h, s3h_pow));
t = _mm256_add_epi64(t, _mm256_mul_epu32(s4l, s4l_pow));
t = _mm256_add_epi64(t, _mm256_mul_epu32(s4h, s4h_pow));
_mm256_storeu_si256(res.as_mut_ptr() as *mut __m256i, t);
}
}
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn b_to_znx128_avx2(nn: usize, res: &mut [i128], a: &[u64]) {
assert!(res.len() >= nn, "b_to_znx128_avx2: res.len()={} < nn={}", res.len(), nn);
assert!(a.len() >= 4 * nn, "b_to_znx128_avx2: a.len()={} < 4*nn={}", a.len(), 4 * nn);
let q: [i128; 4] = Primes30::Q.map(|qi| qi as i128);
let total_q: i128 = q[0] * q[1] * q[2] * q[3];
let qm: [i128; 4] = [q[1] * q[2] * q[3], q[0] * q[2] * q[3], q[0] * q[1] * q[3], q[0] * q[1] * q[2]];
let half = (total_q + 1) / 2;
unsafe {
let q_vec = _mm256_loadu_si256(Q_VEC.as_ptr() as *const __m256i);
let mu_vec = _mm256_loadu_si256(BARRETT_MU.as_ptr() as *const __m256i);
let pow32_vec = _mm256_loadu_si256(POW32.as_ptr() as *const __m256i);
let crt_vec = _mm256_loadu_si256(CRT_VEC.as_ptr() as *const __m256i);
let mut a_ptr = a.as_ptr() as *const __m256i;
for r in &mut res[..nn] {
let xv = _mm256_loadu_si256(a_ptr);
let xk = reduce_b_to_canonical(xv, q_vec, mu_vec, pow32_vec);
let crt_prod = _mm256_mul_epu32(xk, crt_vec); let t = barrett_reduce(crt_prod, q_vec, mu_vec);
let mut t_arr = [0u64; 4];
_mm256_storeu_si256(t_arr.as_mut_ptr() as *mut __m256i, t);
let mut tmp: i128 = 0;
for k in 0..4 {
tmp += t_arr[k] as i128 * qm[k];
}
tmp %= total_q;
*r = if tmp >= half { tmp - total_q } else { tmp };
a_ptr = a_ptr.add(1);
}
}
}
#[cfg(all(test, target_feature = "avx2"))]
mod tests {
use super::*;
use poulpy_hal::reference::ntt120::{
arithmetic::{b_from_znx64_ref, b_to_znx128_ref, c_from_b_ref},
mat_vec::{BbbMeta, vec_mat1col_product_bbb_ref},
primes::Primes30,
};
#[test]
fn b_from_znx64_avx2_vs_ref() {
let n = 64usize;
let coeffs: Vec<i64> = (0..n as i64).map(|i| i * 17 - 500).collect();
let mut res_avx = vec![0u64; 4 * n];
let mut res_ref = vec![0u64; 4 * n];
unsafe { b_from_znx64_avx2(n, &mut res_avx, &coeffs) };
b_from_znx64_ref::<Primes30>(n, &mut res_ref, &coeffs);
assert_eq!(res_avx, res_ref, "b_from_znx64: AVX2 vs ref mismatch");
}
#[test]
fn c_from_b_avx2_vs_ref() {
let n = 64usize;
let coeffs: Vec<i64> = (0..n as i64).map(|i| i * 11 + 3).collect();
let mut b = vec![0u64; 4 * n];
b_from_znx64_ref::<Primes30>(n, &mut b, &coeffs);
let mut res_avx = vec![0u32; 8 * n];
let mut res_ref = vec![0u32; 8 * n];
unsafe { c_from_b_avx2(n, &mut res_avx, &b) };
c_from_b_ref::<Primes30>(n, &mut res_ref, &b);
assert_eq!(res_avx, res_ref, "c_from_b: AVX2 vs ref mismatch");
}
#[test]
fn vec_mat1col_product_bbb_avx2_vs_ref() {
let ell = 16usize;
let n = 64usize;
let meta = BbbMeta::<Primes30>::new();
let x_i64: Vec<i64> = (0..ell * n).map(|i| (i as i64 * 7 + 1) % 100).collect();
let y_i64: Vec<i64> = (0..ell * n).map(|i| (i as i64 * 13 + 2) % 100).collect();
let mut x = vec![0u64; 4 * ell * n];
let mut y = vec![0u64; 4 * ell * n];
b_from_znx64_ref::<Primes30>(ell * n, &mut x, &x_i64);
b_from_znx64_ref::<Primes30>(ell * n, &mut y, &y_i64);
let mut res_avx = vec![0u64; 4 * n];
let mut res_ref = vec![0u64; 4 * n];
unsafe { vec_mat1col_product_bbb_avx2(&meta, ell, &mut res_avx, &x, &y) };
vec_mat1col_product_bbb_ref::<Primes30>(&meta, ell, &mut res_ref, &x, &y);
assert_eq!(res_avx, res_ref, "vec_mat1col_product_bbb: AVX2 vs ref mismatch");
}
#[test]
fn b_to_znx128_avx2_vs_ref() {
let n = 64usize;
let coeffs: Vec<i64> = (0..n as i64).map(|i| i * 5 - 160).collect();
let mut b = vec![0u64; 4 * n];
b_from_znx64_ref::<Primes30>(n, &mut b, &coeffs);
let mut res_avx = vec![0i128; n];
let mut res_ref = vec![0i128; n];
unsafe { b_to_znx128_avx2(n, &mut res_avx, &b) };
b_to_znx128_ref::<Primes30>(n, &mut res_ref, &b);
assert_eq!(res_avx, res_ref, "b_to_znx128: AVX2 vs ref mismatch");
}
}