#[inline]
pub fn sigmoid(x: f64) -> f64 {
if x > 500.0 {
return 1.0;
}
if x < -500.0 {
return 0.0;
}
if x >= 0.0 {
1.0 / (1.0 + (-x).exp())
} else {
let exp_x = x.exp();
exp_x / (1.0 + exp_x)
}
}
#[inline]
pub fn sigmoid_derivative(x: f64) -> f64 {
let s = sigmoid(x);
s * (1.0 - s)
}
#[inline]
pub fn log_sigmoid(x: f64) -> f64 {
if x >= 0.0 {
-softplus(-x)
} else {
x - softplus(x)
}
}
#[inline]
pub fn softplus(x: f64) -> f64 {
if x > 20.0 {
x } else if x < -20.0 {
0.0 } else {
(1.0 + x.exp()).ln()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sigmoid_center() {
assert!((sigmoid(0.0) - 0.5).abs() < 1e-10);
}
#[test]
fn test_sigmoid_extremes() {
assert!(sigmoid(100.0) > 0.999999);
assert!(sigmoid(-100.0) < 0.000001);
assert_eq!(sigmoid(600.0), 1.0);
assert_eq!(sigmoid(-600.0), 0.0);
}
#[test]
fn test_sigmoid_symmetry() {
for x in [-5.0, -1.0, 0.5, 2.0, 10.0] {
let sym_diff = (sigmoid(x) + sigmoid(-x) - 1.0).abs();
assert!(sym_diff < 1e-10, "symmetry failed for x={}", x);
}
}
#[test]
fn test_sigmoid_derivative() {
assert!((sigmoid_derivative(0.0) - 0.25).abs() < 1e-10);
assert!(sigmoid_derivative(1.0) < sigmoid_derivative(0.0));
assert!(sigmoid_derivative(-1.0) < sigmoid_derivative(0.0));
}
#[test]
fn test_log_sigmoid() {
let expected = -std::f64::consts::LN_2;
assert!((log_sigmoid(0.0) - expected).abs() < 1e-10);
for x in [-10.0, -1.0, 0.0, 1.0, 10.0] {
assert!(log_sigmoid(x) < 0.0);
}
}
#[test]
fn test_softplus() {
assert!((softplus(0.0) - std::f64::consts::LN_2).abs() < 1e-10);
for x in [-10.0, -1.0, 0.0, 1.0, 10.0] {
assert!(softplus(x) > 0.0);
}
}
}