auto_diff_data_pipe/dataloader/
mod.rs1use auto_diff::{Var, AutoDiffError};
2
3#[derive(Copy, Clone)]
4pub enum DataSlice {
5 Train,
6 Test,
7 Tune,
8 Other,
9}
10
11pub trait DataLoader {
12 fn get_size(&self, slice: Option<DataSlice>) -> Result<Vec<usize>, AutoDiffError>;
14 fn get_item(&self, index: usize, slice: Option<DataSlice>) -> Result<(Var, Var), AutoDiffError>;
16 fn get_batch(&self, start: usize, end: usize, slice: Option<DataSlice>) -> Result<(Var, Var), AutoDiffError>;
18 fn get_indexed_batch(&self, index: &[usize], slice: Option<DataSlice>) -> Result<(Var, Var), AutoDiffError> {
20 let mut data: Vec<Var> = vec![];
21 let mut label: Vec<Var> = vec![];
22
23 for elem_index in index {
24 let (elem_data, elem_label) = self.get_item(*elem_index, slice)?;
25 data.push(elem_data);
26 label.push(elem_label);
27 }
28 let d1 = data[0].cat(&data[1..], 0)?;
29 let d2 = label[0].cat(&label[1..], 0)?;
30 d1.reset_net();
31 d2.reset_net();
32 Ok((d1, d2))
33 }
34}
35
36pub mod mnist;