use super::DataLoader;
use std::{
io::{self, Read},
path::Path,
};
pub struct MnistDataset {
pub images: Vec<f32>,
pub labels: Vec<f32>,
pub n: usize,
}
impl MnistDataset {
pub fn load(images_path: &Path, labels_path: &Path) -> io::Result<Self> {
let images_raw = std::fs::read(images_path)?;
let labels_raw = std::fs::read(labels_path)?;
Self::from_buffers(&images_raw, &labels_raw)
}
pub fn load_gz(images_path: &Path, labels_path: &Path) -> io::Result<Self> {
let images_raw = read_gz(images_path)?;
let labels_raw = read_gz(labels_path)?;
Self::from_buffers(&images_raw, &labels_raw)
}
fn from_buffers(images_raw: &[u8], labels_raw: &[u8]) -> io::Result<Self> {
let images = parse_idx_images(images_raw)?;
let labels = parse_idx_labels(labels_raw)?;
if images.n != labels.n {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("image count ({}) != label count ({})", images.n, labels.n),
));
}
Ok(Self {
images: images.data,
labels: labels.data,
n: images.n,
})
}
pub fn loader(self, batch_size: usize) -> DataLoader {
DataLoader::new(self.images, self.labels, 784, 10, batch_size)
}
}
fn read_gz(path: &Path) -> io::Result<Vec<u8>> {
let file = std::fs::File::open(path)?;
let mut decoder = flate2::read::GzDecoder::new(file);
let mut buf = Vec::new();
decoder.read_to_end(&mut buf)?;
Ok(buf)
}
struct ParsedImages {
data: Vec<f32>,
n: usize,
}
struct ParsedLabels {
data: Vec<f32>,
n: usize,
}
fn read_u32_be(buf: &[u8], offset: usize) -> u32 {
u32::from_be_bytes([
buf[offset],
buf[offset + 1],
buf[offset + 2],
buf[offset + 3],
])
}
fn parse_idx_images(buf: &[u8]) -> io::Result<ParsedImages> {
if buf.len() < 16 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"IDX image file too short",
));
}
let magic = read_u32_be(buf, 0);
if magic != 0x00000803 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("bad IDX image magic: 0x{:08x}, expected 0x00000803", magic),
));
}
let n = read_u32_be(buf, 4) as usize;
let rows = read_u32_be(buf, 8) as usize;
let cols = read_u32_be(buf, 12) as usize;
let pixels = rows * cols; let expected = 16 + n * pixels;
if buf.len() < expected {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"IDX image file truncated: need {} bytes, got {}",
expected,
buf.len()
),
));
}
let data: Vec<f32> = buf[16..16 + n * pixels]
.iter()
.map(|&b| b as f32 / 255.0)
.collect();
Ok(ParsedImages { data, n })
}
fn parse_idx_labels(buf: &[u8]) -> io::Result<ParsedLabels> {
if buf.len() < 8 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"IDX label file too short",
));
}
let magic = read_u32_be(buf, 0);
if magic != 0x00000801 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("bad IDX label magic: 0x{:08x}, expected 0x00000801", magic),
));
}
let n = read_u32_be(buf, 4) as usize;
let expected = 8 + n;
if buf.len() < expected {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"IDX label file truncated: need {} bytes, got {}",
expected,
buf.len()
),
));
}
let mut data = vec![0.0_f32; n * 10];
for i in 0..n {
let label = buf[8 + i] as usize;
if label >= 10 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("label {} out of range at index {}", label, i),
));
}
data[i * 10 + label] = 1.0;
}
Ok(ParsedLabels { data, n })
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_idx_images() {
let mut buf = Vec::new();
buf.extend_from_slice(&0x00000803_u32.to_be_bytes()); buf.extend_from_slice(&2_u32.to_be_bytes()); buf.extend_from_slice(&2_u32.to_be_bytes()); buf.extend_from_slice(&2_u32.to_be_bytes()); buf.extend_from_slice(&[0, 128, 255, 64, 32, 96, 192, 160]);
let parsed = parse_idx_images(&buf).unwrap();
assert_eq!(parsed.n, 2);
assert_eq!(parsed.data.len(), 8); assert_eq!(parsed.data[0], 0.0);
assert!((parsed.data[1] - 128.0 / 255.0).abs() < 1e-6);
assert!((parsed.data[2] - 1.0).abs() < 1e-6);
}
#[test]
fn test_parse_idx_labels() {
let mut buf = Vec::new();
buf.extend_from_slice(&0x00000801_u32.to_be_bytes()); buf.extend_from_slice(&3_u32.to_be_bytes()); buf.extend_from_slice(&[7, 2, 0]);
let parsed = parse_idx_labels(&buf).unwrap();
assert_eq!(parsed.n, 3);
assert_eq!(parsed.data.len(), 30); assert_eq!(parsed.data[7], 1.0);
assert_eq!(parsed.data[0], 0.0);
assert_eq!(parsed.data[10 + 2], 1.0);
assert_eq!(parsed.data[20], 1.0);
}
#[test]
fn test_parse_idx_bad_magic() {
let buf = [0, 0, 0, 0, 0, 0, 0, 1, 0];
assert!(parse_idx_labels(&buf).is_err());
}
#[test]
fn test_parse_idx_truncated() {
let mut buf = Vec::new();
buf.extend_from_slice(&0x00000801_u32.to_be_bytes());
buf.extend_from_slice(&100_u32.to_be_bytes()); buf.extend_from_slice(&[0, 1]); assert!(parse_idx_labels(&buf).is_err());
}
}