use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Serialize)]
pub enum ActivationFn {
ReLU,
Sigmoid,
Tanh,
Linear,
Step(f64),
}
impl ActivationFn {
pub(crate) fn run(&self, x: f64) -> f64 {
match self {
ActivationFn::ReLU => x.max(0.0),
ActivationFn::Sigmoid => 1.0 / (1.0 + (-x).exp()),
ActivationFn::Tanh => x.tanh(),
ActivationFn::Linear => x,
ActivationFn::Step(threshold) => {
if x > *threshold {
1.0
} else {
0.0
}
}
}
}
}