use num_traits::Float;
pub fn sigmoid<T: Float>(x: T) -> T {
T::one() / (T::one() + (-x).exp())
}
pub fn silu<T: Float>(x: T) -> T {
x * sigmoid(x)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sigmoid_at_zero_is_one_half() {
assert!((sigmoid(0.0_f64) - 0.5).abs() < 1e-12);
}
#[test]
fn sigmoid_saturates_for_large_positive_input() {
assert!(sigmoid(50.0_f64) > 0.999999);
}
#[test]
fn sigmoid_saturates_for_large_negative_input() {
assert!(sigmoid(-50.0_f64) < 0.000001);
}
#[test]
fn silu_at_zero_is_zero() {
assert_eq!(silu(0.0_f64), 0.0);
}
#[test]
fn silu_matches_definition() {
let x = 2.0_f64;
let expected = x * (1.0 / (1.0 + (-x).exp()));
assert!((silu(x) - expected).abs() < 1e-12);
}
}