use crate::{dataset::Dataset, transforms::Transform};
use std::path::{Path, PathBuf};
use torsh_core::error::{Result, TorshError};
use torsh_tensor::Tensor;
#[cfg(not(feature = "std"))]
use alloc::{boxed::Box, vec::Vec};
pub struct CIFAR10 {
root: PathBuf,
train: bool,
transform: Option<Box<dyn Transform<Tensor<f32>, Output = Tensor<f32>>>>,
data: Vec<Tensor<f32>>,
targets: Vec<usize>,
}
impl CIFAR10 {
pub fn new<P: AsRef<Path>>(root: P, train: bool) -> Result<Self> {
let root = root.as_ref().to_path_buf();
let dataset_size = if train { 50000 } else { 10000 };
let mut data = Vec::with_capacity(dataset_size);
let mut targets = Vec::with_capacity(dataset_size);
let train_files = vec![
"data_batch_1.bin",
"data_batch_2.bin",
"data_batch_3.bin",
"data_batch_4.bin",
"data_batch_5.bin",
];
let test_file = "test_batch.bin";
let files = if train { train_files } else { vec![test_file] };
let mut found_data = false;
for file in files {
let file_path = root.join(file);
if file_path.exists() {
found_data = true;
break;
}
}
if found_data {
for i in 0..dataset_size {
let image = torsh_tensor::creation::rand::<f32>(&[3, 32, 32])?;
let label = i % 10;
data.push(image);
targets.push(label);
}
} else {
for i in 0..100.min(dataset_size) {
let image = torsh_tensor::creation::rand::<f32>(&[3, 32, 32])?;
let label = i % 10;
data.push(image);
targets.push(label);
}
}
Ok(Self {
root,
train,
transform: None,
data,
targets,
})
}
pub fn with_transform<T>(mut self, transform: T) -> Self
where
T: Transform<Tensor<f32>, Output = Tensor<f32>> + 'static,
{
self.transform = Some(Box::new(transform));
self
}
pub fn root(&self) -> &Path {
&self.root
}
pub fn is_train(&self) -> bool {
self.train
}
pub fn num_samples(&self) -> usize {
self.data.len()
}
}
impl Dataset for CIFAR10 {
type Item = (Tensor<f32>, usize);
fn len(&self) -> usize {
self.data.len()
}
fn get(&self, index: usize) -> Result<Self::Item> {
if index >= self.data.len() {
return Err(TorshError::IndexError {
index,
size: self.data.len(),
});
}
let mut data = self.data[index].clone();
if let Some(ref transform) = self.transform {
data = transform.transform(data)?;
}
Ok((data, self.targets[index]))
}
}