1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
use ndarray_ext::NdArray;
use op;
use tensor::Tensor;
struct SGDOp {
pub lr: f32,
}
impl ::op::Op for SGDOp {
fn name(&self) -> &str {
"SGD"
}
fn compute(&self, mut ctx: ::runtime::OpComputeContext) -> op::ComputeResult {
let xs = unsafe { ctx.grab_assignable_inputs() };
let updates = {
let grad: &NdArray = xs[1];
grad * self.lr
};
xs[0].zip_mut_with(&updates, |a, &b| *a -= b);
vec![Err(::op::ComputeError::NoOutput)]
}
fn grad(&self, _: &Tensor, _: &[&Tensor], _: &Tensor) -> Vec<Option<Tensor>> {
vec![None]
}
}
pub struct SGD {
pub lr: f32,
}
impl<'a> SGD {
fn compute_updates<T: AsRef<Tensor>>(
&mut self,
params: &[&'a Tensor],
grads: &[T],
) -> Vec<Tensor> {
params
.into_iter()
.zip(grads)
.map(|(param, grad)| {
Tensor::builder()
.set_inputs(vec![param, grad.as_ref()])
.build(SGDOp { lr: self.lr })
})
.collect()
}
}