1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
use super::tensor::Tensor;
use super::var::Module;
pub trait Optimizer {
fn step(&mut self, m: &Module);
}
pub struct SGD {
lr: Tensor,
}
impl SGD {
pub fn new(lr: f32) -> SGD {
SGD {
lr: Tensor::from_vec_f32(&vec![lr], &vec![1]),
}
}
}
impl Optimizer for SGD {
fn step(&mut self, m: &Module) {
m._visit_op(|x| {
let weights = x.get_values();
let grads = x.get_grads();
let mut new_weight = Vec::new();
for (i, j) in weights.iter().zip(grads.iter()) {
new_weight.push(i.add(&j.mul(&self.lr)));
}
x.set_values(&new_weight);
});
}
}