neural_network_rs/neural_network/optimizer/
rmsprop_optimizer.rs1use itertools::izip;
2use ndarray::Array2;
3
4use crate::neural_network::{layer::Layer, Summary};
5
6use super::Optimizer;
7
8#[allow(non_camel_case_types)]
9pub struct RMS_PROP {
10 learning_rate: f64,
11 decay: f64,
12 iteration: usize,
13 current_learning_rate: f64,
14 epsilon: f64,
15 rho: f64,
16 weights_cache: Vec<Array2<f64>>,
17 biases_cache: Vec<Array2<f64>>,
18}
19
20impl RMS_PROP {
21 pub fn new(learning_rate: f64, decay: f64, epsilon: f64, rho: f64) -> RMS_PROP {
22 RMS_PROP {
23 learning_rate,
24 decay,
25 iteration: 0,
26 current_learning_rate: learning_rate,
27 weights_cache: Vec::new(),
28 biases_cache: Vec::new(),
29 epsilon,
30 rho,
31 }
32 }
33
34 pub fn default() -> RMS_PROP {
35 RMS_PROP::new(0.001, 1e-4, 1e-7, 0.9)
36 }
37}
38
39impl Optimizer for RMS_PROP {
40 fn update_params(
41 &mut self,
42 layers: &mut Vec<Layer>,
43
44 nabla_bs: &Vec<Array2<f64>>,
45 nabla_ws: &Vec<Array2<f64>>,
46 ) {
47 for (i, (layer, nabla_b, nabla_w)) in izip!(layers, nabla_bs, nabla_ws).enumerate() {
48 self.weights_cache[i] =
50 self.rho * &self.weights_cache[i] + (1.0 - self.rho) * (nabla_w * nabla_w);
51 self.biases_cache[i] =
52 self.rho * &self.biases_cache[i] + (1.0 - self.rho) * (nabla_b * nabla_b);
53
54 let weights_update = -self.current_learning_rate * nabla_w
56 / (self.weights_cache[i].mapv(|x| x.sqrt()) + self.epsilon);
57 let biases_update = -self.current_learning_rate * nabla_b
58 / (self.biases_cache[i].mapv(|x| x.sqrt()) + self.epsilon);
59
60 layer.weights = &layer.weights + weights_update;
62 layer.biases = &layer.biases + biases_update;
63 }
64 }
65
66 fn initialize(&mut self, layers: &Vec<Layer>) {
67 for layer in layers {
68 self.weights_cache.push(Array2::zeros(layer.weights.dim()));
69 self.biases_cache.push(Array2::zeros(layer.biases.dim()));
70 }
71 }
72
73 fn pre_update(&mut self) {
74 if self.decay > 0.0 {
75 self.current_learning_rate =
76 self.learning_rate * (1.0 / (1.0 + self.decay * self.iteration as f64));
77 }
78 }
79
80 fn post_update(&mut self) {
81 self.iteration += 1;
82 }
83}
84
85impl Summary for RMS_PROP {
86 fn summerize(&self) -> String {
87 "RMS_PROP".to_string()
88 }
89}