#[cfg(feature = "simd")]
use wide::f32x8;
#[cfg(feature = "simd")]
pub fn simd_relu_f32(input: &[f32], output: &mut [f32]) {
let len = input.len().min(output.len());
let simd_len = len / 8;
let remainder_start = simd_len * 8;
let zero = f32x8::splat(0.0);
for i in 0..simd_len {
let idx = i * 8;
let input_simd = f32x8::from([
input[idx],
input[idx + 1],
input[idx + 2],
input[idx + 3],
input[idx + 4],
input[idx + 5],
input[idx + 6],
input[idx + 7],
]);
let result_simd = input_simd.max(zero);
let result_array: [f32; 8] = result_simd.into();
output[idx..idx + 8].copy_from_slice(&result_array);
}
for i in remainder_start..len {
output[i] = input[i].max(0.0);
}
}
#[cfg(feature = "simd")]
pub fn simd_sigmoid_f32(input: &[f32], output: &mut [f32]) {
let len = input.len().min(output.len());
let simd_len = len / 8;
let remainder_start = simd_len * 8;
let one = f32x8::splat(1.0);
for i in 0..simd_len {
let idx = i * 8;
let input_simd = f32x8::from([
input[idx],
input[idx + 1],
input[idx + 2],
input[idx + 3],
input[idx + 4],
input[idx + 5],
input[idx + 6],
input[idx + 7],
]);
let neg_input = -input_simd;
let exp_neg = simd_fast_exp_f32x8(neg_input);
let result_simd = one / (one + exp_neg);
let result_array: [f32; 8] = result_simd.into();
output[idx..idx + 8].copy_from_slice(&result_array);
}
for i in remainder_start..len {
output[i] = 1.0 / (1.0 + (-input[i]).exp());
}
}
#[cfg(feature = "simd")]
pub fn simd_tanh_f32(input: &[f32], output: &mut [f32]) {
let len = input.len().min(output.len());
let simd_len = len / 8;
let remainder_start = simd_len * 8;
for i in 0..simd_len {
let idx = i * 8;
let input_simd = f32x8::from([
input[idx],
input[idx + 1],
input[idx + 2],
input[idx + 3],
input[idx + 4],
input[idx + 5],
input[idx + 6],
input[idx + 7],
]);
let two_x = input_simd * f32x8::splat(2.0);
let exp_2x = simd_fast_exp_f32x8(two_x);
let one = f32x8::splat(1.0);
let result_simd = (exp_2x - one) / (exp_2x + one);
let result_array: [f32; 8] = result_simd.into();
output[idx..idx + 8].copy_from_slice(&result_array);
}
for i in remainder_start..len {
output[i] = input[i].tanh();
}
}
#[cfg(feature = "simd")]
pub fn simd_gelu_f32(input: &[f32], output: &mut [f32]) {
let len = input.len().min(output.len());
let simd_len = len / 8;
let remainder_start = simd_len * 8;
let half = f32x8::splat(0.5);
let one = f32x8::splat(1.0);
let sqrt_2_pi = f32x8::splat((2.0 / core::f32::consts::PI).sqrt());
for i in 0..simd_len {
let idx = i * 8;
let x = f32x8::from([
input[idx],
input[idx + 1],
input[idx + 2],
input[idx + 3],
input[idx + 4],
input[idx + 5],
input[idx + 6],
input[idx + 7],
]);
let x_cubed = x * x * x;
let inner = sqrt_2_pi * (x + f32x8::splat(0.044715) * x_cubed);
let tanh_inner = simd_fast_tanh_f32x8(inner);
let result_simd = half * x * (one + tanh_inner);
let result_array: [f32; 8] = result_simd.into();
output[idx..idx + 8].copy_from_slice(&result_array);
}
for i in remainder_start..len {
let x = input[i];
output[i] = 0.5
* x
* (1.0 + ((2.0 / core::f32::consts::PI).sqrt() * (x + 0.044715 * x * x * x)).tanh());
}
}
#[cfg(feature = "simd")]
fn simd_fast_exp_f32x8(x: f32x8) -> f32x8 {
let one = f32x8::splat(1.0);
let inv_256 = f32x8::splat(1.0 / 256.0);
let clamped = x.max(f32x8::splat(-20.0)).min(f32x8::splat(20.0));
let term = one + clamped * inv_256;
let mut result = term;
result = result * result; result = result * result; result = result * result; result = result * result; result = result * result; result = result * result; result = result * result; result = result * result;
result
}
#[cfg(feature = "simd")]
fn simd_fast_tanh_f32x8(x: f32x8) -> f32x8 {
let abs_x = x.abs();
let x_squared = x * x;
let num = x * (f32x8::splat(27.0) + x_squared);
let den = f32x8::splat(27.0) + f32x8::splat(9.0) * x_squared;
let small_approx = num / den;
let two_abs_x = f32x8::splat(2.0) * abs_x;
let exp_2abs = simd_fast_exp_f32x8(two_abs_x);
let _zero = f32x8::splat(0.0);
let _one = f32x8::splat(1.0);
let _neg_one = f32x8::splat(-1.0);
let x_sign = {
let mut sign_array: [f32; 8] = x.into();
for i in 0..8 {
sign_array[i] = if sign_array[i] > 0.0 {
1.0
} else if sign_array[i] < 0.0 {
-1.0
} else {
0.0
};
}
f32x8::from(sign_array)
};
let large_approx =
x_sign * (f32x8::splat(1.0) - f32x8::splat(2.0) / (exp_2abs + f32x8::splat(1.0)));
let abs_arr = abs_x.to_array();
let small_arr = small_approx.to_array();
let large_arr = large_approx.to_array();
let mut out_arr = [0.0f32; 8];
for i in 0..8 {
out_arr[i] = if abs_arr[i] < 2.0 {
small_arr[i]
} else {
large_arr[i]
};
}
f32x8::from(out_arr)
}
#[cfg(not(feature = "simd"))]
pub fn simd_relu_f32(input: &[f32], output: &mut [f32]) {
let len = input.len().min(output.len());
for i in 0..len {
output[i] = input[i].max(0.0);
}
}
#[cfg(not(feature = "simd"))]
pub fn simd_sigmoid_f32(input: &[f32], output: &mut [f32]) {
let len = input.len().min(output.len());
for i in 0..len {
output[i] = 1.0 / (1.0 + (-input[i]).exp());
}
}
#[cfg(not(feature = "simd"))]
pub fn simd_tanh_f32(input: &[f32], output: &mut [f32]) {
let len = input.len().min(output.len());
for i in 0..len {
output[i] = input[i].tanh();
}
}
#[cfg(not(feature = "simd"))]
pub fn simd_gelu_f32(input: &[f32], output: &mut [f32]) {
let len = input.len().min(output.len());
for i in 0..len {
let x = input[i];
output[i] = 0.5
* x
* (1.0 + ((2.0 / core::f32::consts::PI).sqrt() * (x + 0.044715 * x * x * x)).tanh());
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simd_relu_f32() {
let input = [-2.0, -1.0, 0.0, 1.0, 2.0];
let mut output = [0.0; 5];
simd_relu_f32(&input, &mut output);
assert_eq!(output[0], 0.0); assert_eq!(output[1], 0.0); assert_eq!(output[2], 0.0); assert_eq!(output[3], 1.0); assert_eq!(output[4], 2.0); }
#[test]
fn test_simd_sigmoid_f32() {
let input = [0.0];
let mut output = [0.0];
simd_sigmoid_f32(&input, &mut output);
assert!((output[0] - 0.5).abs() < 0.1); }
#[test]
fn test_simd_tanh_f32() {
let input = [0.0];
let mut output = [0.0];
simd_tanh_f32(&input, &mut output);
assert!(output[0].abs() < 0.1); }
#[test]
fn test_simd_gelu_f32() {
let input = [0.0];
let mut output = [0.0];
simd_gelu_f32(&input, &mut output);
assert!(output[0].abs() < 0.1); }
#[test]
fn test_activation_consistency() {
let input = [0.5];
let mut simd_output = [0.0];
let mut scalar_output = [0.0];
simd_relu_f32(&input, &mut simd_output);
scalar_output[0] = input[0].max(0.0);
assert_eq!(simd_output[0], scalar_output[0]);
}
}