use super::*;
#[derive(Debug, Clone, Copy, Default, Hash, PartialEq, Eq, PartialOrd, Ord, EnumIter)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum ActivationFunction {
#[default]
Sigmoid,
SigmoidDerivative,
Relu,
Tanh,
TanhDerivative,
BinaryStep,
Identity,
}
impl ActivationFunction {
pub fn call(&self, x: f64) -> f64 {
match self {
ActivationFunction::Sigmoid => sigmoid(x),
ActivationFunction::SigmoidDerivative => sigmoid_derivative(x),
ActivationFunction::Relu => relu(x),
ActivationFunction::Tanh => tanh(x),
ActivationFunction::TanhDerivative => tanh_derivative(x),
ActivationFunction::BinaryStep => binary_step(x),
ActivationFunction::Identity => x,
}
}
pub fn derivative(&self, x: f64) -> f64 {
match self {
ActivationFunction::Sigmoid => sigmoid_derivative(x),
ActivationFunction::SigmoidDerivative => sigmoid_derivative(x),
ActivationFunction::Relu => 1.0,
ActivationFunction::Tanh => tanh_derivative(x),
ActivationFunction::TanhDerivative => tanh_derivative(x),
ActivationFunction::BinaryStep => 0.0,
ActivationFunction::Identity => 1.0,
}
}
}
pub fn sigmoid(x: f64) -> f64 {
1.0 / (1.0 + (-x).exp())
}
pub fn sigmoid_derivative(x: f64) -> f64 {
let s = sigmoid(x);
s * (1.0 - s)
}
pub fn relu(x: f64) -> f64 {
if x > 0.0 {
x
} else {
0.0
}
}
pub fn tanh(x: f64) -> f64 {
x.tanh()
}
pub fn tanh_derivative(x: f64) -> f64 {
let t = tanh(x);
1.0 - t.powi(2)
}
pub fn binary_step(x: f64) -> f64 {
if x > 0.0 {
1.0
} else {
0.0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sigmoid() {
assert_eq!(sigmoid(0.0).round(), (0.5f64).round());
assert_eq!(sigmoid(1.0).round(), (0.7310585786300049f64).round());
assert_eq!(sigmoid(-1.0).round(), (0.2689414213699951f64).round());
}
#[test]
fn test_sigmoid_derivative() {
assert_eq!(sigmoid_derivative(0.0).round(), 0.25f64.round());
assert_eq!(
sigmoid_derivative(1.0).round(),
0.19661193324148185f64.round()
);
assert_eq!(
sigmoid_derivative(-1.0).round(),
0.19661193324148185f64.round()
);
}
#[test]
fn test_relu() {
assert_eq!(relu(0.0).round(), 0f64.round());
assert_eq!(relu(1.0).round(), 1f64.round());
assert_eq!(relu(-1.0).round(), 0f64.round());
}
#[test]
fn test_tanh() {
assert_eq!(tanh(0.0).round(), 0.0f64.round());
assert_eq!(tanh(1.0).round(), 0.7615941559557649f64.round());
assert_eq!(tanh(-1.0).round(), -0.7615941559557649f64.round());
}
#[test]
fn test_tanh_derivative() {
assert_eq!(tanh_derivative(0.0).round(), 1.0f64.round());
assert_eq!(tanh_derivative(1.0).round(), 0.41997434161402614f64.round());
assert_eq!(
tanh_derivative(-1.0).round(),
0.41997434161402614f64.round()
);
}
#[test]
fn test_binary_step() {
assert_eq!(binary_step(0.0).round(), 0f64.round());
assert_eq!(binary_step(1.0).round(), 1f64.round());
assert_eq!(binary_step(-1.0).round(), 0f64.round());
}
}