astrai 2.2.0

A pretty bad neural network library
Documentation
use super::*;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InitializationMethod {
    Xavier,
}

impl Network {
	pub fn new(
        input_neuron_amount: usize,
        output_neuron_amount: usize,
        other_layer_amounts: Option<Vec<usize>>,
        layer_activation_functions: Option<Vec<Option<ActivationFunction>>>,
    ) -> Network {
        let mut net = Network {
            layers: vec![],
            connections: vec![],
            shape: vec![]
        };

        net.layers.push(Layer::new(
            input_neuron_amount,
            ActivationFunction::Identity,
        ));
        net.shape.push(input_neuron_amount);

        let acts: Vec<ActivationFunction> = layer_activation_functions
            .unwrap_or_else(|| {
                vec![
                    Some(ActivationFunction::Sigmoid);
                    other_layer_amounts.clone().unwrap_or_default().len() + 1
                ]
            })
            .iter()
            .map(|act| act.unwrap_or(ActivationFunction::Sigmoid))
            .collect();

        if let Some(other_layer_amounts) = other_layer_amounts {
            for (idx, amount) in other_layer_amounts.iter().enumerate() {
                net.layers.push(Layer::new(*amount, acts[idx]));
                net.shape.push(*amount);
            }
        }

        net.layers
            .push(Layer::new(output_neuron_amount, *acts.last().unwrap()));
        net.shape.push(output_neuron_amount);
        net
    }

    pub fn all_connect(&mut self) {
        for (i, layer) in self.layers.clone().iter().enumerate() {
            if i == self.layers.len() - 1 {
                break;
            }
            for j in 0..layer.neuron_amt {
                for k in 0..self.layers[i + 1].clone().neuron_amt {
                    self.add_connection((i, j), (i + 1, k), 0.0);
                }
            }
        }
    }

	pub fn initialize_weights(&mut self, method: InitializationMethod) {
        let mut rng = rand::thread_rng();

        match method {
            InitializationMethod::Xavier => self.xavier_init(&mut rng),
        }
    }

    fn xavier_init(&mut self, rng: &mut impl Rng) {
        for (i, layer) in self.layers.clone().iter().enumerate() {
            let fan_in = self.shape[i] as f64;

            let fan_out = if i == self.shape.len() - 1 {
                // this reverses the positions of fan_in and fan_out, but since they're just summed we're fine
                self.shape[i - 1]
            } else {
                self.shape[i + 1]
            } as f64;

            let normal = Normal::new(0.0, (2.0 / (fan_in + fan_out)).sqrt()).unwrap();

            self.layers[i].bias = Array1::from_shape_fn(layer.neuron_amt, |_| normal.sample(rng));
        }

        for conn in self.connections.iter_mut() {
            let fan_in = self.shape[conn.in_neuron_id.0] as f64;
            let fan_out = self.shape[conn.out_neuron_id.0] as f64;
            let std_dev = (2.0 / (fan_in + fan_out)).sqrt();

            let normal = Normal::new(0.0, std_dev).unwrap();
            conn.weight = normal.sample(rng);
        }
    }
}