use crate::{
collate::{Collate, DefaultCollate},
dataset::Dataset,
sampler::{BatchSampler, BatchingSampler, RandomSampler, SequentialSampler},
};
use scirs2_core::parallel_ops::*;
use torsh_core::error::Result;
#[cfg(not(feature = "std"))]
use alloc::{boxed::Box, vec::Vec};
pub trait DataLoaderTrait<D: Dataset, C: Collate<D::Item>> {
fn len(&self) -> usize;
fn is_empty(&self) -> bool;
}
pub struct DataLoader<D, S, C> {
dataset: D,
sampler: S,
collate_fn: C,
num_workers: usize,
#[allow(dead_code)]
pin_memory: bool,
#[allow(dead_code)]
drop_last: bool,
#[allow(dead_code)]
timeout: Option<std::time::Duration>,
}
impl<D: Dataset> DataLoader<D, (), ()> {
pub fn builder(dataset: D) -> DataLoaderBuilder<D> {
DataLoaderBuilder::new(dataset)
}
}
impl<D, S, C> DataLoader<D, S, C>
where
D: Dataset,
S: BatchSampler,
C: Collate<D::Item>,
{
pub fn iter(&self) -> DataLoaderIterator<'_, D, S, C> {
DataLoaderIterator {
dataset: &self.dataset,
sampler_iter: self.sampler.iter(),
collate_fn: &self.collate_fn,
num_workers: self.num_workers,
}
}
pub fn len(&self) -> usize {
self.sampler.len()
}
pub fn is_empty(&self) -> bool {
self.sampler.is_empty()
}
pub fn dataset(&self) -> &D {
&self.dataset
}
pub fn sampler(&self) -> &S {
&self.sampler
}
pub fn collate_fn(&self) -> &C {
&self.collate_fn
}
pub fn num_workers(&self) -> usize {
self.num_workers
}
}
impl<D, S, C> DataLoaderTrait<D, C> for DataLoader<D, S, C>
where
D: Dataset + Sync,
S: BatchSampler + Sync,
C: Collate<D::Item> + Sync,
D::Item: Send,
C::Output: Send,
S::Iter: Iterator<Item = Vec<usize>>,
{
fn len(&self) -> usize {
self.sampler.len()
}
fn is_empty(&self) -> bool {
self.sampler.is_empty()
}
}
pub struct DataLoaderIterator<'a, D, S, C>
where
D: Dataset,
S: BatchSampler,
C: Collate<D::Item>,
{
dataset: &'a D,
sampler_iter: S::Iter,
collate_fn: &'a C,
num_workers: usize,
}
impl<D, S, C> Iterator for DataLoaderIterator<'_, D, S, C>
where
D: Dataset + Sync,
D::Item: Send,
S: BatchSampler,
S::Iter: Iterator<Item = Vec<usize>>,
C: Collate<D::Item> + Sync,
C::Output: Send,
{
type Item = Result<C::Output>;
fn next(&mut self) -> Option<Self::Item> {
let indices = self.sampler_iter.next()?;
let batch_result = if self.num_workers > 1 {
let samples: Result<Vec<_>> = indices
.into_par_iter()
.map(|idx| self.dataset.get(idx))
.collect();
match samples {
Ok(samples) => self.collate_fn.collate(samples),
Err(e) => return Some(Err(e)),
}
} else {
let mut samples = Vec::with_capacity(indices.len());
for idx in indices {
match self.dataset.get(idx) {
Ok(sample) => samples.push(sample),
Err(e) => return Some(Err(e)),
}
}
self.collate_fn.collate(samples)
};
match batch_result {
Ok(batch) => {
Some(Ok(batch))
}
Err(e) => Some(Err(e)),
}
}
}
pub struct DataLoaderBuilder<D: Dataset> {
dataset: D,
batch_size: Option<usize>,
shuffle: bool,
num_workers: usize,
pin_memory: bool,
drop_last: bool,
timeout: Option<std::time::Duration>,
generator: Option<u64>,
}
impl<D: Dataset> DataLoaderBuilder<D> {
pub fn new(dataset: D) -> Self {
Self {
dataset,
batch_size: None,
shuffle: false,
num_workers: 0,
pin_memory: false,
drop_last: false,
timeout: None,
generator: None,
}
}
pub fn batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = Some(batch_size);
self
}
pub fn shuffle(mut self, shuffle: bool) -> Self {
self.shuffle = shuffle;
self
}
pub fn num_workers(mut self, num_workers: usize) -> Self {
self.num_workers = num_workers;
self
}
pub fn pin_memory(mut self, pin_memory: bool) -> Self {
self.pin_memory = pin_memory;
self
}
pub fn drop_last(mut self, drop_last: bool) -> Self {
self.drop_last = drop_last;
self
}
pub fn timeout(mut self, timeout: std::time::Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub fn generator(mut self, seed: u64) -> Self {
self.generator = Some(seed);
self
}
pub fn build(
self,
) -> Result<DataLoader<D, BatchingSampler<SequentialSampler>, DefaultCollate>> {
let batch_size = self.batch_size.unwrap_or(1);
let base_sampler = SequentialSampler::new(self.dataset.len());
let batch_sampler = BatchingSampler::new(base_sampler, batch_size, self.drop_last);
Ok(DataLoader {
dataset: self.dataset,
sampler: batch_sampler,
collate_fn: DefaultCollate,
num_workers: self.num_workers,
pin_memory: self.pin_memory,
drop_last: self.drop_last,
timeout: self.timeout,
})
}
pub fn build_with_random_sampling(
self,
) -> Result<DataLoader<D, BatchingSampler<RandomSampler>, DefaultCollate>> {
let batch_size = self.batch_size.unwrap_or(1);
let mut base_sampler = RandomSampler::new(self.dataset.len(), None, false);
if let Some(seed) = self.generator {
base_sampler = base_sampler.with_generator(seed);
}
let batch_sampler = BatchingSampler::new(base_sampler, batch_size, self.drop_last);
Ok(DataLoader {
dataset: self.dataset,
sampler: batch_sampler,
collate_fn: DefaultCollate,
num_workers: self.num_workers,
pin_memory: self.pin_memory,
drop_last: self.drop_last,
timeout: self.timeout,
})
}
pub fn build_auto(self) -> Result<Box<dyn DataLoaderTrait<D, DefaultCollate> + Send + Sync>>
where
D: Send + Sync + 'static,
D::Item: Send + Sync + 'static,
DefaultCollate: Collate<D::Item>,
<DefaultCollate as Collate<D::Item>>::Output: Send,
{
if self.shuffle {
Ok(Box::new(self.build_with_random_sampling()?))
} else {
Ok(Box::new(self.build()?))
}
}
}
pub type SimpleDataLoader<D> = DataLoader<D, BatchingSampler<SequentialSampler>, DefaultCollate>;
pub type RandomDataLoader<D> = DataLoader<D, BatchingSampler<RandomSampler>, DefaultCollate>;
#[cfg(test)]
mod tests {
use super::*;
use crate::dataset::TensorDataset;
#[test]
fn test_dataloader_builder() {
let tensor = torsh_tensor::creation::ones::<f32>(&[5]).expect("operation should succeed");
let dataset = TensorDataset::from_tensor(tensor);
let builder = DataLoaderBuilder::new(dataset);
assert_eq!(builder.batch_size, None);
assert!(!builder.shuffle);
assert_eq!(builder.num_workers, 0);
assert!(!builder.pin_memory);
assert!(!builder.drop_last);
}
#[test]
fn test_dataloader_builder_configuration() {
let tensor = torsh_tensor::creation::ones::<f32>(&[5]).expect("operation should succeed");
let dataset = TensorDataset::from_tensor(tensor);
let builder = DataLoaderBuilder::new(dataset)
.batch_size(2)
.shuffle(true)
.num_workers(4)
.pin_memory(true)
.drop_last(true);
assert_eq!(builder.batch_size, Some(2));
assert!(builder.shuffle);
assert_eq!(builder.num_workers, 4);
assert!(builder.pin_memory);
assert!(builder.drop_last);
}
#[test]
fn test_dataloader_sequential_build() {
let tensor = torsh_tensor::creation::ones::<f32>(&[5]).expect("operation should succeed");
let dataset = TensorDataset::from_tensor(tensor);
let dataloader = DataLoaderBuilder::new(dataset)
.batch_size(2)
.build()
.expect("operation should succeed");
assert_eq!(dataloader.len(), 3); assert!(!dataloader.is_empty());
}
#[test]
fn test_dataloader_random_build() {
let tensor = torsh_tensor::creation::ones::<f32>(&[5]).expect("operation should succeed");
let dataset = TensorDataset::from_tensor(tensor);
let dataloader = DataLoaderBuilder::new(dataset)
.batch_size(2)
.generator(42)
.build_with_random_sampling()
.expect("operation should succeed");
assert_eq!(dataloader.len(), 3);
assert!(!dataloader.is_empty());
}
#[test]
fn test_dataloader_iteration() {
let tensor = torsh_tensor::creation::ones::<f32>(&[4]).expect("operation should succeed");
let dataset = TensorDataset::from_tensor(tensor);
let dataloader = DataLoaderBuilder::new(dataset)
.batch_size(2)
.build()
.expect("operation should succeed");
let mut iter = dataloader.iter();
let batch1 = iter
.next()
.expect("iterator should have a next element")
.expect("operation should succeed");
let batch2 = iter
.next()
.expect("iterator should have a next element")
.expect("operation should succeed");
assert!(iter.next().is_none());
assert_eq!(batch1.len(), 1);
assert_eq!(batch2.len(), 1);
assert_eq!(batch1[0].shape().dims(), &[2, 1]); assert_eq!(batch2[0].shape().dims(), &[2, 1]); }
#[test]
fn test_dataloader_drop_last() {
let tensor = torsh_tensor::creation::ones::<f32>(&[5]).expect("operation should succeed");
let dataset = TensorDataset::from_tensor(tensor);
let dataloader = DataLoaderBuilder::new(dataset)
.batch_size(2)
.drop_last(true)
.build()
.expect("operation should succeed");
assert_eq!(dataloader.len(), 2); }
#[test]
fn test_dataloader_trait_implementation() {
let tensor = torsh_tensor::creation::ones::<f32>(&[5]).expect("operation should succeed");
let dataset = TensorDataset::from_tensor(tensor);
let dataloader = DataLoaderBuilder::new(dataset)
.batch_size(2)
.build()
.expect("operation should succeed");
assert_eq!(DataLoaderTrait::len(&dataloader), 3);
assert!(!DataLoaderTrait::is_empty(&dataloader));
}
#[test]
fn test_empty_dataloader() {
let tensors: Vec<torsh_tensor::Tensor<f32>> = vec![];
let dataset = TensorDataset::new(tensors);
let dataloader = DataLoaderBuilder::new(dataset)
.batch_size(2)
.build()
.expect("operation should succeed");
assert_eq!(dataloader.len(), 0);
assert!(dataloader.is_empty());
let mut iter = dataloader.iter();
assert!(iter.next().is_none());
}
}