use torsh_core::error::Result;
pub mod core;
pub mod memory;
pub mod prefetch;
pub mod simple;
pub mod workers;
pub use core::{
DataLoader, DataLoaderBuilder, DataLoaderIterator, DataLoaderTrait, RandomDataLoader,
SimpleDataLoader,
};
pub use simple::{
simple_dataloader, simple_random_dataloader, SimpleConfig, SimpleRandomDataLoader,
};
pub use prefetch::{PrefetchConfig, PrefetchExt, PrefetchIterator};
pub use workers::{
MultiProcessIterator, PersistentWorkerPool, WorkerConfig, WorkerPool, WorkerResult,
};
pub use memory::{CpuMemoryPinner, MemoryPinning, MemoryPinningManager, PinningConfig};
#[cfg(feature = "cuda")]
pub use memory::CudaMemoryPinner;
#[derive(Debug, Clone)]
pub struct DataLoaderConfig {
pub batch_size: usize,
pub shuffle: bool,
pub num_workers: usize,
pub pin_memory: bool,
pub drop_last: bool,
pub timeout: Option<std::time::Duration>,
pub generator: Option<u64>,
pub prefetch_buffer_size: usize,
pub persistent_workers: bool,
pub pinning_config: Option<PinningConfig>,
}
impl Default for DataLoaderConfig {
fn default() -> Self {
Self {
batch_size: 1,
shuffle: false,
num_workers: 0,
pin_memory: false,
drop_last: false,
timeout: None,
generator: None,
prefetch_buffer_size: 0,
persistent_workers: false,
pinning_config: None,
}
}
}
impl DataLoaderConfig {
pub fn new() -> Self {
Self::default()
}
pub fn batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = batch_size;
self
}
pub fn shuffle(mut self, shuffle: bool) -> Self {
self.shuffle = shuffle;
self
}
pub fn num_workers(mut self, num_workers: usize) -> Self {
self.num_workers = num_workers;
self
}
pub fn pin_memory(mut self, pin_memory: bool) -> Self {
self.pin_memory = pin_memory;
self
}
pub fn drop_last(mut self, drop_last: bool) -> Self {
self.drop_last = drop_last;
self
}
pub fn timeout(mut self, timeout: std::time::Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub fn generator(mut self, seed: u64) -> Self {
self.generator = Some(seed);
self
}
pub fn prefetch_buffer_size(mut self, size: usize) -> Self {
self.prefetch_buffer_size = size;
self
}
pub fn persistent_workers(mut self, persistent: bool) -> Self {
self.persistent_workers = persistent;
self
}
pub fn pinning_config(mut self, config: PinningConfig) -> Self {
self.pinning_config = Some(config);
self
}
pub fn for_training() -> Self {
Self::new()
.batch_size(32)
.shuffle(true)
.num_workers(workers::utils::optimal_worker_count(false))
.pin_memory(true)
.prefetch_buffer_size(4)
.persistent_workers(true)
}
pub fn for_inference() -> Self {
Self::new()
.batch_size(1)
.shuffle(false)
.num_workers(workers::utils::optimal_worker_count(true))
.pin_memory(false)
.prefetch_buffer_size(2)
.persistent_workers(false)
}
pub fn for_evaluation() -> Self {
Self::new()
.batch_size(32)
.shuffle(false)
.num_workers(workers::utils::optimal_worker_count(false))
.pin_memory(false)
.prefetch_buffer_size(2)
.persistent_workers(false)
.drop_last(false)
}
}
pub mod utils {
use super::*;
use torsh_core::device::DeviceType;
pub fn optimal_config(
dataset_size: usize,
scenario: &str,
target_device: Option<DeviceType>,
) -> DataLoaderConfig {
let base_config = match scenario.to_lowercase().as_str() {
"training" | "train" => DataLoaderConfig::for_training(),
"inference" | "infer" => DataLoaderConfig::for_inference(),
"evaluation" | "eval" | "test" => DataLoaderConfig::for_evaluation(),
_ => DataLoaderConfig::new(),
};
let mut config = base_config;
if dataset_size < 100 {
config = config.batch_size(dataset_size.min(8));
} else if dataset_size < 1000 {
config = config.batch_size(16);
} else {
config = config.batch_size(32);
}
if let Some(device) = target_device {
match device {
DeviceType::Cuda(device_id) => {
config = config
.pin_memory(true)
.pinning_config(PinningConfig::cuda(device_id));
}
_ => {
config = config.pin_memory(false);
}
}
}
config
}
pub fn validate_config(config: &DataLoaderConfig, dataset_size: usize) -> Result<Vec<String>> {
let mut warnings = Vec::new();
if config.batch_size > dataset_size {
warnings.push(format!(
"Batch size ({}) is larger than dataset size ({})",
config.batch_size, dataset_size
));
}
if config.num_workers
> std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1)
* 2
{
warnings.push(format!(
"Number of workers ({}) may be too high for system capabilities",
config.num_workers
));
}
if config.prefetch_buffer_size > 0 && config.num_workers == 0 {
warnings.push(
"Prefetching enabled but no workers configured, may not improve performance"
.to_string(),
);
}
if config.pin_memory && config.pinning_config.is_none() {
warnings
.push("Memory pinning enabled but no pinning configuration provided".to_string());
}
Ok(warnings)
}
pub fn performance_recommendations(
dataset_size: usize,
batch_size: usize,
scenario: &str,
) -> Vec<String> {
let mut recommendations = Vec::new();
if scenario == "training" && batch_size < 16 {
recommendations.push(
"Consider increasing batch size for training (recommended: 16-32)".to_string(),
);
}
if batch_size > dataset_size / 10 {
recommendations.push(
"Large batch size relative to dataset may reduce training effectiveness"
.to_string(),
);
}
let optimal_workers = workers::utils::optimal_worker_count(scenario == "inference");
recommendations.push(format!(
"Consider using {} workers for optimal performance",
optimal_workers
));
if scenario == "training" {
recommendations.push(
"Enable prefetching (buffer size 2-4) for better training performance".to_string(),
);
}
recommendations.push("Enable memory pinning if transferring data to GPU".to_string());
recommendations
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dataset::TensorDataset;
use torsh_core::device::DeviceType;
#[test]
fn test_dataloader_config() {
let config = DataLoaderConfig::new()
.batch_size(16)
.shuffle(true)
.num_workers(4)
.prefetch_buffer_size(2);
assert_eq!(config.batch_size, 16);
assert!(config.shuffle);
assert_eq!(config.num_workers, 4);
assert_eq!(config.prefetch_buffer_size, 2);
}
#[test]
fn test_training_config() {
let config = DataLoaderConfig::for_training();
assert!(config.shuffle);
assert!(config.persistent_workers);
assert!(config.pin_memory);
assert!(config.prefetch_buffer_size > 0);
}
#[test]
fn test_inference_config() {
let config = DataLoaderConfig::for_inference();
assert!(!config.shuffle);
assert!(!config.persistent_workers);
assert!(!config.pin_memory);
}
#[test]
fn test_evaluation_config() {
let config = DataLoaderConfig::for_evaluation();
assert!(!config.shuffle);
assert!(!config.drop_last);
assert!(!config.persistent_workers);
}
#[test]
fn test_utils_optimal_config() {
let config = utils::optimal_config(1000, "training", Some(DeviceType::Cuda(0)));
assert!(config.shuffle);
assert!(config.pin_memory);
assert!(config.pinning_config.is_some());
}
#[test]
fn test_utils_validate_config() {
let config = DataLoaderConfig::new().batch_size(100);
let warnings = utils::validate_config(&config, 50).expect("utils should succeed");
assert!(!warnings.is_empty());
assert!(warnings[0].contains("Batch size"));
}
#[test]
fn test_utils_performance_recommendations() {
let recommendations = utils::performance_recommendations(1000, 8, "training");
assert!(!recommendations.is_empty());
assert!(recommendations.iter().any(|r| r.contains("batch size")));
}
#[test]
fn test_backward_compatibility() {
use torsh_core::device::DeviceType;
use torsh_tensor::Tensor;
let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0, 5.0], vec![5], DeviceType::Cpu)
.expect("Tensor should succeed");
let dataset = TensorDataset::from_tensor(tensor);
let dataloader =
simple_dataloader(dataset, 2, false).expect("simple dataloader should succeed");
assert_eq!(dataloader.len(), 3);
}
#[test]
fn test_prefetch_integration() {
use torsh_core::device::DeviceType;
use torsh_tensor::Tensor;
let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0, 5.0], vec![5], DeviceType::Cpu)
.expect("Tensor should succeed");
let dataset = TensorDataset::from_tensor(tensor);
let dataloader = DataLoader::builder(dataset)
.batch_size(2)
.build()
.expect("build should succeed with valid configuration");
let mut iter = dataloader.iter();
let first_batch = iter
.next()
.expect("iterator should have a next element")
.expect("operation should succeed");
assert_eq!(first_batch.len(), 1);
assert_eq!(first_batch[0].shape().dims(), &[2, 1]);
}
}