neural_network_rs/neural_network/optimizer/
sgd_optimzer.rs1use itertools::izip;
2use ndarray::Array2;
3
4use crate::neural_network::{layer::Layer, Summary};
5
6use super::Optimizer;
7
8pub struct SGD {
9 momentum: f64,
10 learning_rate: f64,
11 decay: f64,
12 iteration: usize,
13 current_learning_rate: f64,
14 weights_momentum: Vec<Array2<f64>>,
15 biases_momentum: Vec<Array2<f64>>,
16}
17
18impl SGD {
19 pub fn new(learning_rate: f64, momentum: f64, decay: f64) -> SGD {
20 SGD {
21 learning_rate,
22 momentum,
23 decay,
24 iteration: 0,
25 current_learning_rate: learning_rate,
26 weights_momentum: Vec::new(),
27 biases_momentum: Vec::new(),
28 }
29 }
30
31 pub fn default() -> SGD {
32 SGD::new(0.1, 0.5, 0.0005)
33 }
34}
35
36impl Optimizer for SGD {
37 fn update_params(
38 &mut self,
39 layers: &mut Vec<Layer>,
40 nabla_bs: &Vec<Array2<f64>>,
41 nabla_ws: &Vec<Array2<f64>>,
42 ) {
43 for (i, (layer, nabla_b, nabla_w)) in izip!(layers, nabla_bs, nabla_ws).enumerate() {
44 let mut weights_update = -self.current_learning_rate * nabla_w;
46 let mut biases_update = -self.current_learning_rate * nabla_b;
47
48 if self.momentum > 0.0 {
50 let weights_momentum = &self.weights_momentum[i];
51 let biases_momentum = &self.biases_momentum[i];
52
53 weights_update = weights_update + self.momentum * weights_momentum;
54 biases_update = biases_update + self.momentum * biases_momentum;
55
56 self.weights_momentum[i] = weights_update.clone();
57 self.biases_momentum[i] = biases_update.clone();
58 }
59
60 layer.weights = &layer.weights + weights_update;
62 layer.biases = &layer.biases + biases_update;
63 }
64 }
65
66 fn initialize(&mut self, layers: &Vec<Layer>) {
67 for layer in layers {
68 self.weights_momentum
69 .push(Array2::zeros(layer.weights.dim()));
70 self.biases_momentum.push(Array2::zeros(layer.biases.dim()));
71 }
72 }
73
74 fn pre_update(&mut self) {
75 if self.decay > 0.0 {
76 self.current_learning_rate =
77 self.learning_rate * (1.0 / (1.0 + self.decay * self.iteration as f64));
78 }
79 }
80
81 fn post_update(&mut self) {
82 self.iteration += 1;
83 }
84}
85
86impl Summary for SGD {
87 fn summerize(&self) -> String {
88 "SGD".to_string()
89 }
90}