use crate::tensor::Tensor;
use core::f32::consts::PI;
pub fn relu(x: &Tensor) -> Tensor {
let mut out = vec![0.0f32; x.numel()];
for (o, &v) in out.iter_mut().zip(x.data()) {
*o = v.max(0.0);
}
Tensor::from_vec(out, x.shape().as_slice())
}
pub fn silu(x: &Tensor) -> Tensor {
let mut out = vec![0.0f32; x.numel()];
for (o, &v) in out.iter_mut().zip(x.data()) {
*o = v / (1.0 + (-v).exp());
}
Tensor::from_vec(out, x.shape().as_slice())
}
pub fn gelu_tanh(x: &Tensor) -> Tensor {
let c = (2.0 / PI).sqrt();
let mut out = vec![0.0f32; x.numel()];
for (o, &v) in out.iter_mut().zip(x.data()) {
let inner = c * (v + 0.044715 * v * v * v);
*o = 0.5 * v * (1.0 + inner.tanh());
}
Tensor::from_vec(out, x.shape().as_slice())
}
pub fn gelu_erf(x: &Tensor) -> Tensor {
let inv_sqrt2 = 1.0 / 2.0f32.sqrt();
let mut out = vec![0.0f32; x.numel()];
for (o, &v) in out.iter_mut().zip(x.data()) {
*o = 0.5 * v * (1.0 + erf_approx(v * inv_sqrt2));
}
Tensor::from_vec(out, x.shape().as_slice())
}
fn erf_approx(x: f32) -> f32 {
const A1: f32 = 0.254_829_6;
const A2: f32 = -0.284_496_72;
const A3: f32 = 1.421_413_8;
const A4: f32 = -1.453_152_1;
const A5: f32 = 1.061_405_4;
const P: f32 = 0.327_591_1;
let sign = if x < 0.0 { -1.0 } else { 1.0 };
let x = x.abs();
let t = 1.0 / (1.0 + P * x);
let y = 1.0 - (((((A5 * t + A4) * t) + A3) * t + A2) * t + A1) * t * (-x * x).exp();
sign * y
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
(a - b).abs() < eps
}
#[test]
fn relu_clamps_negatives() {
let x = Tensor::from_vec(vec![-2.0, -0.5, 0.0, 0.5, 2.0], &[5]);
let y = relu(&x);
assert_eq!(y.data(), &[0.0, 0.0, 0.0, 0.5, 2.0]);
}
#[test]
fn silu_at_zero_is_zero() {
let x = Tensor::from_vec(vec![0.0], &[1]);
let y = silu(&x);
assert!(approx_eq(y.data()[0], 0.0, 1e-7));
}
#[test]
fn gelu_tanh_at_zero_is_zero() {
let x = Tensor::from_vec(vec![0.0], &[1]);
let y = gelu_tanh(&x);
assert!(approx_eq(y.data()[0], 0.0, 1e-7));
}
#[test]
fn gelu_erf_matches_known_values() {
let x = Tensor::from_vec(vec![1.0, -1.0], &[2]);
let y = gelu_erf(&x);
assert!(approx_eq(y.data()[0], 0.8413_447, 1e-4));
assert!(approx_eq(y.data()[1], -0.1586_553, 1e-4));
}
#[test]
fn erf_known_values() {
assert!(approx_eq(erf_approx(0.0), 0.0, 1e-6));
assert!(approx_eq(erf_approx(1.0), 0.8427_008, 1e-5));
assert!(approx_eq(erf_approx(-1.0), -0.8427_008, 1e-5));
}
}