vision/
cifar100.rs

1use std::io;
2use std::io::ErrorKind;
3use std::io::Read;
4use std::path::{Path, PathBuf};
5use std::fs::{create_dir_all, File};
6
7use futures::{Future, Stream};
8use hyper::Client;
9use tokio_core::reactor::Core;
10
11use flate2::read::GzDecoder;
12use tar::Archive;
13
14pub struct CIFAR100 {
15    pub train_labels: Vec<(u8, u8)>,
16    pub train_imgs: Vec<Vec<u8>>,
17    pub test_labels: Vec<(u8, u8)>,
18    pub test_imgs: Vec<Vec<u8>>
19}
20
21pub struct CIFAR100Builder {
22    data_home: String,
23    force_download: bool,
24    verbose: bool
25}
26
27impl CIFAR100Builder {
28    pub fn new() -> CIFAR100Builder {
29        CIFAR100Builder {
30            data_home: "CIFAR100".into(),
31            force_download: false,
32            verbose: false
33        }
34    }
35
36    pub fn data_home<S: Into<String>>(mut self, dh: S) -> CIFAR100Builder {
37        self.data_home = dh.into();
38        self
39    }
40
41    pub fn force_download(mut self) -> CIFAR100Builder {
42        self.force_download = true;
43        self
44
45    }
46
47    pub fn verbose(mut self) -> CIFAR100Builder {
48        self.verbose = true;
49        self
50    }
51
52    pub fn get_data(self) -> io::Result<CIFAR100> {
53        if self.verbose {
54            println!("Creating data directory: {}", self.data_home);
55        }
56        create_dir_all(&self.data_home)?;
57
58        if self.redownload() {
59            if self.verbose { println!("Downloading CIFAR-100 data"); }
60            self.download();
61        } else if self.verbose { println!("Already downloaded"); }
62
63        if self.verbose { println!("Extracting data"); }
64        
65        let (train_labels, train_imgs) = self.load_train_data()?;
66        let (test_labels, test_imgs) = self.load_test_data()?;
67        if self.verbose { println!("CIFAR-100 Loaded!"); }
68        Ok(CIFAR100 {
69            train_imgs: train_imgs,
70            train_labels: train_labels,
71            test_imgs: test_imgs,
72            test_labels: test_labels
73        })
74    }
75
76    /// Check whether dataset must be downloaded again
77    fn redownload(&self) -> bool {
78        if self.force_download {
79            true
80        } else {
81            let file_names = [
82                "cifar-100-binary/train.bin",
83                "cifar-100-binary/test.bin"
84            ];
85
86            !file_names.iter().all(|f| self.get_file_path(f).is_file())
87        }
88    }
89
90    fn get_file_path<P: AsRef<Path>>(&self, filename: P) -> PathBuf {
91        Path::new(&self.data_home).join(filename)
92    }
93
94    fn download(&self) {
95        let uri = String::from("http://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz").parse().unwrap();
96
97        let mut core = Core::new().unwrap();
98        let client = Client::new(&core.handle());
99
100        let get_data = client.get(uri).and_then(|res| {
101            res.body().concat2()
102        });
103        let all_data = core.run(get_data).unwrap();
104        let mut archive = Archive::new(GzDecoder::new(&*all_data));
105        archive.unpack(self.data_home.clone()).unwrap();
106    }
107
108    fn load_train_data(&self)
109        -> io::Result<(Vec<(u8, u8)>, Vec<Vec<u8>>)>
110    {
111        let file_path = "cifar-100-binary/train.bin";
112        let full_path = self.get_file_path(file_path);
113        self.load_batch_file(full_path)
114    }
115
116    fn load_test_data(&self)
117        -> io::Result<(Vec<(u8, u8)>, Vec<Vec<u8>>)>
118    {
119        let file_path = "cifar-100-binary/test.bin";
120        let full_path = self.get_file_path(file_path);
121        self.load_batch_file(full_path)
122    }
123
124    fn load_batch_file<P: AsRef<Path>>(&self, path: P)
125        -> io::Result<(Vec<(u8, u8)>, Vec<Vec<u8>>)>
126    {
127        let mut buf = vec![0u8; 3074];
128        let mut file = File::open(path)?;
129
130        let mut labels = vec![];
131        let mut pixels = vec![];
132
133        loop {
134            match file.read_exact(&mut buf) {
135                Ok(_) => {
136                    labels.push((buf[0], buf[1]));
137                    pixels.push(buf[2..].into());
138                },
139                Err(e) => match e.kind() {
140                    ErrorKind::UnexpectedEof => break,
141                    _ => return Err(e)
142                }
143            }
144        }
145
146        Ok((labels, pixels))
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153    use std::fs::remove_dir_all;
154
155    #[test]
156    #[ignore]
157    fn test_builder() {
158        let builder = CIFAR100Builder::new();
159        let cifar100 = builder.data_home("CIFAR100").get_data().unwrap();
160        assert_eq!(cifar100.train_imgs.len(), 50000);
161        assert_eq!(cifar100.train_imgs[0].len(), 3072);
162        assert_eq!(cifar100.train_labels.len(), 50000);
163        assert_eq!(cifar100.test_imgs.len(), 10000);
164        assert_eq!(cifar100.test_imgs[0].len(), 3072);
165        assert_eq!(cifar100.test_labels.len(), 10000);
166        remove_dir_all("CIFAR100").unwrap();
167    }
168
169}