nevermind_neu/optimizers/
optim_sgd.rs

1use crate::optimizers::*;
2use crate::cpu_params::*;
3use crate::util::*;
4
5use std::collections::HashMap;
6
7pub struct OptimizerSGD {
8    pub learn_rate: f32,
9    pub momentum: f32,
10    pub delta: HashMap<u64, HashMap<i32, VariantParam>>,
11}
12
13impl OptimizerSGD {
14    pub fn new(learn_rate: f32, momentum: f32) -> Self {
15        Self {
16            learn_rate,
17            momentum,
18            delta: HashMap::new(),
19        }
20    }
21}
22
23impl Default for OptimizerSGD {
24    fn default() -> Self {
25        Self {
26            learn_rate: 1e-2,
27            momentum: 0.8,
28            delta: HashMap::new(),
29        }
30    }
31}
32
33impl OptimizerSGD {
34    fn optimize_layer(
35        buf: &mut [f32],
36        buf_grad: &[f32],
37        delta: &mut [f32],
38        learn_rate: &f32,
39        momentum: &f32,
40    ) {
41        for ((buf_v, buf_grad_v), delta_v) in buf.iter_mut().zip(buf_grad.iter()).zip(delta.iter_mut()) {
42            if *buf_grad_v == 0.0 {
43                continue;
44            }
45
46            *buf_v += *momentum * *delta_v;
47            *delta_v = *learn_rate * buf_grad_v;
48            *buf_v += *delta_v;
49        }
50    }
51}
52
53impl Optimizer for OptimizerSGD {
54    fn optimize_params(&mut self, lp: &mut CpuParams, opt_prms: TrainableBufsIds) {
55        if !self.delta.contains_key(&lp.id) {
56            self.delta.insert(lp.id, HashMap::new());
57        }
58
59        let delta_val = self.delta.get_mut(&lp.id).unwrap();
60
61        for (buf_id, buf_grad_id) in opt_prms.0.iter().zip(opt_prms.1.iter()) {
62            let buf_grad = lp.get_param(*buf_grad_id);
63
64            if !delta_val.contains_key(buf_grad_id) {
65                let zeroed_param = VariantParam::copy_zeroed_shape_from(&buf_grad);
66                delta_val.insert(*buf_grad_id, zeroed_param);
67            }
68
69            let delta_m = delta_val.get_mut(buf_grad_id).unwrap();
70
71            match delta_m {
72                VariantParam::Array1(arr1) => {
73                    let buf_grad_slice = buf_grad.get_arr_1d();
74                    let buf_grad_slice = buf_grad_slice.borrow();
75                    let buf_grad_slice = buf_grad_slice.as_slice().unwrap();
76
77                    let buf_slice = lp.get_1d_buf(*buf_id);
78                    let mut buf_slice = buf_slice.borrow_mut();
79                    let buf_slice = buf_slice.as_slice_mut().unwrap();
80
81                    let delta_slice = arr1.as_slice_mut().unwrap();
82                    OptimizerSGD::optimize_layer(
83                        buf_slice,
84                        buf_grad_slice,
85                        delta_slice,
86                        &self.learn_rate,
87                        &self.momentum
88                    );
89                }
90                VariantParam::Array2(arr2) => {
91                    let buf_grad_slice = buf_grad.get_arr_2d();
92                    let buf_grad_slice = buf_grad_slice.borrow();
93                    let buf_grad_slice = buf_grad_slice.as_slice().unwrap();
94
95                    let buf_slice = lp.get_2d_buf(*buf_id);
96                    let mut buf_slice = buf_slice.borrow_mut();
97                    let buf_slice = buf_slice.as_slice_mut().unwrap();
98
99                    let delta_slice = arr2.as_slice_mut().unwrap();
100                    OptimizerSGD::optimize_layer(
101                        buf_slice,
102                        buf_grad_slice,
103                        delta_slice,
104                        &self.learn_rate,
105                        &self.momentum
106                    );
107                }
108            }
109        }
110    }
111}
112
113impl WithParams for OptimizerSGD {
114    fn cfg(&self) -> HashMap<String, Variant> {
115        let mut cfg_params = HashMap::new();
116
117        cfg_params.insert("type".to_string(), Variant::String("sgd".to_string()));
118        cfg_params.insert("learning_rate".to_string(), Variant::Float(self.learn_rate));
119        cfg_params.insert("momentum".to_string(), Variant::Float(self.momentum));
120
121        cfg_params
122    }
123
124    fn set_cfg(&mut self, args: &HashMap<String, Variant>) {
125        if args.contains_key("learning_rate") {
126            if let Variant::Float(v) = args.get("learning_rate").unwrap() {
127                self.learn_rate = *v;
128            }
129        }
130
131        if args.contains_key("momentum") {
132            if let Variant::Float(v) = args.get("momentum").unwrap() {
133                self.momentum = *v;
134            }
135        }
136    }
137}