auto_diff_ann/
minibatch.rs1use ::rand::prelude::StdRng;
2use auto_diff::{Var, AutoDiffError};
3use auto_diff_data_pipe::dataloader::{DataLoader, DataSlice};
4
5pub struct MiniBatch {
6 rng: StdRng,
7 size: usize,
8}
9impl MiniBatch {
10 pub fn new(rng: StdRng, size: usize) -> MiniBatch {
11 MiniBatch {
12 rng,
13 size,
14 }
15 }
16
17 pub fn batch_size(&self) -> usize {
18 self.size
19 }
20
21 pub fn next(&mut self, loader: &dyn DataLoader, part: &DataSlice) -> Result<(Var, Var), AutoDiffError> {
23 let sample_size = loader.get_size(Some(*part))?[0];
24 let index_t = Var::rand_usize(&mut self.rng, &[self.size], 0, sample_size);
25 loader.get_indexed_batch(&(Vec::<usize>::try_from(index_t)?), Some(*part))
26 }
27 pub fn next_data_slice(&mut self, data: &Var, label: &Var) -> Result<(Var, Var), AutoDiffError> {
29 let sample_size = data.size()[0];
30 let sample_size2 = label.size()[0];
31
32 if sample_size != sample_size2 {
33 return Err(AutoDiffError::new(&format!("minibatch needs data and label has the same N {}, {}",
34 sample_size, sample_size2)));
35 }
36 let index_t = Var::rand_usize(&mut self.rng, &[self.size], 0, sample_size);
37
38 let mdata = data.index_select(0, index_t.clone())?;
39 let mlabel = label.index_select(0, index_t)?;
40 mdata.reset_net();
41 mlabel.reset_net();
42 Ok((mdata, mlabel))
43 }
44
45 pub fn iter_block<'a>(&self, loader: &'a dyn DataLoader, part: & DataSlice) -> Result<BlockIterator<'a>, AutoDiffError> {
46 Ok(BlockIterator {
47 loader,
48 part: *part,
49 block_size: self.size,
50 block_index: 0,
51 })
52 }
53}
54
55pub struct BlockIterator<'a> {
56 loader: &'a dyn DataLoader,
57 part: DataSlice,
58 block_size: usize,
59 block_index: usize,
60}
61impl<'a> Iterator for BlockIterator<'a> {
62 type Item = (Var, Var);
63 fn next(&mut self) -> Option<Self::Item> {
64 let n = if let Ok(size) = self.loader.get_size(Some(self.part)) {
65 size[0]
66 } else {
67 return None;
68 };
69
70 if self.block_index >= n {
71 return None;
72 }
73 let mut end_index = self.block_index + self.block_size;
74 if end_index > n {
75 end_index = n;
76 }
77
78 let result = self.loader.get_batch(self.block_index,
79 end_index,
80 Some(self.part));
81 self.block_index += self.block_size;
82 result.ok()
83 }
84}