use crate::{array::Array, dtype::Dtype, error::Result, ops};
#[inline]
fn scalar_like(value: f32, like: &Array) -> Result<Array> {
crate::error::ensure_handler_installed();
let dtype: Dtype = like.dtype()?;
ops::misc::astype(&Array::full::<f32>(&[0i32; 0], value)?, dtype)
}
pub fn silu(x: &Array) -> Result<Array> {
x.multiply(&x.sigmoid()?)
}
pub fn swiglu(gate: &Array, x: &Array) -> Result<Array> {
silu(gate)?.multiply(x)
}
pub fn gelu(x: &Array) -> Result<Array> {
let inv_sqrt2 = scalar_like(std::f32::consts::FRAC_1_SQRT_2, x)?;
let one = scalar_like(1.0, x)?;
let two = scalar_like(2.0, x)?;
let erf_term = x.multiply(&inv_sqrt2)?.erf()?;
x.multiply(&one.add(&erf_term)?)?.divide(&two)
}
pub fn gelu_approx(x: &Array) -> Result<Array> {
let half = scalar_like(0.5, x)?;
let one = scalar_like(1.0, x)?;
let sqrt_2_over_pi = scalar_like((2.0 / std::f32::consts::PI).sqrt(), x)?;
let c = scalar_like(0.044715, x)?;
let x_cubed = x.square()?.multiply(x)?;
let inner = x.add(&c.multiply(&x_cubed)?)?;
let tanh_term = sqrt_2_over_pi.multiply(&inner)?.tanh()?;
half.multiply(x)?.multiply(&one.add(&tanh_term)?)
}
pub fn gelu_fast_approx(x: &Array) -> Result<Array> {
let c = scalar_like(1.702, x)?;
x.multiply(&c.multiply(x)?.sigmoid()?)
}
#[cfg(test)]
mod tests {
use super::*;
fn sigmoid_ref(v: f32) -> f32 {
1.0 / (1.0 + (-v).exp())
}
fn assert_close(got: &[f32], want: &[f32]) {
assert_eq!(got.len(), want.len(), "length mismatch");
for (g, w) in got.iter().zip(want.iter()) {
assert!(
(g - w).abs() <= 1e-5 + 1e-5 * w.abs(),
"activation mismatch: got {g}, want {w}"
);
}
}
fn sample_input() -> Array {
Array::from_slice::<f32>(&[-2.0, -0.5, 0.0, 0.5, 2.0], &(5usize,)).unwrap()
}
#[test]
fn silu_matches_reference_formula() {
let mut out = silu(&sample_input()).unwrap();
let got = out.to_vec::<f32>().unwrap();
let want: Vec<f32> = [-2.0f32, -0.5, 0.0, 0.5, 2.0]
.iter()
.map(|&x| x * sigmoid_ref(x))
.collect();
assert_close(&got, &want);
}
#[test]
fn silu_zero_is_zero() {
let zero = Array::from_slice::<f32>(&[0.0], &(1usize,)).unwrap();
let mut out = silu(&zero).unwrap();
assert_eq!(out.to_vec::<f32>().unwrap(), vec![0.0]);
}
#[test]
fn swiglu_matches_silu_gate_times_x() {
let gate = Array::from_slice::<f32>(&[-1.0, 0.0, 1.0, 3.0], &(4usize,)).unwrap();
let x = Array::from_slice::<f32>(&[2.0, 5.0, -4.0, 0.5], &(4usize,)).unwrap();
let mut out = swiglu(&gate, &x).unwrap();
let got = out.to_vec::<f32>().unwrap();
let g = [-1.0f32, 0.0, 1.0, 3.0];
let xv = [2.0f32, 5.0, -4.0, 0.5];
let want: Vec<f32> = g
.iter()
.zip(xv.iter())
.map(|(&gi, &xi)| gi * sigmoid_ref(gi) * xi)
.collect();
assert_close(&got, &want);
}
#[test]
fn swiglu_hand_traced_scalar() {
let gate = Array::from_slice::<f32>(&[1.0], &(1usize,)).unwrap();
let x = Array::from_slice::<f32>(&[2.0], &(1usize,)).unwrap();
let mut out = swiglu(&gate, &x).unwrap();
assert_close(&out.to_vec::<f32>().unwrap(), &[1.462_117_2]);
}
#[test]
fn gelu_matches_reference_formula() {
let mut out = gelu(&sample_input()).unwrap();
let got = out.to_vec::<f32>().unwrap();
let want: Vec<f32> = [-2.0f64, -0.5, 0.0, 0.5, 2.0]
.iter()
.map(|&x| (x * (1.0 + libm_erf(x / std::f64::consts::SQRT_2)) / 2.0) as f32)
.collect();
assert_close(&got, &want);
}
#[test]
fn gelu_zero_is_zero() {
let zero = Array::from_slice::<f32>(&[0.0], &(1usize,)).unwrap();
let mut out = gelu(&zero).unwrap();
assert_close(&out.to_vec::<f32>().unwrap(), &[0.0]);
}
#[test]
fn gelu_approx_matches_reference_formula() {
let mut out = gelu_approx(&sample_input()).unwrap();
let got = out.to_vec::<f32>().unwrap();
let want: Vec<f32> = [-2.0f64, -0.5, 0.0, 0.5, 2.0]
.iter()
.map(|&x| {
let inner = (2.0f64 / std::f64::consts::PI).sqrt() * (x + 0.044715 * x * x * x);
(0.5 * x * (1.0 + inner.tanh())) as f32
})
.collect();
assert_close(&got, &want);
}
#[test]
fn gelu_approx_tracks_exact_gelu_within_error_bound() {
let x = Array::from_slice::<f32>(&[-6.0, -3.0, -1.0, 1.0, 3.0, 6.0], &(6usize,)).unwrap();
let mut exact = gelu(&x).unwrap();
let mut approx = gelu_approx(&x).unwrap();
let e = exact.to_vec::<f32>().unwrap();
let a = approx.to_vec::<f32>().unwrap();
for (ev, av) in e.iter().zip(a.iter()) {
assert!(
(ev - av).abs() < 5e-4,
"gelu_approx strayed beyond the documented 5e-4 bound: exact={ev}, approx={av}"
);
}
}
#[test]
fn gelu_fast_approx_matches_reference_formula() {
let mut out = gelu_fast_approx(&sample_input()).unwrap();
let got = out.to_vec::<f32>().unwrap();
let want: Vec<f32> = [-2.0f32, -0.5, 0.0, 0.5, 2.0]
.iter()
.map(|&x| x * sigmoid_ref(1.702 * x))
.collect();
assert_close(&got, &want);
}
#[test]
fn gelu_fast_approx_tracks_exact_gelu_within_error_bound() {
let x = Array::from_slice::<f32>(&[-6.0, -3.0, -1.0, 1.0, 3.0, 6.0], &(6usize,)).unwrap();
let mut exact = gelu(&x).unwrap();
let mut fast = gelu_fast_approx(&x).unwrap();
let e = exact.to_vec::<f32>().unwrap();
let f = fast.to_vec::<f32>().unwrap();
for (ev, fv) in e.iter().zip(f.iter()) {
assert!(
(ev - fv).abs() < 1.5e-2,
"gelu_fast_approx strayed beyond the documented 1.5e-2 bound: exact={ev}, fast={fv}"
);
}
}
#[test]
fn activations_preserve_shape() {
let x = Array::from_slice::<f32>(
&(0..24).map(|i| i as f32 * 0.1).collect::<Vec<_>>(),
&(2, 3, 4),
)
.unwrap();
assert_eq!(silu(&x).unwrap().shape(), vec![2, 3, 4]);
assert_eq!(gelu(&x).unwrap().shape(), vec![2, 3, 4]);
assert_eq!(gelu_approx(&x).unwrap().shape(), vec![2, 3, 4]);
assert_eq!(gelu_fast_approx(&x).unwrap().shape(), vec![2, 3, 4]);
assert_eq!(swiglu(&x, &x).unwrap().shape(), vec![2, 3, 4]);
}
#[test]
fn gelu_variants_preserve_input_dtype() {
for dtype in [Dtype::F16, Dtype::BF16, Dtype::F32] {
let x = sample_input().astype(dtype).unwrap();
assert_eq!(
gelu(&x).unwrap().dtype().unwrap(),
dtype,
"gelu must preserve {dtype:?}"
);
assert_eq!(
gelu_approx(&x).unwrap().dtype().unwrap(),
dtype,
"gelu_approx must preserve {dtype:?}"
);
assert_eq!(
gelu_fast_approx(&x).unwrap().dtype().unwrap(),
dtype,
"gelu_fast_approx must preserve {dtype:?}"
);
}
}
#[test]
fn activations_on_rank0_scalar_stay_rank0() {
let x = Array::full::<f32>(&[0i32; 0], 0.7).unwrap();
assert_eq!(x.ndim(), 0, "rank-0 input precondition");
assert_eq!(gelu(&x).unwrap().ndim(), 0, "gelu must keep rank 0");
assert_eq!(
gelu_approx(&x).unwrap().ndim(),
0,
"gelu_approx must keep rank 0"
);
assert_eq!(
gelu_fast_approx(&x).unwrap().ndim(),
0,
"gelu_fast_approx must keep rank 0"
);
assert_eq!(silu(&x).unwrap().ndim(), 0, "silu must keep rank 0");
assert_eq!(swiglu(&x, &x).unwrap().ndim(), 0, "swiglu must keep rank 0");
}
fn libm_erf(x: f64) -> f64 {
let sign = if x < 0.0 { -1.0 } else { 1.0 };
let x = x.abs();
let t = 1.0 / (1.0 + 0.3275911 * x);
let y = 1.0
- (((((1.061405429 * t - 1.453152027) * t) + 1.421413741) * t - 0.284496736) * t
+ 0.254829592)
* t
* (-x * x).exp();
sign * y
}
}