nevermind_neu/optimizers/
optim_sgd.rs1use crate::optimizers::*;
2use crate::cpu_params::*;
3use crate::util::*;
4
5use std::collections::HashMap;
6
7pub struct OptimizerSGD {
8 pub learn_rate: f32,
9 pub momentum: f32,
10 pub delta: HashMap<u64, HashMap<i32, VariantParam>>,
11}
12
13impl OptimizerSGD {
14 pub fn new(learn_rate: f32, momentum: f32) -> Self {
15 Self {
16 learn_rate,
17 momentum,
18 delta: HashMap::new(),
19 }
20 }
21}
22
23impl Default for OptimizerSGD {
24 fn default() -> Self {
25 Self {
26 learn_rate: 1e-2,
27 momentum: 0.8,
28 delta: HashMap::new(),
29 }
30 }
31}
32
33impl OptimizerSGD {
34 fn optimize_layer(
35 buf: &mut [f32],
36 buf_grad: &[f32],
37 delta: &mut [f32],
38 learn_rate: &f32,
39 momentum: &f32,
40 ) {
41 for ((buf_v, buf_grad_v), delta_v) in buf.iter_mut().zip(buf_grad.iter()).zip(delta.iter_mut()) {
42 if *buf_grad_v == 0.0 {
43 continue;
44 }
45
46 *buf_v += *momentum * *delta_v;
47 *delta_v = *learn_rate * buf_grad_v;
48 *buf_v += *delta_v;
49 }
50 }
51}
52
53impl Optimizer for OptimizerSGD {
54 fn optimize_params(&mut self, lp: &mut CpuParams, opt_prms: TrainableBufsIds) {
55 if !self.delta.contains_key(&lp.id) {
56 self.delta.insert(lp.id, HashMap::new());
57 }
58
59 let delta_val = self.delta.get_mut(&lp.id).unwrap();
60
61 for (buf_id, buf_grad_id) in opt_prms.0.iter().zip(opt_prms.1.iter()) {
62 let buf_grad = lp.get_param(*buf_grad_id);
63
64 if !delta_val.contains_key(buf_grad_id) {
65 let zeroed_param = VariantParam::copy_zeroed_shape_from(&buf_grad);
66 delta_val.insert(*buf_grad_id, zeroed_param);
67 }
68
69 let delta_m = delta_val.get_mut(buf_grad_id).unwrap();
70
71 match delta_m {
72 VariantParam::Array1(arr1) => {
73 let buf_grad_slice = buf_grad.get_arr_1d();
74 let buf_grad_slice = buf_grad_slice.borrow();
75 let buf_grad_slice = buf_grad_slice.as_slice().unwrap();
76
77 let buf_slice = lp.get_1d_buf(*buf_id);
78 let mut buf_slice = buf_slice.borrow_mut();
79 let buf_slice = buf_slice.as_slice_mut().unwrap();
80
81 let delta_slice = arr1.as_slice_mut().unwrap();
82 OptimizerSGD::optimize_layer(
83 buf_slice,
84 buf_grad_slice,
85 delta_slice,
86 &self.learn_rate,
87 &self.momentum
88 );
89 }
90 VariantParam::Array2(arr2) => {
91 let buf_grad_slice = buf_grad.get_arr_2d();
92 let buf_grad_slice = buf_grad_slice.borrow();
93 let buf_grad_slice = buf_grad_slice.as_slice().unwrap();
94
95 let buf_slice = lp.get_2d_buf(*buf_id);
96 let mut buf_slice = buf_slice.borrow_mut();
97 let buf_slice = buf_slice.as_slice_mut().unwrap();
98
99 let delta_slice = arr2.as_slice_mut().unwrap();
100 OptimizerSGD::optimize_layer(
101 buf_slice,
102 buf_grad_slice,
103 delta_slice,
104 &self.learn_rate,
105 &self.momentum
106 );
107 }
108 }
109 }
110 }
111}
112
113impl WithParams for OptimizerSGD {
114 fn cfg(&self) -> HashMap<String, Variant> {
115 let mut cfg_params = HashMap::new();
116
117 cfg_params.insert("type".to_string(), Variant::String("sgd".to_string()));
118 cfg_params.insert("learning_rate".to_string(), Variant::Float(self.learn_rate));
119 cfg_params.insert("momentum".to_string(), Variant::Float(self.momentum));
120
121 cfg_params
122 }
123
124 fn set_cfg(&mut self, args: &HashMap<String, Variant>) {
125 if args.contains_key("learning_rate") {
126 if let Variant::Float(v) = args.get("learning_rate").unwrap() {
127 self.learn_rate = *v;
128 }
129 }
130
131 if args.contains_key("momentum") {
132 if let Variant::Float(v) = args.get("momentum").unwrap() {
133 self.momentum = *v;
134 }
135 }
136 }
137}