auto_diff/
optim.rs

1//!
2//! Gradient based optimization.
3//!
4use std::cell::RefCell;
5use std::rc::Rc;
6use tensor_rs::tensor::Tensor;
7use rand::prelude::StdRng;
8use crate::err::AutoDiffError;
9use super::compute_graph::Net;
10use crate::var::Var;
11
12/// Create random batch view from a large batch.
13pub struct MiniBatch {
14    rng: StdRng,
15    size: usize,
16}
17impl MiniBatch {
18    pub fn new(rng: StdRng, size: usize) -> MiniBatch {
19        MiniBatch {
20            rng,
21            size,
22        }
23    }
24
25    pub fn next(&mut self, data: &Var, label: &Var) -> Result<(Var, Var), AutoDiffError> {
26        let sample_size = data.size()[0];
27        let sample_size2 = label.size()[0];
28
29        if sample_size != sample_size2 {
30            return Err(AutoDiffError::new(&format!("minibatch needs data and label has the same N {}, {}",
31                                                   sample_size, sample_size2)));
32        }
33        let index_t = Var::rand_usize(&mut self.rng, &[self.size], 0, sample_size);
34
35        let mdata = data.index_select(0, index_t.clone())?;
36        let mlabel = label.index_select(0, index_t)?;
37        mdata.reset_net();
38        mlabel.reset_net();
39        Ok((mdata, mlabel))
40    }
41}
42
43pub trait Optimizer {
44    fn step(&mut self, net: Rc<RefCell<Net>>);
45}
46
47// actually it's GD
48pub struct SGD {
49    lr: Tensor,
50}
51impl SGD {
52    #[cfg(feature = "use-f64")]
53    pub fn new(lr: f64) -> SGD {
54        Self::new_f64(lr)
55    }
56    #[cfg(feature = "use-f32")]
57    pub fn new(lr: f32) -> SGD {
58        Self::new_f32(lr)
59    }
60    
61    pub fn new_f64(lr: f64) -> SGD {
62        SGD {
63            lr: Tensor::from_vec_f64(&[lr], &[1]),
64        }
65    }
66    pub fn new_f32(lr: f32) -> SGD {
67        SGD {
68            lr: Tensor::from_vec_f32(&[lr], &[1]),
69        }
70    }
71}
72impl Optimizer for SGD {
73    fn step(&mut self, net: Rc<RefCell<Net>>) {
74        net.borrow_mut().visit_op(|x| {
75            let weights = x.get_values();
76            let grads = x.get_grads();
77            //println!("name: {}, {}, {}", x.get_name(), weights.len(), grads.len());
78
79            let mut new_weight = Vec::new();
80            for (i, j) in weights.iter().zip(grads.iter()) {
81                //println!("{:?}, {:?}, {:?}", i.size(), j.size(), self.lr.size());
82                new_weight.push(i.sub(&j.mul(&self.lr)));
83            }
84            x.set_values(&new_weight);
85        }, None, None);
86    }
87}
88
89
90#[cfg(test)]
91mod tests {
92    use crate::var::Var;
93    use super::*;
94    use rand::prelude::*;
95
96    #[test]
97    fn mini_batch() {
98        let data = Var::ones(&[10, 3]);
99        let label = Var::zeros(&[10]);
100        
101        let rng = StdRng::seed_from_u64(671);
102        let mut minibatch = MiniBatch::new(rng, 4);
103        let (mdata, mlabel) = minibatch.next(&data, &label).unwrap();
104
105        assert_eq!(mdata.size(), [4, 3]);
106        assert_eq!(mlabel.size(), [4]);
107        println!("{:?}, {:?}", mdata, mlabel);
108    }
109}