native_neural_network 0.3.1

Lib no_std Rust for native neural network (.rnn)
Documentation
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ActivationKind {
    Identity,
    Relu,
    Sigmoid,
    Tanh,
}

impl ActivationKind {
    pub fn to_u8(self) -> u8 {
        match self {
            ActivationKind::Identity => 0,
            ActivationKind::Relu => 1,
            ActivationKind::Sigmoid => 2,
            ActivationKind::Tanh => 3,
        }
    }

    pub fn from_u8(v: u8) -> Option<Self> {
        match v {
            0 => Some(ActivationKind::Identity),
            1 => Some(ActivationKind::Relu),
            2 => Some(ActivationKind::Sigmoid),
            3 => Some(ActivationKind::Tanh),
            _ => None,
        }
    }

    pub fn apply(self, x: f32) -> f32 {
        match self {
            ActivationKind::Identity => x,
            ActivationKind::Relu => {
                if x > 0.0 {
                    x
                } else {
                    0.0
                }
            }
            ActivationKind::Sigmoid => {
                if x >= 0.0 {
                    let z = crate::math::expf(-x);
                    1.0 / (1.0 + z)
                } else {
                    let z = crate::math::expf(x);
                    z / (1.0 + z)
                }
            }
            ActivationKind::Tanh => crate::math::tanhf(x),
        }
    }

    pub fn apply_f64(self, x: f64) -> f64 {
        match self {
            ActivationKind::Identity => x,
            ActivationKind::Relu => {
                if x > 0.0 {
                    x
                } else {
                    0.0
                }
            }
            ActivationKind::Sigmoid => {
                if x >= 0.0 {
                    let z = crate::math::expd(-x);
                    1.0 / (1.0 + z)
                } else {
                    let z = crate::math::expd(x);
                    z / (1.0 + z)
                }
            }
            ActivationKind::Tanh => crate::math::tanhd(x),
        }
    }

    pub fn apply_in_place_f64(self, values: &mut [f64]) {
        for value in values {
            *value = self.apply_f64(*value);
        }
    }

    pub fn derivative_from_output(self, output: f32) -> f32 {
        match self {
            ActivationKind::Identity => 1.0,
            ActivationKind::Relu => {
                if output > 0.0 {
                    1.0
                } else {
                    0.0
                }
            }
            ActivationKind::Sigmoid => output * (1.0 - output),
            ActivationKind::Tanh => 1.0 - output * output,
        }
    }

    pub fn derivative_from_output_f64(self, output: f64) -> f64 {
        match self {
            ActivationKind::Identity => 1.0,
            ActivationKind::Relu => {
                if output > 0.0 {
                    1.0
                } else {
                    0.0
                }
            }
            ActivationKind::Sigmoid => output * (1.0 - output),
            ActivationKind::Tanh => 1.0 - output * output,
        }
    }

    pub fn apply_in_place(self, values: &mut [f32]) {
        for value in values {
            *value = self.apply(*value);
        }
    }
}