#![cfg(all(target_feature = "avx512f", feature = "unstable"))]
use std::arch::asm;
#[derive(Clone, Copy)]
#[repr(transparent)]
#[allow(dead_code)]
pub struct F64x8(pub std::arch::x86_64::__m512d);
impl F64x8 {
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn new(vals: [f64; 8]) -> Self {
let ptr = vals.as_ptr();
let vec = std::arch::x86_64::_mm512_loadu_pd(ptr);
Self(vec)
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn splat(val: f64) -> Self {
let vec = std::arch::x86_64::_mm512_set1_pd(val);
Self(vec)
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn from_vals(
v0: f64, v1: f64, v2: f64, v3: f64,
v4: f64, v5: f64, v6: f64, v7: f64,
) -> Self {
let vec = std::arch::x86_64::_mm512_set_pd(v7, v6, v5, v4, v3, v2, v1, v0);
Self(vec)
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn to_array(self) -> [f64; 8] {
let mut result = [0.0; 8];
std::arch::x86_64::_mm512_storeu_pd(result.as_mut_ptr(), self.0);
result
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn add(self, other: Self) -> Self {
Self(std::arch::x86_64::_mm512_add_pd(self.0, other.0))
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn mul(self, other: Self) -> Self {
Self(std::arch::x86_64::_mm512_mul_pd(self.0, other.0))
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn sub(self, other: Self) -> Self {
Self(std::arch::x86_64::_mm512_sub_pd(self.0, other.0))
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn div(self, other: Self) -> Self {
Self(std::arch::x86_64::_mm512_div_pd(self.0, other.0))
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn max(self, other: Self) -> Self {
Self(std::arch::x86_64::_mm512_max_pd(self.0, other.0))
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn min(self, other: Self) -> Self {
Self(std::arch::x86_64::_mm512_min_pd(self.0, other.0))
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn sqrt(self) -> Self {
Self(std::arch::x86_64::_mm512_sqrt_pd(self.0))
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn fmadd(self, mul: Self, add: Self) -> Self {
Self(std::arch::x86_64::_mm512_fmadd_pd(self.0, mul.0, add.0))
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn reduce_add(self) -> f64 {
let lo = std::arch::x86_64::_mm512_castpd512_pd256(self.0);
let hi = std::arch::x86_64::_mm512_extractf32x4_pd::<2>(self.0);
let sum_lo = std::arch::x86_64::_mm256_hadd_pd(lo, lo);
let sum_hi = std::arch::x86_64::_mm256_hadd_pd(hi, hi);
let sum = std::arch::x86_64::_mm256_add_pd(sum_lo, sum_hi);
let sum2 = std::arch::x86_64::_mm256_hadd_pd(sum, sum);
let result = std::arch::x86_64::_mm256_castpd256_pd128(sum2);
let lo_val = std::arch::x86_64::_mm_cvtsd_f64(result);
let hi_val = std::arch::x86_64::_mm_cvtsd_f64(std::arch::x86_64::_mm_shuffle_pd(result, result, 1));
lo_val + hi_val
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn load(ptr: *const f64) -> Self {
Self(std::arch::x86_64::_mm512_loadu_pd(ptr))
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn store(self, ptr: *mut f64) {
std::arch::x86_64::_mm512_storeu_pd(ptr, self.0);
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn load_streaming(ptr: *const f64) -> Self {
Self(std::arch::x86_64::_mm512_loadu_pd(ptr))
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn store_streaming(self, ptr: *mut f64) {
std::arch::x86_64::_mm512_stream_pd(ptr, self.0);
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn abs(self) -> Self {
let sign_mask = std::arch::x86_64::_mm512_set1_pd(-0.0);
let not_sign = std::arch::x86_64::_mm512_andnot_pd(sign_mask, self.0);
Self(not_sign)
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn cmp_le(self, other: Self) -> i32 {
std::arch::x86_64::_mm512_cmp_pd_mask(self.0, other.0, std::arch::x86_64::_CMP_LE_OQ)
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn blend(self, other: Self, mask: i32) -> Self {
Self(std::arch::x86_64::_mm512_mask_blend_pd(mask, self.0, other.0))
}
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn relu_avx512(data: &mut [f64]) {
let len = data.len();
let chunks = len / 8;
let zeros = F64x8::splat(0.0);
for i in 0..chunks {
let idx = i * 8;
let v = F64x8::load(data.as_ptr().add(idx));
let maxed = v.max(zeros);
maxed.store(data.as_mut_ptr().add(idx));
}
for i in (chunks * 8)..len {
data[i] = data[i].max(0.0);
}
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn gelu_avx512(data: &mut [f64]) {
const SQRT_2_OVER_PI: f64 = 0.7978845608028654;
const COEF: f64 = 0.044715;
let len = data.len();
let chunks = len / 8;
let sqrt_2_over_pi = F64x8::splat(SQRT_2_OVER_PI);
let coef = F64x8::splat(COEF);
let half = F64x8::splat(0.5);
let ones = F64x8::splat(1.0);
let zeros = F64x8::splat(0.0);
let twenty_seven = F64x8::splat(27.0);
let nine = F64x8::splat(9.0);
for i in 0..chunks {
let idx = i * 8;
let x = F64x8::load(data.as_ptr().add(idx));
let x3 = x.mul(x).mul(x);
let inner = sqrt_2_over_pi.mul(x.add(coef.mul(x3)));
let x_abs = inner.abs();
let x_sq = inner.mul(inner);
let numerator = x_abs.mul(twenty_seven.add(x_sq));
let denominator = twenty_seven.add(nine.mul(x_sq));
let tanh_approx = numerator.div(denominator);
let ones_mask = tanh_approx.cmp_le(ones);
let neg_ones = F64x8::splat(-1.0);
let clamped = tanh_approx.blend(neg_ones, ones_mask);
let sign_mask = inner.cmp_le(zeros);
let sign = ones.blend(neg_ones, sign_mask);
let tanh_result = sign.mul(clamped.abs());
let gelu_x = x.mul(half).mul(ones.add(tanh_result));
gelu_x.store(data.as_mut_ptr().add(idx));
}
for i in (chunks * 8)..len {
let x = data[i];
let x3 = x * x * x;
let inner = SQRT_2_OVER_PI * (x + COEF * x3);
let tanh_val = inner.abs() * (27.0 + inner * inner) / (27.0 + 9.0 * inner * inner);
let sign = if inner >= 0.0 { 1.0 } else { -1.0 };
data[i] = x * 0.5 * (1.0 + tanh_val.min(1.0) * sign);
}
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn silu_avx512(data: &mut [f64]) {
let len = data.len();
let chunks = len / 8;
let ones = F64x8::splat(1.0);
let neg_ones = F64x8::splat(-1.0);
let zeros = F64x8::splat(0.0);
let half = F64x8::splat(0.5);
let two = F64x8::splat(2.0);
for i in 0..chunks {
let idx = i * 8;
let x = F64x8::load(data.as_ptr().add(idx));
let x_half = x.div(two);
let x_half_abs = x_half.abs();
let x_half_sq = x_half.mul(x_half);
let twenty_seven = F64x8::splat(27.0);
let nine = F64x8::splat(9.0);
let numerator = x_half_abs.mul(twenty_seven.add(x_half_sq));
let denominator = twenty_seven.add(nine.mul(x_half_sq));
let tanh_approx = numerator.div(denominator);
let ones_mask = tanh_approx.cmp_le(ones);
let clamped = tanh_approx.blend(neg_ones, ones_mask);
let sign_mask = x_half.cmp_le(zeros);
let sign = ones.blend(neg_ones, sign_mask);
let tanh_result = sign.mul(clamped.abs());
let sigmoid = half.mul(ones.add(tanh_result));
let silu_result = x.mul(sigmoid);
silu_result.store(data.as_mut_ptr().add(idx));
}
for i in (chunks * 8)..len {
let x = data[i];
let x_half = x / 2.0;
let tanh_val = x_half.abs() * (27.0 + x_half * x_half) / (27.0 + 9.0 * x_half * x_half);
let sign = if x_half >= 0.0 { 1.0 } else { -1.0 };
let sigmoid = 0.5 * (1.0 + tanh_val.min(1.0) * sign);
data[i] = x * sigmoid;
}
}
pub fn has_avx512() -> bool {
#[cfg(target_arch = "x86_64")]
{
use std::arch::x86_64::*;
let cpuid = unsafe { __cpuid_count(7, 0) };
(cpuid.ecx & (1 << 16)) != 0
}
#[cfg(not(target_arch = "x86_64"))]
{
false
}
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn dot_product_avx512(a: &[f64], b: &[f64]) -> f64 {
assert!(a.len() >= 8 && b.len() >= 8);
let va = F64x8::load(a.as_ptr());
let vb = F64x8::load(b.as_ptr());
let product = va.mul(vb);
product.reduce_add()
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn matmul_kernel_avx512(
a: *const f64,
b: *const f64, c: *mut f64,
m: usize,
k: usize,
n: usize,
) {
const BLOCK_SIZE: usize = 8;
const UNROLL_FACTOR: usize = 4;
for i in 0..m {
if i + 1 < m {
crate::utils::prefetch_read_data(a.add((i + 1) * k), k * 8);
}
for j in (0..n).step_by(BLOCK_SIZE) {
if j + BLOCK_SIZE <= n {
let mut acc = [F64x8::splat(0.0); UNROLL_FACTOR];
let mut p = 0;
while p + UNROLL_FACTOR <= k {
let a_ptr = a.add(i * k);
if p + UNROLL_FACTOR < k {
crate::utils::prefetch_read_data(b.add(j * k + p + UNROLL_FACTOR), BLOCK_SIZE * 8);
}
for u in 0..UNROLL_FACTOR {
let a_val = *a_ptr.add(p + u);
let a_vec = F64x8::splat(a_val);
let b_vec = F64x8::load(b.add(j * k + p + u));
acc[u] = acc[u].fmadd(a_vec, b_vec);
}
p += UNROLL_FACTOR;
}
while p < k {
let a_val = *a.add(i * k + p);
let a_vec = F64x8::splat(a_val);
let b_vec = F64x8::load(b.add(j * k + p));
acc[0] = acc[0].fmadd(a_vec, b_vec);
p += 1;
}
let mut result = acc[0];
for u in 1..UNROLL_FACTOR {
result = result.add(acc[u]);
}
result.store(c.add(i * n + j));
}
}
for j in (n / BLOCK_SIZE) * BLOCK_SIZE..n {
let mut sum = 0.0;
for p in 0..k {
sum += *a.add(i * k + p) * *b.add(j * k + p);
}
*c.add(i * n + j) = sum;
}
}
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn softmax_kernel_avx512(data: *mut f64, dim_size: usize) {
let mut max_val = f64::NEG_INFINITY;
for d in (0..dim_size).step_by(8) {
if d + 8 <= dim_size {
let vals = F64x8::load(data.add(d));
let max_vec = vals.max(F64x8::splat(max_val));
let max_arr = max_vec.to_array();
max_val = max_val
.max(max_arr[0])
.max(max_arr[1])
.max(max_arr[2])
.max(max_arr[3])
.max(max_arr[4])
.max(max_arr[5])
.max(max_arr[6])
.max(max_arr[7]);
} else {
for rem_d in d..dim_size {
max_val = max_val.max(*data.add(rem_d));
}
}
}
let mut sum_exp = 0.0;
for d in (0..dim_size).step_by(8) {
if d + 8 <= dim_size {
let vals = F64x8::load(data.add(d));
let shifted = vals.sub(F64x8::splat(max_val));
let arr = shifted.to_array();
let exp_arr = [
arr[0].exp(),
arr[1].exp(),
arr[2].exp(),
arr[3].exp(),
arr[4].exp(),
arr[5].exp(),
arr[6].exp(),
arr[7].exp(),
];
let exp_vec = F64x8::new(exp_arr);
sum_exp += exp_vec.reduce_add();
exp_vec.store(data.add(d));
} else {
for rem_d in d..dim_size {
let exp_val = (*data.add(rem_d) - max_val).exp();
sum_exp += exp_val;
*data.add(rem_d) = exp_val;
}
}
}
let inv_sum = 1.0 / sum_exp;
let inv_sum_vec = F64x8::splat(inv_sum);
for d in (0..dim_size).step_by(8) {
if d + 8 <= dim_size {
let vals = F64x8::load(data.add(d));
let normalized = vals.mul(inv_sum_vec);
normalized.store(data.add(d));
} else {
for rem_d in d..dim_size {
*data.add(rem_d) *= inv_sum;
}
}
}
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn layer_norm_kernel_avx512(
data: *const f64,
out: *mut f64,
dim_size: usize,
epsilon: f64,
) {
let mut sum = 0.0;
for d in (0..dim_size).step_by(8) {
if d + 8 <= dim_size {
let vals = F64x8::load(data.add(d));
sum += vals.reduce_add();
} else {
for rem_d in d..dim_size {
sum += *data.add(rem_d);
}
}
}
let mean = sum / dim_size as f64;
let mut sum_sq_diff = 0.0;
let mean_vec = F64x8::splat(mean);
for d in (0..dim_size).step_by(8) {
if d + 8 <= dim_size {
let vals = F64x8::load(data.add(d));
let diff = vals.sub(mean_vec);
let sq_diff = diff.mul(diff);
sum_sq_diff += sq_diff.reduce_add();
} else {
for rem_d in d..dim_size {
let diff = *data.add(rem_d) - mean;
sum_sq_diff += diff * diff;
}
}
}
let variance = sum_sq_diff / dim_size as f64;
let std = variance.sqrt();
let inv_std = 1.0 / (std + epsilon);
let inv_std_vec = F64x8::splat(inv_std);
let mean_vec = F64x8::splat(mean);
for d in (0..dim_size).step_by(8) {
if d + 8 <= dim_size {
let vals = F64x8::load(data.add(d));
let diff = vals.sub(mean_vec);
let normalized = diff.mul(inv_std_vec);
normalized.store(out.add(d));
} else {
for rem_d in d..dim_size {
*out.add(rem_d) = (*data.add(rem_d) - mean) * inv_std;
}
}
}
}
#[inline]
pub fn has_avx512() -> bool {
#[cfg(target_arch = "x86_64")]
{
is_x86_feature_detected!("avx512f")
}
#[cfg(not(target_arch = "x86_64"))]
{
false
}
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn sparse_compress_avx512(
data: &[f64],
col_indices: &mut Vec<usize>,
values: &mut Vec<f64>,
threshold: f64,
) {
use std::arch::x86_64::*;
let threshold_vec = _mm512_set1_pd(threshold);
let zeros = _mm512_setzero_pd();
let chunks = data.chunks_exact(8);
let remainder = chunks.remainder();
for (chunk_idx, chunk) in chunks.enumerate() {
let base_offset = chunk_idx * 8;
let v = _mm512_loadu_pd(chunk.as_ptr());
let abs_v = _mm512_max_pd(v, _mm512_sub_pd(zeros, v));
let mask = _mm512_cmp_pd_mask(abs_v, threshold_vec, _CMP_GT_OQ);
if mask != 0 {
let arr: [f64; 8] = std::mem::transmute(v);
for (i, &val) in arr.iter().enumerate() {
if val.abs() > threshold {
col_indices.push(base_offset + i);
values.push(val);
}
}
}
}
let base_idx = chunks.len() * 8;
for (i, &val) in remainder.iter().enumerate() {
if val.abs() > threshold {
col_indices.push(base_idx + i);
values.push(val);
}
}
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn dense_to_sparse_avx512(
data: &[f64],
rows: usize,
cols: usize,
row_offsets: &mut Vec<usize>,
col_indices: &mut Vec<usize>,
values: &mut Vec<f64>,
threshold: f64,
) {
let threshold_vec = _mm512_set1_pd(threshold);
let zeros = _mm512_setzero_pd();
for row in 0..rows {
let row_start = row * cols;
let row_data = &data[row_start..];
let chunks = row_data.chunks_exact(8);
let remainder = chunks.remainder();
for (chunk_idx, chunk) in chunks.enumerate() {
let base_offset = chunk_idx * 8;
let v = _mm512_loadu_pd(chunk.as_ptr());
let abs_v = _mm512_max_pd(v, _mm512_sub_pd(zeros, v));
let mask = _mm512_cmp_pd_mask(abs_v, threshold_vec, _CMP_GT_OQ);
if mask != 0 {
let arr: [f64; 8] = std::mem::transmute(v);
for (i, &val) in arr.iter().enumerate() {
if val.abs() > threshold {
col_indices.push(base_offset + i);
values.push(val);
}
}
}
}
let base_idx = chunks.len() * 8;
for (i, &val) in remainder.iter().enumerate() {
if val.abs() > threshold {
col_indices.push(base_idx + i);
values.push(val);
}
}
row_offsets.push(col_indices.len());
}
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn gemv_avx512(a: &[f64], x: &[f64], y: &mut [f64], m: usize, k: usize) {
for i in 0..m {
let row_start = i * k;
let mut sum = 0.0;
let chunks = k / 8;
for j in 0..chunks {
let idx = row_start + j * 8;
let a_vec = _mm512_loadu_pd(a.as_ptr().add(idx));
let x_vec = _mm512_loadu_pd(x.as_ptr().add(j * 8));
let prod = _mm512_mul_pd(a_vec, x_vec);
sum += _mm512_reduce_add_pd(prod);
}
for j in (chunks * 8)..k {
sum += a[row_start + j] * x[j];
}
y[i] = sum;
}
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn axpy_avx512(alpha: f64, x: &[f64], y: &mut [f64], n: usize) {
let alpha_vec = _mm512_set1_pd(alpha);
let chunks = n / 8;
for i in 0..chunks {
let idx = i * 8;
let x_vec = _mm512_loadu_pd(x.as_ptr().add(idx));
let y_vec = _mm512_loadu_pd(y.as_ptr().add(idx));
let result = _mm512_fmadd_pd(alpha_vec, x_vec, y_vec);
_mm512_storeu_pd(y.as_ptr().add(idx), result);
}
for i in (chunks * 8)..n {
y[i] = alpha * x[i] + y[i];
}
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn vector_add_avx512(a: &[f64], b: &[f64], c: &mut [f64], n: usize) {
let chunks = n / 8;
for i in 0..chunks {
let idx = i * 8;
let a_vec = _mm512_loadu_pd(a.as_ptr().add(idx));
let b_vec = _mm512_loadu_pd(b.as_ptr().add(idx));
let result = _mm512_add_pd(a_vec, b_vec);
_mm512_storeu_pd(c.as_ptr().add(idx), result);
}
for i in (chunks * 8)..n {
c[i] = a[i] + b[i];
}
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn vector_sub_avx512(a: &[f64], b: &[f64], c: &mut [f64], n: usize) {
let chunks = n / 8;
for i in 0..chunks {
let idx = i * 8;
let a_vec = _mm512_loadu_pd(a.as_ptr().add(idx));
let b_vec = _mm512_loadu_pd(b.as_ptr().add(idx));
let result = _mm512_sub_pd(a_vec, b_vec);
_mm512_storeu_pd(c.as_ptr().add(idx), result);
}
for i in (chunks * 8)..n {
c[i] = a[i] - b[i];
}
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn vector_mul_avx512(a: &[f64], b: &[f64], c: &mut [f64], n: usize) {
let chunks = n / 8;
for i in 0..chunks {
let idx = i * 8;
let a_vec = _mm512_loadu_pd(a.as_ptr().add(idx));
let b_vec = _mm512_loadu_pd(b.as_ptr().add(idx));
let result = _mm512_mul_pd(a_vec, b_vec);
_mm512_storeu_pd(c.as_ptr().add(idx), result);
}
for i in (chunks * 8)..n {
c[i] = a[i] * b[i];
}
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn sigmoid_avx512(data: &mut [f64]) {
let ones = _mm512_set1_pd(1.0);
let chunks = data.len() / 8;
for i in 0..chunks {
let idx = i * 8;
let x = _mm512_loadu_pd(data.as_ptr().add(idx));
let neg_x = _mm512_sub_pd(_mm512_setzero_pd(), x);
let exp_neg_x = _mm512_exp_pd(neg_x);
let denom = _mm512_add_pd(ones, exp_neg_x);
let sigmoid = _mm512_div_pd(ones, denom);
_mm512_storeu_pd(data.as_ptr().add(idx), sigmoid);
}
for i in (chunks * 8)..data.len() {
data[i] = 1.0 / (1.0 + (-data[i]).exp());
}
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn tanh_avx512(data: &mut [f64]) {
let chunks = data.len() / 8;
for i in 0..chunks {
let idx = i * 8;
let x = _mm512_loadu_pd(data.as_ptr().add(idx));
let exp_2x = _mm512_mul_pd(x, _mm512_set1_pd(2.0));
let exp_2x = _mm512_exp_pd(exp_2x);
let ones = _mm512_set1_pd(1.0);
let num = _mm512_sub_pd(exp_2x, ones);
let denom = _mm512_add_pd(exp_2x, ones);
let tanh = _mm512_div_pd(num, denom);
_mm512_storeu_pd(data.as_ptr().add(idx), tanh);
}
for i in (chunks * 8)..data.len() {
data[i] = data[i].tanh();
}
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn leaky_relu_avx512(data: &mut [f64], alpha: f64) {
let zeros = _mm512_setzero_pd();
let alpha_vec = _mm512_set1_pd(alpha);
let chunks = data.len() / 8;
for i in 0..chunks {
let idx = i * 8;
let x = _mm512_loadu_pd(data.as_ptr().add(idx));
let mask = _mm512_cmp_pd_mask(x, zeros, _CMP_GT_OQ);
let alpha_x = _mm512_mul_pd(x, alpha_vec);
let result = _mm512_mask_blend_pd(mask, alpha_x, x);
_mm512_storeu_pd(data.as_ptr().add(idx), result);
}
for i in (chunks * 8)..data.len() {
data[i] = if data[i] > 0.0 { data[i] } else { alpha * data[i] };
}
}
#[inline]
#[target_feature(enable = "avx512f")]
pub unsafe fn elu_avx512(data: &mut [f64], alpha: f64) {
let zeros = _mm512_setzero_pd();
let ones = _mm512_set1_pd(1.0);
let alpha_vec = _mm512_set1_pd(alpha);
let chunks = data.len() / 8;
for i in 0..chunks {
let idx = i * 8;
let x = _mm512_loadu_pd(data.as_ptr().add(idx));
let mask = _mm512_cmp_pd_mask(x, zeros, _CMP_GT_OQ);
let exp_x = _mm512_exp_pd(x);
let elu = _mm512_mul_pd(alpha_vec, _mm512_sub_pd(exp_x, ones));
let result = _mm512_mask_blend_pd(mask, x, elu);
_mm512_storeu_pd(data.as_ptr().add(idx), result);
}
for i in (chunks * 8)..data.len() {
data[i] = if data[i] > 0.0 {
data[i]
} else {
alpha * (data[i].exp() - 1.0)
};
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[target_feature(enable = "avx512f")]
unsafe fn test_f64x8_add() {
let a = F64x8::new([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
let b = F64x8::new([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]);
let result = a.add(b);
let arr = result.to_array();
assert!((arr[0] - 2.0).abs() < 1e-10);
assert!((arr[7] - 9.0).abs() < 1e-10);
}
#[test]
#[target_feature(enable = "avx512f")]
unsafe fn test_f64x8_mul() {
let a = F64x8::new([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
let b = F64x8::new([2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]);
let result = a.mul(b);
let arr = result.to_array();
assert!((arr[0] - 2.0).abs() < 1e-10);
assert!((arr[7] - 16.0).abs() < 1e-10);
}
#[test]
#[target_feature(enable = "avx512f")]
unsafe fn test_f64x8_reduce_add() {
let a = F64x8::new([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
let sum = a.reduce_add();
assert!((sum - 36.0).abs() < 1e-10);
}
#[test]
#[target_feature(enable = "avx512f")]
unsafe fn test_dot_product() {
let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let b = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
let result = dot_product_avx512(&a, &b);
assert!((result - 36.0).abs() < 1e-10);
}
}