1use std::cell::RefCell;
5use std::rc::Rc;
6use tensor_rs::tensor::Tensor;
7use rand::prelude::StdRng;
8use crate::err::AutoDiffError;
9use super::compute_graph::Net;
10use crate::var::Var;
11
12pub struct MiniBatch {
14 rng: StdRng,
15 size: usize,
16}
17impl MiniBatch {
18 pub fn new(rng: StdRng, size: usize) -> MiniBatch {
19 MiniBatch {
20 rng,
21 size,
22 }
23 }
24
25 pub fn next(&mut self, data: &Var, label: &Var) -> Result<(Var, Var), AutoDiffError> {
26 let sample_size = data.size()[0];
27 let sample_size2 = label.size()[0];
28
29 if sample_size != sample_size2 {
30 return Err(AutoDiffError::new(&format!("minibatch needs data and label has the same N {}, {}",
31 sample_size, sample_size2)));
32 }
33 let index_t = Var::rand_usize(&mut self.rng, &[self.size], 0, sample_size);
34
35 let mdata = data.index_select(0, index_t.clone())?;
36 let mlabel = label.index_select(0, index_t)?;
37 mdata.reset_net();
38 mlabel.reset_net();
39 Ok((mdata, mlabel))
40 }
41}
42
43pub trait Optimizer {
44 fn step(&mut self, net: Rc<RefCell<Net>>);
45}
46
47pub struct SGD {
49 lr: Tensor,
50}
51impl SGD {
52 #[cfg(feature = "use-f64")]
53 pub fn new(lr: f64) -> SGD {
54 Self::new_f64(lr)
55 }
56 #[cfg(feature = "use-f32")]
57 pub fn new(lr: f32) -> SGD {
58 Self::new_f32(lr)
59 }
60
61 pub fn new_f64(lr: f64) -> SGD {
62 SGD {
63 lr: Tensor::from_vec_f64(&[lr], &[1]),
64 }
65 }
66 pub fn new_f32(lr: f32) -> SGD {
67 SGD {
68 lr: Tensor::from_vec_f32(&[lr], &[1]),
69 }
70 }
71}
72impl Optimizer for SGD {
73 fn step(&mut self, net: Rc<RefCell<Net>>) {
74 net.borrow_mut().visit_op(|x| {
75 let weights = x.get_values();
76 let grads = x.get_grads();
77 let mut new_weight = Vec::new();
80 for (i, j) in weights.iter().zip(grads.iter()) {
81 new_weight.push(i.sub(&j.mul(&self.lr)));
83 }
84 x.set_values(&new_weight);
85 }, None, None);
86 }
87}
88
89
90#[cfg(test)]
91mod tests {
92 use crate::var::Var;
93 use super::*;
94 use rand::prelude::*;
95
96 #[test]
97 fn mini_batch() {
98 let data = Var::ones(&[10, 3]);
99 let label = Var::zeros(&[10]);
100
101 let rng = StdRng::seed_from_u64(671);
102 let mut minibatch = MiniBatch::new(rng, 4);
103 let (mdata, mlabel) = minibatch.next(&data, &label).unwrap();
104
105 assert_eq!(mdata.size(), [4, 3]);
106 assert_eq!(mlabel.size(), [4]);
107 println!("{:?}, {:?}", mdata, mlabel);
108 }
109}