nevermind_neu/optimizers/
optim_adagrad.rs

1use crate::optimizers::*;
2use crate::cpu_params::*;
3use crate::util::*;
4
5use log::debug;
6
7use std::collections::HashMap;
8
9pub struct OptimizerAdaGrad {
10    pub learn_rate: f32,
11    pub theta: f32,
12    pub g: HashMap<u64, HashMap<i32, VariantParam>>,
13}
14
15impl OptimizerAdaGrad {
16    pub fn new(learn_rate: f32) -> Self {
17        Self {
18            learn_rate,
19            theta: 1e-8,
20            g: HashMap::new(),
21        }
22    }
23}
24
25impl Default for OptimizerAdaGrad {
26    fn default() -> Self {
27        Self {
28            learn_rate: 1e-2,
29            theta: 1e-6,
30            g: HashMap::new(),
31        }
32    }
33}
34
35impl OptimizerAdaGrad {
36    fn optimize_layer(
37        buf: &mut [f32],
38        buf_grad: &[f32],
39        g: &mut [f32],
40        learn_rate: &f32,
41        theta: &f32,
42    ) {
43        for ((buf_v, buf_grad_v), g_v) in buf.iter_mut().zip(buf_grad.iter()).zip(g.iter_mut()) {
44            if *buf_grad_v == 0.0 {
45                continue;
46            }
47
48            *g_v += buf_grad_v.powf(2.0);
49            *buf_v += (learn_rate / (*g_v + theta).sqrt()) * buf_grad_v;
50        }
51    }
52}
53
54impl Optimizer for OptimizerAdaGrad {
55    fn optimize_params(&mut self, lp: &mut CpuParams, opt_prms: TrainableBufsIds) {
56        if !self.g.contains_key(&lp.id) {
57            self.g.insert(lp.id, HashMap::new());
58            debug!("[opt_ada_grad] Inserted learn_params with id {}", lp.id);
59        }
60
61        for (buf_id, buf_grad_id) in opt_prms.0.iter().zip(opt_prms.1.iter()) {
62            let buf_grad = lp.get_param(*buf_grad_id);
63            let g_val = self.g.get_mut(&lp.id).unwrap();
64
65            if !g_val.contains_key(buf_grad_id) {
66                let zeroed_param = VariantParam::copy_zeroed_shape_from(&buf_grad);
67                g_val.insert(*buf_grad_id, zeroed_param);
68            }
69
70            let g_m = g_val.get_mut(buf_grad_id).unwrap();
71
72            match g_m {
73                VariantParam::Array1(arr1) => {
74                    let buf_grad_slice = buf_grad.get_arr_1d();
75                    let buf_grad_slice = buf_grad_slice.borrow();
76                    let buf_grad_slice = buf_grad_slice.as_slice().unwrap();
77
78                    let buf_slice = lp.get_1d_buf(*buf_id);
79                    let mut buf_slice = buf_slice.borrow_mut();
80                    let buf_slice = buf_slice.as_slice_mut().unwrap();
81
82                    let v_slice = arr1.as_slice_mut().unwrap();
83
84                    OptimizerAdaGrad::optimize_layer(
85                        buf_slice,
86                        buf_grad_slice,
87                        v_slice,
88                        &self.learn_rate,
89                        &self.theta,
90                    );
91                }
92                VariantParam::Array2(arr2) => {
93                    let buf_grad_slice = buf_grad.get_arr_2d();
94                    let buf_grad_slice = buf_grad_slice.borrow();
95                    let buf_grad_slice = buf_grad_slice.as_slice().unwrap();
96
97                    let buf_slice = lp.get_2d_buf(*buf_id);
98                    let mut buf_slice = buf_slice.borrow_mut();
99                    let buf_slice = buf_slice.as_slice_mut().unwrap();
100
101                    let v_slice = arr2.as_slice_mut().unwrap();
102
103                    OptimizerAdaGrad::optimize_layer(
104                        buf_slice,
105                        buf_grad_slice,
106                        v_slice,
107                        &self.learn_rate,
108                        &self.theta,
109                    );
110                }
111            }
112        }
113
114        // match g_m {
115        //     VariantParam::Array1(arr1) => {
116        //         let buf_grad_slice = lp.get_1d_buf(*buf_grad_id);
117        //         let buf_grad_slice = buf_grad_slice.borrow();
118        //         let buf_grad_slice = buf_grad_slice.as_slice().unwrap();
119
120        //         let buf_slice = lp.get_1d_buf(*buf_id);
121        //         let mut buf_slice = buf_slice.borrow_mut();
122        //         let buf_slice = buf_slice.as_slice_mut().unwrap();
123
124        //         let g_slice = arr1.as_slice_mut().unwrap();
125
126        //         OptimizerAdaGrad::optimize_layer(
127        //             buf_slice,
128        //             buf_grad_slice,
129        //             g_slice,
130        //             &self.learn_rate,
131        //             &self.theta,
132        //         );
133        //     },
134        //     VariantParam::Array2(mut arr2) => {
135        //         let buf_grad_slice = lp.get_2d_buf(*buf_grad_id);
136        //         let buf_grad_slice = buf_grad_slice.borrow();
137        //         let buf_grad_slice = buf_grad_slice.as_slice().unwrap();
138
139        //         let mut buf_slice = lp.get_2d_buf(*buf_id);
140        //         let mut buf_slice = buf_slice.borrow_mut();
141        //         let mut buf_slice = buf_slice.as_slice_mut().unwrap();
142
143        //         let g_slice = arr2.as_slice_mut().unwrap();
144
145        //         OptimizerAdaGrad::optimize_layer(
146        //             buf_slice,
147        //             buf_grad_slice,
148        //             g_slice,
149        //             &self.learn_rate,
150        //             &self.theta,
151        //         );
152        //     }
153        // }
154    }
155}
156
157impl WithParams for OptimizerAdaGrad {
158    fn cfg(&self) -> HashMap<String, Variant> {
159        let mut cfg_params = HashMap::new();
160
161        cfg_params.insert("type".to_string(), Variant::String("adagrad".to_string()));
162        cfg_params.insert("learning_rate".to_string(), Variant::Float(self.learn_rate));
163        cfg_params.insert("theta".to_string(), Variant::Float(self.theta));
164
165        cfg_params
166    }
167
168    fn set_cfg(&mut self, args: &HashMap<String, Variant>) {
169        if args.contains_key("learning_rate") {
170            if let Variant::Float(v) = args.get("learning_rate").unwrap() {
171                self.learn_rate = *v;
172            }
173        }
174
175        if args.contains_key("theta") {
176            if let Variant::Float(v) = args.get("theta").unwrap() {
177                self.theta = *v;
178            }
179        }
180    }
181}