nevermind_neu/optimizers/
optim_rms.rs1use crate::optimizers::*;
2
3use std::collections::HashMap;
4use crate::util::*;
5use crate::cpu_params::*;
6
7#[derive(Clone)]
8pub struct OptimizerRMS {
9 pub learn_rate: f32,
10 pub alpha: f32,
11 pub theta: f32,
12 pub rms: HashMap<u64, HashMap<i32, VariantParam>>,
13}
14
15impl OptimizerRMS {
16 pub fn new(learn_rate: f32, alpha: f32) -> Self {
17 Self {
18 learn_rate,
19 alpha,
20 theta: 1e-8,
21 rms: HashMap::new()
22 }
23 }
24}
25
26impl Default for OptimizerRMS {
27 fn default() -> Self {
28 Self {
29 learn_rate: 1e-2,
30 alpha: 0.9,
31 theta: 1e-8,
32 rms: HashMap::new()
33 }
34 }
35}
36
37impl OptimizerRMS {
38 fn optimize_layer(
39 buf: &mut [f32],
40 buf_grad: &[f32],
41 rms: &mut [f32],
42 learn_rate: &f32,
43 alpha: &f32,
44 theta: &f32,
45 ) {
46 for ((buf_v, buf_grad_v), rms_v) in buf.iter_mut().zip(buf_grad.iter()).zip(rms.iter_mut()) {
47 if *buf_grad_v == 0.0 {
48 continue;
49 }
50
51 *rms_v = *alpha * *rms_v + (1.0 - alpha) * buf_grad_v.powf(2.0);
52 *buf_v += (learn_rate / (*rms_v + theta).sqrt()) * buf_grad_v;
53 }
54 }
55}
56
57impl Optimizer for OptimizerRMS {
58 fn optimize_params(&mut self, lp: &mut CpuParams, opt_prms: TrainableBufsIds) {
59 if !self.rms.contains_key(&lp.id) {
60 self.rms.insert(lp.id, HashMap::new());
61 }
62
63 for (buf_id, buf_grad_id) in opt_prms.0.iter().zip(opt_prms.1.iter()) {
64 let buf_grad = lp.get_param(*buf_grad_id);
65 let rms_val = self.rms.get_mut(&lp.id).unwrap();
66
67 if !rms_val.contains_key(buf_grad_id) {
68 let zeroed_param = VariantParam::copy_zeroed_shape_from(&buf_grad);
69 rms_val.insert(*buf_grad_id, zeroed_param);
70 }
71
72 let rms_m = rms_val.get_mut(buf_grad_id).unwrap();
73
74 match rms_m {
75 VariantParam::Array1(arr1) => {
76 let buf_grad_slice = buf_grad.get_arr_1d();
77 let buf_grad_slice = buf_grad_slice.borrow();
78 let buf_grad_slice = buf_grad_slice.as_slice().unwrap();
79
80 let buf_slice = lp.get_1d_buf(*buf_id);
81 let mut buf_slice = buf_slice.borrow_mut();
82 let buf_slice = buf_slice.as_slice_mut().unwrap();
83
84 let rms_slice = arr1.as_slice_mut().unwrap();
85
86 OptimizerRMS::optimize_layer(
87 buf_slice,
88 buf_grad_slice,
89 rms_slice,
90 &self.learn_rate,
91 &self.alpha,
92 &self.theta,
93 );
94 }
95 VariantParam::Array2(arr2) => {
96 let buf_grad_slice = buf_grad.get_arr_2d();
97 let buf_grad_slice = buf_grad_slice.borrow();
98 let buf_grad_slice = buf_grad_slice.as_slice().unwrap();
99
100 let buf_slice = lp.get_2d_buf(*buf_id);
101 let mut buf_slice = buf_slice.borrow_mut();
102 let buf_slice = buf_slice.as_slice_mut().unwrap();
103
104 let rms_slice = arr2.as_slice_mut().unwrap();
105
106 OptimizerRMS::optimize_layer(
107 buf_slice,
108 buf_grad_slice,
109 rms_slice,
110 &self.learn_rate,
111 &self.alpha,
112 &self.theta,
113 );
114 }
115 }
116 }
117 }
118}
119
120impl WithParams for OptimizerRMS {
121 fn cfg(&self) -> HashMap<String, Variant> {
122 let mut cfg_params = HashMap::new();
123
124 cfg_params.insert("type".to_string(), Variant::String("rmsprop".to_string()));
125 cfg_params.insert("learning_rate".to_string(), Variant::Float(self.learn_rate));
126 cfg_params.insert("alpha".to_string(), Variant::Float(self.alpha));
127 cfg_params.insert("theta".to_string(), Variant::Float(self.theta));
128
129 cfg_params
130 }
131
132 fn set_cfg(&mut self, args: &HashMap<String, Variant>) {
133 if args.contains_key("learning_rate") {
134 if let Variant::Float(v) = args.get("learning_rate").unwrap() {
135 self.learn_rate = *v;
136 }
137 }
138
139 if args.contains_key("alpha") {
140 if let Variant::Float(v) = args.get("alpha").unwrap() {
141 self.alpha = *v;
142 }
143 }
144
145 if args.contains_key("theta") {
146 if let Variant::Float(v) = args.get("theta").unwrap() {
147 self.theta = *v;
148 }
149 }
150 }
151}