use crate::data::BatchDataSet;
use crate::tensor::{Device, Result, Tensor, TensorError};
pub const CLASS_NAMES: [&str; 10] = [
"airplane", "automobile", "bird", "cat", "deer",
"dog", "frog", "horse", "ship", "truck",
];
const PIXELS_PER_IMAGE: usize = 3 * 32 * 32; const BYTES_PER_RECORD: usize = 1 + PIXELS_PER_IMAGE; const IMAGES_PER_BATCH: usize = 10_000;
pub struct Cifar10 {
pub images: Tensor,
pub labels: Tensor,
}
impl Cifar10 {
pub fn parse(batches: &[&[u8]]) -> Result<Self> {
if batches.is_empty() {
return Err(TensorError::new("CIFAR-10: no batch data provided"));
}
let mut all_pixels: Vec<f32> = Vec::new();
let mut all_labels: Vec<i64> = Vec::new();
for (batch_idx, &batch) in batches.iter().enumerate() {
let expected = IMAGES_PER_BATCH * BYTES_PER_RECORD;
if batch.len() != expected {
return Err(TensorError::new(&format!(
"CIFAR-10 batch {}: expected {} bytes, got {}",
batch_idx, expected, batch.len()
)));
}
for img_idx in 0..IMAGES_PER_BATCH {
let offset = img_idx * BYTES_PER_RECORD;
let label = batch[offset] as i64;
if label > 9 {
return Err(TensorError::new(&format!(
"CIFAR-10 batch {} image {}: invalid label {}",
batch_idx, img_idx, label
)));
}
all_labels.push(label);
let pixel_start = offset + 1;
let pixel_end = pixel_start + PIXELS_PER_IMAGE;
for &b in &batch[pixel_start..pixel_end] {
all_pixels.push(b as f32 / 255.0);
}
}
}
let n = all_labels.len() as i64;
let images = Tensor::from_f32(&all_pixels, &[n, 3, 32, 32], Device::CPU)?;
let labels = Tensor::from_i64(&all_labels, &[n], Device::CPU)?;
Ok(Cifar10 { images, labels })
}
pub fn len(&self) -> usize {
self.images.shape()[0] as usize
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl BatchDataSet for Cifar10 {
fn len(&self) -> usize {
self.images.shape()[0] as usize
}
fn get_batch(&self, indices: &[usize]) -> Result<Vec<Tensor>> {
let idx: Vec<i64> = indices.iter().map(|&i| (i % self.len()) as i64).collect();
let idx_tensor = Tensor::from_i64(&idx, &[idx.len() as i64], Device::CPU)?;
let images = self.images.index_select(0, &idx_tensor)?;
let labels = self.labels.index_select(0, &idx_tensor)?;
Ok(vec![images, labels])
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_batch(n: usize) -> Vec<u8> {
let mut buf = Vec::with_capacity(n * BYTES_PER_RECORD);
for i in 0..n {
buf.push((i % 10) as u8); for _ in 0..1024 {
buf.push((i % 256) as u8);
}
buf.extend_from_slice(&[0u8; 1024]);
buf.extend_from_slice(&[255u8; 1024]);
}
buf
}
#[test]
fn parse_single_batch() {
let batch = make_batch(IMAGES_PER_BATCH);
let cifar = Cifar10::parse(&[&batch]).unwrap();
assert_eq!(cifar.images.shape(), &[10000, 3, 32, 32]);
assert_eq!(cifar.labels.shape(), &[10000]);
let l = cifar.labels.select(0, 0).unwrap().to_i64_vec().unwrap()[0];
assert_eq!(l, 0);
let l = cifar.labels.select(0, 1).unwrap().to_i64_vec().unwrap()[0];
assert_eq!(l, 1);
}
#[test]
fn parse_multiple_batches() {
let b1 = make_batch(IMAGES_PER_BATCH);
let b2 = make_batch(IMAGES_PER_BATCH);
let cifar = Cifar10::parse(&[&b1, &b2]).unwrap();
assert_eq!(cifar.images.shape(), &[20000, 3, 32, 32]);
}
#[test]
fn wrong_size_rejected() {
let batch = [0u8; 100]; assert!(Cifar10::parse(&[&batch[..]]).is_err());
}
#[test]
fn pixel_normalization() {
let batch = make_batch(IMAGES_PER_BATCH);
let cifar = Cifar10::parse(&[&batch]).unwrap();
let img0 = cifar.images.select(0, 0).unwrap();
let r_pixel: f64 = img0.select(0, 0).unwrap() .select(0, 0).unwrap() .select(0, 0).unwrap() .item().unwrap();
assert!((r_pixel - 0.0).abs() < 1e-6);
let b_pixel: f64 = img0.select(0, 2).unwrap() .select(0, 0).unwrap()
.select(0, 0).unwrap()
.item().unwrap();
assert!((b_pixel - 1.0).abs() < 1e-6);
}
}