auto_diff_data_pipe/dataloader/
mnist.rs

1use crate::dataloader::{DataLoader, DataSlice};
2use auto_diff::{Var, AutoDiffError};
3use std::path::{Path, };
4use std::io;
5use std::fs::File;
6use std::io::Read;
7
8pub struct Mnist {
9    //path: PathBuf,
10    train: Var,
11    test: Var,
12    train_label: Var,
13    test_label: Var,
14}
15impl Mnist {
16    pub fn new() -> Mnist {
17        // TODO download the data if it is not there.
18        unimplemented!()
19    }
20    pub fn load(path: &Path) -> Mnist {
21	
22        let train_fn = path.join("train-images-idx3-ubyte");
23        let test_fn = path.join("t10k-images-idx3-ubyte");
24        let train_label_fn = path.join("train-labels-idx1-ubyte");
25        let test_label_fn = path.join("t10k-labels-idx1-ubyte");
26
27	let train_img;
28	let test_img;
29	let train_label;
30	let test_label;
31	if path.exists() {
32	    train_img = Self::load_images(train_fn);
33	    test_img = Self::load_images(test_fn);
34	    train_label = Self::load_labels(train_label_fn);
35	    test_label = Self::load_labels(test_label_fn);
36	} else {
37	    // TODO download the data if it is not there.
38	    
39	    unimplemented!()
40	}
41	
42        Mnist {
43            //path: PathBuf::from(path),
44	    train: train_img,
45	    test: test_img,
46	    train_label,
47	    test_label,
48        }
49    }
50
51    fn load_images<P: AsRef<Path>>(path: P) -> Var {
52        let mut reader = io::BufReader::new(File::open(path).expect(""));
53        let magic = Self::read_as_u32(&mut reader);
54        if magic != 2051 {
55            panic!("Invalid magic number. expected 2051, got {}", magic)
56        }
57        let num_image = Self::read_as_u32(&mut reader) as usize;
58        let rows = Self::read_as_u32(&mut reader) as usize;
59        let cols = Self::read_as_u32(&mut reader) as usize;
60        assert!(rows == 28 && cols == 28);
61    
62        // read images
63        let mut buf: Vec<u8> = vec![0u8; num_image * rows * cols];
64        let _ = reader.read_exact(buf.as_mut());
65        let ret: Vec<f64> = buf.into_iter().map(|x| (x as f64) / 255.).collect();
66        Var::new(&ret[..], &[num_image, rows, cols])
67    }
68
69    fn load_labels<P: AsRef<Path>>(path: P) -> Var {
70        let mut reader = io::BufReader::new(File::open(path).expect(""));
71        let magic = Self::read_as_u32(&mut reader);
72        if magic != 2049 {
73            panic!("Invalid magic number. Got expect 2049, got {}", magic);
74        }
75        let num_label = Self::read_as_u32(&mut reader) as usize;
76        // read labels
77        let mut buf: Vec<u8> = vec![0u8; num_label];
78        let _ = reader.read_exact(buf.as_mut());
79        let ret: Vec<f64> = buf.into_iter().map(|x| x as f64).collect();
80        Var::new(&ret[..], &[num_label])
81    }
82
83    fn read_as_u32<T: Read>(reader: &mut T) -> u32 {
84        let mut buf: [u8; 4] = [0, 0, 0, 0];
85        let _ = reader.read_exact(&mut buf);
86        u32::from_be_bytes(buf)
87    }
88}
89impl DataLoader for Mnist {
90    fn get_size(&self, slice: Option<DataSlice>) -> Result<Vec<usize>, AutoDiffError> {
91        match slice {
92	    Some(DataSlice::Train) => {Ok(self.train.size())},
93	    Some(DataSlice::Test) => {Ok(self.test.size())},
94	    None => {
95                let n = self.train.size()[0] + self.test.size()[1];
96                let mut new_size = self.train.size();
97                new_size[0] = n;
98                Ok(new_size)
99            },
100	    _ => {Err(AutoDiffError::new("TODO"))}
101	}
102    }
103    fn get_item(&self, index: usize, slice: Option<DataSlice>) -> Result<(Var, Var), AutoDiffError> {
104        match slice {
105	    Some(DataSlice::Train) => {
106                let dim = self.train.size().len();
107                let mut index_block = vec![(index, index+1)];
108                index_block.append(
109                    &mut vec![0; dim-1].iter().zip(&self.train.size()[1..])
110                        .map(|(x,y)| (*x, *y)).collect());
111                let data = self.train.get_patch(&index_block, None)?;
112                let label = self.train_label.get_patch(&[(index, index+1)], None)?;
113		self.train.reset_net();
114		self.train_label.reset_net();
115                Ok((data, label))
116            },
117	    Some(DataSlice::Test) => {
118                let dim = self.test.size().len();
119                let mut index_block = vec![(index, index+1)];
120                index_block.append(
121                    &mut vec![0; dim-1].iter().zip(&self.test.size()[1..])
122                        .map(|(x,y)| (*x, *y)).collect());
123                let data = self.test.get_patch(&index_block, None)?;
124                let label = self.test_label.get_patch(&[(index, index+1)], None)?;
125		self.test.reset_net();
126		self.test_label.reset_net();
127                Ok((data, label))
128            },
129	    _ => {Err(AutoDiffError::new("only train and test"))}
130	}
131    }
132    fn get_batch(&self, start: usize, end: usize, slice: Option<DataSlice>) -> Result<(Var, Var), AutoDiffError> {
133        match slice {
134	    Some(DataSlice::Train) => {
135                let dim = self.train.size().len();
136                let mut index_block = vec![(start, end)];
137                index_block.append(
138                    &mut vec![0; dim-1].iter().zip(&self.train.size()[1..])
139                        .map(|(x,y)| (*x, *y)).collect());
140                let data = self.train.get_patch(&index_block, None)?;
141                let label = self.train_label.get_patch(&[(start, end)], None)?;
142		self.train.reset_net();
143		self.train_label.reset_net();
144                Ok((data, label))
145            },
146	    Some(DataSlice::Test) => {
147                let dim = self.test.size().len();
148                let mut index_block = vec![(start, end)];
149                index_block.append(
150                    &mut vec![0; dim-1].iter().zip(&self.test.size()[1..])
151                        .map(|(x,y)| (*x, *y)).collect());
152                let data = self.test.get_patch(&index_block, None)?;
153                let label = self.test_label.get_patch(&[(start, end)], None)?;
154		self.test.reset_net();
155		self.test_label.reset_net();
156                Ok((data, label))
157            },
158	    _ => {Err(AutoDiffError::new("only train and test"))}
159	}
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    
167    #[test]
168    fn mnist() {
169        let mnist = Mnist::load(Path::new("../auto-diff/examples/data/mnist/"));
170	let (t0, l0) = mnist.get_item(0, Some(DataSlice::Test)).unwrap();
171	println!("{:?}", t0);
172    }
173}
174