auto_diff_ann/
minibatch.rs

1use ::rand::prelude::StdRng;
2use auto_diff::{Var, AutoDiffError};
3use auto_diff_data_pipe::dataloader::{DataLoader, DataSlice};
4
5pub struct MiniBatch {
6    rng: StdRng,
7    size: usize,
8}
9impl MiniBatch {
10    pub fn new(rng: StdRng, size: usize) -> MiniBatch {
11        MiniBatch {
12            rng,
13            size,
14        }
15    }
16
17    pub fn batch_size(&self) -> usize {
18	self.size
19    }
20
21    /// Get a random set of samples from the data loader.
22    pub fn next(&mut self, loader: &dyn DataLoader, part: &DataSlice) -> Result<(Var, Var), AutoDiffError> {
23        let sample_size = loader.get_size(Some(*part))?[0];
24        let index_t = Var::rand_usize(&mut self.rng, &[self.size], 0, sample_size);
25        loader.get_indexed_batch(&(Vec::<usize>::try_from(index_t)?), Some(*part))
26    }
27    /// Get a random set of samples given the data and label.
28    pub fn next_data_slice(&mut self, data: &Var, label: &Var) -> Result<(Var, Var), AutoDiffError> {
29        let sample_size = data.size()[0];
30        let sample_size2 = label.size()[0];
31
32        if sample_size != sample_size2 {
33            return Err(AutoDiffError::new(&format!("minibatch needs data and label has the same N {}, {}",
34                                                   sample_size, sample_size2)));
35        }
36        let index_t = Var::rand_usize(&mut self.rng, &[self.size], 0, sample_size);
37
38        let mdata = data.index_select(0, index_t.clone())?;
39        let mlabel = label.index_select(0, index_t)?;
40        mdata.reset_net();
41        mlabel.reset_net();
42        Ok((mdata, mlabel))
43    }
44
45    pub fn iter_block<'a>(&self, loader: &'a dyn DataLoader, part: & DataSlice) -> Result<BlockIterator<'a>, AutoDiffError> {
46	Ok(BlockIterator {
47            loader,
48	    part: *part,
49	    block_size: self.size,
50	    block_index: 0,
51        })
52    }
53}
54
55pub struct BlockIterator<'a> {
56    loader: &'a dyn DataLoader,
57    part: DataSlice,
58    block_size: usize,
59    block_index: usize,
60}
61impl<'a> Iterator for BlockIterator<'a> {
62    type Item = (Var, Var);
63    fn next(&mut self) -> Option<Self::Item> {
64	let n = if let Ok(size) = self.loader.get_size(Some(self.part)) {
65	    size[0]
66	} else {
67	    return None;
68	};
69
70	if self.block_index >= n {
71	    return None;
72	}
73	let mut end_index = self.block_index + self.block_size;
74	if end_index > n {
75	    end_index = n;
76	}
77
78	let result = self.loader.get_batch(self.block_index,
79					   end_index,
80					   Some(self.part));
81	self.block_index += self.block_size;
82	result.ok()
83    }
84}