#[target_feature(enable = "fma")]
pub fn reim_from_znx_i64_bnd50_fma(res: &mut [f64], a: &[i64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len());
const BOUND: i64 = (1i64 << 50) - 1;
for (i, &val) in a.iter().enumerate() {
assert!(
val.abs() <= BOUND,
"Input a[{}] = {} exceeds bound 2^50-1 ({})",
i,
val,
BOUND
);
}
}
let n: usize = res.len();
unsafe {
use std::arch::x86_64::{
__m256d, __m256i, _mm256_add_epi64, _mm256_castsi256_pd, _mm256_loadu_si256, _mm256_or_pd, _mm256_set1_epi64x,
_mm256_set1_pd, _mm256_storeu_pd, _mm256_sub_pd,
};
let expo: f64 = (1i64 << 52) as f64;
let add_cst: i64 = 1i64 << 51;
let sub_cst: f64 = (3i64 << 51) as f64;
let expo_256: __m256d = _mm256_set1_pd(expo);
let add_cst_256: __m256i = _mm256_set1_epi64x(add_cst);
let sub_cst_256: __m256d = _mm256_set1_pd(sub_cst);
let mut res_ptr: *mut f64 = res.as_mut_ptr();
let mut a_ptr: *const __m256i = a.as_ptr() as *const __m256i;
let span: usize = n >> 2;
for _ in 0..span {
let mut ai64_256: __m256i = _mm256_loadu_si256(a_ptr);
ai64_256 = _mm256_add_epi64(ai64_256, add_cst_256);
let mut af64_256: __m256d = _mm256_castsi256_pd(ai64_256);
af64_256 = _mm256_or_pd(af64_256, expo_256);
af64_256 = _mm256_sub_pd(af64_256, sub_cst_256);
_mm256_storeu_pd(res_ptr, af64_256);
res_ptr = res_ptr.add(4);
a_ptr = a_ptr.add(1);
}
if !res.len().is_multiple_of(4) {
use poulpy_hal::reference::fft64::reim::reim_from_znx_i64_ref;
reim_from_znx_i64_ref(&mut res[span << 2..], &a[span << 2..])
}
}
}
#[allow(dead_code)]
#[target_feature(enable = "avx2,fma")]
pub fn reim_to_znx_i64_bnd63_avx2_fma(res: &mut [i64], divisor: f64, a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len())
}
let sign_mask: u64 = 0x8000000000000000u64;
let expo_mask: u64 = 0x7FF0000000000000u64;
let mantissa_mask: u64 = (i64::MAX as u64) ^ expo_mask;
let mantissa_msb: u64 = 0x0010000000000000u64;
let divi_bits: f64 = divisor * (1i64 << 52) as f64;
let offset: f64 = divisor / 2.;
unsafe {
use std::arch::x86_64::{
__m256d, __m256i, _mm256_add_pd, _mm256_and_pd, _mm256_and_si256, _mm256_castpd_si256, _mm256_castsi256_pd,
_mm256_loadu_pd, _mm256_or_pd, _mm256_or_si256, _mm256_set1_epi64x, _mm256_set1_pd, _mm256_sllv_epi64,
_mm256_srli_epi64, _mm256_srlv_epi64, _mm256_sub_epi64, _mm256_xor_si256,
};
let sign_mask_256: __m256d = _mm256_castsi256_pd(_mm256_set1_epi64x(sign_mask as i64));
let expo_mask_256: __m256i = _mm256_set1_epi64x(expo_mask as i64);
let mantissa_mask_256: __m256i = _mm256_set1_epi64x(mantissa_mask as i64);
let mantissa_msb_256: __m256i = _mm256_set1_epi64x(mantissa_msb as i64);
let offset_256 = _mm256_set1_pd(offset);
let divi_bits_256 = _mm256_castpd_si256(_mm256_set1_pd(divi_bits));
let mut res_ptr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let mut a_ptr: *const f64 = a.as_ptr();
let span: usize = res.len() >> 2;
for _ in 0..span {
use std::arch::x86_64::_mm256_storeu_si256;
let mut a: __m256d = _mm256_loadu_pd(a_ptr);
let asign: __m256d = _mm256_and_pd(a, sign_mask_256);
a = _mm256_add_pd(a, _mm256_or_pd(asign, offset_256));
let mut sign_mask: __m256i = _mm256_castpd_si256(asign);
sign_mask = _mm256_sub_epi64(_mm256_set1_epi64x(0), _mm256_srli_epi64(sign_mask, 63));
let a0exp: __m256i = _mm256_and_si256(_mm256_castpd_si256(a), expo_mask_256);
let mut a0lsh: __m256i = _mm256_sub_epi64(a0exp, divi_bits_256);
let mut a0rsh: __m256i = _mm256_sub_epi64(divi_bits_256, a0exp);
a0lsh = _mm256_srli_epi64(a0lsh, 52);
a0rsh = _mm256_srli_epi64(a0rsh, 52);
let mut a0pos: __m256i = _mm256_and_si256(_mm256_castpd_si256(a), mantissa_mask_256);
a0pos = _mm256_or_si256(a0pos, mantissa_msb_256);
a0lsh = _mm256_sllv_epi64(a0pos, a0lsh);
a0rsh = _mm256_srlv_epi64(a0pos, a0rsh);
let mut out: __m256i = _mm256_or_si256(a0lsh, a0rsh);
out = _mm256_xor_si256(out, sign_mask);
out = _mm256_sub_epi64(out, sign_mask);
_mm256_storeu_si256(res_ptr, out);
res_ptr = res_ptr.add(1);
a_ptr = a_ptr.add(4);
}
if !res.len().is_multiple_of(4) {
use poulpy_hal::reference::fft64::reim::reim_to_znx_i64_ref;
reim_to_znx_i64_ref(&mut res[span << 2..], divisor, &a[span << 2..])
}
}
}
#[target_feature(enable = "avx2,fma")]
pub fn reim_to_znx_i64_inplace_bnd63_avx2_fma(res: &mut [f64], divisor: f64) {
let sign_mask: u64 = 0x8000000000000000u64;
let expo_mask: u64 = 0x7FF0000000000000u64;
let mantissa_mask: u64 = (i64::MAX as u64) ^ expo_mask;
let mantissa_msb: u64 = 0x0010000000000000u64;
let divi_bits: f64 = divisor * (1i64 << 52) as f64;
let offset: f64 = divisor / 2.;
unsafe {
use std::arch::x86_64::{
__m256d, __m256i, _mm256_add_pd, _mm256_and_pd, _mm256_and_si256, _mm256_castpd_si256, _mm256_castsi256_pd,
_mm256_loadu_pd, _mm256_or_pd, _mm256_or_si256, _mm256_set1_epi64x, _mm256_set1_pd, _mm256_sllv_epi64,
_mm256_srli_epi64, _mm256_srlv_epi64, _mm256_sub_epi64, _mm256_xor_si256,
};
use poulpy_hal::reference::fft64::reim::reim_to_znx_i64_inplace_ref;
let sign_mask_256: __m256d = _mm256_castsi256_pd(_mm256_set1_epi64x(sign_mask as i64));
let expo_mask_256: __m256i = _mm256_set1_epi64x(expo_mask as i64);
let mantissa_mask_256: __m256i = _mm256_set1_epi64x(mantissa_mask as i64);
let mantissa_msb_256: __m256i = _mm256_set1_epi64x(mantissa_msb as i64);
let offset_256: __m256d = _mm256_set1_pd(offset);
let divi_bits_256: __m256i = _mm256_castpd_si256(_mm256_set1_pd(divi_bits));
let mut res_ptr_4xi64: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let mut res_ptr_1xf64: *mut f64 = res.as_mut_ptr();
let span: usize = res.len() >> 2;
for _ in 0..span {
use std::arch::x86_64::_mm256_storeu_si256;
let mut a: __m256d = _mm256_loadu_pd(res_ptr_1xf64);
let asign: __m256d = _mm256_and_pd(a, sign_mask_256);
a = _mm256_add_pd(a, _mm256_or_pd(asign, offset_256));
let mut sign_mask: __m256i = _mm256_castpd_si256(asign);
sign_mask = _mm256_sub_epi64(_mm256_set1_epi64x(0), _mm256_srli_epi64(sign_mask, 63));
let a0exp: __m256i = _mm256_and_si256(_mm256_castpd_si256(a), expo_mask_256);
let mut a0lsh: __m256i = _mm256_sub_epi64(a0exp, divi_bits_256);
let mut a0rsh: __m256i = _mm256_sub_epi64(divi_bits_256, a0exp);
a0lsh = _mm256_srli_epi64(a0lsh, 52);
a0rsh = _mm256_srli_epi64(a0rsh, 52);
let mut a0pos: __m256i = _mm256_and_si256(_mm256_castpd_si256(a), mantissa_mask_256);
a0pos = _mm256_or_si256(a0pos, mantissa_msb_256);
a0lsh = _mm256_sllv_epi64(a0pos, a0lsh);
a0rsh = _mm256_srlv_epi64(a0pos, a0rsh);
let mut out: __m256i = _mm256_or_si256(a0lsh, a0rsh);
out = _mm256_xor_si256(out, sign_mask);
out = _mm256_sub_epi64(out, sign_mask);
_mm256_storeu_si256(res_ptr_4xi64, out);
res_ptr_4xi64 = res_ptr_4xi64.add(1);
res_ptr_1xf64 = res_ptr_1xf64.add(4);
}
if !res.len().is_multiple_of(4) {
reim_to_znx_i64_inplace_ref(&mut res[span << 2..], divisor)
}
}
}
#[target_feature(enable = "fma")]
#[allow(dead_code)]
pub fn reim_to_znx_i64_avx2_bnd50_fma(res: &mut [i64], divisor: f64, a: &[f64]) {
#[cfg(debug_assertions)]
{
assert_eq!(res.len(), a.len())
}
unsafe {
use std::arch::x86_64::{
__m256d, __m256i, _mm256_add_pd, _mm256_and_si256, _mm256_castpd_si256, _mm256_loadu_pd, _mm256_set1_epi64x,
_mm256_set1_pd, _mm256_storeu_si256, _mm256_sub_epi64,
};
let mantissa_mask: u64 = 0x000FFFFFFFFFFFFFu64;
let sub_cst: i64 = 1i64 << 51;
let add_cst: f64 = divisor * (3i64 << 51) as f64;
let sub_cst_4: __m256i = _mm256_set1_epi64x(sub_cst);
let add_cst_4: std::arch::x86_64::__m256d = _mm256_set1_pd(add_cst);
let mantissa_mask_4: __m256i = _mm256_set1_epi64x(mantissa_mask as i64);
let mut res_ptr: *mut __m256i = res.as_mut_ptr() as *mut __m256i;
let mut a_ptr = a.as_ptr();
let span: usize = res.len() >> 2;
for _ in 0..span {
let mut a: __m256d = _mm256_loadu_pd(a_ptr);
a = _mm256_add_pd(a, add_cst_4);
let mut ai: __m256i = _mm256_castpd_si256(a);
ai = _mm256_and_si256(ai, mantissa_mask_4);
ai = _mm256_sub_epi64(ai, sub_cst_4);
_mm256_storeu_si256(res_ptr, ai);
res_ptr = res_ptr.add(1);
a_ptr = a_ptr.add(4);
}
if !res.len().is_multiple_of(4) {
use poulpy_hal::reference::fft64::reim::reim_to_znx_i64_ref;
reim_to_znx_i64_ref(&mut res[span << 2..], divisor, &a[span << 2..])
}
}
}
#[cfg(all(test, target_feature = "avx2"))]
mod tests {
use poulpy_hal::reference::fft64::reim::{reim_from_znx_i64_ref, reim_to_znx_i64_ref};
use super::*;
#[test]
fn reim_from_znx_i64_avx2_vs_ref() {
let n = 64usize;
let a: Vec<i64> = (0..n as i64).map(|i| i * 997 - 32000).collect();
let mut res_avx = vec![0f64; n];
let mut res_ref = vec![0f64; n];
unsafe { reim_from_znx_i64_bnd50_fma(&mut res_avx, &a) };
reim_from_znx_i64_ref(&mut res_ref, &a);
assert_eq!(res_avx, res_ref, "reim_from_znx_i64: AVX2 vs ref mismatch");
}
#[test]
fn reim_to_znx_i64_avx2_vs_ref() {
let n = 64usize;
let divisor = 4.0f64;
let a: Vec<f64> = (0..n).map(|i| (i as f64 * 100.0 - 3000.0) * divisor).collect();
let mut res_avx = vec![0i64; n];
let mut res_ref = vec![0i64; n];
unsafe { reim_to_znx_i64_bnd63_avx2_fma(&mut res_avx, divisor, &a) };
reim_to_znx_i64_ref(&mut res_ref, divisor, &a);
assert_eq!(res_avx, res_ref, "reim_to_znx_i64: AVX2 vs ref mismatch");
}
}