use crate::core::error::BellandeError;
use crate::core::tensor::Tensor;
use crate::data::{dataset::Dataset, sampler::Sampler};
use rayon::prelude::*;
use std::sync::{Arc, Mutex};
pub struct DataLoader {
dataset: Arc<Box<dyn Dataset>>,
batch_size: usize,
shuffle: bool,
num_workers: usize,
sampler: Option<Arc<Mutex<Box<dyn Sampler>>>>,
drop_last: bool,
}
impl DataLoader {
pub fn new(
dataset: Box<dyn Dataset>,
batch_size: usize,
shuffle: bool,
num_workers: usize,
sampler: Option<Box<dyn Sampler>>,
drop_last: bool,
) -> Self {
DataLoader {
dataset: Arc::new(dataset),
batch_size,
shuffle,
num_workers,
sampler: sampler.map(|s| Arc::new(Mutex::new(s))),
drop_last,
}
}
pub fn iter(&self) -> DataLoaderIterator {
DataLoaderIterator {
dataloader: self,
index: 0,
}
}
}
pub struct DataLoaderIterator<'a> {
dataloader: &'a DataLoader,
index: usize,
}
impl<'a> Iterator for DataLoaderIterator<'a> {
type Item = Result<(Tensor, Tensor), BellandeError>;
fn next(&mut self) -> Option<Self::Item> {
if self.index >= self.dataloader.dataset.len() {
return None;
}
let batch_indices: Vec<usize> = if let Some(sampler) = &self.dataloader.sampler {
match sampler.lock() {
Ok(mut sampler) => sampler.sample(self.dataloader.batch_size),
Err(_) => return Some(Err(BellandeError::LockError)),
}
} else if self.dataloader.shuffle {
use rand::seq::SliceRandom;
let mut rng = rand::thread_rng();
let mut indices: Vec<usize> = (0..self.dataloader.dataset.len()).collect();
indices.shuffle(&mut rng);
indices[..self.dataloader.batch_size.min(indices.len())].to_vec()
} else {
let end = (self.index + self.dataloader.batch_size).min(self.dataloader.dataset.len());
(self.index..end).collect()
};
if batch_indices.is_empty()
|| (self.dataloader.drop_last && batch_indices.len() < self.dataloader.batch_size)
{
return None;
}
let batch: Vec<(Tensor, Tensor)> = if self.dataloader.num_workers > 1 {
batch_indices
.par_iter()
.map(|&idx| self.dataloader.dataset.get(idx))
.collect()
} else {
batch_indices
.iter()
.map(|&idx| self.dataloader.dataset.get(idx))
.collect()
};
self.index += self.dataloader.batch_size;
if batch.is_empty() {
None
} else {
Some(collate_batch(batch))
}
}
}
fn get_batch_shape(tensors: &[Tensor]) -> Result<Vec<usize>, BellandeError> {
if tensors.is_empty() {
return Err(BellandeError::InvalidInputs(
"Empty tensor batch".to_string(),
));
}
let base_shape = tensors[0].shape();
for (i, tensor) in tensors.iter().enumerate().skip(1) {
if tensor.shape() != base_shape {
return Err(BellandeError::ShapeMismatch(format!(
"tensor 0 has shape {:?} but tensor {} has shape {:?}",
base_shape,
i,
tensor.shape()
)));
}
}
let mut batch_shape = vec![tensors.len()];
batch_shape.extend(base_shape);
Ok(batch_shape)
}
fn collate_batch(batch: Vec<(Tensor, Tensor)>) -> Result<(Tensor, Tensor), BellandeError> {
if batch.is_empty() {
return Err(BellandeError::InvalidInputs(
"Empty batch provided".to_string(),
));
}
let (data_tensors, label_tensors): (Vec<Tensor>, Vec<Tensor>) = batch.into_iter().unzip();
let data_shape = get_batch_shape(&data_tensors)?;
let label_shape = get_batch_shape(&label_tensors)?;
let mut batched_data = Tensor::zeros(&data_shape);
let mut batched_labels = Tensor::zeros(&label_shape);
for (i, data) in data_tensors.iter().enumerate() {
copy_tensor_slice(&mut batched_data, i, data)?;
}
for (i, label) in label_tensors.iter().enumerate() {
copy_tensor_slice(&mut batched_labels, i, label)?;
}
Ok((batched_data, batched_labels))
}
fn copy_tensor_slice(
dest: &mut Tensor,
batch_idx: usize,
source: &Tensor,
) -> Result<(), BellandeError> {
let batch_stride = dest.stride()[0];
let start_idx = batch_idx * batch_stride;
let end_idx = start_idx + batch_stride;
if end_idx > dest.data().len() {
return Err(BellandeError::IndexOutOfBounds);
}
if source.data().len() != batch_stride {
return Err(BellandeError::DimensionMismatch);
}
let dest_slice = &mut dest.data_mut()[start_idx..end_idx];
dest_slice.copy_from_slice(source.data());
Ok(())
}