neural_network_rs/neural_network/optimizer/
rmsprop_optimizer.rs

1use itertools::izip;
2use ndarray::Array2;
3
4use crate::neural_network::{layer::Layer, Summary};
5
6use super::Optimizer;
7
8#[allow(non_camel_case_types)]
9pub struct RMS_PROP {
10    learning_rate: f64,
11    decay: f64,
12    iteration: usize,
13    current_learning_rate: f64,
14    epsilon: f64,
15    rho: f64,
16    weights_cache: Vec<Array2<f64>>,
17    biases_cache: Vec<Array2<f64>>,
18}
19
20impl RMS_PROP {
21    pub fn new(learning_rate: f64, decay: f64, epsilon: f64, rho: f64) -> RMS_PROP {
22        RMS_PROP {
23            learning_rate,
24            decay,
25            iteration: 0,
26            current_learning_rate: learning_rate,
27            weights_cache: Vec::new(),
28            biases_cache: Vec::new(),
29            epsilon,
30            rho,
31        }
32    }
33
34    pub fn default() -> RMS_PROP {
35        RMS_PROP::new(0.001, 1e-4, 1e-7, 0.9)
36    }
37}
38
39impl Optimizer for RMS_PROP {
40    fn update_params(
41        &mut self,
42        layers: &mut Vec<Layer>,
43
44        nabla_bs: &Vec<Array2<f64>>,
45        nabla_ws: &Vec<Array2<f64>>,
46    ) {
47        for (i, (layer, nabla_b, nabla_w)) in izip!(layers, nabla_bs, nabla_ws).enumerate() {
48            //update cache
49            self.weights_cache[i] =
50                self.rho * &self.weights_cache[i] + (1.0 - self.rho) * (nabla_w * nabla_w);
51            self.biases_cache[i] =
52                self.rho * &self.biases_cache[i] + (1.0 - self.rho) * (nabla_b * nabla_b);
53
54            //calculate updates
55            let weights_update = -self.current_learning_rate * nabla_w
56                / (self.weights_cache[i].mapv(|x| x.sqrt()) + self.epsilon);
57            let biases_update = -self.current_learning_rate * nabla_b
58                / (self.biases_cache[i].mapv(|x| x.sqrt()) + self.epsilon);
59
60            //Update weights and biases
61            layer.weights = &layer.weights + weights_update;
62            layer.biases = &layer.biases + biases_update;
63        }
64    }
65
66    fn initialize(&mut self, layers: &Vec<Layer>) {
67        for layer in layers {
68            self.weights_cache.push(Array2::zeros(layer.weights.dim()));
69            self.biases_cache.push(Array2::zeros(layer.biases.dim()));
70        }
71    }
72
73    fn pre_update(&mut self) {
74        if self.decay > 0.0 {
75            self.current_learning_rate =
76                self.learning_rate * (1.0 / (1.0 + self.decay * self.iteration as f64));
77        }
78    }
79
80    fn post_update(&mut self) {
81        self.iteration += 1;
82    }
83}
84
85impl Summary for RMS_PROP {
86    fn summerize(&self) -> String {
87        "RMS_PROP".to_string()
88    }
89}