#![doc(test(attr(allow(unused_variables), deny(warnings))))]
extern crate byteorder;
mod download;
mod tests;
use byteorder::{BigEndian, ReadBytesExt};
use std::fs::File;
use std::io::prelude::*;
use std::path::Path;
static BASE_PATH: &str = "data/";
static BASE_URL: &str = "http://yann.lecun.com/exdb/mnist";
static FASHION_BASE_URL: &str = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com";
static TRN_IMG_FILENAME: &str = "train-images-idx3-ubyte";
static TRN_LBL_FILENAME: &str = "train-labels-idx1-ubyte";
static TST_IMG_FILENAME: &str = "t10k-images-idx3-ubyte";
static TST_LBL_FILENAME: &str = "t10k-labels-idx1-ubyte";
static IMG_MAGIC_NUMBER: u32 = 0x0000_0803;
static LBL_MAGIC_NUMBER: u32 = 0x0000_0801;
static TRN_LEN: u32 = 60000;
static TST_LEN: u32 = 10000;
static CLASSES: usize = 10;
static ROWS: usize = 28;
static COLS: usize = 28;
#[derive(Debug)]
pub struct Mnist {
pub trn_img: Vec<u8>,
pub trn_lbl: Vec<u8>,
pub val_img: Vec<u8>,
pub val_lbl: Vec<u8>,
pub tst_img: Vec<u8>,
pub tst_lbl: Vec<u8>,
}
#[derive(Debug)]
pub struct MnistBuilder<'a> {
lbl_format: LabelFormat,
trn_len: u32,
val_len: u32,
tst_len: u32,
base_path: &'a str,
trn_img_filename: &'a str,
trn_lbl_filename: &'a str,
tst_img_filename: &'a str,
tst_lbl_filename: &'a str,
download_and_extract: bool,
base_url: &'a str,
use_fashion_data: bool,
}
impl<'a> MnistBuilder<'a> {
pub fn new() -> MnistBuilder<'a> {
MnistBuilder {
lbl_format: LabelFormat::Digit,
trn_len: TRN_LEN,
val_len: 0,
tst_len: TST_LEN,
base_path: BASE_PATH,
trn_img_filename: TRN_IMG_FILENAME,
trn_lbl_filename: TRN_LBL_FILENAME,
tst_img_filename: TST_IMG_FILENAME,
tst_lbl_filename: TST_LBL_FILENAME,
download_and_extract: false,
base_url: BASE_URL,
use_fashion_data: false,
}
}
pub fn label_format_digit(&mut self) -> &mut MnistBuilder<'a> {
self.lbl_format = LabelFormat::Digit;
self
}
pub fn label_format_one_hot(&mut self) -> &mut MnistBuilder<'a> {
self.lbl_format = LabelFormat::OneHotVector;
self
}
pub fn training_set_length(&mut self, length: u32) -> &mut MnistBuilder<'a> {
self.trn_len = length;
self
}
pub fn validation_set_length(&mut self, length: u32) -> &mut MnistBuilder<'a> {
self.val_len = length;
self
}
pub fn test_set_length(&mut self, length: u32) -> &mut MnistBuilder<'a> {
self.tst_len = length;
self
}
pub fn base_path(&mut self, base_path: &'a str) -> &mut MnistBuilder<'a> {
self.base_path = base_path;
self
}
pub fn training_images_filename(&mut self, trn_img_filename: &'a str) -> &mut MnistBuilder<'a> {
self.trn_img_filename = trn_img_filename;
self
}
pub fn training_labels_filename(&mut self, trn_lbl_filename: &'a str) -> &mut MnistBuilder<'a> {
self.trn_lbl_filename = trn_lbl_filename;
self
}
pub fn test_images_filename(&mut self, tst_img_filename: &'a str) -> &mut MnistBuilder<'a> {
self.tst_img_filename = tst_img_filename;
self
}
pub fn test_labels_filename(&mut self, tst_lbl_filename: &'a str) -> &mut MnistBuilder<'a> {
self.tst_lbl_filename = tst_lbl_filename;
self
}
pub fn download_and_extract(&mut self) -> &mut MnistBuilder<'a> {
self.download_and_extract = true;
self
}
pub fn use_fashion_data(&mut self) -> &mut MnistBuilder<'a> {
self.use_fashion_data = true;
self
}
pub fn base_url(&mut self, base_url: &'a str) -> &mut MnistBuilder<'a> {
self.base_url = base_url;
self
}
pub fn finalize(&self) -> Mnist {
if self.download_and_extract {
let base_url = if self.use_fashion_data {
FASHION_BASE_URL
} else if self.base_url != BASE_URL {
self.base_url
} else {
BASE_URL
};
#[cfg(feature = "download")]
download::download_and_extract(base_url, &self.base_path, self.use_fashion_data)
.unwrap();
#[cfg(not(feature = "download"))]
{
log::warn!("WARNING: Download disabled.");
log::warn!(" Please use the mnist crate's 'download' feature to enable.");
}
}
let &MnistBuilder {
trn_len,
val_len,
tst_len,
..
} = self;
let (trn_len, val_len, tst_len) = (trn_len as usize, val_len as usize, tst_len as usize);
let total_length = trn_len + val_len + tst_len;
let available_length = (TRN_LEN + TST_LEN) as usize;
assert!(
total_length <= available_length,
format!(
"Total data set length ({}) greater than maximum possible length ({}).",
total_length, available_length
)
);
let mut trn_img = images(
&Path::new(self.base_path).join(self.trn_img_filename),
TRN_LEN,
);
let mut trn_lbl = labels(
&Path::new(self.base_path).join(self.trn_lbl_filename),
TRN_LEN,
);
let mut tst_img = images(
&Path::new(self.base_path).join(self.tst_img_filename),
TST_LEN,
);
let mut tst_lbl = labels(
&Path::new(self.base_path).join(self.tst_lbl_filename),
TST_LEN,
);
trn_img.append(&mut tst_img);
trn_lbl.append(&mut tst_lbl);
let mut val_img = trn_img.split_off(trn_len * ROWS * COLS);
let mut val_lbl = trn_lbl.split_off(trn_len);
let mut tst_img = val_img.split_off(val_len * ROWS * COLS);
let mut tst_lbl = val_lbl.split_off(val_len);
tst_img.split_off(tst_len * ROWS * COLS);
tst_lbl.split_off(tst_len);
if self.lbl_format == LabelFormat::OneHotVector {
fn digit2one_hot(v: Vec<u8>) -> Vec<u8> {
v.iter()
.map(|&i| {
let mut v = vec![0; CLASSES as usize];
v[i as usize] = 1;
v
})
.flatten()
.collect()
}
trn_lbl = digit2one_hot(trn_lbl);
val_lbl = digit2one_hot(val_lbl);
tst_lbl = digit2one_hot(tst_lbl);
}
Mnist {
trn_img,
trn_lbl,
val_img,
val_lbl,
tst_img,
tst_lbl,
}
}
}
impl Default for MnistBuilder<'_> {
fn default() -> Self {
Self::new()
}
}
impl Mnist {
pub fn normalize(self) -> NormalizedMnist {
NormalizedMnist::new(self)
}
}
#[derive(Debug)]
pub struct NormalizedMnist {
pub trn_img: Vec<f32>,
pub trn_lbl: Vec<u8>,
pub val_img: Vec<f32>,
pub val_lbl: Vec<u8>,
pub tst_img: Vec<f32>,
pub tst_lbl: Vec<u8>,
}
impl NormalizedMnist {
pub fn new(mnist: Mnist) -> NormalizedMnist {
NormalizedMnist {
trn_img: normalize_vector(&mnist.trn_img),
trn_lbl: mnist.trn_lbl,
val_img: normalize_vector(&mnist.val_img),
val_lbl: mnist.val_lbl,
tst_img: normalize_vector(&mnist.tst_img),
tst_lbl: mnist.tst_lbl,
}
}
}
fn normalize_vector(v: &[u8]) -> Vec<f32> {
v.iter().map(|&pixel| (pixel as f32) / 255.0_f32).collect()
}
#[derive(Debug, PartialEq)]
enum LabelFormat {
Digit,
OneHotVector,
}
fn labels(path: &Path, expected_length: u32) -> Vec<u8> {
let mut file =
File::open(path).unwrap_or_else(|_| panic!("Unable to find path to labels at {:?}.", path));
let magic_number = file
.read_u32::<BigEndian>()
.unwrap_or_else(|_| panic!("Unable to read magic number from {:?}.", path));
assert!(
LBL_MAGIC_NUMBER == magic_number,
format!(
"Expected magic number {} got {}.",
LBL_MAGIC_NUMBER, magic_number
)
);
let length = file
.read_u32::<BigEndian>()
.unwrap_or_else(|_| panic!("Unable to length from {:?}.", path));
assert!(
expected_length == length,
format!(
"Expected data set length of {} got {}.",
expected_length, length
)
);
file.bytes().map(|b| b.unwrap()).collect()
}
fn images(path: &Path, expected_length: u32) -> Vec<u8> {
let mut content: Vec<u8> = Vec::new();
let mut file = {
let mut fh = File::open(path)
.unwrap_or_else(|_| panic!("Unable to find path to images at {:?}.", path));
let _ = fh
.read_to_end(&mut content)
.unwrap_or_else(|_| panic!("Unable to read whole file in memory ({})", path.display()));
&content[..]
};
let magic_number = file
.read_u32::<BigEndian>()
.unwrap_or_else(|_| panic!("Unable to read magic number from {:?}.", path));
assert!(
IMG_MAGIC_NUMBER == magic_number,
format!(
"Expected magic number {} got {}.",
IMG_MAGIC_NUMBER, magic_number
)
);
let length = file
.read_u32::<BigEndian>()
.unwrap_or_else(|_| panic!("Unable to length from {:?}.", path));
assert!(
expected_length == length,
format!(
"Expected data set length of {} got {}.",
expected_length, length
)
);
let rows = file
.read_u32::<BigEndian>()
.unwrap_or_else(|_| panic!("Unable to number of rows from {:?}.", path))
as usize;
assert!(
ROWS == rows,
format!("Expected rows length of {} got {}.", ROWS, rows)
);
let cols = file
.read_u32::<BigEndian>()
.unwrap_or_else(|_| panic!("Unable to number of columns from {:?}.", path))
as usize;
assert!(
COLS == cols,
format!("Expected cols length of {} got {}.", COLS, cols)
);
file.to_vec()
}