neural_network_rs/neural_network/optimizer/
adam_optimizer.rs1use itertools::izip;
2use ndarray::Array2;
3
4use crate::neural_network::{layer::Layer, Summary};
5
6use super::Optimizer;
7
8#[allow(non_camel_case_types)]
9pub struct ADAM {
10 learning_rate: f64,
11 decay: f64,
12 iteration: usize,
13 current_learning_rate: f64,
14 epsilon: f64,
15 beta_1: f64,
16 beta_2: f64,
17 weights_cache: Vec<Array2<f64>>,
18 biases_cache: Vec<Array2<f64>>,
19 weights_momentum: Vec<Array2<f64>>,
20 biases_momentum: Vec<Array2<f64>>,
21}
22
23impl ADAM {
24 pub fn new(learning_rate: f64, decay: f64, epsilon: f64, beta_1: f64, beta_2: f64) -> ADAM {
25 ADAM {
26 learning_rate,
27 decay,
28 iteration: 0,
29 current_learning_rate: learning_rate,
30 weights_cache: Vec::new(),
31 biases_cache: Vec::new(),
32 weights_momentum: Vec::new(),
33 biases_momentum: Vec::new(),
34 epsilon,
35 beta_1,
36 beta_2,
37 }
38 }
39
40 pub fn default() -> ADAM {
41 ADAM::new(0.002, 1e-5, 1e-7, 0.9, 0.999)
42 }
43}
44
45impl Optimizer for ADAM {
46 fn update_params(
47 &mut self,
48 layers: &mut Vec<Layer>,
49
50 nabla_bs: &Vec<Array2<f64>>,
51 nabla_ws: &Vec<Array2<f64>>,
52 ) {
53 for (i, (layer, nabla_b, nabla_w)) in izip!(layers, nabla_bs, nabla_ws).enumerate() {
54 self.weights_momentum[i] =
56 self.beta_1 * &self.weights_momentum[i] + (1.0 - self.beta_1) * nabla_w;
57
58 self.biases_momentum[i] =
59 self.beta_1 * &self.biases_momentum[i] + (1.0 - self.beta_1) * nabla_b;
60
61 self.weights_cache[i] =
63 self.beta_2 * &self.weights_cache[i] + (1.0 - self.beta_2) * (nabla_w * nabla_w);
64
65 self.biases_cache[i] =
66 self.beta_2 * &self.biases_cache[i] + (1.0 - self.beta_2) * (nabla_b * nabla_b);
67
68 let weights_momentum_corrected =
70 &self.weights_momentum[i] / (1.0 - self.beta_1.powi(i as i32 + 1));
71 let biases_momentum_corrected =
72 &self.biases_momentum[i] / (1.0 - self.beta_1.powi(i as i32 + 1));
73 let weights_cache_corrected =
74 &self.weights_cache[i] / (1.0 - self.beta_2.powi(i as i32 + 1));
75 let biases_cache_corrected =
76 &self.biases_cache[i] / (1.0 - self.beta_2.powi(i as i32 + 1));
77
78 let weights_update = self.current_learning_rate * weights_momentum_corrected
79 / (weights_cache_corrected.mapv(f64::sqrt) + self.epsilon);
80 let biases_update = self.current_learning_rate * biases_momentum_corrected
81 / (biases_cache_corrected.mapv(f64::sqrt) + self.epsilon);
82
83 layer.weights = &layer.weights - &weights_update;
85 layer.biases = &layer.biases - &biases_update;
86 }
87 }
88
89 fn initialize(&mut self, layers: &Vec<Layer>) {
90 for layer in layers {
91 self.weights_cache.push(Array2::zeros(layer.weights.dim()));
92 self.biases_cache.push(Array2::zeros(layer.biases.dim()));
93 self.weights_momentum
94 .push(Array2::zeros(layer.weights.dim()));
95 self.biases_momentum.push(Array2::zeros(layer.biases.dim()));
96 }
97 }
98
99 fn pre_update(&mut self) {
100 if self.decay > 0.0 {
101 self.current_learning_rate =
102 self.learning_rate * (1.0 / (1.0 + self.decay * self.iteration as f64));
103 }
104 }
105
106 fn post_update(&mut self) {
107 self.iteration += 1;
108 }
109}
110
111impl Summary for ADAM {
112 fn summerize(&self) -> String {
113 "ADAM".to_string()
114 }
115}