#[inline]
pub fn sigmoid(x: f32) -> f32 {
if x >= 0.0 {
let exp_neg_x = (-x).exp();
1.0 / (1.0 + exp_neg_x)
} else {
let exp_x = x.exp();
exp_x / (1.0 + exp_x)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sigmoid_at_zero() {
assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
}
#[test]
fn test_sigmoid_bounded() {
assert!(sigmoid(-10.0) > 0.0);
assert!(sigmoid(-10.0) < 0.001);
assert!(sigmoid(10.0) > 0.999);
}
#[test]
fn test_sigmoid_symmetry() {
let x = 2.5;
assert!((sigmoid(-x) - (1.0 - sigmoid(x))).abs() < 1e-6);
}
#[test]
fn test_sigmoid_numerical_stability() {
let extreme_values = [-1000.0, -100.0, 100.0, 1000.0];
for x in extreme_values {
let s = sigmoid(x);
assert!(s.is_finite(), "sigmoid({}) = {} is not finite", x, s);
assert!(s >= 0.0 && s <= 1.0, "sigmoid({}) = {} out of [0,1]", x, s);
}
}
}