auto_diff_data_pipe/dataloader/
mnist.rs1use crate::dataloader::{DataLoader, DataSlice};
2use auto_diff::{Var, AutoDiffError};
3use std::path::{Path, };
4use std::io;
5use std::fs::File;
6use std::io::Read;
7
8pub struct Mnist {
9 train: Var,
11 test: Var,
12 train_label: Var,
13 test_label: Var,
14}
15impl Mnist {
16 pub fn new() -> Mnist {
17 unimplemented!()
19 }
20 pub fn load(path: &Path) -> Mnist {
21
22 let train_fn = path.join("train-images-idx3-ubyte");
23 let test_fn = path.join("t10k-images-idx3-ubyte");
24 let train_label_fn = path.join("train-labels-idx1-ubyte");
25 let test_label_fn = path.join("t10k-labels-idx1-ubyte");
26
27 let train_img;
28 let test_img;
29 let train_label;
30 let test_label;
31 if path.exists() {
32 train_img = Self::load_images(train_fn);
33 test_img = Self::load_images(test_fn);
34 train_label = Self::load_labels(train_label_fn);
35 test_label = Self::load_labels(test_label_fn);
36 } else {
37 unimplemented!()
40 }
41
42 Mnist {
43 train: train_img,
45 test: test_img,
46 train_label,
47 test_label,
48 }
49 }
50
51 fn load_images<P: AsRef<Path>>(path: P) -> Var {
52 let mut reader = io::BufReader::new(File::open(path).expect(""));
53 let magic = Self::read_as_u32(&mut reader);
54 if magic != 2051 {
55 panic!("Invalid magic number. expected 2051, got {}", magic)
56 }
57 let num_image = Self::read_as_u32(&mut reader) as usize;
58 let rows = Self::read_as_u32(&mut reader) as usize;
59 let cols = Self::read_as_u32(&mut reader) as usize;
60 assert!(rows == 28 && cols == 28);
61
62 let mut buf: Vec<u8> = vec![0u8; num_image * rows * cols];
64 let _ = reader.read_exact(buf.as_mut());
65 let ret: Vec<f64> = buf.into_iter().map(|x| (x as f64) / 255.).collect();
66 Var::new(&ret[..], &[num_image, rows, cols])
67 }
68
69 fn load_labels<P: AsRef<Path>>(path: P) -> Var {
70 let mut reader = io::BufReader::new(File::open(path).expect(""));
71 let magic = Self::read_as_u32(&mut reader);
72 if magic != 2049 {
73 panic!("Invalid magic number. Got expect 2049, got {}", magic);
74 }
75 let num_label = Self::read_as_u32(&mut reader) as usize;
76 let mut buf: Vec<u8> = vec![0u8; num_label];
78 let _ = reader.read_exact(buf.as_mut());
79 let ret: Vec<f64> = buf.into_iter().map(|x| x as f64).collect();
80 Var::new(&ret[..], &[num_label])
81 }
82
83 fn read_as_u32<T: Read>(reader: &mut T) -> u32 {
84 let mut buf: [u8; 4] = [0, 0, 0, 0];
85 let _ = reader.read_exact(&mut buf);
86 u32::from_be_bytes(buf)
87 }
88}
89impl DataLoader for Mnist {
90 fn get_size(&self, slice: Option<DataSlice>) -> Result<Vec<usize>, AutoDiffError> {
91 match slice {
92 Some(DataSlice::Train) => {Ok(self.train.size())},
93 Some(DataSlice::Test) => {Ok(self.test.size())},
94 None => {
95 let n = self.train.size()[0] + self.test.size()[1];
96 let mut new_size = self.train.size();
97 new_size[0] = n;
98 Ok(new_size)
99 },
100 _ => {Err(AutoDiffError::new("TODO"))}
101 }
102 }
103 fn get_item(&self, index: usize, slice: Option<DataSlice>) -> Result<(Var, Var), AutoDiffError> {
104 match slice {
105 Some(DataSlice::Train) => {
106 let dim = self.train.size().len();
107 let mut index_block = vec![(index, index+1)];
108 index_block.append(
109 &mut vec![0; dim-1].iter().zip(&self.train.size()[1..])
110 .map(|(x,y)| (*x, *y)).collect());
111 let data = self.train.get_patch(&index_block, None)?;
112 let label = self.train_label.get_patch(&[(index, index+1)], None)?;
113 self.train.reset_net();
114 self.train_label.reset_net();
115 Ok((data, label))
116 },
117 Some(DataSlice::Test) => {
118 let dim = self.test.size().len();
119 let mut index_block = vec![(index, index+1)];
120 index_block.append(
121 &mut vec![0; dim-1].iter().zip(&self.test.size()[1..])
122 .map(|(x,y)| (*x, *y)).collect());
123 let data = self.test.get_patch(&index_block, None)?;
124 let label = self.test_label.get_patch(&[(index, index+1)], None)?;
125 self.test.reset_net();
126 self.test_label.reset_net();
127 Ok((data, label))
128 },
129 _ => {Err(AutoDiffError::new("only train and test"))}
130 }
131 }
132 fn get_batch(&self, start: usize, end: usize, slice: Option<DataSlice>) -> Result<(Var, Var), AutoDiffError> {
133 match slice {
134 Some(DataSlice::Train) => {
135 let dim = self.train.size().len();
136 let mut index_block = vec![(start, end)];
137 index_block.append(
138 &mut vec![0; dim-1].iter().zip(&self.train.size()[1..])
139 .map(|(x,y)| (*x, *y)).collect());
140 let data = self.train.get_patch(&index_block, None)?;
141 let label = self.train_label.get_patch(&[(start, end)], None)?;
142 self.train.reset_net();
143 self.train_label.reset_net();
144 Ok((data, label))
145 },
146 Some(DataSlice::Test) => {
147 let dim = self.test.size().len();
148 let mut index_block = vec![(start, end)];
149 index_block.append(
150 &mut vec![0; dim-1].iter().zip(&self.test.size()[1..])
151 .map(|(x,y)| (*x, *y)).collect());
152 let data = self.test.get_patch(&index_block, None)?;
153 let label = self.test_label.get_patch(&[(start, end)], None)?;
154 self.test.reset_net();
155 self.test_label.reset_net();
156 Ok((data, label))
157 },
158 _ => {Err(AutoDiffError::new("only train and test"))}
159 }
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166
167 #[test]
168 fn mnist() {
169 let mnist = Mnist::load(Path::new("../auto-diff/examples/data/mnist/"));
170 let (t0, l0) = mnist.get_item(0, Some(DataSlice::Test)).unwrap();
171 println!("{:?}", t0);
172 }
173}
174