#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum Activation {
Linear,
Threshold,
Sign,
Sigmoid,
Tanh,
SoftSign,
BentIdentity,
Relu,
}
impl Activation {
pub fn get_func(&self) -> fn(f64) -> f64 {
match self {
Activation::Linear => linear,
Activation::Threshold => threshold,
Activation::Sign => sign,
Activation::Sigmoid => sigmoid,
Activation::Tanh => tanh,
Activation::SoftSign => soft_sign,
Activation::BentIdentity => bent_identity,
Activation::Relu => relu,
}
}
pub fn from_i32(n: i32) -> Activation {
match n {
0 => Activation::Linear,
1 => Activation::Threshold,
2 => Activation::Sign,
3 => Activation::Sigmoid,
4 => Activation::Tanh,
5 => Activation::SoftSign,
6 => Activation::BentIdentity,
_ => Activation::Relu,
}
}
}
pub fn linear(x: f64) -> f64 {
x
}
pub fn threshold(x: f64) -> f64 {
if x > 0.0 {
1.0
} else {
0.0
}
}
pub fn sign(x: f64) -> f64 {
if x > 0.0 {
1.0
} else if x == 0.0 {
0.0
} else {
-1.0
}
}
pub fn sigmoid(x: f64) -> f64 {
1.0 / (1.0 + (-x).exp())
}
pub fn tanh(x: f64) -> f64 {
x.tanh()
}
pub fn soft_sign(x: f64) -> f64 {
x / (1.0 + x.abs())
}
pub fn bent_identity(x: f64) -> f64 {
(((x.powi(2) + 1.0).sqrt() - 1.0) / 2.0) + x
}
pub fn relu(x: f64) -> f64 {
return if x > 0.0 { x } else { 0.0 };
}