use num_traits::Float;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde", serde(rename_all = "lowercase"))]
pub enum Activation {
Linear,
UnitStep,
Sign,
Sigmoid,
Tanh,
SoftSign,
BentIdentity,
Relu,
}
impl Activation {
pub fn apply<T: Float>(&self, x: T) -> T {
match self {
Activation::Linear => Self::linear(x),
Activation::UnitStep => Self::unit_step(x),
Activation::Sign => Self::sign(x),
Activation::Sigmoid => Self::sigmoid(x),
Activation::Tanh => Self::tanh(x),
Activation::SoftSign => Self::soft_sign(x),
Activation::BentIdentity => Self::bent_identity(x),
Activation::Relu => Self::relu(x),
}
}
pub fn get_function<T: Float>(&self) -> fn(T) -> T {
match self {
Activation::Linear => Self::linear,
Activation::UnitStep => Self::unit_step,
Activation::Sign => Self::sign,
Activation::Sigmoid => Self::sigmoid,
Activation::Tanh => Self::tanh,
Activation::SoftSign => Self::soft_sign,
Activation::BentIdentity => Self::bent_identity,
Activation::Relu => Self::relu,
}
}
pub fn linear<T: Float>(x: T) -> T {
x
}
pub fn unit_step<T: Float>(x: T) -> T {
if x > T::zero() {
T::one()
} else {
T::zero()
}
}
pub fn sign<T: Float>(x: T) -> T {
if x > T::zero() {
T::one()
} else if x == T::zero() {
T::zero()
} else {
-T::one()
}
}
pub fn sigmoid<T: Float>(x: T) -> T {
T::one() / (T::one() + (-x).exp())
}
pub fn tanh<T: Float>(x: T) -> T {
x.tanh()
}
pub fn soft_sign<T: Float>(x: T) -> T {
x / (T::one() + x.abs())
}
pub fn bent_identity<T: Float>(x: T) -> T {
(((x.powi(2) + T::one()).sqrt() - T::one()) / (T::one() + T::one())) + x
}
pub fn relu<T: Float>(x: T) -> T {
x.max(T::zero())
}
}
#[cfg(test)]
mod tests {
use assert_approx_eq::assert_approx_eq;
use super::*;
#[test]
fn test_activation() {
assert_approx_eq!(5.0, Activation::Linear.apply(5.0));
assert_approx_eq!(0.0, Activation::UnitStep.apply(-5.0));
assert_approx_eq!(-1.0, Activation::Sign.apply(-5.0));
assert_approx_eq!(0.8807970779778823, Activation::Sigmoid.apply(2.0));
assert_approx_eq!(0.9640275800758169, Activation::Tanh.apply(2.0));
assert_approx_eq!(0.8333333333333334, Activation::SoftSign.apply(5.0));
assert_approx_eq!(7.049509756796392, Activation::BentIdentity.apply(5.0));
assert_approx_eq!(5.0, Activation::Relu.apply(5.0));
assert_approx_eq!(0.0, Activation::Relu.apply(-5.0));
}
}