use crate::data::dataset::{DataError, Dataset};
use crate::data::sampler::Sampler;
use crate::error::RusTorchError;
use crate::tensor::Tensor;
use num_traits::Float;
use rand::seq::SliceRandom;
use rand::thread_rng;
use rayon::prelude::*;
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
pub struct DataLoader<'a, T: Float, D: Dataset<T>> {
dataset: &'a D,
sampler: Box<dyn Sampler + Send + Sync>,
batch_size: usize,
drop_last: bool,
num_workers: usize,
_phantom: std::marker::PhantomData<T>,
}
impl<'a, T: Float, D: Dataset<T>> DataLoader<'a, T, D> {
pub fn new(dataset: &'a D, sampler: Box<dyn Sampler + Send + Sync>, batch_size: usize) -> Self {
Self {
dataset,
sampler,
batch_size,
drop_last: false,
num_workers: 1,
_phantom: std::marker::PhantomData,
}
}
pub fn with_options(
dataset: &'a D,
sampler: Box<dyn Sampler + Send + Sync>,
batch_size: usize,
drop_last: bool,
num_workers: usize,
) -> Self {
Self {
dataset,
sampler,
batch_size,
drop_last,
num_workers,
_phantom: std::marker::PhantomData,
}
}
pub fn next_batch(&mut self) -> Option<Vec<T>> {
let mut indices = Vec::new();
for _ in 0..self.batch_size {
if let Some(idx) = self.sampler.sample() {
indices.push(idx);
} else {
break;
}
}
if indices.is_empty() {
return None;
}
if self.drop_last && indices.len() < self.batch_size {
return None;
}
let mut batch = Vec::new();
for idx in indices {
if let Ok(item) = self.dataset.get_item(idx) {
batch.push(item);
}
}
if batch.is_empty() {
None
} else {
Some(batch)
}
}
pub fn reset(&mut self) {
self.sampler.reset();
}
pub fn is_empty(&self) -> bool {
self.sampler.is_empty()
}
pub fn batch_size(&self) -> usize {
self.batch_size
}
pub fn num_workers(&self) -> usize {
self.num_workers
}
}