use scirs2_core::ndarray::{Array1, Array2, Array3};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
use crate::error::{Result, TimeSeriesError};
#[derive(Debug, Clone)]
pub enum ActivationFunction {
Sigmoid,
Tanh,
ReLU,
GELU,
Swish,
Linear,
}
impl ActivationFunction {
pub fn apply<F: Float>(&self, x: F) -> F {
match self {
ActivationFunction::Sigmoid => {
let one = F::one();
one / (one + (-x).exp())
}
ActivationFunction::Tanh => x.tanh(),
ActivationFunction::ReLU => x.max(F::zero()),
ActivationFunction::GELU => {
let half = F::from(0.5).expect("Failed to convert constant to float");
let one = F::one();
let sqrt_2_pi = F::from(0.7978845608).expect("Failed to convert constant to float"); let coeff = F::from(0.044715).expect("Failed to convert constant to float");
half * x * (one + (sqrt_2_pi * (x + coeff * x * x * x)).tanh())
}
ActivationFunction::Swish => {
let sigmoid = F::one() / (F::one() + (-x).exp());
x * sigmoid
}
ActivationFunction::Linear => x,
}
}
pub fn derivative<F: Float>(&self, x: F) -> F {
match self {
ActivationFunction::Sigmoid => {
let sigmoid = self.apply(x);
sigmoid * (F::one() - sigmoid)
}
ActivationFunction::Tanh => {
let tanh_x = x.tanh();
F::one() - tanh_x * tanh_x
}
ActivationFunction::ReLU => {
if x > F::zero() {
F::one()
} else {
F::zero()
}
}
ActivationFunction::GELU => {
F::one() / (F::one() + (-x).exp())
}
ActivationFunction::Swish => {
let sigmoid = F::one() / (F::one() + (-x).exp());
sigmoid * (F::one() + x * (F::one() - sigmoid))
}
ActivationFunction::Linear => F::one(),
}
}
}