use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use std::mem;
use std::alloc::{alloc, Layout};
use std::ptr;
use std::f32;
use std::f64;
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub unsafe fn avx512_add_f32(a: &[f32], b: &[f32], result: &mut [f32]) {
use std::arch::x86_64::*;
let simd_width = 16;
let simd_chunks = a.len() / simd_width;
for i in 0..simd_chunks {
let idx = i * simd_width;
let a_vec = _mm512_loadu_ps(a.as_ptr().add(idx));
let b_vec = _mm512_loadu_ps(b.as_ptr().add(idx));
let result_vec = _mm512_add_ps(a_vec, b_vec);
_mm512_storeu_ps(result.as_mut_ptr().add(idx), result_vec);
}
let remainder_start = simd_chunks * simd_width;
for i in remainder_start..a.len() {
result[i] = a[i] + b[i];
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub unsafe fn avx512_add_f64(a: &[f64], b: &[f64], result: &mut [f64]) {
use std::arch::x86_64::*;
let simd_width = 8;
let simd_chunks = a.len() / simd_width;
for i in 0..simd_chunks {
let idx = i * simd_width;
let a_vec = _mm512_loadu_pd(a.as_ptr().add(idx));
let b_vec = _mm512_loadu_pd(b.as_ptr().add(idx));
let result_vec = _mm512_add_pd(a_vec, b_vec);
_mm512_storeu_pd(result.as_mut_ptr().add(idx), result_vec);
}
let remainder_start = simd_chunks * simd_width;
for i in remainder_start..a.len() {
result[i] = a[i] + b[i];
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub unsafe fn avx512_mul_f32(a: &[f32], b: &[f32], result: &mut [f32]) {
use std::arch::x86_64::*;
let simd_width = 16;
let simd_chunks = a.len() / simd_width;
for i in 0..simd_chunks {
let idx = i * simd_width;
let a_vec = _mm512_loadu_ps(a.as_ptr().add(idx));
let b_vec = _mm512_loadu_ps(b.as_ptr().add(idx));
let result_vec = _mm512_mul_ps(a_vec, b_vec);
_mm512_storeu_ps(result.as_mut_ptr().add(idx), result_vec);
}
let remainder_start = simd_chunks * simd_width;
for i in remainder_start..a.len() {
result[i] = a[i] * b[i];
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub unsafe fn avx512_mul_f64(a: &[f64], b: &[f64], result: &mut [f64]) {
use std::arch::x86_64::*;
let simd_width = 8;
let simd_chunks = a.len() / simd_width;
for i in 0..simd_chunks {
let idx = i * simd_width;
let a_vec = _mm512_loadu_pd(a.as_ptr().add(idx));
let b_vec = _mm512_loadu_pd(b.as_ptr().add(idx));
let result_vec = _mm512_mul_pd(a_vec, b_vec);
_mm512_storeu_pd(result.as_mut_ptr().add(idx), result_vec);
}
let remainder_start = simd_chunks * simd_width;
for i in remainder_start..a.len() {
result[i] = a[i] * b[i];
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub unsafe fn avx512_div_f32(a: &[f32], b: &[f32], result: &mut [f32]) {
use std::arch::x86_64::*;
let simd_width = 16;
let simd_chunks = a.len() / simd_width;
for i in 0..simd_chunks {
let idx = i * simd_width;
let a_vec = _mm512_loadu_ps(a.as_ptr().add(idx));
let b_vec = _mm512_loadu_ps(b.as_ptr().add(idx));
let result_vec = _mm512_div_ps(a_vec, b_vec);
_mm512_storeu_ps(result.as_mut_ptr().add(idx), result_vec);
}
let remainder_start = simd_chunks * simd_width;
for i in remainder_start..a.len() {
result[i] = a[i] / b[i];
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub unsafe fn avx512_div_f64(a: &[f64], b: &[f64], result: &mut [f64]) {
use std::arch::x86_64::*;
let simd_width = 8;
let simd_chunks = a.len() / simd_width;
for i in 0..simd_chunks {
let idx = i * simd_width;
let a_vec = _mm512_loadu_pd(a.as_ptr().add(idx));
let b_vec = _mm512_loadu_pd(b.as_ptr().add(idx));
let result_vec = _mm512_div_pd(a_vec, b_vec);
_mm512_storeu_pd(result.as_mut_ptr().add(idx), result_vec);
}
let remainder_start = simd_chunks * simd_width;
for i in remainder_start..a.len() {
result[i] = a[i] / b[i];
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub unsafe fn avx512_sqrt_f32(a: &[f32], result: &mut [f32]) {
use std::arch::x86_64::*;
let simd_width = 16;
let simd_chunks = a.len() / simd_width;
for i in 0..simd_chunks {
let idx = i * simd_width;
let a_vec = _mm512_loadu_ps(a.as_ptr().add(idx));
let result_vec = _mm512_sqrt_ps(a_vec);
_mm512_storeu_ps(result.as_mut_ptr().add(idx), result_vec);
}
let remainder_start = simd_chunks * simd_width;
for i in remainder_start..a.len() {
result[i] = a[i].sqrt();
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub unsafe fn avx512_sqrt_f64(a: &[f64], result: &mut [f64]) {
use std::arch::x86_64::*;
let simd_width = 8;
let simd_chunks = a.len() / simd_width;
for i in 0..simd_chunks {
let idx = i * simd_width;
let a_vec = _mm512_loadu_pd(a.as_ptr().add(idx));
let result_vec = _mm512_sqrt_pd(a_vec);
_mm512_storeu_pd(result.as_mut_ptr().add(idx), result_vec);
}
let remainder_start = simd_chunks * simd_width;
for i in remainder_start..a.len() {
result[i] = a[i].sqrt();
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub unsafe fn avx512_sum_f32(a: &[f32]) -> f32 {
use std::arch::x86_64::*;
let simd_width = 16;
let simd_chunks = a.len() / simd_width;
let mut sum_vec = _mm512_setzero_ps();
for i in 0..simd_chunks {
let idx = i * simd_width;
let a_vec = _mm512_loadu_ps(a.as_ptr().add(idx));
sum_vec = _mm512_add_ps(sum_vec, a_vec);
}
let sum = _mm512_reduce_add_ps(sum_vec);
let remainder_start = simd_chunks * simd_width;
let mut result = sum;
for i in remainder_start..a.len() {
result += a[i];
}
result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub unsafe fn avx512_sum_f64(a: &[f64]) -> f64 {
use std::arch::x86_64::*;
let simd_width = 8;
let simd_chunks = a.len() / simd_width;
let mut sum_vec = _mm512_setzero_pd();
for i in 0..simd_chunks {
let idx = i * simd_width;
let a_vec = _mm512_loadu_pd(a.as_ptr().add(idx));
sum_vec = _mm512_add_pd(sum_vec, a_vec);
}
let sum = _mm512_reduce_add_pd(sum_vec);
let remainder_start = simd_chunks * simd_width;
let mut result = sum;
for i in remainder_start..a.len() {
result += a[i];
}
result
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub unsafe fn avx512_exp_f32(a: &[f32], result: &mut [f32]) {
use std::arch::x86_64::*;
let simd_width = 16;
let simd_chunks = a.len() / simd_width;
let log2e = _mm512_set1_ps(1.442695040888963f32);
let c1 = _mm512_set1_ps(1.0f32);
let c2 = _mm512_set1_ps(1.0f32);
let c3 = _mm512_set1_ps(0.5f32);
let c4 = _mm512_set1_ps(0.1666666666f32);
let c5 = _mm512_set1_ps(0.0416666666f32);
let half = _mm512_set1_ps(0.5f32);
let neg_half = _mm512_set1_ps(-0.5f32);
for i in 0..simd_chunks {
let idx = i * simd_width;
let a_vec = _mm512_loadu_ps(a.as_ptr().add(idx));
let x_log2e = _mm512_mul_ps(a_vec, log2e);
let n = _mm512_roundscale_ps::<(_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)>(x_log2e);
let f = _mm512_sub_ps(x_log2e, n);
let pow2n = {
let n_int = _mm512_cvttps_epi32(n);
let biased_n = _mm512_add_epi32(n_int, _mm512_set1_epi32(127 << 23));
let biased_n_shifted = _mm512_slli_epi32(biased_n, 23);
_mm512_castsi512_ps(biased_n_shifted)
};
let f2 = _mm512_mul_ps(f, f);
let f3 = _mm512_mul_ps(f2, f);
let f4 = _mm512_mul_ps(f2, f2);
let poly = _mm512_add_ps(
c1, _mm512_add_ps(
f, _mm512_add_ps(
_mm512_mul_ps(c3, f2), _mm512_add_ps(
_mm512_mul_ps(c4, f3),
_mm512_mul_ps(c5, f4)
)
)
)
);
let exp_x = _mm512_mul_ps(pow2n, poly);
_mm512_storeu_ps(result.as_mut_ptr().add(idx), exp_x);
}
let remainder_start = simd_chunks * simd_width;
for i in remainder_start..a.len() {
result[i] = a[i].exp();
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub unsafe fn avx512_exp_f64(a: &[f64], result: &mut [f64]) {
use std::arch::x86_64::*;
let simd_width = 8;
let simd_chunks = a.len() / simd_width;
let log2e = _mm512_set1_pd(1.442695040888963);
let c1 = _mm512_set1_pd(1.0);
let c2 = _mm512_set1_pd(1.0);
let c3 = _mm512_set1_pd(0.5);
let c4 = _mm512_set1_pd(0.1666666666666667);
let c5 = _mm512_set1_pd(0.0416666666666667);
let c6 = _mm512_set1_pd(0.0083333333333333);
for i in 0..simd_chunks {
let idx = i * simd_width;
let a_vec = _mm512_loadu_pd(a.as_ptr().add(idx));
let x_log2e = _mm512_mul_pd(a_vec, log2e);
let n = _mm512_roundscale_pd::<(_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)>(x_log2e);
let f = _mm512_sub_pd(x_log2e, n);
let pow2n = {
let n_int = _mm512_cvttpd_epi32(n);
let n_int64 = _mm512_cvtepi32_epi64(_mm256_castsi256_si128(n_int));
let biased_n = _mm512_add_epi64(n_int64, _mm512_set1_epi64(1023 << 52));
let biased_n_shifted = _mm512_slli_epi64(biased_n, 52);
_mm512_castsi512_pd(biased_n_shifted)
};
let f2 = _mm512_mul_pd(f, f);
let f3 = _mm512_mul_pd(f2, f);
let f4 = _mm512_mul_pd(f2, f2);
let f5 = _mm512_mul_pd(f4, f);
let poly = _mm512_add_pd(
c1, _mm512_add_pd(
f, _mm512_add_pd(
_mm512_mul_pd(c3, f2), _mm512_add_pd(
_mm512_mul_pd(c4, f3), _mm512_add_pd(
_mm512_mul_pd(c5, f4),
_mm512_mul_pd(c6, f5)
)
)
)
)
);
let exp_x = _mm512_mul_pd(pow2n, poly);
_mm512_storeu_pd(result.as_mut_ptr().add(idx), exp_x);
}
let remainder_start = simd_chunks * simd_width;
for i in remainder_start..a.len() {
result[i] = a[i].exp();
}
}
#[cfg(target_arch = "x86_64")]
pub fn avx512_optimized_add_f32(a: &Array<f32>, b: &Array<f32>) -> Result<Array<f32>> {
if a.shape() != b.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: a.shape(),
actual: b.shape(),
});
}
let a_data = a.to_vec();
let b_data = b.to_vec();
let mut result_data = vec![0.0f32; a_data.len()];
#[cfg(target_feature = "avx512f")]
unsafe {
avx512_add_f32(&a_data, &b_data, &mut result_data);
}
#[cfg(not(target_feature = "avx512f"))]
{
let features = crate::simd_optimize::detect_cpu_features();
if features.avx512f {
unsafe {
avx512_add_f32(&a_data, &b_data, &mut result_data);
}
} else {
if features.avx2 {
unsafe {
crate::simd_optimize::avx2_ops::avx2_add_f32(&a_data, &b_data, &mut result_data);
}
} else {
for i in 0..a_data.len() {
result_data[i] = a_data[i] + b_data[i];
}
}
}
}
Ok(Array::from_vec(result_data).reshape(&a.shape()))
}
#[cfg(target_arch = "x86_64")]
pub fn avx512_optimized_add_f64(a: &Array<f64>, b: &Array<f64>) -> Result<Array<f64>> {
if a.shape() != b.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: a.shape(),
actual: b.shape(),
});
}
let a_data = a.to_vec();
let b_data = b.to_vec();
let mut result_data = vec![0.0f64; a_data.len()];
#[cfg(target_feature = "avx512f")]
unsafe {
avx512_add_f64(&a_data, &b_data, &mut result_data);
}
#[cfg(not(target_feature = "avx512f"))]
{
let features = crate::simd_optimize::detect_cpu_features();
if features.avx512f {
unsafe {
avx512_add_f64(&a_data, &b_data, &mut result_data);
}
} else {
if features.avx2 {
unsafe {
crate::simd_optimize::avx2_ops::avx2_add_f64(&a_data, &b_data, &mut result_data);
}
} else {
for i in 0..a_data.len() {
result_data[i] = a_data[i] + b_data[i];
}
}
}
}
Ok(Array::from_vec(result_data).reshape(&a.shape()))
}
#[cfg(target_arch = "x86_64")]
pub fn avx512_optimized_mul_f32(a: &Array<f32>, b: &Array<f32>) -> Result<Array<f32>> {
if a.shape() != b.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: a.shape(),
actual: b.shape(),
});
}
let a_data = a.to_vec();
let b_data = b.to_vec();
let mut result_data = vec![0.0f32; a_data.len()];
#[cfg(target_feature = "avx512f")]
unsafe {
avx512_mul_f32(&a_data, &b_data, &mut result_data);
}
#[cfg(not(target_feature = "avx512f"))]
{
let features = crate::simd_optimize::detect_cpu_features();
if features.avx512f {
unsafe {
avx512_mul_f32(&a_data, &b_data, &mut result_data);
}
} else {
if features.avx2 {
unsafe {
crate::simd_optimize::avx2_ops::avx2_mul_f32(&a_data, &b_data, &mut result_data);
}
} else {
for i in 0..a_data.len() {
result_data[i] = a_data[i] * b_data[i];
}
}
}
}
Ok(Array::from_vec(result_data).reshape(&a.shape()))
}
#[cfg(target_arch = "x86_64")]
pub fn avx512_optimized_mul_f64(a: &Array<f64>, b: &Array<f64>) -> Result<Array<f64>> {
if a.shape() != b.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: a.shape(),
actual: b.shape(),
});
}
let a_data = a.to_vec();
let b_data = b.to_vec();
let mut result_data = vec![0.0f64; a_data.len()];
#[cfg(target_feature = "avx512f")]
unsafe {
avx512_mul_f64(&a_data, &b_data, &mut result_data);
}
#[cfg(not(target_feature = "avx512f"))]
{
let features = crate::simd_optimize::detect_cpu_features();
if features.avx512f {
unsafe {
avx512_mul_f64(&a_data, &b_data, &mut result_data);
}
} else {
if features.avx2 {
unsafe {
crate::simd_optimize::avx2_ops::avx2_mul_f64(&a_data, &b_data, &mut result_data);
}
} else {
for i in 0..a_data.len() {
result_data[i] = a_data[i] * b_data[i];
}
}
}
}
Ok(Array::from_vec(result_data).reshape(&a.shape()))
}
#[cfg(target_arch = "x86_64")]
pub fn avx512_optimized_sqrt_f32(a: &Array<f32>) -> Array<f32> {
let a_data = a.to_vec();
let mut result_data = vec![0.0f32; a_data.len()];
#[cfg(target_feature = "avx512f")]
unsafe {
avx512_sqrt_f32(&a_data, &mut result_data);
}
#[cfg(not(target_feature = "avx512f"))]
{
let features = crate::simd_optimize::detect_cpu_features();
if features.avx512f {
unsafe {
avx512_sqrt_f32(&a_data, &mut result_data);
}
} else {
if features.avx2 {
unsafe {
crate::simd_optimize::avx2_ops::avx2_sqrt_f32(&a_data, &mut result_data);
}
} else {
for i in 0..a_data.len() {
result_data[i] = a_data[i].sqrt();
}
}
}
}
Array::from_vec(result_data).reshape(&a.shape())
}
#[cfg(target_arch = "x86_64")]
pub fn avx512_optimized_sqrt_f64(a: &Array<f64>) -> Array<f64> {
let a_data = a.to_vec();
let mut result_data = vec![0.0f64; a_data.len()];
#[cfg(target_feature = "avx512f")]
unsafe {
avx512_sqrt_f64(&a_data, &mut result_data);
}
#[cfg(not(target_feature = "avx512f"))]
{
let features = crate::simd_optimize::detect_cpu_features();
if features.avx512f {
unsafe {
avx512_sqrt_f64(&a_data, &mut result_data);
}
} else {
if features.avx2 {
unsafe {
crate::simd_optimize::avx2_ops::avx2_sqrt_f64(&a_data, &mut result_data);
}
} else {
for i in 0..a_data.len() {
result_data[i] = a_data[i].sqrt();
}
}
}
}
Array::from_vec(result_data).reshape(&a.shape())
}
#[cfg(target_arch = "x86_64")]
pub fn avx512_optimized_sum_f32(a: &Array<f32>) -> f32 {
let a_data = a.to_vec();
#[cfg(target_feature = "avx512f")]
unsafe {
return avx512_sum_f32(&a_data);
}
#[cfg(not(target_feature = "avx512f"))]
{
let features = crate::simd_optimize::detect_cpu_features();
if features.avx512f {
unsafe {
return avx512_sum_f32(&a_data);
}
} else {
if features.avx2 {
unsafe {
return crate::simd_optimize::avx2_ops::avx2_sum_f32(&a_data);
}
} else {
return a_data.iter().sum();
}
}
}
a_data.iter().sum()
}
#[cfg(target_arch = "x86_64")]
pub fn avx512_optimized_sum_f64(a: &Array<f64>) -> f64 {
let a_data = a.to_vec();
#[cfg(target_feature = "avx512f")]
unsafe {
return avx512_sum_f64(&a_data);
}
#[cfg(not(target_feature = "avx512f"))]
{
let features = crate::simd_optimize::detect_cpu_features();
if features.avx512f {
unsafe {
return avx512_sum_f64(&a_data);
}
} else {
if features.avx2 {
unsafe {
return crate::simd_optimize::avx2_ops::avx2_sum_f64(&a_data);
}
} else {
return a_data.iter().sum();
}
}
}
a_data.iter().sum()
}
#[cfg(target_arch = "x86_64")]
pub fn avx512_optimized_exp_f32(a: &Array<f32>) -> Array<f32> {
let a_data = a.to_vec();
let mut result_data = vec![0.0f32; a_data.len()];
#[cfg(target_feature = "avx512f")]
unsafe {
avx512_exp_f32(&a_data, &mut result_data);
}
#[cfg(not(target_feature = "avx512f"))]
{
let features = crate::simd_optimize::detect_cpu_features();
if features.avx512f {
unsafe {
avx512_exp_f32(&a_data, &mut result_data);
}
} else {
for i in 0..a_data.len() {
result_data[i] = a_data[i].exp();
}
}
}
Array::from_vec(result_data).reshape(&a.shape())
}
#[cfg(target_arch = "x86_64")]
pub fn avx512_optimized_exp_f64(a: &Array<f64>) -> Array<f64> {
let a_data = a.to_vec();
let mut result_data = vec![0.0f64; a_data.len()];
#[cfg(target_feature = "avx512f")]
unsafe {
avx512_exp_f64(&a_data, &mut result_data);
}
#[cfg(not(target_feature = "avx512f"))]
{
let features = crate::simd_optimize::detect_cpu_features();
if features.avx512f {
unsafe {
avx512_exp_f64(&a_data, &mut result_data);
}
} else {
for i in 0..a_data.len() {
result_data[i] = a_data[i].exp();
}
}
}
Array::from_vec(result_data).reshape(&a.shape())
}