use super::core::DataLoader;
use crate::{
collate::DefaultCollate,
dataset::Dataset,
sampler::{BatchingSampler, RandomSampler, SequentialSampler},
};
use torsh_core::error::Result;
pub type SimpleDataLoader<D> = DataLoader<D, BatchingSampler<SequentialSampler>, DefaultCollate>;
pub type SimpleRandomDataLoader<D> = DataLoader<D, BatchingSampler<RandomSampler>, DefaultCollate>;
pub fn simple_dataloader<D: Dataset>(
dataset: D,
batch_size: usize,
_shuffle: bool, ) -> Result<SimpleDataLoader<D>> {
DataLoader::builder(dataset)
.batch_size(batch_size)
.shuffle(false) .build()
}
pub fn simple_random_dataloader<D: Dataset>(
dataset: D,
batch_size: usize,
generator: Option<u64>,
) -> Result<SimpleRandomDataLoader<D>> {
let mut builder = DataLoader::builder(dataset)
.batch_size(batch_size)
.shuffle(true);
if let Some(seed) = generator {
builder = builder.generator(seed);
}
builder.build_with_random_sampling()
}
#[derive(Debug, Clone)]
pub struct SimpleConfig {
pub batch_size: usize,
pub shuffle: bool,
pub num_workers: usize,
pub drop_last: bool,
pub generator: Option<u64>,
}
impl Default for SimpleConfig {
fn default() -> Self {
Self {
batch_size: 1,
shuffle: false,
num_workers: 0,
drop_last: false,
generator: None,
}
}
}
impl SimpleConfig {
pub fn new() -> Self {
Self::default()
}
pub fn batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = batch_size;
self
}
pub fn shuffle(mut self, shuffle: bool) -> Self {
self.shuffle = shuffle;
self
}
pub fn num_workers(mut self, num_workers: usize) -> Self {
self.num_workers = num_workers;
self
}
pub fn drop_last(mut self, drop_last: bool) -> Self {
self.drop_last = drop_last;
self
}
pub fn generator(mut self, seed: u64) -> Self {
self.generator = Some(seed);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dataset::TensorDataset;
#[test]
fn test_simple_dataloader() {
let tensor = torsh_tensor::creation::ones::<f32>(&[5]).expect("operation should succeed");
let dataset = TensorDataset::from_tensor(tensor);
let dataloader =
simple_dataloader(dataset, 2, false).expect("simple dataloader should succeed");
assert_eq!(dataloader.len(), 3); assert!(!dataloader.is_empty());
}
#[test]
fn test_simple_random_dataloader() {
let tensor = torsh_tensor::creation::ones::<f32>(&[5]).expect("operation should succeed");
let dataset = TensorDataset::from_tensor(tensor);
let dataloader =
simple_random_dataloader(dataset, 2, Some(42)).expect("operation should succeed");
assert_eq!(dataloader.len(), 3);
assert!(!dataloader.is_empty());
}
#[test]
fn test_simple_random_dataloader_no_seed() {
let tensor = torsh_tensor::creation::ones::<f32>(&[5]).expect("operation should succeed");
let dataset = TensorDataset::from_tensor(tensor);
let dataloader = simple_random_dataloader(dataset, 2, None)
.expect("simple random dataloader should succeed");
assert_eq!(dataloader.len(), 3);
assert!(!dataloader.is_empty());
}
#[test]
fn test_simple_config() {
let config = SimpleConfig::new()
.batch_size(4)
.shuffle(true)
.num_workers(2)
.drop_last(true)
.generator(42);
assert_eq!(config.batch_size, 4);
assert!(config.shuffle);
assert_eq!(config.num_workers, 2);
assert!(config.drop_last);
assert_eq!(config.generator, Some(42));
}
#[test]
fn test_simple_config_defaults() {
let config = SimpleConfig::new();
assert_eq!(config.batch_size, 1);
assert!(!config.shuffle);
assert_eq!(config.num_workers, 0);
assert!(!config.drop_last);
assert_eq!(config.generator, None);
}
#[test]
fn test_simple_configured_dataloader_sequential() {
use torsh_core::device::DeviceType;
use torsh_tensor::Tensor;
let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0, 5.0], vec![5], DeviceType::Cpu)
.expect("Tensor should succeed");
let dataset = TensorDataset::from_tensor(tensor);
let _config = SimpleConfig::new()
.batch_size(2)
.shuffle(false)
.drop_last(false);
let dataloader =
simple_dataloader(dataset, 2, false).expect("simple dataloader should succeed");
assert_eq!(dataloader.len(), 3);
}
#[test]
fn test_simple_configured_dataloader_random() {
use torsh_core::device::DeviceType;
use torsh_tensor::Tensor;
let tensor = Tensor::from_data(vec![1.0f32, 2.0, 3.0, 4.0, 5.0], vec![5], DeviceType::Cpu)
.expect("Tensor should succeed");
let _dataset = TensorDataset::from_tensor(tensor);
let config = SimpleConfig::new()
.batch_size(2)
.shuffle(true)
.generator(42);
assert_eq!(config.batch_size, 2);
}
#[test]
fn test_empty_dataset_simple() {
let dataset: TensorDataset<f32> = TensorDataset::new(vec![]);
let dataloader =
simple_dataloader(dataset, 2, false).expect("simple dataloader should succeed");
assert_eq!(dataloader.len(), 0);
assert!(dataloader.is_empty());
}
}