nevermind_neu/optimizers/
optim_adagrad.rs1use crate::optimizers::*;
2use crate::cpu_params::*;
3use crate::util::*;
4
5use log::debug;
6
7use std::collections::HashMap;
8
9pub struct OptimizerAdaGrad {
10 pub learn_rate: f32,
11 pub theta: f32,
12 pub g: HashMap<u64, HashMap<i32, VariantParam>>,
13}
14
15impl OptimizerAdaGrad {
16 pub fn new(learn_rate: f32) -> Self {
17 Self {
18 learn_rate,
19 theta: 1e-8,
20 g: HashMap::new(),
21 }
22 }
23}
24
25impl Default for OptimizerAdaGrad {
26 fn default() -> Self {
27 Self {
28 learn_rate: 1e-2,
29 theta: 1e-6,
30 g: HashMap::new(),
31 }
32 }
33}
34
35impl OptimizerAdaGrad {
36 fn optimize_layer(
37 buf: &mut [f32],
38 buf_grad: &[f32],
39 g: &mut [f32],
40 learn_rate: &f32,
41 theta: &f32,
42 ) {
43 for ((buf_v, buf_grad_v), g_v) in buf.iter_mut().zip(buf_grad.iter()).zip(g.iter_mut()) {
44 if *buf_grad_v == 0.0 {
45 continue;
46 }
47
48 *g_v += buf_grad_v.powf(2.0);
49 *buf_v += (learn_rate / (*g_v + theta).sqrt()) * buf_grad_v;
50 }
51 }
52}
53
54impl Optimizer for OptimizerAdaGrad {
55 fn optimize_params(&mut self, lp: &mut CpuParams, opt_prms: TrainableBufsIds) {
56 if !self.g.contains_key(&lp.id) {
57 self.g.insert(lp.id, HashMap::new());
58 debug!("[opt_ada_grad] Inserted learn_params with id {}", lp.id);
59 }
60
61 for (buf_id, buf_grad_id) in opt_prms.0.iter().zip(opt_prms.1.iter()) {
62 let buf_grad = lp.get_param(*buf_grad_id);
63 let g_val = self.g.get_mut(&lp.id).unwrap();
64
65 if !g_val.contains_key(buf_grad_id) {
66 let zeroed_param = VariantParam::copy_zeroed_shape_from(&buf_grad);
67 g_val.insert(*buf_grad_id, zeroed_param);
68 }
69
70 let g_m = g_val.get_mut(buf_grad_id).unwrap();
71
72 match g_m {
73 VariantParam::Array1(arr1) => {
74 let buf_grad_slice = buf_grad.get_arr_1d();
75 let buf_grad_slice = buf_grad_slice.borrow();
76 let buf_grad_slice = buf_grad_slice.as_slice().unwrap();
77
78 let buf_slice = lp.get_1d_buf(*buf_id);
79 let mut buf_slice = buf_slice.borrow_mut();
80 let buf_slice = buf_slice.as_slice_mut().unwrap();
81
82 let v_slice = arr1.as_slice_mut().unwrap();
83
84 OptimizerAdaGrad::optimize_layer(
85 buf_slice,
86 buf_grad_slice,
87 v_slice,
88 &self.learn_rate,
89 &self.theta,
90 );
91 }
92 VariantParam::Array2(arr2) => {
93 let buf_grad_slice = buf_grad.get_arr_2d();
94 let buf_grad_slice = buf_grad_slice.borrow();
95 let buf_grad_slice = buf_grad_slice.as_slice().unwrap();
96
97 let buf_slice = lp.get_2d_buf(*buf_id);
98 let mut buf_slice = buf_slice.borrow_mut();
99 let buf_slice = buf_slice.as_slice_mut().unwrap();
100
101 let v_slice = arr2.as_slice_mut().unwrap();
102
103 OptimizerAdaGrad::optimize_layer(
104 buf_slice,
105 buf_grad_slice,
106 v_slice,
107 &self.learn_rate,
108 &self.theta,
109 );
110 }
111 }
112 }
113
114 }
155}
156
157impl WithParams for OptimizerAdaGrad {
158 fn cfg(&self) -> HashMap<String, Variant> {
159 let mut cfg_params = HashMap::new();
160
161 cfg_params.insert("type".to_string(), Variant::String("adagrad".to_string()));
162 cfg_params.insert("learning_rate".to_string(), Variant::Float(self.learn_rate));
163 cfg_params.insert("theta".to_string(), Variant::Float(self.theta));
164
165 cfg_params
166 }
167
168 fn set_cfg(&mut self, args: &HashMap<String, Variant>) {
169 if args.contains_key("learning_rate") {
170 if let Variant::Float(v) = args.get("learning_rate").unwrap() {
171 self.learn_rate = *v;
172 }
173 }
174
175 if args.contains_key("theta") {
176 if let Variant::Float(v) = args.get("theta").unwrap() {
177 self.theta = *v;
178 }
179 }
180 }
181}