use crate::data::Dataset;
use crate::error::Result;
use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
use scirs2_core::num_integer::div_ceil;
use scirs2_core::numeric::{Float, FromPrimitive, NumAssign};
use scirs2_core::random::rngs::SmallRng;
use scirs2_core::random::seq::SliceRandom;
use scirs2_core::random::{thread_rng, SeedableRng};
use std::fmt::Debug;
use std::marker::PhantomData;
type BatchResult<F> = Result<(Array<F, IxDyn>, Array<F, IxDyn>)>;
pub struct DataLoader<
F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
D: Dataset<F> + Send + Sync,
> {
pub dataset: D,
pub batch_size: usize,
pub shuffle: bool,
pub drop_last: bool,
indices: Vec<usize>,
position: usize,
_phantom: PhantomData<F>,
}
impl<
F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
D: Dataset<F> + Send + Sync,
> DataLoader<F, D>
{
pub fn new(dataset: D, batch_size: usize, shuffle: bool, drop_last: bool) -> Self {
let indices: Vec<usize> = (0..dataset.len()).collect();
Self {
dataset,
batch_size,
shuffle,
drop_last,
indices,
position: 0,
_phantom: PhantomData,
}
}
pub fn reset(&mut self) {
if self.shuffle {
let mut rng = SmallRng::from_rng(&mut thread_rng());
self.indices.shuffle(&mut rng);
}
self.position = 0;
}
pub fn num_batches(&self) -> usize {
let num = div_ceil(self.dataset.len(), self.batch_size);
if self.drop_last && num > 0 && self.dataset.len() % self.batch_size != 0 {
num - 1
} else {
num
}
}
pub fn len(&self) -> usize {
self.dataset.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn next_batch(&mut self) -> Option<BatchResult<F>> {
if self.position >= self.dataset.len() {
return None;
}
let remaining = self.dataset.len() - self.position;
let batch_size = if remaining < self.batch_size {
if self.drop_last {
return None;
}
remaining
} else {
self.batch_size
};
let batch_indices: Vec<usize> =
self.indices[self.position..self.position + batch_size].to_vec();
self.position += batch_size;
let result = self.load_batch(&batch_indices);
Some(result)
}
fn load_batch(&self, indices: &[usize]) -> Result<(Array<F, IxDyn>, Array<F, IxDyn>)> {
let (first_x, first_y) = self.dataset.get(indices[0])?;
let batch_x_shape = [indices.len()]
.iter()
.chain(first_x.shape())
.cloned()
.collect::<Vec<_>>();
let batch_y_shape = [indices.len()]
.iter()
.chain(first_y.shape())
.cloned()
.collect::<Vec<_>>();
let mut batch_x = Array::zeros(IxDyn(&batch_x_shape));
let mut batch_y = Array::zeros(IxDyn(&batch_y_shape));
for (i, &idx) in indices.iter().enumerate() {
let (x, y) = self.dataset.get(idx)?;
let mut batch_x_slice = batch_x.slice_mut(scirs2_core::ndarray::s![i, ..]);
batch_x_slice.assign(&x);
let mut batch_y_slice = batch_y.slice_mut(scirs2_core::ndarray::s![i, ..]);
batch_y_slice.assign(&y);
}
Ok((batch_x, batch_y))
}
}
impl<
F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
D: Dataset<F> + Send + Sync,
> Iterator for DataLoader<F, D>
{
type Item = Result<(Array<F, IxDyn>, Array<F, IxDyn>)>;
fn next(&mut self) -> Option<Self::Item> {
self.next_batch()
}
}
#[allow(dead_code)]
pub fn iter_batches<
F: Float + NumAssign + Debug + ScalarOperand + FromPrimitive + Send + Sync,
D: Dataset<F> + Send + Sync,
>(
dataset: D,
batch_size: usize,
shuffle: bool,
drop_last: bool,
) -> DataLoader<F, D> {
DataLoader::new(dataset, batch_size, shuffle, drop_last)
}