astrai 2.2.0

A pretty bad neural network library
Documentation
use super::*;

#[derive(Debug, Clone, Copy, Default, Hash, PartialEq, Eq, PartialOrd, Ord, EnumIter)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum ActivationFunction {
    #[default]
    Sigmoid,
    SigmoidDerivative,
    Relu,
    Tanh,
    TanhDerivative,
    BinaryStep,
    Identity,
}

impl ActivationFunction {
    pub fn call(&self, x: f64) -> f64 {
        match self {
            ActivationFunction::Sigmoid => sigmoid(x),
            ActivationFunction::SigmoidDerivative => sigmoid_derivative(x),
            ActivationFunction::Relu => relu(x),
            ActivationFunction::Tanh => tanh(x),
            ActivationFunction::TanhDerivative => tanh_derivative(x),
            ActivationFunction::BinaryStep => binary_step(x),
            ActivationFunction::Identity => x,
        }
    }

    pub fn derivative(&self, x: f64) -> f64 {
        match self {
            ActivationFunction::Sigmoid => sigmoid_derivative(x),
            ActivationFunction::SigmoidDerivative => sigmoid_derivative(x),
            ActivationFunction::Relu => 1.0,
            ActivationFunction::Tanh => tanh_derivative(x),
            ActivationFunction::TanhDerivative => tanh_derivative(x),
            ActivationFunction::BinaryStep => 0.0,
            ActivationFunction::Identity => 1.0,
        }
    }
}

pub fn sigmoid(x: f64) -> f64 {
    1.0 / (1.0 + (-x).exp())
}

pub fn sigmoid_derivative(x: f64) -> f64 {
    let s = sigmoid(x);
    s * (1.0 - s)
}

pub fn relu(x: f64) -> f64 {
    if x > 0.0 {
        x
    } else {
        0.0
    }
}

pub fn tanh(x: f64) -> f64 {
    x.tanh()
}

pub fn tanh_derivative(x: f64) -> f64 {
    let t = tanh(x);
    1.0 - t.powi(2)
}

pub fn binary_step(x: f64) -> f64 {
    if x > 0.0 {
        1.0
    } else {
        0.0
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_sigmoid() {
        assert_eq!(sigmoid(0.0).round(), (0.5f64).round());
        assert_eq!(sigmoid(1.0).round(), (0.7310585786300049f64).round());
        assert_eq!(sigmoid(-1.0).round(), (0.2689414213699951f64).round());
    }
    #[test]
    fn test_sigmoid_derivative() {
        assert_eq!(sigmoid_derivative(0.0).round(), 0.25f64.round());
        assert_eq!(
            sigmoid_derivative(1.0).round(),
            0.19661193324148185f64.round()
        );
        assert_eq!(
            sigmoid_derivative(-1.0).round(),
            0.19661193324148185f64.round()
        );
    }

    #[test]
    fn test_relu() {
        assert_eq!(relu(0.0).round(), 0f64.round());
        assert_eq!(relu(1.0).round(), 1f64.round());
        assert_eq!(relu(-1.0).round(), 0f64.round());
    }

    #[test]
    fn test_tanh() {
        assert_eq!(tanh(0.0).round(), 0.0f64.round());
        assert_eq!(tanh(1.0).round(), 0.7615941559557649f64.round());
        assert_eq!(tanh(-1.0).round(), -0.7615941559557649f64.round());
    }

    #[test]
    fn test_tanh_derivative() {
        assert_eq!(tanh_derivative(0.0).round(), 1.0f64.round());
        assert_eq!(tanh_derivative(1.0).round(), 0.41997434161402614f64.round());
        assert_eq!(
            tanh_derivative(-1.0).round(),
            0.41997434161402614f64.round()
        );
    }

    #[test]
    fn test_binary_step() {
        assert_eq!(binary_step(0.0).round(), 0f64.round());
        assert_eq!(binary_step(1.0).round(), 1f64.round());
        assert_eq!(binary_step(-1.0).round(), 0f64.round());
    }
}