1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
use auto_diff::{Var, AutoDiffError};
#[derive(Copy, Clone)]
pub enum DataSlice {
Train,
Test,
Tune,
Other,
}
pub trait DataLoader {
fn get_size(&self, slice: Option<DataSlice>) -> Result<Vec<usize>, AutoDiffError>;
fn get_item(&self, index: usize, slice: Option<DataSlice>) -> Result<(Var, Var), AutoDiffError>;
fn get_batch(&self, start: usize, end: usize, slice: Option<DataSlice>) -> Result<(Var, Var), AutoDiffError>;
fn get_indexed_batch(&self, index: &[usize], slice: Option<DataSlice>) -> Result<(Var, Var), AutoDiffError> {
let mut data: Vec<Var> = vec![];
let mut label: Vec<Var> = vec![];
for elem_index in index {
let (elem_data, elem_label) = self.get_item(*elem_index, slice)?;
data.push(elem_data);
label.push(elem_label);
}
let d1 = data[0].cat(&data[1..], 0)?;
let d2 = label[0].cat(&label[1..], 0)?;
d1.reset_net();
d2.reset_net();
Ok((d1, d2))
}
}
pub mod mnist;