nevermind_neu/optimizers/
optim_rms.rs

1use crate::optimizers::*;
2
3use std::collections::HashMap;
4use crate::util::*;
5use crate::cpu_params::*;
6
7#[derive(Clone)]
8pub struct OptimizerRMS {
9    pub learn_rate: f32,
10    pub alpha: f32,
11    pub theta: f32,
12    pub rms: HashMap<u64, HashMap<i32, VariantParam>>,
13}
14
15impl OptimizerRMS {
16    pub fn new(learn_rate: f32, alpha: f32) -> Self {
17        Self {
18            learn_rate,
19            alpha,
20            theta: 1e-8,
21            rms: HashMap::new()
22        }
23    }
24}
25
26impl Default for OptimizerRMS {
27    fn default() -> Self {
28        Self {
29            learn_rate: 1e-2,
30            alpha: 0.9,
31            theta: 1e-8,
32            rms: HashMap::new()
33        }
34    }
35}
36
37impl OptimizerRMS {
38    fn optimize_layer(
39        buf: &mut [f32],
40        buf_grad: &[f32],
41        rms: &mut [f32],
42        learn_rate: &f32,
43        alpha: &f32,
44        theta: &f32,
45    ) {
46        for ((buf_v, buf_grad_v), rms_v) in buf.iter_mut().zip(buf_grad.iter()).zip(rms.iter_mut()) {
47            if *buf_grad_v == 0.0 {
48                continue;
49            }
50
51            *rms_v = *alpha * *rms_v + (1.0 - alpha) * buf_grad_v.powf(2.0);
52            *buf_v += (learn_rate / (*rms_v + theta).sqrt()) * buf_grad_v;
53        }
54    }
55}
56
57impl Optimizer for OptimizerRMS {
58    fn optimize_params(&mut self, lp: &mut CpuParams, opt_prms: TrainableBufsIds) {
59        if !self.rms.contains_key(&lp.id) {
60            self.rms.insert(lp.id, HashMap::new());
61        }
62
63        for (buf_id, buf_grad_id) in opt_prms.0.iter().zip(opt_prms.1.iter()) {
64            let buf_grad = lp.get_param(*buf_grad_id);
65            let rms_val = self.rms.get_mut(&lp.id).unwrap();
66
67            if !rms_val.contains_key(buf_grad_id) {
68                let zeroed_param = VariantParam::copy_zeroed_shape_from(&buf_grad);
69                rms_val.insert(*buf_grad_id, zeroed_param);
70            }
71
72            let rms_m = rms_val.get_mut(buf_grad_id).unwrap();
73
74            match rms_m {
75                VariantParam::Array1(arr1) => {
76                    let buf_grad_slice = buf_grad.get_arr_1d();
77                    let buf_grad_slice = buf_grad_slice.borrow();
78                    let buf_grad_slice = buf_grad_slice.as_slice().unwrap();
79
80                    let buf_slice = lp.get_1d_buf(*buf_id);
81                    let mut buf_slice = buf_slice.borrow_mut();
82                    let buf_slice = buf_slice.as_slice_mut().unwrap();
83
84                    let rms_slice = arr1.as_slice_mut().unwrap();
85
86                    OptimizerRMS::optimize_layer(
87                        buf_slice,
88                        buf_grad_slice,
89                        rms_slice,
90                        &self.learn_rate,
91                        &self.alpha,
92                        &self.theta,
93                    );
94                }
95                VariantParam::Array2(arr2) => {
96                    let buf_grad_slice = buf_grad.get_arr_2d();
97                    let buf_grad_slice = buf_grad_slice.borrow();
98                    let buf_grad_slice = buf_grad_slice.as_slice().unwrap();
99
100                    let buf_slice = lp.get_2d_buf(*buf_id);
101                    let mut buf_slice = buf_slice.borrow_mut();
102                    let buf_slice = buf_slice.as_slice_mut().unwrap();
103
104                    let rms_slice = arr2.as_slice_mut().unwrap();
105
106                    OptimizerRMS::optimize_layer(
107                        buf_slice,
108                        buf_grad_slice,
109                        rms_slice,
110                        &self.learn_rate,
111                        &self.alpha,
112                        &self.theta,
113                    );
114                }
115            }
116        }
117    }
118}
119
120impl WithParams for OptimizerRMS {
121    fn cfg(&self) -> HashMap<String, Variant> {
122        let mut cfg_params = HashMap::new();
123
124        cfg_params.insert("type".to_string(), Variant::String("rmsprop".to_string()));
125        cfg_params.insert("learning_rate".to_string(), Variant::Float(self.learn_rate));
126        cfg_params.insert("alpha".to_string(), Variant::Float(self.alpha));
127        cfg_params.insert("theta".to_string(), Variant::Float(self.theta));
128
129        cfg_params
130    }
131
132    fn set_cfg(&mut self, args: &HashMap<String, Variant>) {
133        if args.contains_key("learning_rate") {
134            if let Variant::Float(v) = args.get("learning_rate").unwrap() {
135                self.learn_rate = *v;
136            }
137        }
138
139        if args.contains_key("alpha") {
140            if let Variant::Float(v) = args.get("alpha").unwrap() {
141                self.alpha = *v;
142            }
143        }
144
145        if args.contains_key("theta") {
146            if let Variant::Float(v) = args.get("theta").unwrap() {
147                self.theta = *v;
148            }
149        }
150    }
151}