astrai/network/
initialization.rs

1use super::*;
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4pub enum InitializationMethod {
5    Xavier,
6}
7
8impl Network {
9	pub fn new(
10        input_neuron_amount: usize,
11        output_neuron_amount: usize,
12        other_layer_amounts: Option<Vec<usize>>,
13        layer_activation_functions: Option<Vec<Option<ActivationFunction>>>,
14    ) -> Network {
15        let mut net = Network {
16            layers: vec![],
17            connections: vec![],
18            shape: vec![]
19        };
20
21        net.layers.push(Layer::new(
22            input_neuron_amount,
23            ActivationFunction::Identity,
24        ));
25        net.shape.push(input_neuron_amount);
26
27        let acts: Vec<ActivationFunction> = layer_activation_functions
28            .unwrap_or_else(|| {
29                vec![
30                    Some(ActivationFunction::Sigmoid);
31                    other_layer_amounts.clone().unwrap_or_default().len() + 1
32                ]
33            })
34            .iter()
35            .map(|act| act.unwrap_or(ActivationFunction::Sigmoid))
36            .collect();
37
38        if let Some(other_layer_amounts) = other_layer_amounts {
39            for (idx, amount) in other_layer_amounts.iter().enumerate() {
40                net.layers.push(Layer::new(*amount, acts[idx]));
41                net.shape.push(*amount);
42            }
43        }
44
45        net.layers
46            .push(Layer::new(output_neuron_amount, *acts.last().unwrap()));
47        net.shape.push(output_neuron_amount);
48        net
49    }
50
51    pub fn all_connect(&mut self) {
52        for (i, layer) in self.layers.clone().iter().enumerate() {
53            if i == self.layers.len() - 1 {
54                break;
55            }
56            for j in 0..layer.neuron_amt {
57                for k in 0..self.layers[i + 1].clone().neuron_amt {
58                    self.add_connection((i, j), (i + 1, k), 0.0);
59                }
60            }
61        }
62    }
63
64	pub fn initialize_weights(&mut self, method: InitializationMethod) {
65        let mut rng = rand::thread_rng();
66
67        match method {
68            InitializationMethod::Xavier => self.xavier_init(&mut rng),
69        }
70    }
71
72    fn xavier_init(&mut self, rng: &mut impl Rng) {
73        for (i, layer) in self.layers.clone().iter().enumerate() {
74            let fan_in = self.shape[i] as f64;
75
76            let fan_out = if i == self.shape.len() - 1 {
77                // this reverses the positions of fan_in and fan_out, but since they're just summed we're fine
78                self.shape[i - 1]
79            } else {
80                self.shape[i + 1]
81            } as f64;
82
83            let normal = Normal::new(0.0, (2.0 / (fan_in + fan_out)).sqrt()).unwrap();
84
85            self.layers[i].bias = Array1::from_shape_fn(layer.neuron_amt, |_| normal.sample(rng));
86        }
87
88        for conn in self.connections.iter_mut() {
89            let fan_in = self.shape[conn.in_neuron_id.0] as f64;
90            let fan_out = self.shape[conn.out_neuron_id.0] as f64;
91            let std_dev = (2.0 / (fan_in + fan_out)).sqrt();
92
93            let normal = Normal::new(0.0, std_dev).unwrap();
94            conn.weight = normal.sample(rng);
95        }
96    }
97}