use core::arch::x86_64::{
__m256i, _mm_cvtsi64_si128, _mm256_add_epi64, _mm256_and_si256, _mm256_loadu_si256, _mm256_mul_epu32, _mm256_set1_epi64x,
_mm256_setzero_si256, _mm256_srl_epi64, _mm256_srli_epi64, _mm256_storeu_si256,
};
use poulpy_hal::reference::ntt120::{mat_vec::BbcMeta, primes::Primes30};
#[inline(always)]
unsafe fn reduce_bbc(s_lo: __m256i, s_hi: __m256i, mask_h2: __m256i, h2: u64, s2l: __m256i, s2h: __m256i) -> __m256i {
unsafe {
let h2_count = _mm_cvtsi64_si128(h2 as i64);
let hi_lo = _mm256_and_si256(s_hi, mask_h2);
let hi_hi = _mm256_srl_epi64(s_hi, h2_count);
let t = _mm256_add_epi64(s_lo, _mm256_mul_epu32(hi_lo, s2l));
_mm256_add_epi64(t, _mm256_mul_epu32(hi_hi, s2h))
}
}
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn vec_mat1col_product_bbc_avx2(meta: &BbcMeta<Primes30>, ell: usize, res: &mut [u64], x: &[u32], y: &[u32]) {
unsafe {
let mask32 = _mm256_set1_epi64x(u32::MAX as i64);
let mut s1 = _mm256_setzero_si256(); let mut s2 = _mm256_setzero_si256();
let mut x_ptr = x.as_ptr() as *const __m256i;
let mut y_ptr = y.as_ptr() as *const __m256i;
for _ in 0..ell {
let xv = _mm256_loadu_si256(x_ptr);
let xl = _mm256_and_si256(xv, mask32); let xh = _mm256_srli_epi64::<32>(xv);
let yv = _mm256_loadu_si256(y_ptr);
let y0 = _mm256_and_si256(yv, mask32); let y1 = _mm256_srli_epi64::<32>(yv);
let a = _mm256_mul_epu32(xl, y0);
let b = _mm256_mul_epu32(xh, y1);
s1 = _mm256_add_epi64(s1, _mm256_and_si256(a, mask32));
s1 = _mm256_add_epi64(s1, _mm256_and_si256(b, mask32));
s2 = _mm256_add_epi64(s2, _mm256_srli_epi64::<32>(a));
s2 = _mm256_add_epi64(s2, _mm256_srli_epi64::<32>(b));
x_ptr = x_ptr.add(1);
y_ptr = y_ptr.add(1);
}
let mask_h2 = _mm256_set1_epi64x(((1u64 << meta.h) - 1) as i64);
let s2l_pow_red = _mm256_loadu_si256(meta.s2l_pow_red.as_ptr() as *const __m256i);
let s2h_pow_red = _mm256_loadu_si256(meta.s2h_pow_red.as_ptr() as *const __m256i);
let t = reduce_bbc(s1, s2, mask_h2, meta.h, s2l_pow_red, s2h_pow_red);
_mm256_storeu_si256(res.as_mut_ptr() as *mut __m256i, t);
}
}
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn vec_mat1col_product_x2_bbc_avx2(
meta: &BbcMeta<Primes30>,
ell: usize,
res: &mut [u64],
x: &[u32],
y: &[u32],
) {
unsafe {
let mask32 = _mm256_set1_epi64x(u32::MAX as i64);
let mut s0 = _mm256_setzero_si256(); let mut s1 = _mm256_setzero_si256(); let mut s2 = _mm256_setzero_si256(); let mut s3 = _mm256_setzero_si256();
let mut x_ptr = x.as_ptr() as *const __m256i;
let mut y_ptr = y.as_ptr() as *const __m256i;
for _ in 0..ell {
let xa = _mm256_loadu_si256(x_ptr);
let xa_hi = _mm256_srli_epi64::<32>(xa);
let ya = _mm256_loadu_si256(y_ptr);
let ya_hi = _mm256_srli_epi64::<32>(ya);
let prod_a_lo = _mm256_mul_epu32(xa, ya);
let prod_a_hi = _mm256_mul_epu32(xa_hi, ya_hi);
s0 = _mm256_add_epi64(s0, _mm256_and_si256(prod_a_lo, mask32));
s0 = _mm256_add_epi64(s0, _mm256_and_si256(prod_a_hi, mask32));
s1 = _mm256_add_epi64(s1, _mm256_srli_epi64::<32>(prod_a_lo));
s1 = _mm256_add_epi64(s1, _mm256_srli_epi64::<32>(prod_a_hi));
let xb = _mm256_loadu_si256(x_ptr.add(1));
let xb_hi = _mm256_srli_epi64::<32>(xb);
let yb = _mm256_loadu_si256(y_ptr.add(1));
let yb_hi = _mm256_srli_epi64::<32>(yb);
let prod_b_lo = _mm256_mul_epu32(xb, yb);
let prod_b_hi = _mm256_mul_epu32(xb_hi, yb_hi);
s2 = _mm256_add_epi64(s2, _mm256_and_si256(prod_b_lo, mask32));
s2 = _mm256_add_epi64(s2, _mm256_and_si256(prod_b_hi, mask32));
s3 = _mm256_add_epi64(s3, _mm256_srli_epi64::<32>(prod_b_lo));
s3 = _mm256_add_epi64(s3, _mm256_srli_epi64::<32>(prod_b_hi));
x_ptr = x_ptr.add(2);
y_ptr = y_ptr.add(2);
}
let mask_h2 = _mm256_set1_epi64x(((1u64 << meta.h) - 1) as i64);
let s2l_pow_red = _mm256_loadu_si256(meta.s2l_pow_red.as_ptr() as *const __m256i);
let s2h_pow_red = _mm256_loadu_si256(meta.s2h_pow_red.as_ptr() as *const __m256i);
let res_ptr = res.as_mut_ptr() as *mut __m256i;
_mm256_storeu_si256(res_ptr, reduce_bbc(s0, s1, mask_h2, meta.h, s2l_pow_red, s2h_pow_red));
_mm256_storeu_si256(res_ptr.add(1), reduce_bbc(s2, s3, mask_h2, meta.h, s2l_pow_red, s2h_pow_red));
}
}
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn vec_mat2cols_product_x2_bbc_avx2(
meta: &BbcMeta<Primes30>,
ell: usize,
res: &mut [u64],
x: &[u32],
y: &[u32],
) {
unsafe {
let mask32 = _mm256_set1_epi64x(u32::MAX as i64);
let mut s0 = _mm256_setzero_si256();
let mut s1 = _mm256_setzero_si256();
let mut s2 = _mm256_setzero_si256();
let mut s3 = _mm256_setzero_si256();
let mut s4 = _mm256_setzero_si256();
let mut s5 = _mm256_setzero_si256();
let mut s6 = _mm256_setzero_si256();
let mut s7 = _mm256_setzero_si256();
let mut x_ptr = x.as_ptr() as *const __m256i;
let mut y_ptr = y.as_ptr() as *const __m256i;
for _ in 0..ell {
let xa = _mm256_loadu_si256(x_ptr);
let xa_hi = _mm256_srli_epi64::<32>(xa);
let xb = _mm256_loadu_si256(x_ptr.add(1));
let xb_hi = _mm256_srli_epi64::<32>(xb);
let yc0a = _mm256_loadu_si256(y_ptr);
let yc0a_hi = _mm256_srli_epi64::<32>(yc0a);
let p0a_lo = _mm256_mul_epu32(xa, yc0a);
let p0a_hi = _mm256_mul_epu32(xa_hi, yc0a_hi);
s0 = _mm256_add_epi64(s0, _mm256_and_si256(p0a_lo, mask32));
s0 = _mm256_add_epi64(s0, _mm256_and_si256(p0a_hi, mask32));
s1 = _mm256_add_epi64(s1, _mm256_srli_epi64::<32>(p0a_lo));
s1 = _mm256_add_epi64(s1, _mm256_srli_epi64::<32>(p0a_hi));
let yc0b = _mm256_loadu_si256(y_ptr.add(1));
let yc0b_hi = _mm256_srli_epi64::<32>(yc0b);
let p0b_lo = _mm256_mul_epu32(xb, yc0b);
let p0b_hi = _mm256_mul_epu32(xb_hi, yc0b_hi);
s2 = _mm256_add_epi64(s2, _mm256_and_si256(p0b_lo, mask32));
s2 = _mm256_add_epi64(s2, _mm256_and_si256(p0b_hi, mask32));
s3 = _mm256_add_epi64(s3, _mm256_srli_epi64::<32>(p0b_lo));
s3 = _mm256_add_epi64(s3, _mm256_srli_epi64::<32>(p0b_hi));
let yc1a = _mm256_loadu_si256(y_ptr.add(2));
let yc1a_hi = _mm256_srli_epi64::<32>(yc1a);
let p1a_lo = _mm256_mul_epu32(xa, yc1a);
let p1a_hi = _mm256_mul_epu32(xa_hi, yc1a_hi);
s4 = _mm256_add_epi64(s4, _mm256_and_si256(p1a_lo, mask32));
s4 = _mm256_add_epi64(s4, _mm256_and_si256(p1a_hi, mask32));
s5 = _mm256_add_epi64(s5, _mm256_srli_epi64::<32>(p1a_lo));
s5 = _mm256_add_epi64(s5, _mm256_srli_epi64::<32>(p1a_hi));
let yc1b = _mm256_loadu_si256(y_ptr.add(3));
let yc1b_hi = _mm256_srli_epi64::<32>(yc1b);
let p1b_lo = _mm256_mul_epu32(xb, yc1b);
let p1b_hi = _mm256_mul_epu32(xb_hi, yc1b_hi);
s6 = _mm256_add_epi64(s6, _mm256_and_si256(p1b_lo, mask32));
s6 = _mm256_add_epi64(s6, _mm256_and_si256(p1b_hi, mask32));
s7 = _mm256_add_epi64(s7, _mm256_srli_epi64::<32>(p1b_lo));
s7 = _mm256_add_epi64(s7, _mm256_srli_epi64::<32>(p1b_hi));
x_ptr = x_ptr.add(2);
y_ptr = y_ptr.add(4);
}
let mask_h2 = _mm256_set1_epi64x(((1u64 << meta.h) - 1) as i64);
let s2l_pow_red = _mm256_loadu_si256(meta.s2l_pow_red.as_ptr() as *const __m256i);
let s2h_pow_red = _mm256_loadu_si256(meta.s2h_pow_red.as_ptr() as *const __m256i);
let res_ptr = res.as_mut_ptr() as *mut __m256i;
_mm256_storeu_si256(res_ptr, reduce_bbc(s0, s1, mask_h2, meta.h, s2l_pow_red, s2h_pow_red));
_mm256_storeu_si256(res_ptr.add(1), reduce_bbc(s2, s3, mask_h2, meta.h, s2l_pow_red, s2h_pow_red));
_mm256_storeu_si256(res_ptr.add(2), reduce_bbc(s4, s5, mask_h2, meta.h, s2l_pow_red, s2h_pow_red));
_mm256_storeu_si256(res_ptr.add(3), reduce_bbc(s6, s7, mask_h2, meta.h, s2l_pow_red, s2h_pow_red));
}
}
#[cfg(all(test, target_feature = "avx2"))]
mod tests {
use super::*;
use poulpy_hal::reference::ntt120::{
arithmetic::{b_from_znx64_ref, c_from_b_ref},
mat_vec::{BbcMeta, vec_mat1col_product_bbc_ref, vec_mat1col_product_x2_bbc_ref, vec_mat2cols_product_x2_bbc_ref},
primes::Primes30,
};
fn b_to_u32(b: &[u64]) -> Vec<u32> {
b.iter().flat_map(|&v| [v as u32, (v >> 32) as u32]).collect()
}
fn make_q120b_u32(ell: usize, n: usize, seed: i64) -> Vec<u32> {
let coeffs: Vec<i64> = (0..ell * n).map(|i| (i as i64 * seed + 1) % 50 + 1).collect();
let mut b = vec![0u64; 4 * ell * n];
b_from_znx64_ref::<Primes30>(ell * n, &mut b, &coeffs);
b_to_u32(&b)
}
fn make_q120c_u32(ell: usize, n: usize, seed: i64) -> Vec<u32> {
let coeffs: Vec<i64> = (0..ell * n).map(|i| (i as i64 * seed + 2) % 50 + 1).collect();
let mut b = vec![0u64; 4 * ell * n];
b_from_znx64_ref::<Primes30>(ell * n, &mut b, &coeffs);
let mut c = vec![0u32; 8 * ell * n];
c_from_b_ref::<Primes30>(ell * n, &mut c, &b);
c
}
#[test]
fn vec_mat1col_product_bbc_avx2_vs_ref() {
let ell = 8usize;
let n = 1usize; let meta = BbcMeta::<Primes30>::new();
let x = make_q120b_u32(ell, n, 7); let y = make_q120c_u32(ell, n, 13);
let mut res_avx = vec![0u64; 4];
let mut res_ref = vec![0u64; 4];
unsafe { vec_mat1col_product_bbc_avx2(&meta, ell, &mut res_avx, &x, &y) };
vec_mat1col_product_bbc_ref::<Primes30>(&meta, ell, &mut res_ref, &x, &y);
assert_eq!(res_avx, res_ref, "vec_mat1col_product_bbc: AVX2 vs ref mismatch");
}
#[test]
fn vec_mat1col_product_x2_bbc_avx2_vs_ref() {
let ell = 8usize;
let n = 1usize;
let meta = BbcMeta::<Primes30>::new();
let x: Vec<u32> = {
let a = make_q120b_u32(ell, n, 5);
let b = make_q120b_u32(ell, n, 11);
(0..ell)
.flat_map(|i| a[8 * i..8 * i + 8].iter().chain(b[8 * i..8 * i + 8].iter()).cloned())
.collect()
};
let y: Vec<u32> = {
let a = make_q120c_u32(ell, n, 3);
let b = make_q120c_u32(ell, n, 17);
(0..ell)
.flat_map(|i| a[8 * i..8 * i + 8].iter().chain(b[8 * i..8 * i + 8].iter()).cloned())
.collect()
};
let mut res_avx = vec![0u64; 8];
let mut res_ref = vec![0u64; 8];
unsafe { vec_mat1col_product_x2_bbc_avx2(&meta, ell, &mut res_avx, &x, &y) };
vec_mat1col_product_x2_bbc_ref::<Primes30>(&meta, ell, &mut res_ref, &x, &y);
assert_eq!(res_avx, res_ref, "vec_mat1col_product_x2_bbc: AVX2 vs ref mismatch");
}
#[test]
fn vec_mat2cols_product_x2_bbc_avx2_vs_ref() {
let ell = 8usize;
let n = 1usize;
let meta = BbcMeta::<Primes30>::new();
let x: Vec<u32> = {
let a = make_q120b_u32(ell, n, 7);
let b = make_q120b_u32(ell, n, 19);
(0..ell)
.flat_map(|i| a[8 * i..8 * i + 8].iter().chain(b[8 * i..8 * i + 8].iter()).cloned())
.collect()
};
let y: Vec<u32> = {
let c0a = make_q120c_u32(ell, n, 2);
let c0b = make_q120c_u32(ell, n, 9);
let c1a = make_q120c_u32(ell, n, 23);
let c1b = make_q120c_u32(ell, n, 31);
(0..ell)
.flat_map(|i| {
c0a[8 * i..8 * i + 8]
.iter()
.chain(c0b[8 * i..8 * i + 8].iter())
.chain(c1a[8 * i..8 * i + 8].iter())
.chain(c1b[8 * i..8 * i + 8].iter())
.cloned()
})
.collect()
};
let mut res_avx = vec![0u64; 16];
let mut res_ref = vec![0u64; 16];
unsafe { vec_mat2cols_product_x2_bbc_avx2(&meta, ell, &mut res_avx, &x, &y) };
vec_mat2cols_product_x2_bbc_ref::<Primes30>(&meta, ell, &mut res_ref, &x, &y);
assert_eq!(res_avx, res_ref, "vec_mat2cols_product_x2_bbc: AVX2 vs ref mismatch");
}
}