use std::io;
use std::io::{Error, ErrorKind};
use std::io::{Read, Write};
use std::path::{Path, PathBuf};
use std::fs::{create_dir_all, File};
use futures::{Future, Stream};
use futures::future::join_all;
use hyper::Client;
use tokio_core::reactor::Core;
use byteorder::{BigEndian, ReadBytesExt};
use flate2::read::GzDecoder;
const LABEL_MAGIC_NO: u32 = 2049;
const IMG_MAGIC_NO: u32 = 2051;
pub struct FashionMNIST {
pub train_labels: Vec<u8>,
pub train_imgs: Vec<Vec<u8>>,
pub test_labels: Vec<u8>,
pub test_imgs: Vec<Vec<u8>>
}
pub struct FashionMNISTBuilder {
data_home: String,
force_download: bool,
verbose: bool
}
impl FashionMNISTBuilder {
pub fn new() -> FashionMNISTBuilder {
FashionMNISTBuilder {
data_home: "FashionMNIST".into(),
force_download: false,
verbose: false
}
}
pub fn data_home<S: Into<String>>(mut self, dh: S) -> FashionMNISTBuilder {
self.data_home = dh.into();
self
}
pub fn force_download(mut self) -> FashionMNISTBuilder {
self.force_download = true;
self
}
pub fn verbose(mut self) -> FashionMNISTBuilder {
self.verbose = true;
self
}
pub fn get_data(self) -> io::Result<FashionMNIST> {
if self.verbose {
println!("Creating data directory: {}", self.data_home);
}
create_dir_all(&self.data_home)?;
if self.redownload() {
if self.verbose { println!("Downloading FashionMNIST data"); }
self.download();
} else if self.verbose { println!("Already downloaded"); }
if self.verbose { println!("Extracting data"); }
let (_train_lbl_meta, train_labels) = self.extract_labels(
self.get_file_path("train-labels-idx1-ubyte.gz"))?;
let (_train_img_meta, train_imgs) = self.extract_images(
self.get_file_path("train-images-idx3-ubyte.gz"))?;
let (_test_lbl_meta, test_labels) = self.extract_labels(
self.get_file_path("t10k-labels-idx1-ubyte.gz"))?;
let (_test_img_meta, test_imgs) = self.extract_images(
self.get_file_path("t10k-images-idx3-ubyte.gz"))?;
if self.verbose { println!("FashionMNIST Loaded!"); }
Ok(FashionMNIST {
train_imgs: train_imgs,
train_labels: train_labels,
test_imgs: test_imgs,
test_labels: test_labels
})
}
fn redownload(&self) -> bool {
if self.force_download {
true
} else {
let file_names = [
"train-labels-idx1-ubyte.gz",
"train-images-idx3-ubyte.gz",
"t10k-labels-idx1-ubyte.gz",
"t10k-images-idx3-ubyte.gz"
];
!file_names.iter().all(|f| self.get_file_path(f).is_file())
}
}
fn get_file_path(&self, filename: &str) -> PathBuf {
Path::new(&self.data_home).join(filename)
}
fn download(&self) {
let base_uri = String::from("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/");
let file_names = [
"train-labels-idx1-ubyte.gz",
"train-images-idx3-ubyte.gz",
"t10k-labels-idx1-ubyte.gz",
"t10k-images-idx3-ubyte.gz"
];
let mut core = Core::new().unwrap();
let client = Client::new(&core.handle());
let all_gets = join_all(file_names.iter().map(move |f| {
let full_uri = (base_uri.clone() + f).parse().unwrap();
client.get(full_uri).and_then(move |res| {
let path = self.get_file_path(f);
let mut file = File::create(path).unwrap();
res.body().for_each(move |chunk| {
file.write_all(&chunk)
.map(|_| ())
.map_err(From::from)
})
})
}));
core.run(all_gets).unwrap();
}
fn extract_labels<P: AsRef<Path>>(&self, label_file_path: P)
-> io::Result<([u32; 2], Vec<u8>)>
{
let mut decoder = self.get_decoder(label_file_path)?;
let mut metadata_buf = [0u32; 2];
decoder.read_u32_into::<BigEndian>(&mut metadata_buf)?;
let mut labels = Vec::new();
decoder.read_to_end(&mut labels)?;
if metadata_buf[0] != LABEL_MAGIC_NO {
Err(Error::new(ErrorKind::InvalidData,
"Unable to verify FashionMNIST data. Force redownload."))
} else {
Ok((metadata_buf, labels))
}
}
fn extract_images<P: AsRef<Path>>(&self, img_file_path: P)
-> io::Result<([u32; 4], Vec<Vec<u8>>)>
{
let mut decoder = self.get_decoder(img_file_path)?;
let mut metadata_buf = [0u32; 4];
decoder.read_u32_into::<BigEndian>(&mut metadata_buf)?;
let mut imgs = Vec::new();
decoder.read_to_end(&mut imgs)?;
if metadata_buf[0] != IMG_MAGIC_NO {
Err(Error::new(ErrorKind::InvalidData,
"Unable to verify FashionMNIST data. Force redownload."))
} else {
Ok((metadata_buf, imgs.chunks(784).map(|x| x.into()).collect()))
}
}
fn get_decoder<P: AsRef<Path>>(&self, archive: P) -> io::Result<GzDecoder<File>> {
let archive = File::open(archive)?;
Ok(GzDecoder::new(archive))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs::remove_dir_all;
#[test]
#[ignore]
fn test_builder() {
let builder = FashionMNISTBuilder::new();
let mnist = builder.data_home("FashionMNIST").get_data().unwrap();
assert_eq!(mnist.train_imgs.len(), 60000);
remove_dir_all("FashionMNIST").unwrap();
}
}