#[cfg(all(target_arch = "x86_64", feature = "simd-avx2"))]
use std::is_x86_feature_detected;
#[inline]
fn dot_scalar(a: &[f64], b: &[f64]) -> f64 {
let n = a.len().min(b.len());
let mut sum = 0.0;
for i in 0..n {
sum += a[i] * b[i];
}
sum
}
#[inline]
fn mat_vec_scalar(w: &[f64], x: &[f64], _rows: usize, cols: usize, out: &mut [f64]) {
for (row, out_i) in out.iter_mut().enumerate() {
let start = row * cols;
let mut sum = 0.0;
for j in 0..cols {
sum += w[start + j] * x[j];
}
*out_i = sum;
}
}
#[inline]
fn tanh_scalar(input: &[f64], output: &mut [f64]) {
for (i, &x) in input.iter().enumerate() {
output[i] = crate::math::tanh(x);
}
}
#[inline]
fn exp_scalar(input: &[f64], output: &mut [f64]) {
for (i, &x) in input.iter().enumerate() {
output[i] = crate::math::exp(x);
}
}
#[inline]
fn sigmoid_scalar(input: &[f64], output: &mut [f64]) {
for (i, &x) in input.iter().enumerate() {
output[i] = crate::math::sigmoid(x);
}
}
#[inline]
fn silu_scalar(input: &[f64], output: &mut [f64]) {
for (i, &x) in input.iter().enumerate() {
output[i] = x * crate::math::sigmoid(x);
}
}
#[cfg(all(target_arch = "x86_64", feature = "simd-avx2"))]
mod avx2 {
#[target_feature(enable = "avx2")]
pub(super) unsafe fn dot_avx2(a: &[f64], b: &[f64]) -> f64 {
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
let n = a.len().min(b.len());
let chunks = n / 4;
let remainder = n % 4;
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
unsafe {
let mut acc = _mm256_setzero_pd();
for i in 0..chunks {
let offset = i * 4;
let va = _mm256_loadu_pd(a_ptr.add(offset));
let vb = _mm256_loadu_pd(b_ptr.add(offset));
acc = _mm256_add_pd(acc, _mm256_mul_pd(va, vb));
}
let hi128 = _mm256_extractf128_pd(acc, 1); let lo128 = _mm256_castpd256_pd128(acc); let pair = _mm_add_pd(lo128, hi128); let high64 = _mm_unpackhi_pd(pair, pair); let total = _mm_add_sd(pair, high64); let mut scalar_sum = _mm_cvtsd_f64(total);
let base = chunks * 4;
for i in 0..remainder {
scalar_sum += *a_ptr.add(base + i) * *b_ptr.add(base + i);
}
scalar_sum
}
}
#[target_feature(enable = "avx2")]
pub(super) unsafe fn mat_vec_avx2(
w: &[f64],
x: &[f64],
_rows: usize,
cols: usize,
out: &mut [f64],
) {
for (row, out_i) in out.iter_mut().enumerate() {
let row_start = row * cols;
unsafe {
*out_i = dot_avx2(&w[row_start..row_start + cols], &x[..cols]);
}
}
}
#[target_feature(enable = "avx2")]
pub(super) unsafe fn tanh_avx2(input: &[f64], output: &mut [f64]) {
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
let n = input.len();
let chunks = n / 4;
unsafe {
let c15 = _mm256_set1_pd(15.0);
let c6 = _mm256_set1_pd(6.0);
let pos_sat = _mm256_set1_pd(4.97);
let neg_sat = _mm256_set1_pd(-4.97);
let one = _mm256_set1_pd(1.0);
let neg_one = _mm256_set1_pd(-1.0);
for i in 0..chunks {
let off = i * 4;
let x = _mm256_loadu_pd(input.as_ptr().add(off));
let x2 = _mm256_mul_pd(x, x);
let numer = _mm256_mul_pd(x, _mm256_add_pd(c15, x2));
let denom = _mm256_add_pd(c15, _mm256_mul_pd(c6, x2));
let approx = _mm256_div_pd(numer, denom);
let clamped = _mm256_min_pd(one, _mm256_max_pd(neg_one, approx));
let sat_pos = _mm256_cmp_pd(x, pos_sat, _CMP_GT_OQ);
let sat_neg = _mm256_cmp_pd(x, neg_sat, _CMP_LT_OQ);
let result = _mm256_blendv_pd(clamped, one, sat_pos);
let result = _mm256_blendv_pd(result, neg_one, sat_neg);
_mm256_storeu_pd(output.as_mut_ptr().add(off), result);
}
}
for i in (chunks * 4)..n {
output[i] = crate::math::tanh(input[i]);
}
}
#[target_feature(enable = "avx2")]
pub(super) unsafe fn exp_avx2(input: &[f64], output: &mut [f64]) {
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
let n = input.len();
let chunks = n / 4;
unsafe {
let ln2 = _mm256_set1_pd(core::f64::consts::LN_2);
let log2e = _mm256_set1_pd(core::f64::consts::LOG2_E);
let clamp_hi = _mm256_set1_pd(708.0);
let clamp_lo = _mm256_set1_pd(-708.0);
let one = _mm256_set1_pd(1.0);
let half = _mm256_set1_pd(0.5);
let c3 = _mm256_set1_pd(1.0 / 6.0);
let c4 = _mm256_set1_pd(1.0 / 24.0);
let c5 = _mm256_set1_pd(1.0 / 120.0);
let bias = _mm256_set1_epi64x(1023);
for i in 0..chunks {
let off = i * 4;
let x = _mm256_loadu_pd(input.as_ptr().add(off));
let x = _mm256_min_pd(clamp_hi, _mm256_max_pd(clamp_lo, x));
let x_scaled = _mm256_mul_pd(x, log2e);
let n_f = _mm256_floor_pd(_mm256_add_pd(x_scaled, half));
let r = _mm256_sub_pd(x, _mm256_mul_pd(n_f, ln2));
let mut p = _mm256_add_pd(c4, _mm256_mul_pd(c5, r));
p = _mm256_add_pd(c3, _mm256_mul_pd(p, r));
p = _mm256_add_pd(half, _mm256_mul_pd(p, r));
p = _mm256_add_pd(one, _mm256_mul_pd(p, r));
p = _mm256_add_pd(one, _mm256_mul_pd(p, r));
let n_i32 = _mm256_cvtpd_epi32(n_f);
let n_i64 = _mm256_cvtepi32_epi64(n_i32);
let shifted = _mm256_slli_epi64(_mm256_add_epi64(n_i64, bias), 52);
let pow2n = _mm256_castsi256_pd(shifted);
let result = _mm256_mul_pd(p, pow2n);
_mm256_storeu_pd(output.as_mut_ptr().add(off), result);
}
}
for i in (chunks * 4)..n {
output[i] = crate::math::exp(input[i]);
}
}
#[target_feature(enable = "avx2")]
pub(super) unsafe fn sigmoid_avx2(input: &[f64], output: &mut [f64]) {
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
let n = input.len();
let chunks = n / 4;
unsafe {
let ln2 = _mm256_set1_pd(core::f64::consts::LN_2);
let log2e = _mm256_set1_pd(core::f64::consts::LOG2_E);
let clamp_hi = _mm256_set1_pd(708.0);
let clamp_lo = _mm256_set1_pd(-708.0);
let one = _mm256_set1_pd(1.0);
let half = _mm256_set1_pd(0.5);
let c3 = _mm256_set1_pd(1.0 / 6.0);
let c4 = _mm256_set1_pd(1.0 / 24.0);
let c5 = _mm256_set1_pd(1.0 / 120.0);
let bias = _mm256_set1_epi64x(1023);
let neg_one = _mm256_set1_pd(-1.0);
for i in 0..chunks {
let off = i * 4;
let x = _mm256_loadu_pd(input.as_ptr().add(off));
let neg_x = _mm256_mul_pd(x, neg_one);
let neg_x = _mm256_min_pd(clamp_hi, _mm256_max_pd(clamp_lo, neg_x));
let x_scaled = _mm256_mul_pd(neg_x, log2e);
let n_f = _mm256_floor_pd(_mm256_add_pd(x_scaled, half));
let r = _mm256_sub_pd(neg_x, _mm256_mul_pd(n_f, ln2));
let mut p = _mm256_add_pd(c4, _mm256_mul_pd(c5, r));
p = _mm256_add_pd(c3, _mm256_mul_pd(p, r));
p = _mm256_add_pd(half, _mm256_mul_pd(p, r));
p = _mm256_add_pd(one, _mm256_mul_pd(p, r));
p = _mm256_add_pd(one, _mm256_mul_pd(p, r));
let n_i32 = _mm256_cvtpd_epi32(n_f);
let n_i64 = _mm256_cvtepi32_epi64(n_i32);
let shifted = _mm256_slli_epi64(_mm256_add_epi64(n_i64, bias), 52);
let pow2n = _mm256_castsi256_pd(shifted);
let exp_neg_x = _mm256_mul_pd(p, pow2n);
let result = _mm256_div_pd(one, _mm256_add_pd(one, exp_neg_x));
_mm256_storeu_pd(output.as_mut_ptr().add(off), result);
}
}
for i in (chunks * 4)..n {
output[i] = crate::math::sigmoid(input[i]);
}
}
#[target_feature(enable = "avx2")]
pub(super) unsafe fn silu_avx2(input: &[f64], output: &mut [f64]) {
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
unsafe {
sigmoid_avx2(input, output);
}
let n = input.len();
let chunks = n / 4;
unsafe {
for i in 0..chunks {
let off = i * 4;
let x = _mm256_loadu_pd(input.as_ptr().add(off));
let sig = _mm256_loadu_pd(output.as_ptr().add(off));
_mm256_storeu_pd(output.as_mut_ptr().add(off), _mm256_mul_pd(x, sig));
}
}
for i in (chunks * 4)..n {
output[i] *= input[i];
}
}
}
pub fn simd_dot(a: &[f64], b: &[f64]) -> f64 {
#[cfg(all(target_arch = "x86_64", feature = "simd-avx2"))]
{
if is_x86_feature_detected!("avx2") {
return unsafe { avx2::dot_avx2(a, b) };
}
}
dot_scalar(a, b)
}
pub fn simd_mat_vec(w: &[f64], x: &[f64], rows: usize, cols: usize, out: &mut [f64]) {
assert!(
w.len() >= rows * cols,
"simd_mat_vec: w.len()={} < rows*cols={}",
w.len(),
rows * cols
);
assert!(
out.len() >= rows,
"simd_mat_vec: out.len()={} < rows={}",
out.len(),
rows
);
assert!(
x.len() >= cols,
"simd_mat_vec: x.len()={} < cols={}",
x.len(),
cols
);
#[cfg(all(target_arch = "x86_64", feature = "simd-avx2"))]
{
if is_x86_feature_detected!("avx2") {
unsafe {
avx2::mat_vec_avx2(w, x, rows, cols, out);
}
return;
}
}
mat_vec_scalar(w, x, rows, cols, out);
}
pub fn simd_tanh(input: &[f64], output: &mut [f64]) {
assert!(
output.len() >= input.len(),
"simd_tanh: output.len()={} < input.len()={}",
output.len(),
input.len()
);
#[cfg(all(target_arch = "x86_64", feature = "simd-avx2"))]
{
if is_x86_feature_detected!("avx2") {
unsafe {
avx2::tanh_avx2(input, output);
}
return;
}
}
tanh_scalar(input, output);
}
pub fn simd_exp(input: &[f64], output: &mut [f64]) {
assert!(
output.len() >= input.len(),
"simd_exp: output.len()={} < input.len()={}",
output.len(),
input.len()
);
#[cfg(all(target_arch = "x86_64", feature = "simd-avx2"))]
{
if is_x86_feature_detected!("avx2") {
unsafe {
avx2::exp_avx2(input, output);
}
return;
}
}
exp_scalar(input, output);
}
pub fn simd_sigmoid(input: &[f64], output: &mut [f64]) {
assert!(
output.len() >= input.len(),
"simd_sigmoid: output.len()={} < input.len()={}",
output.len(),
input.len()
);
#[cfg(all(target_arch = "x86_64", feature = "simd-avx2"))]
{
if is_x86_feature_detected!("avx2") {
unsafe {
avx2::sigmoid_avx2(input, output);
}
return;
}
}
sigmoid_scalar(input, output);
}
pub fn simd_silu(input: &[f64], output: &mut [f64]) {
assert!(
output.len() >= input.len(),
"simd_silu: output.len()={} < input.len()={}",
output.len(),
input.len()
);
#[cfg(all(target_arch = "x86_64", feature = "simd-avx2"))]
{
if is_x86_feature_detected!("avx2") {
unsafe {
avx2::silu_avx2(input, output);
}
return;
}
}
silu_scalar(input, output);
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
use alloc::vec::Vec;
struct TestRng(u64);
impl TestRng {
fn new(seed: u64) -> Self {
Self(seed)
}
fn next_u64(&mut self) -> u64 {
let mut x = self.0;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.0 = x;
x
}
fn next_f64(&mut self) -> f64 {
(self.next_u64() >> 11) as f64 / ((1u64 << 53) as f64) * 2.0 - 1.0
}
fn fill_vec(&mut self, n: usize) -> Vec<f64> {
(0..n).map(|_| self.next_f64()).collect()
}
}
#[test]
fn dot_empty_returns_zero() {
let a: [f64; 0] = [];
let b: [f64; 0] = [];
assert_eq!(simd_dot(&a, &b), 0.0, "dot of empty slices should be 0");
}
#[test]
fn dot_single_element() {
let a = [3.0];
let b = [4.0];
assert!(
(simd_dot(&a, &b) - 12.0).abs() < 1e-12,
"dot([3], [4]) should be 12, got {}",
simd_dot(&a, &b)
);
}
#[test]
fn dot_known_result() {
let a = [1.0, 2.0, 3.0];
let b = [4.0, 5.0, 6.0];
let result = simd_dot(&a, &b);
assert!(
(result - 32.0).abs() < 1e-12,
"dot([1,2,3], [4,5,6]) should be 32, got {}",
result
);
}
#[test]
fn dot_large_matches_scalar() {
let mut rng = TestRng::new(42);
let a = rng.fill_vec(1000);
let b = rng.fill_vec(1000);
let simd_result = simd_dot(&a, &b);
let scalar_result = dot_scalar(&a, &b);
assert!(
(simd_result - scalar_result).abs() < 1e-9,
"1000-element dot: SIMD={} vs scalar={}, diff={}",
simd_result,
scalar_result,
(simd_result - scalar_result).abs()
);
}
#[test]
fn dot_mismatched_lengths() {
let a = [1.0, 2.0, 3.0, 999.0];
let b = [4.0, 5.0, 6.0];
let result = simd_dot(&a, &b);
assert!(
(result - 32.0).abs() < 1e-12,
"mismatched lengths should use min, expected 32, got {}",
result
);
}
#[test]
fn dot_non_aligned_length() {
let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
let b = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
let result = simd_dot(&a, &b);
assert!(
(result - 28.0).abs() < 1e-12,
"dot of [1..7] with [1..1] should be 28, got {}",
result
);
}
#[test]
fn dot_negative_values() {
let a = [-1.0, -2.0, -3.0, -4.0];
let b = [4.0, 3.0, 2.0, 1.0];
let result = simd_dot(&a, &b);
assert!(
(result - (-20.0)).abs() < 1e-12,
"expected -20, got {}",
result
);
}
#[test]
fn dot_orthogonal_vectors() {
let a = [1.0, 0.0, 0.0, 0.0];
let b = [0.0, 1.0, 0.0, 0.0];
let result = simd_dot(&a, &b);
assert!(
result.abs() < 1e-12,
"orthogonal vectors should have dot=0, got {}",
result
);
}
#[test]
fn mat_vec_identity_like() {
let w = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
let x = [1.0, 2.0, 3.0];
let mut out = [0.0; 3];
simd_mat_vec(&w, &x, 3, 3, &mut out);
assert!(
(out[0] - 1.0).abs() < 1e-12,
"identity row 0: expected 1, got {}",
out[0]
);
assert!(
(out[1] - 2.0).abs() < 1e-12,
"identity row 1: expected 2, got {}",
out[1]
);
assert!(
(out[2] - 3.0).abs() < 1e-12,
"identity row 2: expected 3, got {}",
out[2]
);
}
#[test]
fn mat_vec_known_result() {
let w = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let x = [1.0, 2.0, 3.0];
let mut out = [0.0; 2];
simd_mat_vec(&w, &x, 2, 3, &mut out);
assert!(
(out[0] - 14.0).abs() < 1e-12,
"row 0: expected 14, got {}",
out[0]
);
assert!(
(out[1] - 32.0).abs() < 1e-12,
"row 1: expected 32, got {}",
out[1]
);
}
#[test]
fn mat_vec_large_matches_scalar() {
let mut rng = TestRng::new(7777);
let rows = 100;
let cols = 100;
let w = rng.fill_vec(rows * cols);
let x = rng.fill_vec(cols);
let mut out_simd = vec![0.0; rows];
let mut out_scalar = vec![0.0; rows];
simd_mat_vec(&w, &x, rows, cols, &mut out_simd);
mat_vec_scalar(&w, &x, rows, cols, &mut out_scalar);
for i in 0..rows {
assert!(
(out_simd[i] - out_scalar[i]).abs() < 1e-9,
"row {}: SIMD={} vs scalar={}, diff={}",
i,
out_simd[i],
out_scalar[i],
(out_simd[i] - out_scalar[i]).abs()
);
}
}
#[test]
fn mat_vec_single_row() {
let w = [1.0, 2.0, 3.0, 4.0, 5.0];
let x = [2.0, 2.0, 2.0, 2.0, 2.0];
let mut out = [0.0; 1];
simd_mat_vec(&w, &x, 1, 5, &mut out);
assert!(
(out[0] - 30.0).abs() < 1e-12,
"single-row mat_vec should be dot product, expected 30, got {}",
out[0]
);
}
#[test]
fn mat_vec_single_element() {
let w = [7.0];
let x = [3.0];
let mut out = [0.0; 1];
simd_mat_vec(&w, &x, 1, 1, &mut out);
assert!(
(out[0] - 21.0).abs() < 1e-12,
"1x1 mat_vec: 7*3=21, got {}",
out[0]
);
}
#[test]
#[should_panic(expected = "simd_mat_vec: w.len()")]
fn mat_vec_panics_w_too_short() {
let w = [1.0, 2.0]; let x = [1.0, 2.0, 3.0];
let mut out = [0.0; 2];
simd_mat_vec(&w, &x, 2, 3, &mut out);
}
#[test]
#[should_panic(expected = "simd_mat_vec: out.len()")]
fn mat_vec_panics_out_too_short() {
let w = [1.0; 6];
let x = [1.0; 3];
let mut out = [0.0; 1]; simd_mat_vec(&w, &x, 2, 3, &mut out);
}
#[test]
#[should_panic(expected = "simd_mat_vec: x.len()")]
fn mat_vec_panics_x_too_short() {
let w = [1.0; 6];
let x = [1.0; 2]; let mut out = [0.0; 2];
simd_mat_vec(&w, &x, 2, 3, &mut out);
}
#[cfg(all(target_arch = "x86_64", feature = "std"))]
#[test]
fn simd_available_on_x86() {
let a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let b = [8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
let result = simd_dot(&a, &b);
assert!(
(result - 120.0).abs() < 1e-12,
"8-element dot product should be 120, got {}",
result
);
}
#[test]
fn tanh_known_values() {
let input = [0.0, 1.0, -1.0, 5.0, -5.0, 0.5];
let mut output = [0.0; 6];
simd_tanh(&input, &mut output);
let expected = [0.0, 0.7616, -0.7616, 0.9999, -0.9999, 0.4621];
for (i, (&got, &exp)) in output.iter().zip(expected.iter()).enumerate() {
assert!(
(got - exp).abs() < 0.01,
"tanh[{i}]: expected ~{exp}, got {got}"
);
}
}
#[test]
fn tanh_matches_scalar() {
let mut rng = TestRng::new(42);
let input = rng.fill_vec(100);
let mut simd_out = vec![0.0; 100];
let mut scalar_out = vec![0.0; 100];
simd_tanh(&input, &mut simd_out);
for (i, &x) in input.iter().enumerate() {
scalar_out[i] = crate::math::tanh(x);
}
for i in 0..100 {
assert!(
(simd_out[i] - scalar_out[i]).abs() < 0.01,
"tanh[{i}]: SIMD={} vs scalar={}",
simd_out[i],
scalar_out[i]
);
}
}
#[test]
fn exp_known_values() {
let input = [0.0, 1.0, -1.0, 2.0, -2.0];
let mut output = [0.0; 5];
simd_exp(&input, &mut output);
let expected = [
1.0,
core::f64::consts::E,
1.0 / core::f64::consts::E,
core::f64::consts::E * core::f64::consts::E,
1.0 / (core::f64::consts::E * core::f64::consts::E),
];
for (i, (&got, &exp)) in output.iter().zip(expected.iter()).enumerate() {
let rel = (got - exp).abs() / exp.abs().max(1e-15);
assert!(
rel < 1e-5,
"exp[{i}]: expected {exp}, got {got}, rel_err={rel}"
);
}
}
#[test]
fn exp_matches_scalar() {
let mut rng = TestRng::new(99);
let input: Vec<f64> = (0..100).map(|_| rng.next_f64() * 10.0).collect();
let mut simd_out = vec![0.0; 100];
let mut scalar_out = vec![0.0; 100];
simd_exp(&input, &mut simd_out);
for (i, &x) in input.iter().enumerate() {
scalar_out[i] = crate::math::exp(x);
}
for i in 0..100 {
let rel = (simd_out[i] - scalar_out[i]).abs() / scalar_out[i].abs().max(1e-15);
assert!(
rel < 1e-5,
"exp[{i}] (x={}): SIMD={} vs scalar={}, rel_err={}",
input[i],
simd_out[i],
scalar_out[i],
rel
);
}
}
#[test]
fn exp_extreme_values() {
let input = [700.0, -700.0, 0.0, 100.0, -100.0];
let mut output = [0.0; 5];
simd_exp(&input, &mut output);
assert!(output[0].is_finite(), "exp(700) should be finite");
assert!(output[0] > 0.0, "exp(700) should be positive");
assert!(output[1] > 0.0, "exp(-700) should be positive");
assert!(output[1].is_finite(), "exp(-700) should be finite");
assert!((output[2] - 1.0).abs() < 1e-12, "exp(0) should be 1.0");
}
#[test]
fn sigmoid_known_values() {
let input = [0.0, 10.0, -10.0, 1.0];
let mut output = [0.0; 4];
simd_sigmoid(&input, &mut output);
assert!(
(output[0] - 0.5).abs() < 0.01,
"sigmoid(0) should be ~0.5, got {}",
output[0]
);
assert!(
output[1] > 0.99,
"sigmoid(10) should be ~1.0, got {}",
output[1]
);
assert!(
output[2] < 0.01,
"sigmoid(-10) should be ~0.0, got {}",
output[2]
);
}
#[test]
fn sigmoid_matches_scalar() {
let mut rng = TestRng::new(123);
let input: Vec<f64> = (0..100).map(|_| rng.next_f64() * 20.0 - 10.0).collect();
let mut simd_out = vec![0.0; 100];
let mut scalar_out = vec![0.0; 100];
simd_sigmoid(&input, &mut simd_out);
for (i, &x) in input.iter().enumerate() {
scalar_out[i] = crate::math::sigmoid(x);
}
for i in 0..100 {
assert!(
(simd_out[i] - scalar_out[i]).abs() < 1e-6,
"sigmoid[{i}] (x={}): SIMD={} vs scalar={}, diff={}",
input[i],
simd_out[i],
scalar_out[i],
(simd_out[i] - scalar_out[i]).abs()
);
}
}
#[test]
fn silu_known_values() {
let input = [0.0, 1.0, -1.0, 3.0];
let mut output = [0.0; 4];
simd_silu(&input, &mut output);
assert!(
output[0].abs() < 0.01,
"silu(0) should be ~0, got {}",
output[0]
);
assert!(
(output[1] - 0.731).abs() < 0.01,
"silu(1) should be ~0.731, got {}",
output[1]
);
}
#[test]
fn silu_matches_scalar() {
let mut rng = TestRng::new(456);
let input: Vec<f64> = (0..100).map(|_| rng.next_f64() * 10.0 - 5.0).collect();
let mut simd_out = vec![0.0; 100];
simd_silu(&input, &mut simd_out);
for (i, &x) in input.iter().enumerate() {
let expected = x * crate::math::sigmoid(x);
assert!(
(simd_out[i] - expected).abs() < 1e-6,
"silu[{i}] (x={}): SIMD={} vs scalar={}, diff={}",
x,
simd_out[i],
expected,
(simd_out[i] - expected).abs()
);
}
}
#[test]
fn activations_handle_empty() {
let input: [f64; 0] = [];
let mut output: [f64; 0] = [];
simd_tanh(&input, &mut output);
simd_exp(&input, &mut output);
simd_sigmoid(&input, &mut output);
simd_silu(&input, &mut output);
}
}