use std::sync::mpsc;
use std::thread;
use crate::data::collate::{collate_batch, shuffled_indices};
use crate::data::dataset::{Batch, Dataset};
use crate::error::Result;
use numr::runtime::cpu::{CpuDevice, CpuRuntime};
pub struct DataLoader<D> {
dataset: D,
batch_size: usize,
seed: u64,
prefetch_count: usize,
device: CpuDevice,
}
impl<D> DataLoader<D>
where
D: Dataset<CpuRuntime>,
{
pub fn new(dataset: D, batch_size: usize, seed: u64, device: CpuDevice) -> Self {
Self {
dataset,
batch_size,
seed,
prefetch_count: 2,
device,
}
}
pub fn with_prefetch(mut self, count: usize) -> Self {
self.prefetch_count = count;
self
}
pub fn num_batches(&self) -> usize {
self.dataset.len() / self.batch_size
}
pub fn dataset(&self) -> &D {
&self.dataset
}
pub fn iter(&self, epoch: u64) -> DataLoaderIter<'_, D> {
let indices = shuffled_indices(self.dataset.len(), self.seed.wrapping_add(epoch));
let num_batches = indices.len() / self.batch_size;
let indices: Vec<usize> = indices[..num_batches * self.batch_size].to_vec();
DataLoaderIter {
loader: self,
indices,
batch_idx: 0,
num_batches,
}
}
}
pub struct DataLoaderIter<'a, D> {
loader: &'a DataLoader<D>,
indices: Vec<usize>,
batch_idx: usize,
num_batches: usize,
}
impl<'a, D> DataLoaderIter<'a, D>
where
D: Dataset<CpuRuntime>,
{
pub fn remaining(&self) -> usize {
self.num_batches - self.batch_idx
}
fn advance(&mut self) -> Result<Option<Batch<CpuRuntime>>> {
if self.batch_idx >= self.num_batches {
return Ok(None);
}
let start = self.batch_idx * self.loader.batch_size;
let end = start + self.loader.batch_size;
let batch_indices = &self.indices[start..end];
let batch = collate_batch(&self.loader.dataset, batch_indices, &self.loader.device)?;
self.batch_idx += 1;
Ok(Some(batch))
}
}
impl<D> Iterator for DataLoaderIter<'_, D>
where
D: Dataset<CpuRuntime>,
{
type Item = Result<Batch<CpuRuntime>>;
fn next(&mut self) -> Option<Self::Item> {
match self.advance() {
Ok(Some(batch)) => Some(Ok(batch)),
Ok(None) => None,
Err(e) => Some(Err(e)),
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.remaining();
(remaining, Some(remaining))
}
}
impl<D> ExactSizeIterator for DataLoaderIter<'_, D> where D: Dataset<CpuRuntime> {}
pub struct PrefetchIter {
receiver: mpsc::Receiver<Result<Batch<CpuRuntime>>>,
_handle: Option<thread::JoinHandle<()>>,
}
impl<D> DataLoader<D>
where
D: Dataset<CpuRuntime> + Clone + 'static,
{
pub fn prefetch_iter(&self, epoch: u64) -> PrefetchIter {
let indices = shuffled_indices(self.dataset.len(), self.seed.wrapping_add(epoch));
let num_batches = indices.len() / self.batch_size;
let indices: Vec<usize> = indices[..num_batches * self.batch_size].to_vec();
let capacity = self.prefetch_count.max(1);
let (tx, rx) = mpsc::sync_channel::<Result<Batch<CpuRuntime>>>(capacity);
let dataset = self.dataset.clone();
let batch_size = self.batch_size;
let device = self.device.clone();
let handle = thread::spawn(move || {
for batch_idx in 0..num_batches {
let start = batch_idx * batch_size;
let end = start + batch_size;
let batch_indices = &indices[start..end];
let result = collate_batch(&dataset, batch_indices, &device);
if tx.send(result).is_err() {
break; }
}
});
PrefetchIter {
receiver: rx,
_handle: Some(handle),
}
}
}
impl Iterator for PrefetchIter {
type Item = Result<Batch<CpuRuntime>>;
fn next(&mut self) -> Option<Self::Item> {
self.receiver.recv().ok()
}
}
impl Drop for PrefetchIter {
fn drop(&mut self) {
if let Some(handle) = self._handle.take() {
let _ = handle.join();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::data::collate::shuffled_indices;
use crate::data::dataset::Batch;
use numr::runtime::cpu::CpuDevice;
use numr::tensor::Tensor;
#[derive(Clone)]
struct SeqDataset {
size: usize,
}
impl Dataset<CpuRuntime> for SeqDataset {
fn len(&self) -> usize {
self.size
}
fn get(&self, idx: usize, device: &CpuDevice) -> Result<Batch<CpuRuntime>> {
let val = idx as f32;
Ok(Batch {
inputs: Tensor::from_slice(&[val, val + 0.1], &[2], device),
targets: Tensor::from_slice(&[val + 1.0, val + 1.1], &[2], device),
})
}
}
#[test]
fn test_shuffled_indices_deterministic() {
let a = shuffled_indices(100, 42);
let b = shuffled_indices(100, 42);
assert_eq!(a, b);
let c = shuffled_indices(100, 43);
assert_ne!(a, c);
}
#[test]
fn test_shuffled_indices_permutation() {
let indices = shuffled_indices(10, 123);
let mut sorted = indices.clone();
sorted.sort();
assert_eq!(sorted, (0..10).collect::<Vec<_>>());
}
#[test]
fn test_dataloader_basic() {
let device = CpuDevice::new();
let ds = SeqDataset { size: 10 };
let loader = DataLoader::new(ds, 3, 0, device);
assert_eq!(loader.num_batches(), 3);
let mut count = 0;
for batch in loader.iter(0) {
let batch = batch.expect("batch should not error");
assert_eq!(batch.inputs.shape(), &[3, 2]);
assert_eq!(batch.targets.shape(), &[3, 2]);
count += 1;
}
assert_eq!(count, 3);
}
#[test]
fn test_dataloader_different_epochs_different_order() {
let device = CpuDevice::new();
let ds = SeqDataset { size: 10 };
let loader = DataLoader::new(ds, 5, 42, device);
let epoch0_vals: Vec<f32> = loader
.iter(0)
.flat_map(|b| b.unwrap().inputs.to_vec::<f32>())
.collect();
let epoch1_vals: Vec<f32> = loader
.iter(1)
.flat_map(|b| b.unwrap().inputs.to_vec::<f32>())
.collect();
assert_eq!(epoch0_vals.len(), epoch1_vals.len());
assert_ne!(epoch0_vals, epoch1_vals);
}
#[test]
fn test_dataloader_prefetch() {
let device = CpuDevice::new();
let ds = SeqDataset { size: 10 };
let loader = DataLoader::new(ds, 3, 0, device).with_prefetch(2);
let mut count = 0;
for batch in loader.prefetch_iter(0) {
let batch = batch.expect("prefetch batch should not error");
assert_eq!(batch.inputs.shape(), &[3, 2]);
count += 1;
}
assert_eq!(count, 3);
}
#[test]
fn test_dataloader_empty() {
let device = CpuDevice::new();
let ds = SeqDataset { size: 2 };
let loader = DataLoader::new(ds, 5, 0, device);
assert_eq!(loader.num_batches(), 0);
let mut iter = loader.iter(0);
assert!(iter.next().is_none());
}
#[test]
fn test_dataloader_exact_size() {
let device = CpuDevice::new();
let ds = SeqDataset { size: 10 };
let loader = DataLoader::new(ds, 3, 0, device);
let iter = loader.iter(0);
assert_eq!(iter.len(), 3);
}
}