astrai 2.2.0

A pretty bad neural network library
Documentation
use super::*;

#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Layer {
    pub neuron_amt: usize,
    pub z: Array1<Option<f64>>,
    pub delta: Array1<f64>,
    pub activation: Array1<Option<f64>>,
    pub activation_list: Vec<Vec<f64>>,
    pub bias: Array1<f64>,
    pub activation_function: ActivationFunction,
}

impl PartialEq for Layer {
    fn eq(&self, other: &Self) -> bool {
        self.neuron_amt == other.neuron_amt
            && self.z == other.z
            && self.delta == other.delta
            && self.activation == other.activation
            && self.activation_list == other.activation_list
            && self.bias == other.bias
            && self.activation_function == other.activation_function
    }
}

impl Eq for Layer {}

#[profiling::all_functions]
impl Layer {
    pub fn new(neuron_amt: usize, activation_function: ActivationFunction) -> Layer {
        Layer {
            neuron_amt,
            z: Array1::from_elem(neuron_amt, None),
            delta: Array1::from_elem(neuron_amt, 0.0),
            activation: Array1::from_elem(neuron_amt, None),
            activation_list: vec![vec![]; neuron_amt],
            bias: Array1::from_elem(neuron_amt, 0.0),
            activation_function,
        }
    }

    pub fn activation_matrix(&self) -> Array2<f64> {
        let mut matrix = Array2::<f64>::zeros((self.neuron_amt, self.neuron_amt));
        for (i, value) in self.activation.iter().enumerate() {
            matrix[[i, i]] = value.unwrap();
        }
        matrix
    }

    pub fn apply_z_array(&mut self, z: Array1<f64>) {
        self.z = z.mapv(Some);
    }

    pub fn apply_delta_array(&mut self, delta: Array1<f64>) {
        self.delta = delta;
    }

    pub fn apply_activation_array(&mut self, activation: Array1<f64>) {
        self.activation = activation.mapv(Some);
    }

    pub fn zero_out(&mut self) {
        self.z.mapv::<Option<f64>, _>(|_| None);
        self.delta.mapv(|_| 0.0);
        self.activation.mapv::<Option<f64>, _>(|_| None);
    }
}