use crate::array::Array;
use crate::error::Result;
use std::mem;
use std::f32::consts::PI as PI_F32;
use std::f64::consts::PI as PI_F64;
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub unsafe fn avx512_log_f32(a: &[f32], result: &mut [f32]) {
use std::arch::x86_64::*;
let simd_width = 16;
let simd_chunks = a.len() / simd_width;
let one = _mm512_set1_ps(1.0f32);
let neg_one_half = _mm512_set1_ps(-0.5f32);
let one_third = _mm512_set1_ps(1.0f32 / 3.0f32);
let neg_one_fourth = _mm512_set1_ps(-0.25f32);
let ln2 = _mm512_set1_ps(0.693147180559945f32);
for i in 0..simd_chunks {
let idx = i * simd_width;
let a_vec = _mm512_loadu_ps(a.as_ptr().add(idx));
let x_bits = _mm512_castps_si512(a_vec);
let exp_bits = _mm512_srli_epi32(x_bits, 23);
let exp = _mm512_sub_epi32(exp_bits, _mm512_set1_epi32(127));
let exp_f = _mm512_cvtepi32_ps(exp);
let mantissa_mask = _mm512_set1_epi32(0x007FFFFF);
let mantissa_bits = _mm512_and_si512(x_bits, mantissa_mask);
let mantissa_bits_with_ones = _mm512_or_si512(mantissa_bits, _mm512_set1_epi32(0x3F800000)); let mantissa = _mm512_castsi512_ps(mantissa_bits_with_ones);
let f = _mm512_sub_ps(mantissa, one);
let f2 = _mm512_mul_ps(f, f);
let f3 = _mm512_mul_ps(f2, f);
let f4 = _mm512_mul_ps(f2, f2);
let log_mantissa = _mm512_add_ps(
f, _mm512_add_ps(
_mm512_mul_ps(neg_one_half, f2), _mm512_add_ps(
_mm512_mul_ps(one_third, f3),
_mm512_mul_ps(neg_one_fourth, f4)
)
)
);
let n_log2 = _mm512_mul_ps(exp_f, ln2);
let log_x = _mm512_add_ps(n_log2, log_mantissa);
_mm512_storeu_ps(result.as_mut_ptr().add(idx), log_x);
}
let remainder_start = simd_chunks * simd_width;
for i in remainder_start..a.len() {
result[i] = a[i].ln();
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub unsafe fn avx512_log_f64(a: &[f64], result: &mut [f64]) {
use std::arch::x86_64::*;
let simd_width = 8;
let simd_chunks = a.len() / simd_width;
let one = _mm512_set1_pd(1.0);
let neg_one_half = _mm512_set1_pd(-0.5);
let one_third = _mm512_set1_pd(1.0 / 3.0);
let neg_one_fourth = _mm512_set1_pd(-0.25);
let one_fifth = _mm512_set1_pd(0.2);
let ln2 = _mm512_set1_pd(0.693147180559945);
for i in 0..simd_chunks {
let idx = i * simd_width;
let a_vec = _mm512_loadu_pd(a.as_ptr().add(idx));
let x_bits = _mm512_castpd_si512(a_vec);
let exp_mask = _mm512_set1_epi64(0x7FF0000000000000);
let exp_bits = _mm512_and_si512(x_bits, exp_mask);
let exp_bits_shifted = _mm512_srli_epi64(exp_bits, 52);
let exp_bias = _mm512_set1_epi64(1023);
let exp_unbiased = _mm512_sub_epi64(exp_bits_shifted, exp_bias);
let exp_f = _mm512_cvtepi64_pd(exp_unbiased);
let mantissa_mask = _mm512_set1_epi64(0x000FFFFFFFFFFFFF);
let mantissa_bits = _mm512_and_si512(x_bits, mantissa_mask);
let mantissa_bits_with_exp = _mm512_or_si512(mantissa_bits, _mm512_set1_epi64(0x3FF0000000000000));
let mantissa = _mm512_castsi512_pd(mantissa_bits_with_exp);
let f = _mm512_sub_pd(mantissa, one);
let f2 = _mm512_mul_pd(f, f);
let f3 = _mm512_mul_pd(f2, f);
let f4 = _mm512_mul_pd(f2, f2);
let f5 = _mm512_mul_pd(f4, f);
let log_mantissa = _mm512_add_pd(
f, _mm512_add_pd(
_mm512_mul_pd(neg_one_half, f2), _mm512_add_pd(
_mm512_mul_pd(one_third, f3), _mm512_add_pd(
_mm512_mul_pd(neg_one_fourth, f4),
_mm512_mul_pd(one_fifth, f5)
)
)
)
);
let n_log2 = _mm512_mul_pd(exp_f, ln2);
let log_x = _mm512_add_pd(n_log2, log_mantissa);
_mm512_storeu_pd(result.as_mut_ptr().add(idx), log_x);
}
let remainder_start = simd_chunks * simd_width;
for i in remainder_start..a.len() {
result[i] = a[i].ln();
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub unsafe fn avx512_abs_f32(a: &[f32], result: &mut [f32]) {
use std::arch::x86_64::*;
let simd_width = 16;
let simd_chunks = a.len() / simd_width;
let abs_mask = _mm512_set1_epi32(0x7FFFFFFF);
for i in 0..simd_chunks {
let idx = i * simd_width;
let a_vec = _mm512_loadu_ps(a.as_ptr().add(idx));
let a_bits = _mm512_castps_si512(a_vec);
let abs_bits = _mm512_and_si512(a_bits, abs_mask);
let abs_vec = _mm512_castsi512_ps(abs_bits);
_mm512_storeu_ps(result.as_mut_ptr().add(idx), abs_vec);
}
let remainder_start = simd_chunks * simd_width;
for i in remainder_start..a.len() {
result[i] = a[i].abs();
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub unsafe fn avx512_abs_f64(a: &[f64], result: &mut [f64]) {
use std::arch::x86_64::*;
let simd_width = 8;
let simd_chunks = a.len() / simd_width;
let abs_mask = _mm512_set1_epi64(0x7FFFFFFFFFFFFFFF);
for i in 0..simd_chunks {
let idx = i * simd_width;
let a_vec = _mm512_loadu_pd(a.as_ptr().add(idx));
let a_bits = _mm512_castpd_si512(a_vec);
let abs_bits = _mm512_and_si512(a_bits, abs_mask);
let abs_vec = _mm512_castsi512_pd(abs_bits);
_mm512_storeu_pd(result.as_mut_ptr().add(idx), abs_vec);
}
let remainder_start = simd_chunks * simd_width;
for i in remainder_start..a.len() {
result[i] = a[i].abs();
}
}
#[cfg(target_arch = "x86_64")]
pub fn avx512_optimized_log_f32(a: &Array<f32>) -> Array<f32> {
let a_data = a.to_vec();
let mut result_data = vec![0.0f32; a_data.len()];
#[cfg(target_feature = "avx512f")]
unsafe {
avx512_log_f32(&a_data, &mut result_data);
}
#[cfg(not(target_feature = "avx512f"))]
{
let features = crate::simd_optimize::detect_cpu_features();
if features.avx512f {
unsafe {
avx512_log_f32(&a_data, &mut result_data);
}
} else {
for i in 0..a_data.len() {
result_data[i] = a_data[i].ln();
}
}
}
Array::from_vec(result_data).reshape(&a.shape())
}
#[cfg(target_arch = "x86_64")]
pub fn avx512_optimized_log_f64(a: &Array<f64>) -> Array<f64> {
let a_data = a.to_vec();
let mut result_data = vec![0.0f64; a_data.len()];
#[cfg(target_feature = "avx512f")]
unsafe {
avx512_log_f64(&a_data, &mut result_data);
}
#[cfg(not(target_feature = "avx512f"))]
{
let features = crate::simd_optimize::detect_cpu_features();
if features.avx512f {
unsafe {
avx512_log_f64(&a_data, &mut result_data);
}
} else {
for i in 0..a_data.len() {
result_data[i] = a_data[i].ln();
}
}
}
Array::from_vec(result_data).reshape(&a.shape())
}
#[cfg(target_arch = "x86_64")]
pub fn avx512_optimized_abs_f32(a: &Array<f32>) -> Array<f32> {
let a_data = a.to_vec();
let mut result_data = vec![0.0f32; a_data.len()];
#[cfg(target_feature = "avx512f")]
unsafe {
avx512_abs_f32(&a_data, &mut result_data);
}
#[cfg(not(target_feature = "avx512f"))]
{
let features = crate::simd_optimize::detect_cpu_features();
if features.avx512f {
unsafe {
avx512_abs_f32(&a_data, &mut result_data);
}
} else {
for i in 0..a_data.len() {
result_data[i] = a_data[i].abs();
}
}
}
Array::from_vec(result_data).reshape(&a.shape())
}
#[cfg(target_arch = "x86_64")]
pub fn avx512_optimized_abs_f64(a: &Array<f64>) -> Array<f64> {
let a_data = a.to_vec();
let mut result_data = vec![0.0f64; a_data.len()];
#[cfg(target_feature = "avx512f")]
unsafe {
avx512_abs_f64(&a_data, &mut result_data);
}
#[cfg(not(target_feature = "avx512f"))]
{
let features = crate::simd_optimize::detect_cpu_features();
if features.avx512f {
unsafe {
avx512_abs_f64(&a_data, &mut result_data);
}
} else {
for i in 0..a_data.len() {
result_data[i] = a_data[i].abs();
}
}
}
Array::from_vec(result_data).reshape(&a.shape())
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub unsafe fn avx512_sin_f32(a: &[f32], result: &mut [f32]) {
use std::arch::x86_64::*;
let simd_width = 16;
let simd_chunks = a.len() / simd_width;
let two_over_pi = _mm512_set1_ps(0.6366197723675814f32); let pi_over_two = _mm512_set1_ps(PI_F32 / 2.0);
let c1 = _mm512_set1_ps(1.0f32);
let c3 = _mm512_set1_ps(-1.0f32 / 6.0f32);
let c5 = _mm512_set1_ps(1.0f32 / 120.0f32);
let c7 = _mm512_set1_ps(-1.0f32 / 5040.0f32);
for i in 0..simd_chunks {
let idx = i * simd_width;
let a_vec = _mm512_loadu_ps(a.as_ptr().add(idx));
let k_float = _mm512_mul_ps(a_vec, two_over_pi);
let k = _mm512_roundscale_ps::<{ _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC }>(k_float);
let x = _mm512_fnmadd_ps(k, pi_over_two, a_vec);
let x2 = _mm512_mul_ps(x, x);
let x3 = _mm512_mul_ps(x, x2);
let x5 = _mm512_mul_ps(x3, x2);
let x7 = _mm512_mul_ps(x5, x2);
let sin_x = _mm512_add_ps(
x, _mm512_add_ps(
_mm512_mul_ps(c3, x3), _mm512_add_ps(
_mm512_mul_ps(c5, x5),
_mm512_mul_ps(c7, x7)
)
)
);
let k_int = _mm512_cvttps_epi32(k);
let quadrant = _mm512_and_epi32(k_int, _mm512_set1_epi32(3));
let mask_q1 = _mm512_cmpeq_epi32_mask(quadrant, _mm512_set1_epi32(1));
let mask_q2 = _mm512_cmpeq_epi32_mask(quadrant, _mm512_set1_epi32(2));
let mask_q3 = _mm512_cmpeq_epi32_mask(quadrant, _mm512_set1_epi32(3));
let c2 = _mm512_set1_ps(-0.5f32);
let c4 = _mm512_set1_ps(1.0f32 / 24.0f32);
let c6 = _mm512_set1_ps(-1.0f32 / 720.0f32);
let x4 = _mm512_mul_ps(x2, x2);
let x6 = _mm512_mul_ps(x4, x2);
let cos_x = _mm512_add_ps(
c1, _mm512_add_ps(
_mm512_mul_ps(c2, x2), _mm512_add_ps(
_mm512_mul_ps(c4, x4),
_mm512_mul_ps(c6, x6)
)
)
);
let result_vec = _mm512_mask_blend_ps(mask_q1 | mask_q3, sin_x, cos_x);
let neg_sign_mask = _mm512_castsi512_ps(_mm512_set1_epi32(0x80000000)); let sign_mask = _mm512_mask_blend_ps(mask_q2 | mask_q3, _mm512_setzero_ps(), neg_sign_mask);
let final_result = _mm512_xor_ps(result_vec, sign_mask);
_mm512_storeu_ps(result.as_mut_ptr().add(idx), final_result);
}
let remainder_start = simd_chunks * simd_width;
for i in remainder_start..a.len() {
result[i] = a[i].sin();
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
pub unsafe fn avx512_sin_f64(a: &[f64], result: &mut [f64]) {
use std::arch::x86_64::*;
let simd_width = 8;
let simd_chunks = a.len() / simd_width;
let two_over_pi = _mm512_set1_pd(0.6366197723675814); let pi_over_two = _mm512_set1_pd(PI_F64 / 2.0);
let c1 = _mm512_set1_pd(1.0);
let c3 = _mm512_set1_pd(-1.0 / 6.0);
let c5 = _mm512_set1_pd(1.0 / 120.0);
let c7 = _mm512_set1_pd(-1.0 / 5040.0);
let c9 = _mm512_set1_pd(1.0 / 362880.0);
for i in 0..simd_chunks {
let idx = i * simd_width;
let a_vec = _mm512_loadu_pd(a.as_ptr().add(idx));
let k_float = _mm512_mul_pd(a_vec, two_over_pi);
let k = _mm512_roundscale_pd::<{ _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC }>(k_float);
let x = _mm512_fnmadd_pd(k, pi_over_two, a_vec);
let x2 = _mm512_mul_pd(x, x);
let x3 = _mm512_mul_pd(x, x2);
let x5 = _mm512_mul_pd(x3, x2);
let x7 = _mm512_mul_pd(x5, x2);
let x9 = _mm512_mul_pd(x7, x2);
let sin_x = _mm512_add_pd(
x, _mm512_add_pd(
_mm512_mul_pd(c3, x3), _mm512_add_pd(
_mm512_mul_pd(c5, x5), _mm512_add_pd(
_mm512_mul_pd(c7, x7),
_mm512_mul_pd(c9, x9)
)
)
)
);
let k_int = _mm512_cvttpd_epi32(k);
let quadrant = _mm256_and_si256(_mm256_castsi128_si256(k_int), _mm256_set1_epi32(3));
let quadrant_64 = _mm512_cvtepi32_epi64(_mm256_castsi256_si128(quadrant));
let mask_q1 = _mm512_cmpeq_epi64_mask(quadrant_64, _mm512_set1_epi64(1));
let mask_q2 = _mm512_cmpeq_epi64_mask(quadrant_64, _mm512_set1_epi64(2));
let mask_q3 = _mm512_cmpeq_epi64_mask(quadrant_64, _mm512_set1_epi64(3));
let c2 = _mm512_set1_pd(-0.5);
let c4 = _mm512_set1_pd(1.0 / 24.0);
let c6 = _mm512_set1_pd(-1.0 / 720.0);
let c8 = _mm512_set1_pd(1.0 / 40320.0);
let x4 = _mm512_mul_pd(x2, x2);
let x6 = _mm512_mul_pd(x4, x2);
let x8 = _mm512_mul_pd(x4, x4);
let cos_x = _mm512_add_pd(
c1, _mm512_add_pd(
_mm512_mul_pd(c2, x2), _mm512_add_pd(
_mm512_mul_pd(c4, x4), _mm512_add_pd(
_mm512_mul_pd(c6, x6),
_mm512_mul_pd(c8, x8)
)
)
)
);
let result_vec = _mm512_mask_blend_pd(mask_q1 | mask_q3, sin_x, cos_x);
let neg_sign_mask = _mm512_castsi512_pd(_mm512_set1_epi64(0x8000000000000000)); let sign_mask = _mm512_mask_blend_pd(mask_q2 | mask_q3, _mm512_setzero_pd(), neg_sign_mask);
let final_result = _mm512_xor_pd(result_vec, sign_mask);
_mm512_storeu_pd(result.as_mut_ptr().add(idx), final_result);
}
let remainder_start = simd_chunks * simd_width;
for i in remainder_start..a.len() {
result[i] = a[i].sin();
}
}
#[cfg(target_arch = "x86_64")]
pub fn avx512_optimized_sin_f32(a: &Array<f32>) -> Array<f32> {
let a_data = a.to_vec();
let mut result_data = vec![0.0f32; a_data.len()];
#[cfg(target_feature = "avx512f")]
unsafe {
avx512_sin_f32(&a_data, &mut result_data);
}
#[cfg(not(target_feature = "avx512f"))]
{
let features = crate::simd_optimize::detect_cpu_features();
if features.avx512f {
unsafe {
avx512_sin_f32(&a_data, &mut result_data);
}
} else {
for i in 0..a_data.len() {
result_data[i] = a_data[i].sin();
}
}
}
Array::from_vec(result_data).reshape(&a.shape())
}
#[cfg(target_arch = "x86_64")]
pub fn avx512_optimized_sin_f64(a: &Array<f64>) -> Array<f64> {
let a_data = a.to_vec();
let mut result_data = vec![0.0f64; a_data.len()];
#[cfg(target_feature = "avx512f")]
unsafe {
avx512_sin_f64(&a_data, &mut result_data);
}
#[cfg(not(target_feature = "avx512f"))]
{
let features = crate::simd_optimize::detect_cpu_features();
if features.avx512f {
unsafe {
avx512_sin_f64(&a_data, &mut result_data);
}
} else {
for i in 0..a_data.len() {
result_data[i] = a_data[i].sin();
}
}
}
Array::from_vec(result_data).reshape(&a.shape())
}