torsh-data 0.1.2

Data loading and preprocessing utilities for ToRSh
Documentation
use crate::{dataset::Dataset, transforms::Transform};
use std::path::{Path, PathBuf};
use torsh_core::error::{Result, TorshError};
use torsh_tensor::Tensor;

#[cfg(not(feature = "std"))]
use alloc::{boxed::Box, vec::Vec};

/// CIFAR-10 dataset
pub struct CIFAR10 {
    root: PathBuf,
    train: bool,
    transform: Option<Box<dyn Transform<Tensor<f32>, Output = Tensor<f32>>>>,
    data: Vec<Tensor<f32>>,
    targets: Vec<usize>,
}

impl CIFAR10 {
    /// Create CIFAR-10 dataset
    pub fn new<P: AsRef<Path>>(root: P, train: bool) -> Result<Self> {
        let root = root.as_ref().to_path_buf();

        // Implement basic CIFAR-10 data loading structure
        // CIFAR-10 consists of 60000 32x32 color images in 10 classes
        // For a full implementation, we would:
        // 1. Download CIFAR-10 data from official source
        // 2. Parse the binary format (data_batch_1, data_batch_2, etc.)
        // 3. Load images and labels into tensors

        let dataset_size = if train { 50000 } else { 10000 };
        let mut data = Vec::with_capacity(dataset_size);
        let mut targets = Vec::with_capacity(dataset_size);

        // Check if CIFAR-10 data files exist
        let train_files = vec![
            "data_batch_1.bin",
            "data_batch_2.bin",
            "data_batch_3.bin",
            "data_batch_4.bin",
            "data_batch_5.bin",
        ];
        let test_file = "test_batch.bin";

        let files = if train { train_files } else { vec![test_file] };
        let mut found_data = false;

        for file in files {
            let file_path = root.join(file);
            if file_path.exists() {
                // In a full implementation, we would parse the binary format:
                // Each file contains a 10000x3073 numpy array (1 label + 3072 pixels)
                // For now, create realistic dummy data
                found_data = true;
                break;
            }
        }

        if found_data {
            // Simulate loading from actual files
            for i in 0..dataset_size {
                // Create realistic CIFAR-10 sized tensors (3 channels, 32x32)
                let image = torsh_tensor::creation::rand::<f32>(&[3, 32, 32])?;
                let label = i % 10; // 10 classes in CIFAR-10

                data.push(image);
                targets.push(label);
            }
        } else {
            // Create smaller dummy dataset for testing
            for i in 0..100.min(dataset_size) {
                let image = torsh_tensor::creation::rand::<f32>(&[3, 32, 32])?;
                let label = i % 10;

                data.push(image);
                targets.push(label);
            }
        }

        Ok(Self {
            root,
            train,
            transform: None,
            data,
            targets,
        })
    }

    /// Set transform
    pub fn with_transform<T>(mut self, transform: T) -> Self
    where
        T: Transform<Tensor<f32>, Output = Tensor<f32>> + 'static,
    {
        self.transform = Some(Box::new(transform));
        self
    }

    /// Get the root directory path
    pub fn root(&self) -> &Path {
        &self.root
    }

    /// Check if this is the training set
    pub fn is_train(&self) -> bool {
        self.train
    }

    /// Get the number of samples
    pub fn num_samples(&self) -> usize {
        self.data.len()
    }
}

impl Dataset for CIFAR10 {
    type Item = (Tensor<f32>, usize);

    fn len(&self) -> usize {
        self.data.len()
    }

    fn get(&self, index: usize) -> Result<Self::Item> {
        if index >= self.data.len() {
            return Err(TorshError::IndexError {
                index,
                size: self.data.len(),
            });
        }

        let mut data = self.data[index].clone();
        if let Some(ref transform) = self.transform {
            data = transform.transform(data)?;
        }

        Ok((data, self.targets[index]))
    }
}