use ndarray;
use std::fs::File;
use std::io;
use std::io::Read;
use std::path::Path;
type NdArray = ndarray::Array<f32, ndarray::IxDyn>;
pub fn load() -> ((NdArray, NdArray), (NdArray, NdArray)) {
let (train_x, num_image_train): (Vec<f32>, usize) =
load_images("data/mnist/train-images-idx3-ubyte");
let (train_y, num_label_train): (Vec<f32>, usize) =
load_labels("data/mnist/train-labels-idx1-ubyte");
let (test_x, num_image_test): (Vec<f32>, usize) =
load_images("data/mnist/t10k-images-idx3-ubyte");
let (test_y, num_label_test): (Vec<f32>, usize) =
load_labels("data/mnist/t10k-labels-idx1-ubyte");
let as_arr = NdArray::from_shape_vec;
let x_train = as_arr(ndarray::IxDyn(&[num_image_train, 28 * 28]), train_x).unwrap();
let y_train = as_arr(ndarray::IxDyn(&[num_label_train, 1]), train_y).unwrap();
let x_test = as_arr(ndarray::IxDyn(&[num_image_test, 28 * 28]), test_x).unwrap();
let y_test = as_arr(ndarray::IxDyn(&[num_label_test, 1]), test_y).unwrap();
((x_train, y_train), (x_test, y_test))
}
fn load_images<P: AsRef<Path>>(path: P) -> (Vec<f32>, usize) {
let ref mut buf_reader =
io::BufReader::new(File::open(path).expect("Please run ./download_mnist.sh beforehand"));
let magic = read_be_u32(buf_reader);
if magic != 2051 {
panic!("Invalid magic number. expected 2051, got {}", magic)
}
let num_image = read_be_u32(buf_reader) as usize;
let rows = read_be_u32(buf_reader) as usize;
let cols = read_be_u32(buf_reader) as usize;
assert!(rows == 28 && cols == 28);
let mut buf: Vec<u8> = vec![0u8; num_image * rows * cols];
let _ = buf_reader.read_exact(buf.as_mut());
let ret = buf.into_iter().map(|x| (x as f32) / 255.).collect();
(ret, num_image)
}
fn load_labels<P: AsRef<Path>>(path: P) -> (Vec<f32>, usize) {
let ref mut buf_reader = io::BufReader::new(File::open(path).unwrap());
let magic = read_be_u32(buf_reader);
if magic != 2049 {
panic!("Invalid magic number. Got expect 2049, got {}", magic);
}
let num_label = read_be_u32(buf_reader) as usize;
let mut buf: Vec<u8> = vec![0u8; num_label];
let _ = buf_reader.read_exact(buf.as_mut());
let ret: Vec<f32> = buf.into_iter().map(|x| x as f32).collect();
(ret, num_label)
}
fn read_be_u32<T: Read>(reader: &mut T) -> u32 {
let mut buf: [u8; 4] = [0, 0, 0, 0];
let _ = reader.read_exact(&mut buf);
u32::from_be_bytes(buf)
}
#[allow(dead_code)]
fn main() {}