auto_diff_data_pipe/dataloader/
mod.rs

1use auto_diff::{Var, AutoDiffError};
2
3#[derive(Copy, Clone)]
4pub enum DataSlice {
5    Train,
6    Test,
7    Tune,
8    Other,
9}
10
11pub trait DataLoader {
12    /// The shape of the data if applicable.
13    fn get_size(&self, slice: Option<DataSlice>) -> Result<Vec<usize>, AutoDiffError>;
14    /// Return one sample.
15    fn get_item(&self, index: usize, slice: Option<DataSlice>) -> Result<(Var, Var), AutoDiffError>;
16    /// Return a batch following original order.
17    fn get_batch(&self, start: usize, end: usize, slice: Option<DataSlice>) -> Result<(Var, Var), AutoDiffError>;
18    /// Return a batch given the index.
19    fn get_indexed_batch(&self, index: &[usize], slice: Option<DataSlice>) -> Result<(Var, Var), AutoDiffError> {
20        let mut data: Vec<Var> = vec![];
21        let mut label: Vec<Var> = vec![];
22
23        for elem_index in index {
24            let (elem_data, elem_label) = self.get_item(*elem_index, slice)?;
25            data.push(elem_data);
26            label.push(elem_label);
27        }
28        let d1 = data[0].cat(&data[1..], 0)?;
29        let d2 = label[0].cat(&label[1..], 0)?;
30        d1.reset_net();
31        d2.reset_net();
32        Ok((d1, d2))
33    }
34}
35
36pub mod mnist;