pub trait Activator {
fn activate(x: f64) -> f64;
fn derivative(x: f64) -> f64;
}
pub struct Linear;
impl Activator for Linear {
fn activate(x: f64) -> f64 {
x
}
fn derivative(_: f64) -> f64 {
1.0
}
}
pub struct Sigmoid;
impl Activator for Sigmoid {
fn activate(x: f64) -> f64 {
1.0 / (1.0 + libm::exp(-x))
}
fn derivative(x: f64) -> f64 {
let s = Sigmoid::activate(x);
s * (1.0 - s)
}
}
pub struct Tanh;
impl Activator for Tanh {
fn activate(x: f64) -> f64 {
libm::tanh(x)
}
fn derivative(x: f64) -> f64 {
1.0 - libm::pow(libm::tanh(x), 2.0)
}
}
pub struct Swish;
impl Activator for Swish {
fn activate(x: f64) -> f64 {
x * (1.0 / (1.0 + libm::exp(-x)))
}
fn derivative(x: f64) -> f64 {
let sigmoid_x = 1.0 / (1.0 + libm::exp(-x));
sigmoid_x + x * sigmoid_x * (1.0 - sigmoid_x)
}
}
pub struct ReLU;
impl Activator for ReLU {
fn activate(x: f64) -> f64 {
x.max(0.0)
}
fn derivative(x: f64) -> f64 {
if x > 0.0 {
1.0
} else {
0.0
}
}
}
pub struct LeakyReLU;
impl Activator for LeakyReLU {
fn activate(x: f64) -> f64 {
if x > 0.0 {
x
} else {
0.01 * x
}
}
fn derivative(x: f64) -> f64 {
if x > 0.0 {
1.0
} else {
0.01
}
}
}
pub struct ELU;
impl Activator for ELU {
fn activate(x: f64) -> f64 {
if x > 0.0 {
x
} else {
libm::exp(x) - 1.0
}
}
fn derivative(x: f64) -> f64 {
if x > 0.0 {
1.0
} else {
libm::exp(x)
}
}
}
pub struct Softplus;
impl Activator for Softplus {
fn activate(x: f64) -> f64 {
libm::log(1.0 + libm::exp(x))
}
fn derivative(x: f64) -> f64 {
1.0 / (1.0 + libm::exp(-x))
}
}
pub struct HardSigmoid;
impl Activator for HardSigmoid {
fn activate(x: f64) -> f64 {
(0.2 * x + 0.5).clamp(0.0, 1.0)
}
fn derivative(x: f64) -> f64 {
if (-2.5..=2.5).contains(&x) {
0.2
} else {
0.0
}
}
}
#[cfg(test)]
mod edge_case_tests {
use super::*;
#[test]
fn test_linear_edge_cases() {
assert_eq!(Linear::activate(-1000.0), -1000.0);
assert_eq!(Linear::activate(1000.0), 1000.0);
assert_eq!(Linear::derivative(0.0), 1.0);
assert_eq!(Linear::derivative(-1000.0), 1.0);
assert_eq!(Linear::derivative(1000.0), 1.0);
}
#[test]
fn test_sigmoid_edge_cases() {
assert!((Sigmoid::activate(100.0) - 1.0).abs() < 1e-6);
assert!((Sigmoid::activate(-100.0) - 0.0).abs() < 1e-6);
assert!((Sigmoid::activate(0.0) - 0.5).abs() < 1e-6);
assert!((Sigmoid::derivative(0.0) - 0.25).abs() < 1e-6);
}
#[test]
fn test_tanh_edge_cases() {
assert!((Tanh::activate(100.0) - 1.0).abs() < 1e-6);
assert!((Tanh::activate(-100.0) - (-1.0)).abs() < 1e-6);
assert!((Tanh::activate(0.0) - 0.0).abs() < 1e-6);
assert!((Tanh::derivative(0.0) - 1.0).abs() < 1e-6);
}
#[test]
fn test_relu_edge_cases() {
assert_eq!(ReLU::activate(0.0), 0.0);
assert_eq!(ReLU::activate(-0.0), 0.0);
assert_eq!(ReLU::activate(-1.0), 0.0);
assert_eq!(ReLU::activate(1.0), 1.0);
assert_eq!(ReLU::derivative(0.0), 0.0);
assert_eq!(ReLU::derivative(-1.0), 0.0);
assert_eq!(ReLU::derivative(1.0), 1.0);
}
#[test]
fn test_leaky_relu_edge_cases() {
assert_eq!(LeakyReLU::activate(-1.0), -0.01);
assert_eq!(LeakyReLU::activate(-100.0), -1.0);
assert_eq!(LeakyReLU::activate(0.0), 0.0);
assert_eq!(LeakyReLU::activate(1.0), 1.0);
assert_eq!(LeakyReLU::derivative(-1.0), 0.01);
assert_eq!(LeakyReLU::derivative(1.0), 1.0);
}
#[test]
fn test_swish_edge_cases() {
assert!((Swish::activate(-10.0) - (-10.0 * Sigmoid::activate(-10.0))).abs() < 1e-6);
assert!((Swish::activate(0.0) - 0.0).abs() < 1e-6);
assert!((Swish::activate(10.0) - (10.0 * Sigmoid::activate(10.0))).abs() < 1e-6);
assert!((Swish::derivative(0.0) - 0.5).abs() < 1e-6);
}
#[test]
fn test_elu_edge_cases() {
assert!((ELU::activate(-1e-5) - (libm::exp(-1e-5) - 1.0)).abs() < 1e-6);
assert!((ELU::activate(-10.0) - (libm::exp(-10.0) - 1.0)).abs() < 1e-6);
assert_eq!(ELU::activate(0.0), 0.0);
assert_eq!(ELU::activate(1.0), 1.0);
assert!((ELU::derivative(-1e-5) - libm::exp(-1e-5)).abs() < 1e-6);
assert_eq!(ELU::derivative(0.0), 1.0);
}
#[test]
fn test_softplus_edge_cases() {
assert!((Softplus::activate(-100.0) - 0.0).abs() < 1e-6);
assert!((Softplus::activate(100.0) - 100.0).abs() < 1e-6);
assert!((Softplus::activate(0.0) - libm::log(2.0)).abs() < 1e-6);
}
#[test]
fn test_hard_sigmoid_edge_cases() {
assert_eq!(HardSigmoid::activate(-100.0), 0.0);
assert_eq!(HardSigmoid::activate(100.0), 1.0);
assert_eq!(HardSigmoid::activate(-2.5), 0.0);
assert_eq!(HardSigmoid::activate(2.5), 1.0);
assert_eq!(HardSigmoid::activate(0.0), 0.5);
assert_eq!(HardSigmoid::activate(1.0), 0.7); assert_eq!(HardSigmoid::activate(-1.0), 0.3);
assert_eq!(HardSigmoid::derivative(0.0), 0.2);
assert_eq!(HardSigmoid::derivative(2.0), 0.2);
assert_eq!(HardSigmoid::derivative(-2.0), 0.2);
assert_eq!(HardSigmoid::derivative(-3.0), 0.0);
assert_eq!(HardSigmoid::derivative(3.0), 0.0);
}
}