astrai/network/
initialization.rs1use super::*;
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4pub enum InitializationMethod {
5 Xavier,
6}
7
8impl Network {
9 pub fn new(
10 input_neuron_amount: usize,
11 output_neuron_amount: usize,
12 other_layer_amounts: Option<Vec<usize>>,
13 layer_activation_functions: Option<Vec<Option<ActivationFunction>>>,
14 ) -> Network {
15 let mut net = Network {
16 layers: vec![],
17 connections: vec![],
18 shape: vec![]
19 };
20
21 net.layers.push(Layer::new(
22 input_neuron_amount,
23 ActivationFunction::Identity,
24 ));
25 net.shape.push(input_neuron_amount);
26
27 let acts: Vec<ActivationFunction> = layer_activation_functions
28 .unwrap_or_else(|| {
29 vec![
30 Some(ActivationFunction::Sigmoid);
31 other_layer_amounts.clone().unwrap_or_default().len() + 1
32 ]
33 })
34 .iter()
35 .map(|act| act.unwrap_or(ActivationFunction::Sigmoid))
36 .collect();
37
38 if let Some(other_layer_amounts) = other_layer_amounts {
39 for (idx, amount) in other_layer_amounts.iter().enumerate() {
40 net.layers.push(Layer::new(*amount, acts[idx]));
41 net.shape.push(*amount);
42 }
43 }
44
45 net.layers
46 .push(Layer::new(output_neuron_amount, *acts.last().unwrap()));
47 net.shape.push(output_neuron_amount);
48 net
49 }
50
51 pub fn all_connect(&mut self) {
52 for (i, layer) in self.layers.clone().iter().enumerate() {
53 if i == self.layers.len() - 1 {
54 break;
55 }
56 for j in 0..layer.neuron_amt {
57 for k in 0..self.layers[i + 1].clone().neuron_amt {
58 self.add_connection((i, j), (i + 1, k), 0.0);
59 }
60 }
61 }
62 }
63
64 pub fn initialize_weights(&mut self, method: InitializationMethod) {
65 let mut rng = rand::thread_rng();
66
67 match method {
68 InitializationMethod::Xavier => self.xavier_init(&mut rng),
69 }
70 }
71
72 fn xavier_init(&mut self, rng: &mut impl Rng) {
73 for (i, layer) in self.layers.clone().iter().enumerate() {
74 let fan_in = self.shape[i] as f64;
75
76 let fan_out = if i == self.shape.len() - 1 {
77 self.shape[i - 1]
79 } else {
80 self.shape[i + 1]
81 } as f64;
82
83 let normal = Normal::new(0.0, (2.0 / (fan_in + fan_out)).sqrt()).unwrap();
84
85 self.layers[i].bias = Array1::from_shape_fn(layer.neuron_amt, |_| normal.sample(rng));
86 }
87
88 for conn in self.connections.iter_mut() {
89 let fan_in = self.shape[conn.in_neuron_id.0] as f64;
90 let fan_out = self.shape[conn.out_neuron_id.0] as f64;
91 let std_dev = (2.0 / (fan_in + fan_out)).sqrt();
92
93 let normal = Normal::new(0.0, std_dev).unwrap();
94 conn.weight = normal.sample(rng);
95 }
96 }
97}