neural_network_rs/neural_network/optimizer/
adam_optimizer.rs

1use itertools::izip;
2use ndarray::Array2;
3
4use crate::neural_network::{layer::Layer, Summary};
5
6use super::Optimizer;
7
8#[allow(non_camel_case_types)]
9pub struct ADAM {
10    learning_rate: f64,
11    decay: f64,
12    iteration: usize,
13    current_learning_rate: f64,
14    epsilon: f64,
15    beta_1: f64,
16    beta_2: f64,
17    weights_cache: Vec<Array2<f64>>,
18    biases_cache: Vec<Array2<f64>>,
19    weights_momentum: Vec<Array2<f64>>,
20    biases_momentum: Vec<Array2<f64>>,
21}
22
23impl ADAM {
24    pub fn new(learning_rate: f64, decay: f64, epsilon: f64, beta_1: f64, beta_2: f64) -> ADAM {
25        ADAM {
26            learning_rate,
27            decay,
28            iteration: 0,
29            current_learning_rate: learning_rate,
30            weights_cache: Vec::new(),
31            biases_cache: Vec::new(),
32            weights_momentum: Vec::new(),
33            biases_momentum: Vec::new(),
34            epsilon,
35            beta_1,
36            beta_2,
37        }
38    }
39
40    pub fn default() -> ADAM {
41        ADAM::new(0.002, 1e-5, 1e-7, 0.9, 0.999)
42    }
43}
44
45impl Optimizer for ADAM {
46    fn update_params(
47        &mut self,
48        layers: &mut Vec<Layer>,
49
50        nabla_bs: &Vec<Array2<f64>>,
51        nabla_ws: &Vec<Array2<f64>>,
52    ) {
53        for (i, (layer, nabla_b, nabla_w)) in izip!(layers, nabla_bs, nabla_ws).enumerate() {
54            //update momentum
55            self.weights_momentum[i] =
56                self.beta_1 * &self.weights_momentum[i] + (1.0 - self.beta_1) * nabla_w;
57
58            self.biases_momentum[i] =
59                self.beta_1 * &self.biases_momentum[i] + (1.0 - self.beta_1) * nabla_b;
60
61            //update cache
62            self.weights_cache[i] =
63                self.beta_2 * &self.weights_cache[i] + (1.0 - self.beta_2) * (nabla_w * nabla_w);
64
65            self.biases_cache[i] =
66                self.beta_2 * &self.biases_cache[i] + (1.0 - self.beta_2) * (nabla_b * nabla_b);
67
68            //corrections
69            let weights_momentum_corrected =
70                &self.weights_momentum[i] / (1.0 - self.beta_1.powi(i as i32 + 1));
71            let biases_momentum_corrected =
72                &self.biases_momentum[i] / (1.0 - self.beta_1.powi(i as i32 + 1));
73            let weights_cache_corrected =
74                &self.weights_cache[i] / (1.0 - self.beta_2.powi(i as i32 + 1));
75            let biases_cache_corrected =
76                &self.biases_cache[i] / (1.0 - self.beta_2.powi(i as i32 + 1));
77
78            let weights_update = self.current_learning_rate * weights_momentum_corrected
79                / (weights_cache_corrected.mapv(f64::sqrt) + self.epsilon);
80            let biases_update = self.current_learning_rate * biases_momentum_corrected
81                / (biases_cache_corrected.mapv(f64::sqrt) + self.epsilon);
82
83            //updates
84            layer.weights = &layer.weights - &weights_update;
85            layer.biases = &layer.biases - &biases_update;
86        }
87    }
88
89    fn initialize(&mut self, layers: &Vec<Layer>) {
90        for layer in layers {
91            self.weights_cache.push(Array2::zeros(layer.weights.dim()));
92            self.biases_cache.push(Array2::zeros(layer.biases.dim()));
93            self.weights_momentum
94                .push(Array2::zeros(layer.weights.dim()));
95            self.biases_momentum.push(Array2::zeros(layer.biases.dim()));
96        }
97    }
98
99    fn pre_update(&mut self) {
100        if self.decay > 0.0 {
101            self.current_learning_rate =
102                self.learning_rate * (1.0 / (1.0 + self.decay * self.iteration as f64));
103        }
104    }
105
106    fn post_update(&mut self) {
107        self.iteration += 1;
108    }
109}
110
111impl Summary for ADAM {
112    fn summerize(&self) -> String {
113        "ADAM".to_string()
114    }
115}