1use std::io;
2use std::io::ErrorKind;
3use std::io::Read;
4use std::path::{Path, PathBuf};
5use std::fs::{create_dir_all, File};
6
7use futures::{Future, Stream};
8use hyper::Client;
9use tokio_core::reactor::Core;
10
11use flate2::read::GzDecoder;
12use tar::Archive;
13
14pub struct CIFAR100 {
15 pub train_labels: Vec<(u8, u8)>,
16 pub train_imgs: Vec<Vec<u8>>,
17 pub test_labels: Vec<(u8, u8)>,
18 pub test_imgs: Vec<Vec<u8>>
19}
20
21pub struct CIFAR100Builder {
22 data_home: String,
23 force_download: bool,
24 verbose: bool
25}
26
27impl CIFAR100Builder {
28 pub fn new() -> CIFAR100Builder {
29 CIFAR100Builder {
30 data_home: "CIFAR100".into(),
31 force_download: false,
32 verbose: false
33 }
34 }
35
36 pub fn data_home<S: Into<String>>(mut self, dh: S) -> CIFAR100Builder {
37 self.data_home = dh.into();
38 self
39 }
40
41 pub fn force_download(mut self) -> CIFAR100Builder {
42 self.force_download = true;
43 self
44
45 }
46
47 pub fn verbose(mut self) -> CIFAR100Builder {
48 self.verbose = true;
49 self
50 }
51
52 pub fn get_data(self) -> io::Result<CIFAR100> {
53 if self.verbose {
54 println!("Creating data directory: {}", self.data_home);
55 }
56 create_dir_all(&self.data_home)?;
57
58 if self.redownload() {
59 if self.verbose { println!("Downloading CIFAR-100 data"); }
60 self.download();
61 } else if self.verbose { println!("Already downloaded"); }
62
63 if self.verbose { println!("Extracting data"); }
64
65 let (train_labels, train_imgs) = self.load_train_data()?;
66 let (test_labels, test_imgs) = self.load_test_data()?;
67 if self.verbose { println!("CIFAR-100 Loaded!"); }
68 Ok(CIFAR100 {
69 train_imgs: train_imgs,
70 train_labels: train_labels,
71 test_imgs: test_imgs,
72 test_labels: test_labels
73 })
74 }
75
76 fn redownload(&self) -> bool {
78 if self.force_download {
79 true
80 } else {
81 let file_names = [
82 "cifar-100-binary/train.bin",
83 "cifar-100-binary/test.bin"
84 ];
85
86 !file_names.iter().all(|f| self.get_file_path(f).is_file())
87 }
88 }
89
90 fn get_file_path<P: AsRef<Path>>(&self, filename: P) -> PathBuf {
91 Path::new(&self.data_home).join(filename)
92 }
93
94 fn download(&self) {
95 let uri = String::from("http://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz").parse().unwrap();
96
97 let mut core = Core::new().unwrap();
98 let client = Client::new(&core.handle());
99
100 let get_data = client.get(uri).and_then(|res| {
101 res.body().concat2()
102 });
103 let all_data = core.run(get_data).unwrap();
104 let mut archive = Archive::new(GzDecoder::new(&*all_data));
105 archive.unpack(self.data_home.clone()).unwrap();
106 }
107
108 fn load_train_data(&self)
109 -> io::Result<(Vec<(u8, u8)>, Vec<Vec<u8>>)>
110 {
111 let file_path = "cifar-100-binary/train.bin";
112 let full_path = self.get_file_path(file_path);
113 self.load_batch_file(full_path)
114 }
115
116 fn load_test_data(&self)
117 -> io::Result<(Vec<(u8, u8)>, Vec<Vec<u8>>)>
118 {
119 let file_path = "cifar-100-binary/test.bin";
120 let full_path = self.get_file_path(file_path);
121 self.load_batch_file(full_path)
122 }
123
124 fn load_batch_file<P: AsRef<Path>>(&self, path: P)
125 -> io::Result<(Vec<(u8, u8)>, Vec<Vec<u8>>)>
126 {
127 let mut buf = vec![0u8; 3074];
128 let mut file = File::open(path)?;
129
130 let mut labels = vec![];
131 let mut pixels = vec![];
132
133 loop {
134 match file.read_exact(&mut buf) {
135 Ok(_) => {
136 labels.push((buf[0], buf[1]));
137 pixels.push(buf[2..].into());
138 },
139 Err(e) => match e.kind() {
140 ErrorKind::UnexpectedEof => break,
141 _ => return Err(e)
142 }
143 }
144 }
145
146 Ok((labels, pixels))
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153 use std::fs::remove_dir_all;
154
155 #[test]
156 #[ignore]
157 fn test_builder() {
158 let builder = CIFAR100Builder::new();
159 let cifar100 = builder.data_home("CIFAR100").get_data().unwrap();
160 assert_eq!(cifar100.train_imgs.len(), 50000);
161 assert_eq!(cifar100.train_imgs[0].len(), 3072);
162 assert_eq!(cifar100.train_labels.len(), 50000);
163 assert_eq!(cifar100.test_imgs.len(), 10000);
164 assert_eq!(cifar100.test_imgs[0].len(), 3072);
165 assert_eq!(cifar100.test_labels.len(), 10000);
166 remove_dir_all("CIFAR100").unwrap();
167 }
168
169}