#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
use super::super::math::avx2::{exp_f32, exp_f64, tanh_f32, tanh_f64};
use super::{
gelu_mul_scalar_f32, gelu_mul_scalar_f64, relu_mul_scalar_f32, relu_mul_scalar_f64,
sigmoid_mul_scalar_f32, sigmoid_mul_scalar_f64, silu_mul_scalar_f32, silu_mul_scalar_f64,
};
const F32_LANES: usize = 8;
const F64_LANES: usize = 4;
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn silu_mul_f32(a: *const f32, b: *const f32, out: *mut f32, len: usize) {
let chunks = len / F32_LANES;
let one = _mm256_set1_ps(1.0);
for c in 0..chunks {
let offset = c * F32_LANES;
let x = _mm256_loadu_ps(a.add(offset));
let y = _mm256_loadu_ps(b.add(offset));
let neg_x = _mm256_sub_ps(_mm256_setzero_ps(), x);
let exp_neg_x = exp_f32(neg_x);
let activation = _mm256_div_ps(x, _mm256_add_ps(one, exp_neg_x));
let result = _mm256_mul_ps(activation, y);
_mm256_storeu_ps(out.add(offset), result);
}
let processed = chunks * F32_LANES;
if processed < len {
silu_mul_scalar_f32(
a.add(processed),
b.add(processed),
out.add(processed),
len - processed,
);
}
}
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn silu_mul_f64(a: *const f64, b: *const f64, out: *mut f64, len: usize) {
let chunks = len / F64_LANES;
let one = _mm256_set1_pd(1.0);
for c in 0..chunks {
let offset = c * F64_LANES;
let x = _mm256_loadu_pd(a.add(offset));
let y = _mm256_loadu_pd(b.add(offset));
let neg_x = _mm256_sub_pd(_mm256_setzero_pd(), x);
let exp_neg_x = exp_f64(neg_x);
let activation = _mm256_div_pd(x, _mm256_add_pd(one, exp_neg_x));
let result = _mm256_mul_pd(activation, y);
_mm256_storeu_pd(out.add(offset), result);
}
let processed = chunks * F64_LANES;
if processed < len {
silu_mul_scalar_f64(
a.add(processed),
b.add(processed),
out.add(processed),
len - processed,
);
}
}
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn gelu_mul_f32(a: *const f32, b: *const f32, out: *mut f32, len: usize) {
let chunks = len / F32_LANES;
let half = _mm256_set1_ps(0.5);
let one = _mm256_set1_ps(1.0);
let sqrt_2_over_pi = _mm256_set1_ps(0.7978845608);
let tanh_coef = _mm256_set1_ps(0.044715);
for c in 0..chunks {
let offset = c * F32_LANES;
let x = _mm256_loadu_ps(a.add(offset));
let y = _mm256_loadu_ps(b.add(offset));
let x_cubed = _mm256_mul_ps(_mm256_mul_ps(x, x), x);
let inner = _mm256_mul_ps(sqrt_2_over_pi, _mm256_fmadd_ps(tanh_coef, x_cubed, x));
let tanh_inner = tanh_f32(inner);
let activation = _mm256_mul_ps(half, _mm256_mul_ps(x, _mm256_add_ps(one, tanh_inner)));
let result = _mm256_mul_ps(activation, y);
_mm256_storeu_ps(out.add(offset), result);
}
let processed = chunks * F32_LANES;
if processed < len {
gelu_mul_scalar_f32(
a.add(processed),
b.add(processed),
out.add(processed),
len - processed,
);
}
}
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn gelu_mul_f64(a: *const f64, b: *const f64, out: *mut f64, len: usize) {
let chunks = len / F64_LANES;
let half = _mm256_set1_pd(0.5);
let one = _mm256_set1_pd(1.0);
let sqrt_2_over_pi = _mm256_set1_pd(0.7978845608028654);
let tanh_coef = _mm256_set1_pd(0.044715);
for c in 0..chunks {
let offset = c * F64_LANES;
let x = _mm256_loadu_pd(a.add(offset));
let y = _mm256_loadu_pd(b.add(offset));
let x_cubed = _mm256_mul_pd(_mm256_mul_pd(x, x), x);
let inner = _mm256_mul_pd(sqrt_2_over_pi, _mm256_fmadd_pd(tanh_coef, x_cubed, x));
let tanh_inner = tanh_f64(inner);
let activation = _mm256_mul_pd(half, _mm256_mul_pd(x, _mm256_add_pd(one, tanh_inner)));
let result = _mm256_mul_pd(activation, y);
_mm256_storeu_pd(out.add(offset), result);
}
let processed = chunks * F64_LANES;
if processed < len {
gelu_mul_scalar_f64(
a.add(processed),
b.add(processed),
out.add(processed),
len - processed,
);
}
}
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn relu_mul_f32(a: *const f32, b: *const f32, out: *mut f32, len: usize) {
let chunks = len / F32_LANES;
let zero = _mm256_setzero_ps();
for c in 0..chunks {
let offset = c * F32_LANES;
let x = _mm256_loadu_ps(a.add(offset));
let y = _mm256_loadu_ps(b.add(offset));
let activation = _mm256_max_ps(zero, x);
let result = _mm256_mul_ps(activation, y);
_mm256_storeu_ps(out.add(offset), result);
}
let processed = chunks * F32_LANES;
if processed < len {
relu_mul_scalar_f32(
a.add(processed),
b.add(processed),
out.add(processed),
len - processed,
);
}
}
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn relu_mul_f64(a: *const f64, b: *const f64, out: *mut f64, len: usize) {
let chunks = len / F64_LANES;
let zero = _mm256_setzero_pd();
for c in 0..chunks {
let offset = c * F64_LANES;
let x = _mm256_loadu_pd(a.add(offset));
let y = _mm256_loadu_pd(b.add(offset));
let activation = _mm256_max_pd(zero, x);
let result = _mm256_mul_pd(activation, y);
_mm256_storeu_pd(out.add(offset), result);
}
let processed = chunks * F64_LANES;
if processed < len {
relu_mul_scalar_f64(
a.add(processed),
b.add(processed),
out.add(processed),
len - processed,
);
}
}
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn sigmoid_mul_f32(a: *const f32, b: *const f32, out: *mut f32, len: usize) {
let chunks = len / F32_LANES;
let one = _mm256_set1_ps(1.0);
for c in 0..chunks {
let offset = c * F32_LANES;
let x = _mm256_loadu_ps(a.add(offset));
let y = _mm256_loadu_ps(b.add(offset));
let neg_x = _mm256_sub_ps(_mm256_setzero_ps(), x);
let exp_neg_x = exp_f32(neg_x);
let activation = _mm256_div_ps(one, _mm256_add_ps(one, exp_neg_x));
let result = _mm256_mul_ps(activation, y);
_mm256_storeu_ps(out.add(offset), result);
}
let processed = chunks * F32_LANES;
if processed < len {
sigmoid_mul_scalar_f32(
a.add(processed),
b.add(processed),
out.add(processed),
len - processed,
);
}
}
#[target_feature(enable = "avx2", enable = "fma")]
pub unsafe fn sigmoid_mul_f64(a: *const f64, b: *const f64, out: *mut f64, len: usize) {
let chunks = len / F64_LANES;
let one = _mm256_set1_pd(1.0);
for c in 0..chunks {
let offset = c * F64_LANES;
let x = _mm256_loadu_pd(a.add(offset));
let y = _mm256_loadu_pd(b.add(offset));
let neg_x = _mm256_sub_pd(_mm256_setzero_pd(), x);
let exp_neg_x = exp_f64(neg_x);
let activation = _mm256_div_pd(one, _mm256_add_pd(one, exp_neg_x));
let result = _mm256_mul_pd(activation, y);
_mm256_storeu_pd(out.add(offset), result);
}
let processed = chunks * F64_LANES;
if processed < len {
sigmoid_mul_scalar_f64(
a.add(processed),
b.add(processed),
out.add(processed),
len - processed,
);
}
}