use std::io::Read;
use crate::data::BatchDataSet;
use crate::tensor::{Device, Result, Tensor, TensorError};
pub struct Mnist {
pub images: Tensor,
pub labels: Tensor,
}
impl Mnist {
pub fn parse(images_gz: &[u8], labels_gz: &[u8]) -> Result<Self> {
let images_raw = gunzip(images_gz)?;
let labels_raw = gunzip(labels_gz)?;
if images_raw.len() < 16 {
return Err(TensorError::new("MNIST images: file too short for header"));
}
let magic = read_u32_be(&images_raw[0..4]);
if magic != 2051 {
return Err(TensorError::new(&format!(
"MNIST images: bad magic {magic}, expected 2051"
)));
}
let n = read_u32_be(&images_raw[4..8]) as usize;
let rows = read_u32_be(&images_raw[8..12]) as usize;
let cols = read_u32_be(&images_raw[12..16]) as usize;
let pixel_count = n * rows * cols;
if images_raw.len() < 16 + pixel_count {
return Err(TensorError::new(&format!(
"MNIST images: expected {} pixels, got {}",
pixel_count,
images_raw.len() - 16
)));
}
let pixels: Vec<f32> = images_raw[16..16 + pixel_count]
.iter()
.map(|&b| b as f32 / 255.0)
.collect();
let images = Tensor::from_f32(&pixels, &[n as i64, 1, rows as i64, cols as i64], Device::CPU)?;
if labels_raw.len() < 8 {
return Err(TensorError::new("MNIST labels: file too short for header"));
}
let magic = read_u32_be(&labels_raw[0..4]);
if magic != 2049 {
return Err(TensorError::new(&format!(
"MNIST labels: bad magic {magic}, expected 2049"
)));
}
let n_labels = read_u32_be(&labels_raw[4..8]) as usize;
if n_labels != n {
return Err(TensorError::new(&format!(
"MNIST: {n} images but {n_labels} labels"
)));
}
if labels_raw.len() < 8 + n {
return Err(TensorError::new("MNIST labels: truncated"));
}
let label_vals: Vec<i64> = labels_raw[8..8 + n]
.iter()
.map(|&b| b as i64)
.collect();
let labels = Tensor::from_i64(&label_vals, &[n as i64], Device::CPU)?;
Ok(Mnist { images, labels })
}
pub fn len(&self) -> usize {
self.images.shape()[0] as usize
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl BatchDataSet for Mnist {
fn len(&self) -> usize {
self.images.shape()[0] as usize
}
fn get_batch(&self, indices: &[usize]) -> Result<Vec<Tensor>> {
let idx: Vec<i64> = indices.iter().map(|&i| (i % self.len()) as i64).collect();
let idx_tensor = Tensor::from_i64(&idx, &[idx.len() as i64], Device::CPU)?;
let images = self.images.index_select(0, &idx_tensor)?;
let labels = self.labels.index_select(0, &idx_tensor)?;
Ok(vec![images, labels])
}
}
fn gunzip(data: &[u8]) -> Result<Vec<u8>> {
let mut decoder = flate2::read::GzDecoder::new(data);
let mut out = Vec::new();
decoder
.read_to_end(&mut out)
.map_err(|e| TensorError::new(&format!("gzip decompression failed: {e}")))?;
Ok(out)
}
fn read_u32_be(bytes: &[u8]) -> u32 {
u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])
}
#[cfg(test)]
mod tests {
use super::*;
fn make_idx3(n: u32, rows: u32, cols: u32, pixels: &[u8]) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&2051u32.to_be_bytes());
buf.extend_from_slice(&n.to_be_bytes());
buf.extend_from_slice(&rows.to_be_bytes());
buf.extend_from_slice(&cols.to_be_bytes());
buf.extend_from_slice(pixels);
buf
}
fn make_idx1(n: u32, labels: &[u8]) -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&2049u32.to_be_bytes());
buf.extend_from_slice(&n.to_be_bytes());
buf.extend_from_slice(labels);
buf
}
fn gzip(data: &[u8]) -> Vec<u8> {
use flate2::write::GzEncoder;
use flate2::Compression;
use std::io::Write;
let mut enc = GzEncoder::new(Vec::new(), Compression::fast());
enc.write_all(data).unwrap();
enc.finish().unwrap()
}
#[test]
fn parse_small_mnist() {
let pixels = vec![0u8, 128, 255, 0, 0, 0, 255, 255, 255, 10, 20, 30, 40, 50, 60, 70, 80, 90];
let images_raw = make_idx3(2, 3, 3, &pixels);
let labels_raw = make_idx1(2, &[3, 7]);
let images_gz = gzip(&images_raw);
let labels_gz = gzip(&labels_raw);
let mnist = Mnist::parse(&images_gz, &labels_gz).unwrap();
assert_eq!(mnist.images.shape(), &[2, 1, 3, 3]);
assert_eq!(mnist.labels.shape(), &[2]);
let first = mnist.images.select(0, 0).unwrap();
let val: f64 = first.select(0, 0).unwrap()
.select(0, 0).unwrap()
.select(0, 0).unwrap()
.item().unwrap();
assert!((val - 0.0).abs() < 1e-6);
let val: f64 = first.select(0, 0).unwrap()
.select(0, 0).unwrap()
.select(0, 1).unwrap()
.item().unwrap();
assert!((val - 128.0 / 255.0).abs() < 1e-4);
let l0 = mnist.labels.select(0, 0).unwrap().to_i64_vec().unwrap()[0];
let l1 = mnist.labels.select(0, 1).unwrap().to_i64_vec().unwrap()[0];
assert_eq!(l0, 3);
assert_eq!(l1, 7);
}
#[test]
fn get_batch_wraps_indices() {
let pixels = vec![0u8; 2 * 3 * 3];
let images_raw = make_idx3(2, 3, 3, &pixels);
let labels_raw = make_idx1(2, &[0, 1]);
let mnist = Mnist::parse(&gzip(&images_raw), &gzip(&labels_raw)).unwrap();
let batch = mnist.get_batch(&[0, 5]).unwrap();
assert_eq!(batch[0].shape(), &[2, 1, 3, 3]);
assert_eq!(batch[1].shape(), &[2]);
}
#[test]
fn bad_magic_rejected() {
let mut images_raw = make_idx3(1, 2, 2, &[0; 4]);
images_raw[3] = 99; let labels_raw = make_idx1(1, &[0]);
let result = Mnist::parse(&gzip(&images_raw), &gzip(&labels_raw));
assert!(result.is_err());
}
}