use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum Activation {
Tanh,
Relu,
Sigmoid,
Elu,
Softsign,
Linear,
}
impl Activation {
pub fn apply(&self, x: f64) -> f64 {
match self {
Activation::Tanh => x.tanh(),
Activation::Relu => x.max(0.0),
Activation::Sigmoid => 1.0 / (1.0 + (-x).exp()),
Activation::Elu => {
if x > 0.0 {
x
} else {
x.exp() - 1.0
}
}
Activation::Softsign => x / (1.0 + x.abs()),
Activation::Linear => x,
}
}
pub fn derivative(&self, fx: f64) -> f64 {
match self {
Activation::Tanh => 1.0 - fx * fx,
Activation::Relu => {
if fx > 0.0 {
1.0
} else {
0.0
}
}
Activation::Sigmoid => fx * (1.0 - fx),
Activation::Elu => {
if fx > 0.0 {
1.0
} else {
fx + 1.0
}
}
Activation::Softsign => {
let t = 1.0 - fx.abs();
t * t
}
Activation::Linear => 1.0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tanh_apply_zero() {
assert_eq!(Activation::Tanh.apply(0.0), 0.0);
}
#[test]
fn test_tanh_apply_known() {
let expected = 1.0_f64.tanh();
let result = Activation::Tanh.apply(1.0);
assert!((result - expected).abs() < 1e-12);
}
#[test]
fn test_tanh_apply_negative() {
let expected = (-2.0_f64).tanh();
let result = Activation::Tanh.apply(-2.0);
assert!((result - expected).abs() < 1e-12);
}
#[test]
fn test_relu_apply_negative_is_zero() {
assert_eq!(Activation::Relu.apply(-5.0), 0.0);
}
#[test]
fn test_relu_apply_zero_is_zero() {
assert_eq!(Activation::Relu.apply(0.0), 0.0);
}
#[test]
fn test_relu_apply_positive_is_identity() {
assert_eq!(Activation::Relu.apply(3.7), 3.7);
}
#[test]
fn test_sigmoid_apply_zero_is_half() {
assert!((Activation::Sigmoid.apply(0.0) - 0.5).abs() < 1e-12);
}
#[test]
fn test_sigmoid_apply_large_stays_below_one() {
let result = Activation::Sigmoid.apply(30.0);
assert!(result < 1.0);
assert!(result > 0.99);
}
#[test]
fn test_sigmoid_apply_very_negative_stays_above_zero() {
let result = Activation::Sigmoid.apply(-100.0);
assert!(result > 0.0);
}
#[test]
fn test_elu_apply_positive_is_identity() {
assert_eq!(Activation::Elu.apply(3.0), 3.0);
}
#[test]
fn test_elu_apply_zero_is_zero() {
assert!((Activation::Elu.apply(0.0)).abs() < 1e-12);
}
#[test]
fn test_elu_apply_negative_is_exp_minus_one() {
let expected = (-1.0_f64).exp() - 1.0;
let result = Activation::Elu.apply(-1.0);
assert!((result - expected).abs() < 1e-12);
}
#[test]
fn test_elu_apply_large_negative_approaches_minus_one() {
let result = Activation::Elu.apply(-100.0);
assert!((result - (-1.0)).abs() < 1e-10);
}
#[test]
fn test_softsign_apply_positive() {
let result = Activation::Softsign.apply(2.0);
assert!((result - 2.0 / 3.0).abs() < 1e-12);
}
#[test]
fn test_softsign_apply_zero() {
assert!((Activation::Softsign.apply(0.0)).abs() < 1e-12);
}
#[test]
fn test_softsign_apply_negative() {
let result = Activation::Softsign.apply(-3.0);
assert!((result - (-0.75)).abs() < 1e-12);
}
#[test]
fn test_softsign_apply_bounded() {
assert!(Activation::Softsign.apply(100.0) < 1.0);
assert!(Activation::Softsign.apply(-100.0) > -1.0);
}
#[test]
fn test_linear_apply_is_identity() {
assert_eq!(Activation::Linear.apply(42.0), 42.0);
}
#[test]
fn test_tanh_derivative_formula() {
let result = Activation::Tanh.derivative(0.5);
assert!((result - 0.75).abs() < 1e-12);
}
#[test]
fn test_tanh_derivative_at_zero_is_one() {
assert!((Activation::Tanh.derivative(0.0) - 1.0).abs() < 1e-12);
}
#[test]
fn test_relu_derivative_zero_output_is_zero() {
assert_eq!(Activation::Relu.derivative(0.0), 0.0);
}
#[test]
fn test_relu_derivative_positive_output_is_one() {
assert_eq!(Activation::Relu.derivative(2.0), 1.0);
}
#[test]
fn test_sigmoid_derivative_formula() {
let result = Activation::Sigmoid.derivative(0.7);
assert!((result - 0.21).abs() < 1e-12);
}
#[test]
fn test_sigmoid_derivative_at_half() {
assert!((Activation::Sigmoid.derivative(0.5) - 0.25).abs() < 1e-12);
}
#[test]
fn test_elu_derivative_positive_is_one() {
assert_eq!(Activation::Elu.derivative(2.0), 1.0);
}
#[test]
fn test_elu_derivative_negative_is_fx_plus_one() {
let result = Activation::Elu.derivative(-0.6);
assert!((result - 0.4).abs() < 1e-12);
}
#[test]
fn test_elu_derivative_at_minus_one_is_zero() {
assert!((Activation::Elu.derivative(-1.0)).abs() < 1e-12);
}
#[test]
fn test_softsign_derivative_at_zero() {
assert!((Activation::Softsign.derivative(0.0) - 1.0).abs() < 1e-12);
}
#[test]
fn test_softsign_derivative_positive() {
let result = Activation::Softsign.derivative(0.5);
assert!((result - 0.25).abs() < 1e-12);
}
#[test]
fn test_softsign_derivative_negative() {
let result = Activation::Softsign.derivative(-0.5);
assert!((result - 0.25).abs() < 1e-12);
}
#[test]
fn test_softsign_derivative_high_saturation() {
let result = Activation::Softsign.derivative(0.9);
assert!((result - 0.01).abs() < 1e-12);
}
#[test]
fn test_softsign_derivative_always_positive() {
for &fx in &[-0.9, -0.5, 0.0, 0.5, 0.9] {
assert!(Activation::Softsign.derivative(fx) > 0.0);
}
}
#[test]
fn test_linear_derivative_always_one() {
assert_eq!(Activation::Linear.derivative(999.0), 1.0);
assert_eq!(Activation::Linear.derivative(-42.0), 1.0);
assert_eq!(Activation::Linear.derivative(0.0), 1.0);
}
#[test]
fn test_all_activations_produce_finite_output_for_extreme_inputs() {
let variants = [
Activation::Tanh,
Activation::Relu,
Activation::Sigmoid,
Activation::Elu,
Activation::Softsign,
Activation::Linear,
];
for act in &variants {
for &x in &[-100.0, 100.0] {
let y = act.apply(x);
assert!(y.is_finite(), "{:?}.apply({}) was not finite", act, x);
}
}
}
#[test]
fn test_all_derivatives_finite_for_typical_post_activation_values() {
let cases: [(Activation, f64); 6] = [
(Activation::Tanh, 0.5),
(Activation::Relu, 1.0),
(Activation::Sigmoid, 0.5),
(Activation::Elu, -0.5),
(Activation::Softsign, 0.5),
(Activation::Linear, 0.0),
];
for (act, fx) in &cases {
let d = act.derivative(*fx);
assert!(d.is_finite(), "{:?}.derivative({}) was not finite", act, fx);
}
}
#[test]
fn test_serde_roundtrip_all_variants() {
let variants = [
Activation::Tanh,
Activation::Relu,
Activation::Sigmoid,
Activation::Elu,
Activation::Softsign,
Activation::Linear,
];
for act in &variants {
let json = serde_json::to_string(act).unwrap();
let back: Activation = serde_json::from_str(&json).unwrap();
assert_eq!(*act, back);
}
}
#[test]
fn test_serde_unknown_variant_returns_error() {
let result = serde_json::from_str::<Activation>("\"Softmax\"");
assert!(result.is_err());
}
}