neural_network_rs/neural_network/optimizer/
sgd_optimzer.rs

1use itertools::izip;
2use ndarray::Array2;
3
4use crate::neural_network::{layer::Layer, Summary};
5
6use super::Optimizer;
7
8pub struct SGD {
9    momentum: f64,
10    learning_rate: f64,
11    decay: f64,
12    iteration: usize,
13    current_learning_rate: f64,
14    weights_momentum: Vec<Array2<f64>>,
15    biases_momentum: Vec<Array2<f64>>,
16}
17
18impl SGD {
19    pub fn new(learning_rate: f64, momentum: f64, decay: f64) -> SGD {
20        SGD {
21            learning_rate,
22            momentum,
23            decay,
24            iteration: 0,
25            current_learning_rate: learning_rate,
26            weights_momentum: Vec::new(),
27            biases_momentum: Vec::new(),
28        }
29    }
30
31    pub fn default() -> SGD {
32        SGD::new(0.1, 0.5, 0.0005)
33    }
34}
35
36impl Optimizer for SGD {
37    fn update_params(
38        &mut self,
39        layers: &mut Vec<Layer>,
40        nabla_bs: &Vec<Array2<f64>>,
41        nabla_ws: &Vec<Array2<f64>>,
42    ) {
43        for (i, (layer, nabla_b, nabla_w)) in izip!(layers, nabla_bs, nabla_ws).enumerate() {
44            //Calculate standart update_params
45            let mut weights_update = -self.current_learning_rate * nabla_w;
46            let mut biases_update = -self.current_learning_rate * nabla_b;
47
48            //Add momentum
49            if self.momentum > 0.0 {
50                let weights_momentum = &self.weights_momentum[i];
51                let biases_momentum = &self.biases_momentum[i];
52
53                weights_update = weights_update + self.momentum * weights_momentum;
54                biases_update = biases_update + self.momentum * biases_momentum;
55
56                self.weights_momentum[i] = weights_update.clone();
57                self.biases_momentum[i] = biases_update.clone();
58            }
59
60            //Update weights and biases
61            layer.weights = &layer.weights + weights_update;
62            layer.biases = &layer.biases + biases_update;
63        }
64    }
65
66    fn initialize(&mut self, layers: &Vec<Layer>) {
67        for layer in layers {
68            self.weights_momentum
69                .push(Array2::zeros(layer.weights.dim()));
70            self.biases_momentum.push(Array2::zeros(layer.biases.dim()));
71        }
72    }
73
74    fn pre_update(&mut self) {
75        if self.decay > 0.0 {
76            self.current_learning_rate =
77                self.learning_rate * (1.0 / (1.0 + self.decay * self.iteration as f64));
78        }
79    }
80
81    fn post_update(&mut self) {
82        self.iteration += 1;
83    }
84}
85
86impl Summary for SGD {
87    fn summerize(&self) -> String {
88        "SGD".to_string()
89    }
90}