use std::fs::File;
use std::path::Path;
use std::io;
use std::io::Read;
use auto_diff::Var;
pub fn load_images<P: AsRef<Path>>(path: P) -> Var {
let ref mut reader = io::BufReader::new(File::open(path).expect(""));
let magic = read_as_u32(reader);
if magic != 2051 {
panic!("Invalid magic number. expected 2051, got {}", magic)
}
let num_image = read_as_u32(reader) as usize;
let rows = read_as_u32(reader) as usize;
let cols = read_as_u32(reader) as usize;
assert!(rows == 28 && cols == 28);
let mut buf: Vec<u8> = vec![0u8; num_image * rows * cols];
let _ = reader.read_exact(buf.as_mut());
let ret: Vec<f64> = buf.into_iter().map(|x| (x as f64) / 255.).collect();
let ret = Var::new(&ret[..], &vec![num_image, rows, cols]);
ret
}
pub fn load_labels<P: AsRef<Path>>(path: P) -> Var {
let ref mut reader = io::BufReader::new(File::open(path).expect(""));
let magic = read_as_u32(reader);
if magic != 2049 {
panic!("Invalid magic number. Got expect 2049, got {}", magic);
}
let num_label = read_as_u32(reader) as usize;
let mut buf: Vec<u8> = vec![0u8; num_label];
let _ = reader.read_exact(buf.as_mut());
let ret: Vec<f64> = buf.into_iter().map(|x| x as f64).collect();
let ret = Var::new(&ret[..], &vec![num_label]);
ret
}
fn read_as_u32<T: Read>(reader: &mut T) -> u32 {
let mut buf: [u8; 4] = [0, 0, 0, 0];
let _ = reader.read_exact(&mut buf);
u32::from_be_bytes(buf)
}
#[allow(dead_code)]
pub fn main() {
let t = load_images("examples/data/mnist/train-images-idx3-ubyte");
for i in 0..10 {
let first_image = t.get_patch(&vec![(i,i+1),(0,28),(0,28)], None).unwrap();
let rgb_img = first_image.cat(&vec![first_image.clone(), first_image.clone()], 0).unwrap();
let rgb_img = rgb_img.permute(&vec![1, 2, 0]).unwrap();
let _rgb_img = rgb_img * Var::fill(&vec![1], &Var::new(&[255.], &[1]));
}
let first_image = t.get_patch(&vec![(0,1),(0,28),(0,28)], None).unwrap();
let rgb_img = first_image.cat(&vec![first_image.clone(), first_image.clone()], 0).unwrap();
let rgb_img = rgb_img.permute(&vec![1, 2, 0]).unwrap();
let _rgb_img = rgb_img * Var::fill(&vec![1], &Var::new(&[255.], &[1]));
let first_image = t.get_patch(&vec![(10,11),(0,28),(0,28)], None).unwrap();
let rgb_img = first_image.cat(&vec![first_image.clone(), first_image.clone()], 0).unwrap();
let rgb_img = rgb_img.permute(&vec![1, 2, 0]).unwrap();
let _rgb_img = rgb_img * Var::fill(&vec![1], &Var::new(&[255.], &[1]));
let l = load_labels("examples/data/mnist/train-labels-idx1-ubyte");
println!("{}, {}", l.get_f32(&vec![0]).unwrap(), l.get_f32(&vec![10]).unwrap());
}