#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
const NEON_LANE_WIDTH: usize = 4;
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn exp_neon(x: float32x4_t) -> float32x4_t {
let log2e = vdupq_n_f32(1.442695041); let ln2_hi = vdupq_n_f32(0.693359375); let ln2_lo = vdupq_n_f32(-2.12194440e-4);
let c1 = vdupq_n_f32(1.0);
let c2 = vdupq_n_f32(0.5);
let c3 = vdupq_n_f32(0.166666666666); let c4 = vdupq_n_f32(0.041666666666); let c5 = vdupq_n_f32(0.008333333333);
let half = vdupq_n_f32(0.5);
let n = vrndnq_f32(vmulq_f32(x, log2e));
let r = vsubq_f32(vsubq_f32(x, vmulq_f32(n, ln2_hi)), vmulq_f32(n, ln2_lo));
let r2 = vmulq_f32(r, r);
let r3 = vmulq_f32(r2, r);
let r4 = vmulq_f32(r2, r2);
let r5 = vmulq_f32(r4, r);
let poly = vaddq_f32(
c1,
vaddq_f32(
r,
vaddq_f32(
vmulq_f32(r2, c2),
vaddq_f32(
vmulq_f32(r3, c3),
vaddq_f32(vmulq_f32(r4, c4), vmulq_f32(r5, c5)),
),
),
),
);
let n_i32 = vcvtq_s32_f32(n);
let bias = vdupq_n_s32(127);
let shift = vdupq_n_s32(23);
let exp_n = vreinterpretq_f32_s32(vshlq_s32(vaddq_s32(n_i32, bias), shift));
vmulq_f32(poly, exp_n)
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn sigmoid_neon(x: float32x4_t) -> float32x4_t {
let one = vdupq_n_f32(1.0);
let neg_x = vnegq_f32(x);
let exp_neg_x = exp_neon(neg_x);
let denom = vaddq_f32(one, exp_neg_x);
let recip_est = vrecpeq_f32(denom);
let recip = vmulq_f32(recip_est, vrecpsq_f32(denom, recip_est));
recip
}
#[inline]
pub fn silu(x: &mut [f32]) {
#[cfg(target_arch = "aarch64")]
unsafe {
silu_neon_impl(x);
}
#[cfg(not(target_arch = "aarch64"))]
{
silu_scalar(x);
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn silu_neon_impl(x: &mut [f32]) {
let len = x.len();
let ptr = x.as_mut_ptr();
let chunks = len / NEON_LANE_WIDTH;
let mut idx = 0usize;
for _ in 0..chunks {
let v = vld1q_f32(ptr.add(idx));
let sigmoid_v = sigmoid_neon(v);
let result = vmulq_f32(v, sigmoid_v);
vst1q_f32(ptr.add(idx), result);
idx += NEON_LANE_WIDTH;
}
for i in idx..len {
let v = *ptr.add(i);
*ptr.add(i) = v / (1.0 + (-v).exp());
}
}
#[cfg(not(target_arch = "aarch64"))]
#[inline]
fn silu_scalar(x: &mut [f32]) {
for v in x.iter_mut() {
*v = *v / (1.0 + (-*v).exp());
}
}
#[inline]
pub fn silu_vec(x: &[f32]) -> Vec<f32> {
let mut result = x.to_vec();
silu(&mut result);
result
}
#[inline]
pub fn gelu(x: &mut [f32]) {
#[cfg(target_arch = "aarch64")]
unsafe {
gelu_neon_impl(x);
}
#[cfg(not(target_arch = "aarch64"))]
{
gelu_scalar(x);
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn tanh_neon(x: float32x4_t) -> float32x4_t {
let two = vdupq_n_f32(2.0);
let one = vdupq_n_f32(1.0);
let exp_2x = exp_neon(vmulq_f32(two, x));
let numerator = vsubq_f32(exp_2x, one);
let denominator = vaddq_f32(exp_2x, one);
let recip_est = vrecpeq_f32(denominator);
let recip = vmulq_f32(recip_est, vrecpsq_f32(denominator, recip_est));
vmulq_f32(numerator, recip)
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn gelu_neon_impl(x: &mut [f32]) {
let len = x.len();
let ptr = x.as_mut_ptr();
let chunks = len / NEON_LANE_WIDTH;
let half = vdupq_n_f32(0.5);
let one = vdupq_n_f32(1.0);
let sqrt_2_over_pi = vdupq_n_f32(0.7978845608); let coeff = vdupq_n_f32(0.044715);
let mut idx = 0usize;
for _ in 0..chunks {
let v = vld1q_f32(ptr.add(idx));
let v2 = vmulq_f32(v, v);
let v3 = vmulq_f32(v2, v);
let inner = vmulq_f32(sqrt_2_over_pi, vaddq_f32(v, vmulq_f32(coeff, v3)));
let tanh_inner = tanh_neon(inner);
let result = vmulq_f32(half, vmulq_f32(v, vaddq_f32(one, tanh_inner)));
vst1q_f32(ptr.add(idx), result);
idx += NEON_LANE_WIDTH;
}
for i in idx..len {
let v = *ptr.add(i);
let inner = 0.7978845608 * (v + 0.044715 * v * v * v);
*ptr.add(i) = 0.5 * v * (1.0 + inner.tanh());
}
}
#[cfg(not(target_arch = "aarch64"))]
#[inline]
fn gelu_scalar(x: &mut [f32]) {
const SQRT_2_OVER_PI: f32 = 0.7978845608;
const COEFF: f32 = 0.044715;
for v in x.iter_mut() {
let inner = SQRT_2_OVER_PI * (*v + COEFF * *v * *v * *v);
*v = 0.5 * *v * (1.0 + inner.tanh());
}
}
#[inline]
pub fn gelu_vec(x: &[f32]) -> Vec<f32> {
let mut result = x.to_vec();
gelu(&mut result);
result
}
#[inline]
pub fn gelu_exact(x: &mut [f32]) {
const INV_SQRT_2: f32 = 0.7071067812;
for v in x.iter_mut() {
*v = *v * 0.5 * (1.0 + erf(*v * INV_SQRT_2));
}
}
fn erf(x: f32) -> f32 {
let a1 = 0.254829592;
let a2 = -0.284496736;
let a3 = 1.421413741;
let a4 = -1.453152027;
let a5 = 1.061405429;
let p = 0.3275911;
let sign = if x < 0.0 { -1.0 } else { 1.0 };
let x = x.abs();
let t = 1.0 / (1.0 + p * x);
let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
sign * y
}
#[inline]
pub fn relu(x: &mut [f32]) {
#[cfg(target_arch = "aarch64")]
unsafe {
relu_neon_impl(x);
}
#[cfg(not(target_arch = "aarch64"))]
{
relu_scalar(x);
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn relu_neon_impl(x: &mut [f32]) {
let len = x.len();
let ptr = x.as_mut_ptr();
let chunks = len / NEON_LANE_WIDTH;
let zero = vdupq_n_f32(0.0);
let mut idx = 0usize;
for _ in 0..chunks {
let v = vld1q_f32(ptr.add(idx));
let result = vmaxq_f32(v, zero);
vst1q_f32(ptr.add(idx), result);
idx += NEON_LANE_WIDTH;
}
for i in idx..len {
let v = *ptr.add(i);
*ptr.add(i) = v.max(0.0);
}
}
#[cfg(not(target_arch = "aarch64"))]
#[inline]
fn relu_scalar(x: &mut [f32]) {
for v in x.iter_mut() {
*v = v.max(0.0);
}
}
#[inline]
pub fn relu_vec(x: &[f32]) -> Vec<f32> {
let mut result = x.to_vec();
relu(&mut result);
result
}
#[inline]
pub fn leaky_relu(x: &mut [f32], alpha: f32) {
#[cfg(target_arch = "aarch64")]
unsafe {
leaky_relu_neon_impl(x, alpha);
}
#[cfg(not(target_arch = "aarch64"))]
{
for v in x.iter_mut() {
*v = if *v > 0.0 { *v } else { alpha * *v };
}
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn leaky_relu_neon_impl(x: &mut [f32], alpha: f32) {
let len = x.len();
let ptr = x.as_mut_ptr();
let chunks = len / NEON_LANE_WIDTH;
let alpha_vec = vdupq_n_f32(alpha);
let zero = vdupq_n_f32(0.0);
let mut idx = 0usize;
for _ in 0..chunks {
let v = vld1q_f32(ptr.add(idx));
let alpha_v = vmulq_f32(v, alpha_vec);
let mask = vcgtq_f32(v, zero);
let result = vbslq_f32(mask, v, alpha_v);
vst1q_f32(ptr.add(idx), result);
idx += NEON_LANE_WIDTH;
}
for i in idx..len {
let v = *ptr.add(i);
*ptr.add(i) = if v > 0.0 { v } else { alpha * v };
}
}
#[inline]
pub fn softmax(x: &mut [f32]) {
if x.is_empty() {
return;
}
#[cfg(target_arch = "aarch64")]
unsafe {
softmax_neon_impl(x);
}
#[cfg(not(target_arch = "aarch64"))]
{
softmax_scalar(x);
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn softmax_neon_impl(x: &mut [f32]) {
let len = x.len();
let ptr = x.as_mut_ptr();
let chunks = len / NEON_LANE_WIDTH;
let mut max_vec = vdupq_n_f32(f32::NEG_INFINITY);
let mut idx = 0usize;
for _ in 0..chunks {
let v = vld1q_f32(ptr.add(idx));
max_vec = vmaxq_f32(max_vec, v);
idx += NEON_LANE_WIDTH;
}
let mut max_val = vmaxvq_f32(max_vec);
for i in idx..len {
max_val = max_val.max(*ptr.add(i));
}
let max_vec = vdupq_n_f32(max_val);
let mut sum_vec = vdupq_n_f32(0.0);
idx = 0;
for _ in 0..chunks {
let v = vld1q_f32(ptr.add(idx));
let shifted = vsubq_f32(v, max_vec);
let exp_val = exp_neon(shifted);
vst1q_f32(ptr.add(idx), exp_val);
sum_vec = vaddq_f32(sum_vec, exp_val);
idx += NEON_LANE_WIDTH;
}
let mut sum_val = vaddvq_f32(sum_vec);
for i in idx..len {
let shifted = *ptr.add(i) - max_val;
let exp_val = shifted.exp();
*ptr.add(i) = exp_val;
sum_val += exp_val;
}
let inv_sum = 1.0 / sum_val;
let inv_sum_vec = vdupq_n_f32(inv_sum);
idx = 0;
for _ in 0..chunks {
let v = vld1q_f32(ptr.add(idx));
vst1q_f32(ptr.add(idx), vmulq_f32(v, inv_sum_vec));
idx += NEON_LANE_WIDTH;
}
for i in idx..len {
*ptr.add(i) *= inv_sum;
}
}
#[cfg(not(target_arch = "aarch64"))]
#[inline]
fn softmax_scalar(x: &mut [f32]) {
let max_val = x.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0;
for v in x.iter_mut() {
*v = (*v - max_val).exp();
sum += *v;
}
let inv_sum = 1.0 / sum;
for v in x.iter_mut() {
*v *= inv_sum;
}
}
#[inline]
pub fn softmax_vec(x: &[f32]) -> Vec<f32> {
let mut result = x.to_vec();
softmax(&mut result);
result
}
#[inline]
pub fn softmax_temperature(x: &mut [f32], temperature: f32) {
if temperature <= 0.0 || x.is_empty() {
return;
}
let inv_temp = 1.0 / temperature;
for v in x.iter_mut() {
*v *= inv_temp;
}
softmax(x);
}
#[inline]
pub fn batch_silu(data: &mut [f32], stride: usize) {
for chunk in data.chunks_mut(stride) {
silu(chunk);
}
}
#[inline]
pub fn batch_gelu(data: &mut [f32], stride: usize) {
for chunk in data.chunks_mut(stride) {
gelu(chunk);
}
}
#[inline]
pub fn batch_softmax(data: &mut [f32], stride: usize) {
for chunk in data.chunks_mut(stride) {
softmax(chunk);
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPSILON: f32 = 1e-4;
fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
(a - b).abs() < eps
}
#[test]
fn test_silu_basic() {
let mut x = vec![0.0, 1.0, -1.0, 2.0, -2.0, 0.5, -0.5, 3.0];
let expected: Vec<f32> = x.iter().map(|&v: &f32| v / (1.0 + (-v).exp())).collect();
silu(&mut x);
for (got, exp) in x.iter().zip(expected.iter()) {
assert!(
approx_eq(*got, *exp, EPSILON),
"SiLU mismatch: got {}, expected {}",
got,
exp
);
}
}
#[test]
fn test_silu_zero() {
let mut x = vec![0.0];
silu(&mut x);
assert!(approx_eq(x[0], 0.0, EPSILON));
}
#[test]
fn test_silu_one() {
let mut x = vec![1.0];
silu(&mut x);
assert!(approx_eq(x[0], 0.7311, 0.001));
}
#[test]
fn test_silu_large_vector() {
let mut x: Vec<f32> = (0..128).map(|i| (i as f32 - 64.0) * 0.1).collect();
let expected: Vec<f32> = x.iter().map(|&v: &f32| v / (1.0 + (-v).exp())).collect();
silu(&mut x);
for (i, (got, exp)) in x.iter().zip(expected.iter()).enumerate() {
assert!(
approx_eq(*got, *exp, EPSILON),
"SiLU mismatch at index {}: got {}, expected {}",
i,
got,
exp
);
}
}
#[test]
fn test_gelu_basic() {
let mut x = vec![0.0, 1.0, -1.0, 2.0];
let expected = vec![
0.0, 0.8412, -0.159, 1.954, ];
gelu(&mut x);
for (i, (got, exp)) in x.iter().zip(expected.iter()).enumerate() {
assert!(
approx_eq(*got, *exp, 0.01),
"GELU mismatch at index {}: got {}, expected {}",
i,
got,
exp
);
}
}
#[test]
fn test_gelu_zero() {
let mut x = vec![0.0];
gelu(&mut x);
assert!(approx_eq(x[0], 0.0, EPSILON));
}
#[test]
fn test_gelu_large_vector() {
let mut x: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) * 0.1).collect();
let original = x.clone();
gelu(&mut x);
for (i, (&orig, &result)) in original.iter().zip(x.iter()).enumerate() {
if orig > 1.0 {
assert!(
result > 0.0,
"GELU({}) should be positive, got {}",
orig,
result
);
}
if orig > 3.0 {
assert!(
approx_eq(result, orig, 0.1),
"GELU({}) should approach x, got {}",
orig,
result
);
}
}
}
#[test]
fn test_relu_basic() {
let mut x = vec![1.0, -2.0, 3.0, -4.0, 0.0, 0.5, -0.5, 10.0];
let expected = vec![1.0, 0.0, 3.0, 0.0, 0.0, 0.5, 0.0, 10.0];
relu(&mut x);
for (i, (got, exp)) in x.iter().zip(expected.iter()).enumerate() {
assert_eq!(
*got, *exp,
"ReLU mismatch at index {}: got {}, expected {}",
i, got, exp
);
}
}
#[test]
fn test_relu_all_positive() {
let mut x = vec![1.0, 2.0, 3.0, 4.0];
let expected = x.clone();
relu(&mut x);
assert_eq!(x, expected);
}
#[test]
fn test_relu_all_negative() {
let mut x = vec![-1.0, -2.0, -3.0, -4.0];
relu(&mut x);
assert!(x.iter().all(|&v| v == 0.0));
}
#[test]
fn test_leaky_relu() {
let mut x = vec![1.0, -2.0, 3.0, -4.0];
let alpha = 0.01;
let expected = vec![1.0, -0.02, 3.0, -0.04];
leaky_relu(&mut x, alpha);
for (i, (got, exp)) in x.iter().zip(expected.iter()).enumerate() {
assert!(
approx_eq(*got, *exp, EPSILON),
"Leaky ReLU mismatch at index {}: got {}, expected {}",
i,
got,
exp
);
}
}
#[test]
fn test_softmax_basic() {
let mut x = vec![1.0, 2.0, 3.0, 4.0];
softmax(&mut x);
let sum: f32 = x.iter().sum();
assert!(
approx_eq(sum, 1.0, EPSILON),
"Softmax sum should be 1.0, got {}",
sum
);
assert!(x.iter().all(|&v| v > 0.0));
for i in 1..x.len() {
assert!(x[i] > x[i - 1], "Softmax should preserve order");
}
}
#[test]
fn test_softmax_uniform() {
let mut x = vec![1.0, 1.0, 1.0, 1.0];
softmax(&mut x);
for v in &x {
assert!(approx_eq(*v, 0.25, EPSILON));
}
}
#[test]
fn test_softmax_numerical_stability() {
let mut x = vec![1000.0, 1001.0, 1002.0, 1003.0];
softmax(&mut x);
let sum: f32 = x.iter().sum();
assert!(
approx_eq(sum, 1.0, EPSILON),
"Softmax sum should be 1.0, got {}",
sum
);
assert!(x.iter().all(|&v| v.is_finite()), "Values should be finite");
}
#[test]
fn test_softmax_temperature() {
let x = vec![1.0, 2.0, 3.0, 4.0];
let mut low_temp = x.clone();
softmax_temperature(&mut low_temp, 0.5);
let mut high_temp = x.clone();
softmax_temperature(&mut high_temp, 2.0);
let low_max = low_temp.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let high_max = high_temp.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
assert!(
low_max > high_max,
"Low temperature should give sharper distribution"
);
}
#[test]
fn test_batch_silu() {
let mut data = vec![0.0, 1.0, -1.0, 2.0, 0.5, -0.5, 1.5, -1.5];
let stride = 4;
let expected: Vec<f32> = data.iter().map(|&v: &f32| v / (1.0 + (-v).exp())).collect();
batch_silu(&mut data, stride);
for (i, (got, exp)) in data.iter().zip(expected.iter()).enumerate() {
assert!(
approx_eq(*got, *exp, EPSILON),
"Batch SiLU mismatch at index {}: got {}, expected {}",
i,
got,
exp
);
}
}
#[test]
fn test_batch_softmax() {
let mut data = vec![1.0, 2.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0];
let stride = 4;
batch_softmax(&mut data, stride);
let sum1: f32 = data[0..4].iter().sum();
assert!(approx_eq(sum1, 1.0, EPSILON));
let sum2: f32 = data[4..8].iter().sum();
assert!(approx_eq(sum2, 1.0, EPSILON));
for &v in &data[4..8] {
assert!(approx_eq(v, 0.25, EPSILON));
}
}
#[test]
fn test_silu_vec() {
let x = vec![0.0, 1.0, -1.0, 2.0];
let original = x.clone();
let result = silu_vec(&x);
assert_eq!(x, original);
assert!(approx_eq(result[0], 0.0, EPSILON));
assert!(approx_eq(result[1], 0.7311, 0.001));
}
#[test]
fn test_softmax_vec() {
let x = vec![1.0, 2.0, 3.0, 4.0];
let result = softmax_vec(&x);
let sum: f32 = result.iter().sum();
assert!(approx_eq(sum, 1.0, EPSILON));
}
#[test]
fn test_empty_input() {
let mut empty: Vec<f32> = vec![];
silu(&mut empty);
gelu(&mut empty);
relu(&mut empty);
softmax(&mut empty);
}
#[test]
fn test_single_element() {
let mut x = vec![2.0];
softmax(&mut x);
assert!(
approx_eq(x[0], 1.0, EPSILON),
"Softmax of single element should be 1.0"
);
}
#[test]
fn test_non_aligned_length() {
let mut x = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
let expected: Vec<f32> = x.iter().map(|&v: &f32| v / (1.0 + (-v).exp())).collect();
silu(&mut x);
for (i, (got, exp)) in x.iter().zip(expected.iter()).enumerate() {
assert!(
approx_eq(*got, *exp, EPSILON),
"Non-aligned SiLU mismatch at {}: got {}, expected {}",
i,
got,
exp
);
}
}
}