#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn mullo_epi64_avx2(a: core::arch::x86_64::__m256i, b: core::arch::x86_64::__m256i) -> core::arch::x86_64::__m256i {
use core::arch::x86_64::{_mm256_add_epi64, _mm256_mul_epu32, _mm256_slli_epi64, _mm256_srli_epi64};
let lo = _mm256_mul_epu32(a, b);
let cross = _mm256_add_epi64(
_mm256_mul_epu32(_mm256_srli_epi64(a, 32), b),
_mm256_mul_epu32(a, _mm256_srli_epi64(b, 32)),
);
_mm256_add_epi64(lo, _mm256_slli_epi64(cross, 32))
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub(crate) unsafe fn znx_hadamard_product_i64_avx2(res: &mut [i64], a: &[i64], b: &[i64]) {
debug_assert_eq!(res.len(), a.len());
debug_assert_eq!(res.len(), b.len());
use core::arch::x86_64::{__m256i, _mm256_loadu_si256, _mm256_storeu_si256};
let n = res.len();
let chunks = n >> 2;
unsafe {
let mut rr = res.as_mut_ptr() as *mut __m256i;
let mut aa = a.as_ptr() as *const __m256i;
let mut bb = b.as_ptr() as *const __m256i;
for _ in 0..chunks {
let x = _mm256_loadu_si256(aa);
let y = _mm256_loadu_si256(bb);
_mm256_storeu_si256(rr, mullo_epi64_avx2(x, y));
rr = rr.add(1);
aa = aa.add(1);
bb = bb.add(1);
}
}
for i in (chunks << 2)..n {
res[i] = a[i].wrapping_mul(b[i]);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn znx_mul_power_of_two_avx(k: i64, res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
}
use core::arch::x86_64::{
__m128i, __m256i, _mm_cvtsi32_si128, _mm256_add_epi64, _mm256_and_si256, _mm256_cmpgt_epi64, _mm256_loadu_si256,
_mm256_or_si256, _mm256_set1_epi64x, _mm256_setzero_si256, _mm256_sll_epi64, _mm256_srl_epi64, _mm256_srli_epi64,
_mm256_storeu_si256, _mm256_sub_epi64,
};
let n: usize = res.len();
if n == 0 {
return;
}
if k == 0 {
use poulpy_cpu_ref::reference::znx::znx_copy_ref;
znx_copy_ref(res, a);
return;
}
let span: usize = n >> 2;
unsafe {
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let mut aa: *const __m256i = a.as_ptr() as *const __m256i;
if k > 0 {
#[cfg(debug_assertions)]
{
debug_assert!(k <= 63);
}
let cnt128: __m128i = _mm_cvtsi32_si128(k as i32);
for _ in 0..span {
let x: __m256i = _mm256_loadu_si256(aa);
let y: __m256i = _mm256_sll_epi64(x, cnt128);
_mm256_storeu_si256(rr, y);
rr = rr.add(1);
aa = aa.add(1);
}
if !n.is_multiple_of(4) {
use poulpy_cpu_ref::reference::znx::znx_mul_power_of_two_ref;
znx_mul_power_of_two_ref(k, &mut res[span << 2..], &a[span << 2..]);
}
return;
}
let kp = -k;
assert!((1..=63).contains(&kp), "kp must be in [1, 63], got {}", kp);
let cnt_right: __m128i = _mm_cvtsi32_si128(kp as i32);
let bias_base: __m256i = _mm256_set1_epi64x(1_i64 << (kp - 1));
let top_mask: __m256i = _mm256_set1_epi64x(-1_i64 << (64 - kp)); let zero: __m256i = _mm256_setzero_si256();
for _ in 0..span {
let x = _mm256_loadu_si256(aa);
let sign_bit_x: __m256i = _mm256_srli_epi64(x, 63);
let bias: __m256i = _mm256_sub_epi64(bias_base, sign_bit_x);
let t: __m256i = _mm256_add_epi64(x, bias);
let lsr: __m256i = _mm256_srl_epi64(t, cnt_right);
let neg: __m256i = _mm256_cmpgt_epi64(zero, t);
let fill: __m256i = _mm256_and_si256(neg, top_mask);
let y: __m256i = _mm256_or_si256(lsr, fill);
_mm256_storeu_si256(rr, y);
rr = rr.add(1);
aa = aa.add(1);
}
}
if !n.is_multiple_of(4) {
use poulpy_cpu_ref::reference::znx::znx_mul_power_of_two_ref;
znx_mul_power_of_two_ref(k, &mut res[span << 2..], &a[span << 2..]);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn znx_mul_power_of_two_assign_avx(k: i64, res: &mut [i64]) {
use core::arch::x86_64::{
__m128i, __m256i, _mm_cvtsi32_si128, _mm256_add_epi64, _mm256_and_si256, _mm256_cmpgt_epi64, _mm256_loadu_si256,
_mm256_or_si256, _mm256_set1_epi64x, _mm256_setzero_si256, _mm256_sll_epi64, _mm256_srl_epi64, _mm256_srli_epi64,
_mm256_storeu_si256, _mm256_sub_epi64,
};
let n: usize = res.len();
if n == 0 {
return;
}
if k == 0 {
return;
}
let span: usize = n >> 2;
unsafe {
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
if k > 0 {
#[cfg(debug_assertions)]
{
debug_assert!(k <= 63);
}
let cnt128: __m128i = _mm_cvtsi32_si128(k as i32);
for _ in 0..span {
let x: __m256i = _mm256_loadu_si256(rr);
let y: __m256i = _mm256_sll_epi64(x, cnt128);
_mm256_storeu_si256(rr, y);
rr = rr.add(1);
}
if !n.is_multiple_of(4) {
use poulpy_cpu_ref::reference::znx::znx_mul_power_of_two_assign_ref;
znx_mul_power_of_two_assign_ref(k, &mut res[span << 2..]);
}
return;
}
let kp = -k;
assert!((1..=63).contains(&kp), "kp must be in [1, 63], got {}", kp);
let cnt_right: __m128i = _mm_cvtsi32_si128(kp as i32);
let bias_base: __m256i = _mm256_set1_epi64x(1_i64 << (kp - 1));
let top_mask: __m256i = _mm256_set1_epi64x(-1_i64 << (64 - kp)); let zero: __m256i = _mm256_setzero_si256();
for _ in 0..span {
let x = _mm256_loadu_si256(rr);
let sign_bit_x: __m256i = _mm256_srli_epi64(x, 63);
let bias: __m256i = _mm256_sub_epi64(bias_base, sign_bit_x);
let t: __m256i = _mm256_add_epi64(x, bias);
let lsr: __m256i = _mm256_srl_epi64(t, cnt_right);
let neg: __m256i = _mm256_cmpgt_epi64(zero, t);
let fill: __m256i = _mm256_and_si256(neg, top_mask);
let y: __m256i = _mm256_or_si256(lsr, fill);
_mm256_storeu_si256(rr, y);
rr = rr.add(1);
}
}
if !n.is_multiple_of(4) {
use poulpy_cpu_ref::reference::znx::znx_mul_power_of_two_assign_ref;
znx_mul_power_of_two_assign_ref(k, &mut res[span << 2..]);
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn znx_mul_add_power_of_two_avx(k: i64, res: &mut [i64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
}
use core::arch::x86_64::{
__m128i, __m256i, _mm_cvtsi32_si128, _mm256_add_epi64, _mm256_and_si256, _mm256_cmpgt_epi64, _mm256_loadu_si256,
_mm256_or_si256, _mm256_set1_epi64x, _mm256_setzero_si256, _mm256_sll_epi64, _mm256_srl_epi64, _mm256_srli_epi64,
_mm256_storeu_si256, _mm256_sub_epi64,
};
let n: usize = res.len();
if n == 0 {
return;
}
if k == 0 {
use crate::znx_avx::znx_add_assign_avx;
znx_add_assign_avx(res, a);
return;
}
let span: usize = n >> 2;
unsafe {
let mut rr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let mut aa: *const __m256i = a.as_ptr() as *const __m256i;
if k > 0 {
#[cfg(debug_assertions)]
{
debug_assert!(k <= 63);
}
let cnt128: __m128i = _mm_cvtsi32_si128(k as i32);
for _ in 0..span {
let x: __m256i = _mm256_loadu_si256(aa);
let y: __m256i = _mm256_loadu_si256(rr);
_mm256_storeu_si256(rr, _mm256_add_epi64(y, _mm256_sll_epi64(x, cnt128)));
rr = rr.add(1);
aa = aa.add(1);
}
if !n.is_multiple_of(4) {
use poulpy_cpu_ref::reference::znx::znx_mul_add_power_of_two_ref;
znx_mul_add_power_of_two_ref(k, &mut res[span << 2..], &a[span << 2..]);
}
return;
}
let kp = -k;
assert!((1..=63).contains(&kp), "kp must be in [1, 63], got {}", kp);
let cnt_right: __m128i = _mm_cvtsi32_si128(kp as i32);
let bias_base: __m256i = _mm256_set1_epi64x(1_i64 << (kp - 1));
let top_mask: __m256i = _mm256_set1_epi64x(-1_i64 << (64 - kp)); let zero: __m256i = _mm256_setzero_si256();
for _ in 0..span {
let x: __m256i = _mm256_loadu_si256(aa);
let y: __m256i = _mm256_loadu_si256(rr);
let sign_bit_x: __m256i = _mm256_srli_epi64(x, 63);
let bias: __m256i = _mm256_sub_epi64(bias_base, sign_bit_x);
let t: __m256i = _mm256_add_epi64(x, bias);
let lsr: __m256i = _mm256_srl_epi64(t, cnt_right);
let neg: __m256i = _mm256_cmpgt_epi64(zero, t);
let fill: __m256i = _mm256_and_si256(neg, top_mask);
let out: __m256i = _mm256_or_si256(lsr, fill);
_mm256_storeu_si256(rr, _mm256_add_epi64(y, out));
rr = rr.add(1);
aa = aa.add(1);
}
}
if !n.is_multiple_of(4) {
use poulpy_cpu_ref::reference::znx::znx_mul_add_power_of_two_ref;
znx_mul_add_power_of_two_ref(k, &mut res[span << 2..], &a[span << 2..]);
}
}