1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
//!
//! Gradient based optimization.
//!
use std::cell::RefCell;
use tensor_rs::tensor::Tensor;
//use super::var1::{Func, Module};
use crate::rand;
pub struct MiniBatch {
rng: RefCell<rand::RNG>,
size: usize,
}
impl MiniBatch {
pub fn new(rng: rand::RNG, size: usize) -> MiniBatch {
MiniBatch {
rng: RefCell::new(rng),
size: size,
}
}
pub fn next(&self, data: &Tensor, label: &Tensor) -> (Tensor, Tensor) {
let sample_size = data.size()[0];
let sample_size2 = label.size()[0];
if sample_size != sample_size2 {
panic!("minibatch needs data and label has the same N {}, {}",
sample_size, sample_size2);
}
let index = self.rng.borrow_mut().gen_range_usize(0, sample_size, Some(self.size));
//println!("minibatch index: {:?}", index);
let index_t = Tensor::from_vec_usize(&index, &[index.len()]);
let mdata = data.index_select(0, &index_t);
let mlabel = label.index_select(0, &index_t);
(mdata, mlabel)
}
}
//pub trait Optimizer {
// fn step(&mut self, m: &Module);
// fn step2(&mut self, m: &Func);
//}
//
//// actually it's GD
//pub struct SGD {
// lr: Tensor,
//}
//impl SGD {
// pub fn new(lr: f32) -> SGD {
// SGD {
// lr: Tensor::from_vec_f32(&[lr], &[1]),
// }
// }
//}
//impl Optimizer for SGD {
// fn step(&mut self, m: &Module) {
// m._visit_op(|x| {
// let weights = x.get_values();
// let grads = x.get_grads();
// // println!("name: {}, {}, {}", x.get_name(), weights.len(), grads.len());
//
// let mut new_weight = Vec::new();
// for (i, j) in weights.iter().zip(grads.iter()) {
// // println!("{:?}, {:?}, {:?}", i.size(), j.size(), self.lr.size());
//
// new_weight.push(i.add(&j.mul(&self.lr)));
// }
// x.set_values(&new_weight);
// });
// }
//
// fn step2(&mut self, m: &Func) {
// m._visit_op(|x| {
// if x.get_update_counter() == 0 && x.get_name() != "Nop" {
// println!("name: {}, ", x.get_name(), );
// println!("Warning: haven't seen a backward pass, missing .backward call before update?");
// return;
// }
//
// let weights = x.get_values();
// let grads = x.get_grads();
// //println!("name: {}, {}, {}", x.get_name(), weights.len(), grads.len());
//
// let mut new_weight = Vec::new();
// for (i, j) in weights.iter().zip(grads.iter()) {
// //println!("{:?}, {:?}, {:?}", i.size(), j.size(), self.lr.size());
//
// new_weight.push(i.add(&j.mul(&self.lr)));
// }
// x.set_values(&new_weight);
// });
// }
//}
#[cfg(test)]
mod tests {
use tensor_rs::tensor::Tensor;
use crate::rand::RNG;
use super::*;
#[test]
fn mini_batch() {
let data = Tensor::ones(&[10, 3]);
let label = Tensor::zeros(&[10]);
let rng = RNG::new();
let minibatch = MiniBatch::new(rng, 4);
let (mdata, mlabel) = minibatch.next(&data, &label);
assert_eq!(mdata.size(), [4, 3]);
assert_eq!(mlabel.size(), [4]);
println!("{:?}, {:?}", mdata, mlabel);
}
}