use crate::collate::{Collate, stack_tensors};
use crate::dataset::Dataset;
use crate::sampler::{RandomSampler, Sampler, SequentialSampler};
use axonml_core::Device;
use axonml_tensor::Tensor;
use rayon::prelude::*;
use std::marker::PhantomData;
use std::sync::mpsc;
use std::thread;
#[derive(Debug, Clone)]
pub struct Batch {
pub data: Tensor<f32>,
pub targets: Tensor<f32>,
pub size: usize,
}
impl Batch {
#[must_use]
pub fn new(data: Tensor<f32>, targets: Tensor<f32>) -> Self {
let size = data.shape()[0];
Self {
data,
targets,
size,
}
}
#[must_use]
pub fn len(&self) -> usize {
self.size
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.size == 0
}
}
pub struct DataLoader<D>
where
D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
{
dataset: D,
batch_size: usize,
shuffle: bool,
drop_last: bool,
num_workers: usize,
}
impl<D> DataLoader<D>
where
D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
{
pub fn new(dataset: D, batch_size: usize) -> Self {
Self {
dataset,
batch_size,
shuffle: false,
drop_last: false,
num_workers: 0,
}
}
pub fn shuffle(mut self, shuffle: bool) -> Self {
self.shuffle = shuffle;
self
}
pub fn drop_last(mut self, drop_last: bool) -> Self {
self.drop_last = drop_last;
self
}
pub fn num_workers(mut self, num_workers: usize) -> Self {
self.num_workers = num_workers;
self
}
pub fn batch_size(&self) -> usize {
self.batch_size
}
pub fn len(&self) -> usize {
let total = self.dataset.len();
if self.drop_last {
total / self.batch_size
} else {
total.div_ceil(self.batch_size)
}
}
pub fn is_empty(&self) -> bool {
self.dataset.is_empty()
}
pub fn dataset_len(&self) -> usize {
self.dataset.len()
}
pub fn iter(&self) -> DataLoaderIter<'_, D> {
let indices: Vec<usize> = if self.shuffle {
let sampler = RandomSampler::new(self.dataset.len());
sampler.iter().collect()
} else {
let sampler = SequentialSampler::new(self.dataset.len());
sampler.iter().collect()
};
DataLoaderIter {
dataset: &self.dataset,
indices,
batch_size: self.batch_size,
drop_last: self.drop_last,
position: 0,
num_workers: self.num_workers,
}
}
}
pub struct DataLoaderIter<'a, D>
where
D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
{
dataset: &'a D,
indices: Vec<usize>,
batch_size: usize,
drop_last: bool,
position: usize,
num_workers: usize,
}
impl<D> Iterator for DataLoaderIter<'_, D>
where
D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
{
type Item = Batch;
fn next(&mut self) -> Option<Self::Item> {
if self.position >= self.indices.len() {
return None;
}
let end = (self.position + self.batch_size).min(self.indices.len());
let batch_indices = &self.indices[self.position..end];
if batch_indices.len() < self.batch_size && self.drop_last {
return None;
}
let samples: Vec<(Tensor<f32>, Tensor<f32>)> = if self.num_workers > 0 {
batch_indices
.par_iter()
.filter_map(|&idx| self.dataset.get(idx))
.collect()
} else {
batch_indices
.iter()
.filter_map(|&idx| self.dataset.get(idx))
.collect()
};
if samples.is_empty() {
return None;
}
let data_samples: Vec<Tensor<f32>> = samples.iter().map(|(x, _)| x.clone()).collect();
let target_samples: Vec<Tensor<f32>> = samples.iter().map(|(_, y)| y.clone()).collect();
let data = stack_tensors(&data_samples);
let targets = stack_tensors(&target_samples);
self.position = end;
Some(Batch::new(data, targets))
}
}
impl<D> DataLoaderIter<'_, D>
where
D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
{
#[must_use]
pub fn remaining(&self) -> usize {
let remaining_samples = self.indices.len().saturating_sub(self.position);
if self.drop_last {
remaining_samples / self.batch_size
} else {
remaining_samples.div_ceil(self.batch_size)
}
}
}
pub struct GpuPrefetchIter {
receiver: mpsc::Receiver<Batch>,
_worker: Option<thread::JoinHandle<()>>,
}
impl GpuPrefetchIter {
fn new_streaming<D>(
dataset: D,
indices: Vec<usize>,
batch_size: usize,
drop_last: bool,
num_workers: usize,
device: Device,
) -> Self
where
D: Dataset<Item = (Tensor<f32>, Tensor<f32>)> + 'static,
{
let (tx, rx) = mpsc::sync_channel(2);
let worker = thread::spawn(move || {
let mut position = 0;
while position < indices.len() {
let end = (position + batch_size).min(indices.len());
let batch_indices = &indices[position..end];
if batch_indices.len() < batch_size && drop_last {
break;
}
let samples: Vec<(Tensor<f32>, Tensor<f32>)> = if num_workers > 0 {
batch_indices
.par_iter()
.filter_map(|&idx| dataset.get(idx))
.collect()
} else {
batch_indices
.iter()
.filter_map(|&idx| dataset.get(idx))
.collect()
};
if samples.is_empty() {
break;
}
let data_samples: Vec<Tensor<f32>> =
samples.iter().map(|(x, _)| x.clone()).collect();
let target_samples: Vec<Tensor<f32>> =
samples.iter().map(|(_, y)| y.clone()).collect();
let data = stack_tensors(&data_samples);
let targets = stack_tensors(&target_samples);
let gpu_data = match data.to_device(device) {
Ok(t) => t,
Err(_) => data,
};
let gpu_targets = match targets.to_device(device) {
Ok(t) => t,
Err(_) => targets,
};
if tx.send(Batch::new(gpu_data, gpu_targets)).is_err() {
break;
}
position = end;
}
});
Self {
receiver: rx,
_worker: Some(worker),
}
}
}
impl Iterator for GpuPrefetchIter {
type Item = Batch;
fn next(&mut self) -> Option<Self::Item> {
self.receiver.recv().ok()
}
}
impl<D> DataLoader<D>
where
D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
{
pub fn prefetch_to_gpu(&self, device: Device) -> GpuPrefetchIter
where
D: Clone + 'static,
{
let indices: Vec<usize> = if self.shuffle {
let sampler = RandomSampler::new(self.dataset.len());
sampler.iter().collect()
} else {
(0..self.dataset.len()).collect()
};
GpuPrefetchIter::new_streaming(
self.dataset.clone(),
indices,
self.batch_size,
self.drop_last,
self.num_workers,
device,
)
}
}
pub struct GenericDataLoader<D, C, T>
where
D: Dataset<Item = T>,
C: Collate<T>,
T: Send,
{
dataset: D,
collate_fn: C,
batch_size: usize,
shuffle: bool,
drop_last: bool,
num_workers: usize,
_phantom: PhantomData<T>,
}
impl<D, C, T> GenericDataLoader<D, C, T>
where
D: Dataset<Item = T>,
C: Collate<T>,
T: Send,
{
pub fn new(dataset: D, collate_fn: C, batch_size: usize) -> Self {
Self {
dataset,
collate_fn,
batch_size,
shuffle: false,
drop_last: false,
num_workers: 0,
_phantom: PhantomData,
}
}
pub fn num_workers(mut self, num_workers: usize) -> Self {
self.num_workers = num_workers;
self
}
pub fn shuffle(mut self, shuffle: bool) -> Self {
self.shuffle = shuffle;
self
}
pub fn drop_last(mut self, drop_last: bool) -> Self {
self.drop_last = drop_last;
self
}
pub fn len(&self) -> usize {
let total = self.dataset.len();
if self.drop_last {
total / self.batch_size
} else {
total.div_ceil(self.batch_size)
}
}
pub fn is_empty(&self) -> bool {
self.dataset.is_empty()
}
#[allow(clippy::iter_not_returning_iterator)]
pub fn iter(&self) -> GenericDataLoaderIter<'_, D, C, T> {
let indices: Vec<usize> = if self.shuffle {
let sampler = RandomSampler::new(self.dataset.len());
sampler.iter().collect()
} else {
(0..self.dataset.len()).collect()
};
GenericDataLoaderIter {
dataset: &self.dataset,
collate_fn: &self.collate_fn,
indices,
batch_size: self.batch_size,
drop_last: self.drop_last,
position: 0,
num_workers: self.num_workers,
_phantom: PhantomData,
}
}
}
pub struct GenericDataLoaderIter<'a, D, C, T>
where
D: Dataset<Item = T>,
C: Collate<T>,
T: Send,
{
dataset: &'a D,
collate_fn: &'a C,
indices: Vec<usize>,
batch_size: usize,
drop_last: bool,
position: usize,
num_workers: usize,
_phantom: PhantomData<T>,
}
impl<D, C, T> Iterator for GenericDataLoaderIter<'_, D, C, T>
where
D: Dataset<Item = T>,
C: Collate<T>,
T: Send + Sync,
{
type Item = C::Output;
fn next(&mut self) -> Option<Self::Item> {
if self.position >= self.indices.len() {
return None;
}
let end = (self.position + self.batch_size).min(self.indices.len());
let batch_indices = &self.indices[self.position..end];
if batch_indices.len() < self.batch_size && self.drop_last {
return None;
}
let samples: Vec<T> = if self.num_workers > 0 {
batch_indices
.par_iter()
.filter_map(|&idx| self.dataset.get(idx))
.collect()
} else {
batch_indices
.iter()
.filter_map(|&idx| self.dataset.get(idx))
.collect()
};
if samples.is_empty() {
return None;
}
self.position = end;
Some(self.collate_fn.collate(samples))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::collate::DefaultCollate;
use crate::dataset::TensorDataset;
fn create_test_dataset(size: usize) -> TensorDataset {
let data: Vec<f32> = (0..size * 2).map(|i| i as f32).collect();
let targets: Vec<f32> = (0..size).map(|i| (i % 2) as f32).collect();
let x = Tensor::from_vec(data, &[size, 2]).unwrap();
let y = Tensor::from_vec(targets, &[size]).unwrap();
TensorDataset::new(x, y)
}
#[test]
fn test_dataloader_basic() {
let dataset = create_test_dataset(10);
let loader = DataLoader::new(dataset, 3);
assert_eq!(loader.batch_size(), 3);
assert_eq!(loader.len(), 4);
let batches: Vec<Batch> = loader.iter().collect();
assert_eq!(batches.len(), 4);
assert_eq!(batches[0].len(), 3);
assert_eq!(batches[1].len(), 3);
assert_eq!(batches[2].len(), 3);
assert_eq!(batches[3].len(), 1);
}
#[test]
fn test_dataloader_drop_last() {
let dataset = create_test_dataset(10);
let loader = DataLoader::new(dataset, 3).drop_last(true);
assert_eq!(loader.len(), 3);
let batches: Vec<Batch> = loader.iter().collect();
assert_eq!(batches.len(), 3);
for batch in &batches {
assert_eq!(batch.len(), 3);
}
}
#[test]
fn test_dataloader_shuffle() {
let dataset = create_test_dataset(100);
let loader = DataLoader::new(dataset, 10).shuffle(true);
let batch1: Vec<Batch> = loader.iter().take(1).collect();
let batch2: Vec<Batch> = loader.iter().take(1).collect();
assert!(!batch1.is_empty());
assert!(!batch2.is_empty());
}
#[test]
fn test_dataloader_exact_batches() {
let dataset = create_test_dataset(9);
let loader = DataLoader::new(dataset, 3);
let batches: Vec<Batch> = loader.iter().collect();
assert_eq!(batches.len(), 3);
for batch in &batches {
assert_eq!(batch.len(), 3);
}
}
#[test]
fn test_batch_struct() {
let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
let targets = Tensor::from_vec(vec![0.0, 1.0], &[2]).unwrap();
let batch = Batch::new(data, targets);
assert_eq!(batch.len(), 2);
assert!(!batch.is_empty());
}
#[test]
fn test_dataloader_empty() {
let x = Tensor::from_vec(vec![], &[0, 2]).unwrap();
let y = Tensor::from_vec(vec![], &[0]).unwrap();
let dataset = TensorDataset::new(x, y);
let loader = DataLoader::new(dataset, 3);
assert!(loader.is_empty());
let batches: Vec<Batch> = loader.iter().collect();
assert!(batches.is_empty());
}
#[test]
fn test_dataloader_single_item() {
let dataset = create_test_dataset(1);
let loader = DataLoader::new(dataset, 3);
let batches: Vec<Batch> = loader.iter().collect();
assert_eq!(batches.len(), 1);
assert_eq!(batches[0].len(), 1);
}
#[test]
fn test_dataloader_iteration_order() {
let dataset = create_test_dataset(6);
let loader = DataLoader::new(dataset, 2).shuffle(false);
let batches: Vec<Batch> = loader.iter().collect();
assert_eq!(batches[0].data.to_vec(), vec![0.0, 1.0, 2.0, 3.0]);
assert_eq!(batches[1].data.to_vec(), vec![4.0, 5.0, 6.0, 7.0]);
assert_eq!(batches[2].data.to_vec(), vec![8.0, 9.0, 10.0, 11.0]);
}
#[test]
fn test_generic_dataloader() {
let dataset = create_test_dataset(6);
let collate = DefaultCollate::new();
let loader = GenericDataLoader::new(dataset, collate, 2);
let batches: Vec<_> = loader.iter().collect();
assert_eq!(batches.len(), 3);
}
#[test]
fn test_dataloader_remaining() {
let dataset = create_test_dataset(10);
let loader = DataLoader::new(dataset, 3);
let mut iter = loader.iter();
assert_eq!(iter.remaining(), 4);
iter.next();
assert_eq!(iter.remaining(), 3);
iter.next();
assert_eq!(iter.remaining(), 2);
}
#[test]
fn test_parallel_dataloader() {
let dataset = create_test_dataset(100);
let loader = DataLoader::new(dataset, 10).num_workers(4);
let batches: Vec<Batch> = loader.iter().collect();
assert_eq!(batches.len(), 10);
let total_samples: usize = batches.iter().map(|b| b.len()).sum();
assert_eq!(total_samples, 100);
}
#[test]
fn test_parallel_vs_sequential_equivalence() {
let dataset_seq = create_test_dataset(50);
let dataset_par = create_test_dataset(50);
let loader_seq = DataLoader::new(dataset_seq, 5).num_workers(0);
let batches_seq: Vec<Batch> = loader_seq.iter().collect();
let loader_par = DataLoader::new(dataset_par, 5).num_workers(4);
let batches_par: Vec<Batch> = loader_par.iter().collect();
assert_eq!(batches_seq.len(), batches_par.len());
for i in 0..batches_seq.len() {
assert_eq!(batches_seq[i].data.to_vec(), batches_par[i].data.to_vec());
assert_eq!(
batches_seq[i].targets.to_vec(),
batches_par[i].targets.to_vec()
);
}
}
#[test]
fn test_parallel_dataloader_drop_last() {
let dataset = create_test_dataset(95);
let loader = DataLoader::new(dataset, 10).drop_last(true).num_workers(4);
let batches: Vec<Batch> = loader.iter().collect();
assert_eq!(batches.len(), 9);
for batch in &batches {
assert_eq!(batch.len(), 10);
}
}
#[test]
fn test_parallel_generic_dataloader() {
let dataset = create_test_dataset(60);
let collate = DefaultCollate::new();
let loader = GenericDataLoader::new(dataset, collate, 10).num_workers(4);
let batches: Vec<_> = loader.iter().collect();
assert_eq!(batches.len(), 6);
}
#[test]
fn test_gpu_prefetch_cpu_fallback() {
use axonml_core::Device;
let dataset = create_test_dataset(10);
let loader = DataLoader::new(dataset, 3);
let batches: Vec<Batch> = loader.prefetch_to_gpu(Device::Cpu).collect();
assert_eq!(batches.len(), 4);
assert_eq!(batches[0].len(), 3);
assert_eq!(batches[1].len(), 3);
assert_eq!(batches[2].len(), 3);
assert_eq!(batches[3].len(), 1);
}
#[test]
fn test_gpu_prefetch_data_integrity() {
use axonml_core::Device;
let dataset = create_test_dataset(6);
let loader = DataLoader::new(dataset, 2).shuffle(false);
let batches: Vec<Batch> = loader.prefetch_to_gpu(Device::Cpu).collect();
assert_eq!(batches[0].data.to_vec(), vec![0.0, 1.0, 2.0, 3.0]);
assert_eq!(batches[1].data.to_vec(), vec![4.0, 5.0, 6.0, 7.0]);
assert_eq!(batches[2].data.to_vec(), vec![8.0, 9.0, 10.0, 11.0]);
}
#[test]
fn test_gpu_prefetch_early_drop() {
use axonml_core::Device;
let dataset = create_test_dataset(100);
let loader = DataLoader::new(dataset, 10);
let mut iter = loader.prefetch_to_gpu(Device::Cpu);
let first = iter.next();
assert!(first.is_some());
assert_eq!(first.unwrap().len(), 10);
drop(iter);
}
}