use crate::data::Dataset;
use crate::error::{RusTorchError, RusTorchResult};
use crate::tensor::Tensor;
use crate::vision::transforms::Transform;
use crate::vision::{Image, ImageFormat};
use num_traits::Float;
use std::collections::HashMap;
use std::path::{Path, PathBuf};
#[derive(Debug)]
pub struct MNIST<T: Float> {
pub root: PathBuf,
pub train: bool,
images: Vec<Tensor<T>>,
labels: Vec<i64>,
transform: Option<Box<dyn Transform<T>>>,
download: bool,
}
impl<T: Float + From<f32> + From<u8> + Copy + 'static> MNIST<T> {
pub fn new<P: AsRef<Path>>(root: P, train: bool, download: bool) -> RusTorchResult<Self> {
let root = root.as_ref().to_path_buf();
let mut dataset = MNIST {
root,
train,
images: Vec::new(),
labels: Vec::new(),
transform: None,
download,
};
dataset.load_data()?;
Ok(dataset)
}
pub fn with_transform(mut self, transform: Box<dyn Transform<T>>) -> Self {
self.transform = Some(transform);
self
}
fn load_data(&mut self) -> RusTorchResult<()> {
let data_dir = self.root.join("MNIST").join("raw");
if !data_dir.exists() {
if self.download {
self.download_data()?;
} else {
return Err(RusTorchError::DatasetError(format!(
"MNIST data not found at {:?}. Set download=true to download.",
data_dir
)));
}
}
let num_samples = if self.train { 60000 } else { 10000 };
for i in 0..num_samples {
let image_data: Vec<T> = (0..784).map(|_| <T as From<f32>>::from(0.5f32)).collect();
let image_tensor = Tensor::from_vec(image_data, vec![1, 28, 28]);
self.images.push(image_tensor);
self.labels.push((i % 10) as i64);
}
Ok(())
}
fn download_data(&self) -> RusTorchResult<()> {
let data_dir = self.root.join("MNIST").join("raw");
std::fs::create_dir_all(&data_dir)
.map_err(|e| RusTorchError::IoError(format!("Failed to create directory: {}", e)))?;
println!("Note: MNIST download not implemented - using dummy data");
Ok(())
}
pub fn num_classes(&self) -> usize {
10
}
}
impl<T: Float + From<f32> + From<u8> + Copy + 'static> Dataset<(Tensor<T>, Tensor<T>)>
for MNIST<T>
{
fn len(&self) -> usize {
self.images.len()
}
fn get_item(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>), crate::data::DataError> {
if index >= self.images.len() {
return Err(crate::error::RusTorchError::InvalidParameters {
operation: "MNIST::get_item".to_string(),
message: format!(
"Index {} out of bounds for dataset of size {}",
index,
self.images.len()
),
});
}
let mut image = self.images[index].clone();
let label = Tensor::from_vec(
vec![<T as From<u8>>::from(self.labels[index] as u8)],
vec![1],
);
if let Some(ref transform) = self.transform {
if let Ok(img) = Image::new(image.clone(), ImageFormat::CHW) {
if let Ok(transformed_img) = transform.apply(&img) {
image = transformed_img.data;
}
}
}
Ok((image, label))
}
}
#[derive(Debug)]
pub struct CIFAR10<T: Float> {
pub root: PathBuf,
pub train: bool,
images: Vec<Tensor<T>>,
labels: Vec<i64>,
transform: Option<Box<dyn Transform<T>>>,
download: bool,
classes: Vec<String>,
}
impl<T: Float + From<f32> + From<u8> + Copy + 'static> CIFAR10<T> {
pub fn new<P: AsRef<Path>>(root: P, train: bool, download: bool) -> RusTorchResult<Self> {
let root = root.as_ref().to_path_buf();
let classes = vec![
"airplane".to_string(),
"automobile".to_string(),
"bird".to_string(),
"cat".to_string(),
"deer".to_string(),
"dog".to_string(),
"frog".to_string(),
"horse".to_string(),
"ship".to_string(),
"truck".to_string(),
];
let mut dataset = CIFAR10 {
root,
train,
images: Vec::new(),
labels: Vec::new(),
transform: None,
download,
classes,
};
dataset.load_data()?;
Ok(dataset)
}
pub fn with_transform(mut self, transform: Box<dyn Transform<T>>) -> Self {
self.transform = Some(transform);
self
}
fn load_data(&mut self) -> RusTorchResult<()> {
let data_dir = self.root.join("cifar-10-batches-py");
if !data_dir.exists() {
if self.download {
self.download_data()?;
} else {
return Err(RusTorchError::DatasetError(format!(
"CIFAR-10 data not found at {:?}. Set download=true to download.",
data_dir
)));
}
}
let num_samples = if self.train { 50000 } else { 10000 };
for i in 0..num_samples {
let image_data: Vec<T> = (0..3072).map(|_| <T as From<f32>>::from(0.5f32)).collect();
let image_tensor = Tensor::from_vec(image_data, vec![3, 32, 32]);
self.images.push(image_tensor);
self.labels.push((i % 10) as i64);
}
Ok(())
}
fn download_data(&self) -> RusTorchResult<()> {
let data_dir = self.root.join("cifar-10-batches-py");
std::fs::create_dir_all(&data_dir)
.map_err(|e| RusTorchError::IoError(format!("Failed to create directory: {}", e)))?;
println!("Note: CIFAR-10 download not implemented - using dummy data");
Ok(())
}
pub fn num_classes(&self) -> usize {
10
}
pub fn class_names(&self) -> &[String] {
&self.classes
}
}
impl<T: Float + From<f32> + From<u8> + Copy + 'static> Dataset<(Tensor<T>, Tensor<T>)>
for CIFAR10<T>
{
fn len(&self) -> usize {
self.images.len()
}
fn get_item(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>), crate::data::DataError> {
if index >= self.images.len() {
return Err(crate::error::RusTorchError::InvalidParameters {
operation: "CIFAR10::get_item".to_string(),
message: format!(
"Index {} out of bounds for dataset of size {}",
index,
self.images.len()
),
});
}
let mut image = self.images[index].clone();
let label = Tensor::from_vec(
vec![<T as From<u8>>::from(self.labels[index] as u8)],
vec![1],
);
if let Some(ref transform) = self.transform {
if let Ok(img) = Image::new(image.clone(), ImageFormat::CHW) {
if let Ok(transformed_img) = transform.apply(&img) {
image = transformed_img.data;
}
}
}
Ok((image, label))
}
}
#[derive(Debug)]
pub struct CIFAR100<T: Float> {
pub root: PathBuf,
pub train: bool,
images: Vec<Tensor<T>>,
fine_labels: Vec<i64>,
coarse_labels: Vec<i64>,
transform: Option<Box<dyn Transform<T>>>,
download: bool,
fine_classes: Vec<String>,
coarse_classes: Vec<String>,
}
impl<T: Float + From<f32> + From<u8> + Copy + 'static> CIFAR100<T> {
pub fn new<P: AsRef<Path>>(root: P, train: bool, download: bool) -> RusTorchResult<Self> {
let root = root.as_ref().to_path_buf();
let fine_classes: Vec<String> = (0..100).map(|i| format!("class_{}", i)).collect();
let coarse_classes: Vec<String> = (0..20).map(|i| format!("superclass_{}", i)).collect();
let mut dataset = CIFAR100 {
root,
train,
images: Vec::new(),
fine_labels: Vec::new(),
coarse_labels: Vec::new(),
transform: None,
download,
fine_classes,
coarse_classes,
};
dataset.load_data()?;
Ok(dataset)
}
pub fn with_transform(mut self, transform: Box<dyn Transform<T>>) -> Self {
self.transform = Some(transform);
self
}
fn load_data(&mut self) -> RusTorchResult<()> {
let data_dir = self.root.join("cifar-100-python");
if !data_dir.exists() {
if self.download {
self.download_data()?;
} else {
return Err(RusTorchError::DatasetError(format!(
"CIFAR-100 data not found at {:?}. Set download=true to download.",
data_dir
)));
}
}
let num_samples = if self.train { 50000 } else { 10000 };
for i in 0..num_samples {
let image_data: Vec<T> = (0..3072).map(|_| <T as From<f32>>::from(0.5f32)).collect();
let image_tensor = Tensor::from_vec(image_data, vec![3, 32, 32]);
self.images.push(image_tensor);
self.fine_labels.push((i % 100) as i64);
self.coarse_labels.push((i % 20) as i64);
}
Ok(())
}
fn download_data(&self) -> RusTorchResult<()> {
let data_dir = self.root.join("cifar-100-python");
std::fs::create_dir_all(&data_dir)
.map_err(|e| RusTorchError::IoError(format!("Failed to create directory: {}", e)))?;
println!("Note: CIFAR-100 download not implemented - using dummy data");
Ok(())
}
pub fn num_fine_classes(&self) -> usize {
100
}
pub fn num_coarse_classes(&self) -> usize {
20
}
pub fn fine_class_names(&self) -> &[String] {
&self.fine_classes
}
pub fn coarse_class_names(&self) -> &[String] {
&self.coarse_classes
}
}
impl<T: Float + From<f32> + From<u8> + Copy + 'static> Dataset<(Tensor<T>, Tensor<T>)>
for CIFAR100<T>
{
fn len(&self) -> usize {
self.images.len()
}
fn get_item(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>), crate::data::DataError> {
if index >= self.images.len() {
return Err(crate::error::RusTorchError::InvalidParameters {
operation: "CIFAR100::get_item".to_string(),
message: format!(
"Index {} out of bounds for dataset of size {}",
index,
self.images.len()
),
});
}
let mut image = self.images[index].clone();
let label = Tensor::from_vec(
vec![<T as From<u8>>::from(self.fine_labels[index] as u8)],
vec![1],
);
if let Some(ref transform) = self.transform {
if let Ok(img) = Image::new(image.clone(), ImageFormat::CHW) {
if let Ok(transformed_img) = transform.apply(&img) {
image = transformed_img.data;
}
}
}
Ok((image, label))
}
}
#[derive(Debug)]
pub struct ImageFolder<T: Float> {
pub root: PathBuf,
class_to_idx: HashMap<String, usize>,
samples: Vec<(PathBuf, usize)>,
transform: Option<Box<dyn Transform<T>>>,
}
impl<T: Float + From<f32> + From<u8> + Copy + 'static> ImageFolder<T> {
pub fn new<P: AsRef<Path>>(root: P) -> RusTorchResult<Self> {
let root = root.as_ref().to_path_buf();
if !root.exists() || !root.is_dir() {
return Err(RusTorchError::DatasetError(format!(
"Root directory {:?} does not exist or is not a directory",
root
)));
}
let mut dataset = ImageFolder {
root,
class_to_idx: HashMap::new(),
samples: Vec::new(),
transform: None,
};
dataset.scan_directory()?;
Ok(dataset)
}
pub fn with_transform(mut self, transform: Box<dyn Transform<T>>) -> Self {
self.transform = Some(transform);
self
}
fn scan_directory(&mut self) -> RusTorchResult<()> {
let mut class_names = Vec::new();
for entry in std::fs::read_dir(&self.root)
.map_err(|e| RusTorchError::IoError(format!("Failed to read directory: {}", e)))?
{
let entry = entry.map_err(|e| RusTorchError::IoError(e.to_string()))?;
let path = entry.path();
if path.is_dir() {
if let Some(class_name) = path.file_name() {
if let Some(class_name_str) = class_name.to_str() {
class_names.push(class_name_str.to_string());
}
}
}
}
class_names.sort();
for (idx, class_name) in class_names.iter().enumerate() {
self.class_to_idx.insert(class_name.clone(), idx);
}
for class_name in class_names {
let class_dir = self.root.join(&class_name);
let class_idx = self.class_to_idx[&class_name];
for entry in std::fs::read_dir(&class_dir).map_err(|e| {
RusTorchError::IoError(format!("Failed to read class directory: {}", e))
})? {
let entry = entry.map_err(|e| RusTorchError::IoError(e.to_string()))?;
let path = entry.path();
if path.is_file() {
if let Some(extension) = path.extension() {
if let Some(ext_str) = extension.to_str() {
if matches!(
ext_str.to_lowercase().as_str(),
"jpg" | "jpeg" | "png" | "bmp" | "tiff"
) {
self.samples.push((path, class_idx));
}
}
}
}
}
}
Ok(())
}
pub fn classes(&self) -> Vec<String> {
let mut classes: Vec<_> = self.class_to_idx.iter().collect();
classes.sort_by_key(|(_, &idx)| idx);
classes.into_iter().map(|(name, _)| name.clone()).collect()
}
pub fn num_classes(&self) -> usize {
self.class_to_idx.len()
}
}
impl<T: Float + From<f32> + From<u8> + Copy + 'static> Dataset<(Tensor<T>, Tensor<T>)>
for ImageFolder<T>
{
fn len(&self) -> usize {
self.samples.len()
}
fn get_item(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>), crate::data::DataError> {
if index >= self.samples.len() {
return Err(crate::error::RusTorchError::InvalidParameters {
operation: "ImageFolder::get_item".to_string(),
message: format!(
"Index {} out of bounds for dataset of size {}",
index,
self.samples.len()
),
});
}
let (_path, class_idx) = &self.samples[index];
let image_data: Vec<T> = (0..3072).map(|_| <T as From<f32>>::from(0.5f32)).collect();
let mut image = Tensor::from_vec(image_data, vec![3, 32, 32]);
let label = Tensor::from_vec(vec![<T as From<u8>>::from(*class_idx as u8)], vec![1]);
if let Some(ref transform) = self.transform {
if let Ok(img) = Image::new(image.clone(), ImageFormat::CHW) {
if let Ok(transformed_img) = transform.apply(&img) {
image = transformed_img.data;
}
}
}
Ok((image, label))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
#[test]
fn test_mnist_creation() {
let temp_dir = env::temp_dir().join("test_mnist");
std::fs::create_dir_all(&temp_dir).unwrap();
let mnist = MNIST::<f32>::new(&temp_dir, true, true);
assert!(mnist.is_ok());
let mnist = mnist.unwrap();
assert!(Dataset::len(&mnist) > 0);
assert_eq!(mnist.num_classes(), 10);
}
#[test]
fn test_cifar10_creation() {
let temp_dir = env::temp_dir().join("test_cifar10");
std::fs::create_dir_all(&temp_dir).unwrap();
let cifar10 = CIFAR10::<f32>::new(&temp_dir, true, true);
assert!(cifar10.is_ok());
let cifar10 = cifar10.unwrap();
assert!(Dataset::len(&cifar10) > 0);
assert_eq!(cifar10.num_classes(), 10);
assert_eq!(cifar10.class_names().len(), 10);
}
#[test]
fn test_cifar100_creation() {
let temp_dir = env::temp_dir().join("test_cifar100");
std::fs::create_dir_all(&temp_dir).unwrap();
let cifar100 = CIFAR100::<f32>::new(&temp_dir, true, true);
assert!(cifar100.is_ok());
let cifar100 = cifar100.unwrap();
assert!(Dataset::len(&cifar100) > 0);
assert_eq!(cifar100.num_fine_classes(), 100);
assert_eq!(cifar100.num_coarse_classes(), 20);
let fine_classes = cifar100.fine_class_names();
let coarse_classes = cifar100.coarse_class_names();
assert_eq!(fine_classes.len(), 100);
assert_eq!(coarse_classes.len(), 20);
}
}