#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
use super::super::common::{exp_coefficients, log_coefficients};
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
pub unsafe fn exp_f32(x: __m256) -> __m256 {
use exp_coefficients::*;
let log2e = _mm256_set1_ps(std::f32::consts::LOG2_E);
let ln2 = _mm256_set1_ps(std::f32::consts::LN_2);
let c0 = _mm256_set1_ps(C0_F32);
let c1 = _mm256_set1_ps(C1_F32);
let c2 = _mm256_set1_ps(C2_F32);
let c3 = _mm256_set1_ps(C3_F32);
let c4 = _mm256_set1_ps(C4_F32);
let c5 = _mm256_set1_ps(C5_F32);
let c6 = _mm256_set1_ps(C6_F32);
let x = _mm256_max_ps(x, _mm256_set1_ps(MIN_F32));
let x = _mm256_min_ps(x, _mm256_set1_ps(MAX_F32));
let y = _mm256_mul_ps(x, log2e);
let n = _mm256_round_ps::<{ _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC }>(y);
let f = _mm256_sub_ps(y, n);
let r = _mm256_mul_ps(f, ln2);
let r2 = _mm256_mul_ps(r, r);
let r3 = _mm256_mul_ps(r2, r);
let r4 = _mm256_mul_ps(r2, r2);
let r5 = _mm256_mul_ps(r4, r);
let r6 = _mm256_mul_ps(r4, r2);
let mut poly = c0;
poly = _mm256_fmadd_ps(c1, r, poly);
poly = _mm256_fmadd_ps(c2, r2, poly);
poly = _mm256_fmadd_ps(c3, r3, poly);
poly = _mm256_fmadd_ps(c4, r4, poly);
poly = _mm256_fmadd_ps(c5, r5, poly);
poly = _mm256_fmadd_ps(c6, r6, poly);
let n_i32 = _mm256_cvtps_epi32(n);
let bias = _mm256_set1_epi32(127);
let exp_bits = _mm256_slli_epi32::<23>(_mm256_add_epi32(n_i32, bias));
let pow2n = _mm256_castsi256_ps(exp_bits);
_mm256_mul_ps(pow2n, poly)
}
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
pub unsafe fn exp_f64(x: __m256d) -> __m256d {
use exp_coefficients::*;
let log2e = _mm256_set1_pd(std::f64::consts::LOG2_E);
let ln2 = _mm256_set1_pd(std::f64::consts::LN_2);
let c0 = _mm256_set1_pd(C0_F64);
let c1 = _mm256_set1_pd(C1_F64);
let c2 = _mm256_set1_pd(C2_F64);
let c3 = _mm256_set1_pd(C3_F64);
let c4 = _mm256_set1_pd(C4_F64);
let c5 = _mm256_set1_pd(C5_F64);
let c6 = _mm256_set1_pd(C6_F64);
let x = _mm256_max_pd(x, _mm256_set1_pd(MIN_F64));
let x = _mm256_min_pd(x, _mm256_set1_pd(MAX_F64));
let y = _mm256_mul_pd(x, log2e);
let n = _mm256_round_pd::<{ _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC }>(y);
let f = _mm256_sub_pd(y, n);
let r = _mm256_mul_pd(f, ln2);
let r2 = _mm256_mul_pd(r, r);
let r3 = _mm256_mul_pd(r2, r);
let r4 = _mm256_mul_pd(r2, r2);
let r5 = _mm256_mul_pd(r4, r);
let r6 = _mm256_mul_pd(r4, r2);
let mut poly = c0;
poly = _mm256_fmadd_pd(c1, r, poly);
poly = _mm256_fmadd_pd(c2, r2, poly);
poly = _mm256_fmadd_pd(c3, r3, poly);
poly = _mm256_fmadd_pd(c4, r4, poly);
poly = _mm256_fmadd_pd(c5, r5, poly);
poly = _mm256_fmadd_pd(c6, r6, poly);
let mut result = [0.0f64; 4];
let mut n_arr = [0.0f64; 4];
let mut poly_arr = [0.0f64; 4];
_mm256_storeu_pd(n_arr.as_mut_ptr(), n);
_mm256_storeu_pd(poly_arr.as_mut_ptr(), poly);
for i in 0..4 {
let n_i = n_arr[i] as i64;
let exp_bits = ((n_i + 1023) as u64) << 52;
let pow2n = f64::from_bits(exp_bits);
result[i] = pow2n * poly_arr[i];
}
_mm256_loadu_pd(result.as_ptr())
}
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
pub unsafe fn log_f32(x: __m256) -> __m256 {
use log_coefficients::*;
let one = _mm256_set1_ps(1.0);
let ln2 = _mm256_set1_ps(std::f32::consts::LN_2);
let sqrt2 = _mm256_set1_ps(std::f32::consts::SQRT_2);
let half = _mm256_set1_ps(0.5);
let c1 = _mm256_set1_ps(C1_F32);
let c2 = _mm256_set1_ps(C2_F32);
let c3 = _mm256_set1_ps(C3_F32);
let c4 = _mm256_set1_ps(C4_F32);
let c5 = _mm256_set1_ps(C5_F32);
let c6 = _mm256_set1_ps(C6_F32);
let c7 = _mm256_set1_ps(C7_F32);
let x_bits = _mm256_castps_si256(x);
let exp_raw = _mm256_srli_epi32::<23>(x_bits);
let exp_unbiased = _mm256_sub_epi32(exp_raw, _mm256_set1_epi32(EXP_BIAS_F32));
let mut n = _mm256_cvtepi32_ps(exp_unbiased);
let mantissa_mask = _mm256_set1_epi32(MANTISSA_MASK_F32);
let exp_zero = _mm256_set1_epi32(EXP_ZERO_F32);
let m_bits = _mm256_or_si256(_mm256_and_si256(x_bits, mantissa_mask), exp_zero);
let mut m = _mm256_castsi256_ps(m_bits);
let need_adjust = _mm256_cmp_ps::<_CMP_GT_OQ>(m, sqrt2);
m = _mm256_blendv_ps(m, _mm256_mul_ps(m, half), need_adjust);
n = _mm256_blendv_ps(n, _mm256_add_ps(n, one), need_adjust);
let f = _mm256_sub_ps(m, one);
let mut poly = c7;
poly = _mm256_fmadd_ps(poly, f, c6);
poly = _mm256_fmadd_ps(poly, f, c5);
poly = _mm256_fmadd_ps(poly, f, c4);
poly = _mm256_fmadd_ps(poly, f, c3);
poly = _mm256_fmadd_ps(poly, f, c2);
poly = _mm256_fmadd_ps(poly, f, c1);
poly = _mm256_mul_ps(poly, f);
_mm256_fmadd_ps(n, ln2, poly)
}
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
pub unsafe fn log_f64(x: __m256d) -> __m256d {
use log_coefficients::*;
let one = _mm256_set1_pd(1.0);
let ln2 = _mm256_set1_pd(std::f64::consts::LN_2);
let sqrt2_val = std::f64::consts::SQRT_2;
let c1 = _mm256_set1_pd(C1_F64);
let c2 = _mm256_set1_pd(C2_F64);
let c3 = _mm256_set1_pd(C3_F64);
let c4 = _mm256_set1_pd(C4_F64);
let c5 = _mm256_set1_pd(C5_F64);
let c6 = _mm256_set1_pd(C6_F64);
let c7 = _mm256_set1_pd(C7_F64);
let c8 = _mm256_set1_pd(C8_F64);
let c9 = _mm256_set1_pd(C9_F64);
let x_bits = _mm256_castpd_si256(x);
let exp_raw = _mm256_srli_epi64::<52>(x_bits);
let mantissa_mask = _mm256_set1_epi64x(MANTISSA_MASK_F64 as i64);
let exp_zero = _mm256_set1_epi64x(EXP_ZERO_F64 as i64);
let m_bits = _mm256_or_si256(_mm256_and_si256(x_bits, mantissa_mask), exp_zero);
let m_initial = _mm256_castsi256_pd(m_bits);
let mut m_arr = [0.0f64; 4];
let mut exp_arr = [0i64; 4];
_mm256_storeu_pd(m_arr.as_mut_ptr(), m_initial);
_mm256_storeu_si256(exp_arr.as_mut_ptr() as *mut __m256i, exp_raw);
let mut n_arr = [0.0f64; 4];
for i in 0..4 {
let mut exp_unbiased = exp_arr[i] - EXP_BIAS_F64;
let mut m = m_arr[i];
if m > sqrt2_val {
m *= 0.5;
exp_unbiased += 1;
}
n_arr[i] = exp_unbiased as f64;
m_arr[i] = m;
}
let n = _mm256_loadu_pd(n_arr.as_ptr());
let m = _mm256_loadu_pd(m_arr.as_ptr());
let f = _mm256_sub_pd(m, one);
let mut poly = c9;
poly = _mm256_fmadd_pd(poly, f, c8);
poly = _mm256_fmadd_pd(poly, f, c7);
poly = _mm256_fmadd_pd(poly, f, c6);
poly = _mm256_fmadd_pd(poly, f, c5);
poly = _mm256_fmadd_pd(poly, f, c4);
poly = _mm256_fmadd_pd(poly, f, c3);
poly = _mm256_fmadd_pd(poly, f, c2);
poly = _mm256_fmadd_pd(poly, f, c1);
poly = _mm256_mul_pd(poly, f);
_mm256_fmadd_pd(n, ln2, poly)
}
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
pub unsafe fn exp2_f32(x: __m256) -> __m256 {
let ln2 = _mm256_set1_ps(std::f32::consts::LN_2);
exp_f32(_mm256_mul_ps(x, ln2))
}
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
pub unsafe fn exp2_f64(x: __m256d) -> __m256d {
let ln2 = _mm256_set1_pd(std::f64::consts::LN_2);
exp_f64(_mm256_mul_pd(x, ln2))
}
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
pub unsafe fn expm1_f32(x: __m256) -> __m256 {
let one = _mm256_set1_ps(1.0);
let half = _mm256_set1_ps(0.5);
let abs_x = _mm256_andnot_ps(_mm256_set1_ps(-0.0), x);
let x2 = _mm256_mul_ps(x, x);
let x3 = _mm256_mul_ps(x2, x);
let x4 = _mm256_mul_ps(x2, x2);
let c2 = _mm256_set1_ps(0.5);
let c3 = _mm256_set1_ps(1.0 / 6.0);
let c4 = _mm256_set1_ps(1.0 / 24.0);
let taylor = _mm256_fmadd_ps(c4, x4, _mm256_fmadd_ps(c3, x3, _mm256_fmadd_ps(c2, x2, x)));
let exp_result = _mm256_sub_ps(exp_f32(x), one);
let mask = _mm256_cmp_ps::<_CMP_GT_OQ>(abs_x, half);
_mm256_blendv_ps(taylor, exp_result, mask)
}
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
pub unsafe fn expm1_f64(x: __m256d) -> __m256d {
let one = _mm256_set1_pd(1.0);
let half = _mm256_set1_pd(0.5);
let abs_x = _mm256_andnot_pd(_mm256_set1_pd(-0.0), x);
let x2 = _mm256_mul_pd(x, x);
let x3 = _mm256_mul_pd(x2, x);
let x4 = _mm256_mul_pd(x2, x2);
let c2 = _mm256_set1_pd(0.5);
let c3 = _mm256_set1_pd(1.0 / 6.0);
let c4 = _mm256_set1_pd(1.0 / 24.0);
let taylor = _mm256_fmadd_pd(c4, x4, _mm256_fmadd_pd(c3, x3, _mm256_fmadd_pd(c2, x2, x)));
let exp_result = _mm256_sub_pd(exp_f64(x), one);
let mask = _mm256_cmp_pd::<_CMP_GT_OQ>(abs_x, half);
_mm256_blendv_pd(taylor, exp_result, mask)
}
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
pub unsafe fn log2_f32(x: __m256) -> __m256 {
let log2e = _mm256_set1_ps(std::f32::consts::LOG2_E);
_mm256_mul_ps(log_f32(x), log2e)
}
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
pub unsafe fn log2_f64(x: __m256d) -> __m256d {
let log2e = _mm256_set1_pd(std::f64::consts::LOG2_E);
_mm256_mul_pd(log_f64(x), log2e)
}
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
pub unsafe fn log10_f32(x: __m256) -> __m256 {
let log10e = _mm256_set1_ps(std::f32::consts::LOG10_E);
_mm256_mul_ps(log_f32(x), log10e)
}
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
pub unsafe fn log10_f64(x: __m256d) -> __m256d {
let log10e = _mm256_set1_pd(std::f64::consts::LOG10_E);
_mm256_mul_pd(log_f64(x), log10e)
}
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
pub unsafe fn log1p_f32(x: __m256) -> __m256 {
let one = _mm256_set1_ps(1.0);
let half = _mm256_set1_ps(0.5);
let abs_x = _mm256_andnot_ps(_mm256_set1_ps(-0.0), x);
let x2 = _mm256_mul_ps(x, x);
let x3 = _mm256_mul_ps(x2, x);
let x4 = _mm256_mul_ps(x2, x2);
let c2 = _mm256_set1_ps(-0.5);
let c3 = _mm256_set1_ps(1.0 / 3.0);
let c4 = _mm256_set1_ps(-0.25);
let taylor = _mm256_fmadd_ps(c4, x4, _mm256_fmadd_ps(c3, x3, _mm256_fmadd_ps(c2, x2, x)));
let log_result = log_f32(_mm256_add_ps(one, x));
let mask = _mm256_cmp_ps::<_CMP_GT_OQ>(abs_x, half);
_mm256_blendv_ps(taylor, log_result, mask)
}
#[target_feature(enable = "avx2", enable = "fma")]
#[inline]
pub unsafe fn log1p_f64(x: __m256d) -> __m256d {
let one = _mm256_set1_pd(1.0);
let half = _mm256_set1_pd(0.5);
let abs_x = _mm256_andnot_pd(_mm256_set1_pd(-0.0), x);
let x2 = _mm256_mul_pd(x, x);
let x3 = _mm256_mul_pd(x2, x);
let x4 = _mm256_mul_pd(x2, x2);
let c2 = _mm256_set1_pd(-0.5);
let c3 = _mm256_set1_pd(1.0 / 3.0);
let c4 = _mm256_set1_pd(-0.25);
let taylor = _mm256_fmadd_pd(c4, x4, _mm256_fmadd_pd(c3, x3, _mm256_fmadd_pd(c2, x2, x)));
let log_result = log_f64(_mm256_add_pd(one, x));
let mask = _mm256_cmp_pd::<_CMP_GT_OQ>(abs_x, half);
_mm256_blendv_pd(taylor, log_result, mask)
}