1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
use std::cell::RefCell;
use tensor_rs::tensor::Tensor;
use super::var::{Func, Module};
use crate::rand;
pub struct MiniBatch {
rng: RefCell<rand::RNG>,
size: usize,
}
impl MiniBatch {
pub fn new(rng: rand::RNG, size: usize) -> MiniBatch {
MiniBatch {
rng: RefCell::new(rng),
size: size,
}
}
pub fn next(&self, data: &Tensor, label: &Tensor) -> (Tensor, Tensor) {
let sample_size = data.size()[0];
let sample_size2 = label.size()[0];
if sample_size != sample_size2 {
panic!("minibatch needs data and label has the same N {}, {}",
sample_size, sample_size2);
}
let index = self.rng.borrow_mut().gen_range_usize(0, sample_size, Some(self.size));
let index_t = Tensor::from_vec_usize(&index, &[index.len()]);
let mdata = data.index_select(0, &index_t);
let mlabel = label.index_select(0, &index_t);
(mdata, mlabel)
}
}
pub trait Optimizer {
fn step(&mut self, m: &Module);
fn step2(&mut self, m: &Func);
}
pub struct SGD {
lr: Tensor,
}
impl SGD {
pub fn new(lr: f32) -> SGD {
SGD {
lr: Tensor::from_vec_f32(&vec![lr], &vec![1]),
}
}
}
impl Optimizer for SGD {
fn step(&mut self, m: &Module) {
m._visit_op(|x| {
let weights = x.get_values();
let grads = x.get_grads();
let mut new_weight = Vec::new();
for (i, j) in weights.iter().zip(grads.iter()) {
new_weight.push(i.add(&j.mul(&self.lr)));
}
x.set_values(&new_weight);
});
}
fn step2(&mut self, m: &Func) {
m._visit_op(|x| {
if x.get_update_counter() <= 0 && x.get_name() != "Nop" {
println!("name: {}, ", x.get_name(), );
println!("Warning: haven't seen a backward pass, missing .backward call before update?");
return;
}
let weights = x.get_values();
let grads = x.get_grads();
let mut new_weight = Vec::new();
for (i, j) in weights.iter().zip(grads.iter()) {
new_weight.push(i.add(&j.mul(&self.lr)));
}
x.set_values(&new_weight);
});
}
}
#[cfg(test)]
mod tests {
use tensor_rs::tensor::Tensor;
use crate::rand::RNG;
use super::*;
#[test]
fn mini_batch() {
let data = Tensor::ones(&[10, 3]);
let label = Tensor::zeros(&[10]);
let rng = RNG::new();
let minibatch = MiniBatch::new(rng, 4);
let (mdata, mlabel) = minibatch.next(&data, &label);
assert_eq!(mdata.size(), [4, 3]);
assert_eq!(mlabel.size(), [4]);
println!("{:?}, {:?}", mdata, mlabel);
}
}