use super::fetch::{Fetcher, MapDatasetFetcher};
use crate::{
collate::{Collate, DefaultCollate},
sampler::{BatchIterator, BatchSampler, Sampler, SequentialSampler},
Dataset, Len,
};
mod builder;
use builder::Builder;
#[derive(Debug, Clone, PartialEq, PartialOrd, Hash, Eq, Ord)]
pub struct DataLoader<D, S = SequentialSampler, C = DefaultCollate> {
dataset: D,
batch_sampler: BatchSampler<S>,
collate_fn: C,
}
impl<D> DataLoader<D, SequentialSampler, DefaultCollate>
where
D: Dataset,
DefaultCollate: Collate<D::Sample>,
{
pub fn builder(dataset: D) -> Builder<D, SequentialSampler, DefaultCollate> {
Builder::new(dataset)
}
}
impl<D, S, C> DataLoader<D, S, C>
where
D: Dataset + Sync,
S: Sampler,
C: Collate<D::Sample>,
D::Sample: Send,
{
pub fn iter(&self) -> SingleProcessDataLoaderIter<'_, D, S, C> {
SingleProcessDataLoaderIter::new(self)
}
}
impl<D, S, C> Len for DataLoader<D, S, C>
where
D: Dataset,
S: Sampler,
C: Collate<D::Sample>,
{
fn len(&self) -> usize {
self.batch_sampler.len()
}
}
#[derive(Debug)]
pub struct SingleProcessDataLoaderIter<'dataset, D, S = SequentialSampler, C = DefaultCollate>
where
D: Dataset + Sync,
S: Sampler,
C: Collate<D::Sample>,
{
sampler_iter: BatchIterator<S::IntoIter>,
num_yielded: u64,
data_fetcher: MapDatasetFetcher<'dataset, D, C>,
}
impl<'dataset, D, S, C> SingleProcessDataLoaderIter<'dataset, D, S, C>
where
D: Dataset + Sync,
S: Sampler,
C: Collate<D::Sample>,
D::Sample: Send,
{
fn new(loader: &DataLoader<D, S, C>) -> SingleProcessDataLoaderIter<'_, D, S, C> {
SingleProcessDataLoaderIter {
sampler_iter: loader.batch_sampler.iter(),
num_yielded: 0,
data_fetcher: MapDatasetFetcher {
dataset: &loader.dataset,
collate_fn: &loader.collate_fn,
},
}
}
fn next_index(&mut self) -> Option<Vec<usize>> {
self.sampler_iter.next()
}
fn next_data(&mut self) -> Option<C::Output> {
let index = self.next_index();
if let Some(index) = index {
let data = self.data_fetcher.fetch(index);
return Some(data);
}
None
}
}
impl<'dataset, D, S, C> Iterator for SingleProcessDataLoaderIter<'dataset, D, S, C>
where
D: Dataset + Sync,
S: Sampler,
C: Collate<D::Sample>,
D::Sample: Send,
{
type Item = C::Output;
fn next(&mut self) -> Option<Self::Item> {
let data = self.next_data();
if let Some(data) = data {
self.num_yielded += 1;
return Some(data);
}
None
}
}
impl<'dataset, D, S, C> IntoIterator for &'dataset DataLoader<D, S, C>
where
D: Dataset + Sync,
S: Sampler,
C: Collate<D::Sample>,
D::Sample: Send,
{
type Item = C::Output;
type IntoIter = SingleProcessDataLoaderIter<'dataset, D, S, C>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::collate::NoOpCollate;
use crate::sampler::RandomSampler;
use crate::sampler::SequentialSampler;
use crate::Len;
use crate::NdarrayDataset;
use ndarray::{arr0, array, Array, Array1, Array4, Axis, Ix1, Ix4, Slice};
use ndarray_rand::rand_distr::{Normal, Uniform};
use ndarray_rand::RandomExt;
use std::collections::HashMap;
#[test]
fn len() {
let dataset = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let dataloader = DataLoader::builder(dataset)
.batch_size(2)
.drop_last()
.build();
assert_eq!(dataloader.len(), dataloader.batch_sampler.len());
assert_eq!(dataloader.len(), 5);
}
#[test]
fn one_dimension_basic() {
let dataset = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let dataloader = DataLoader::builder(dataset).batch_size(2).build();
let mut iter = dataloader.iter();
assert_eq!(iter.next(), Some(array![1, 2]));
assert_eq!(iter.next(), Some(array![3, 4]));
assert_eq!(iter.next(), Some(array![5, 6]));
assert_eq!(iter.next(), Some(array![7, 8]));
assert_eq!(iter.next(), Some(array![9, 10]));
assert_eq!(iter.next(), None);
}
#[test]
fn two_iteration() {
let dataset = vec![1, 2, 3, 4];
let dataloader = DataLoader::builder(dataset).batch_size(2).build();
let mut iter = dataloader.iter();
assert_eq!(iter.next(), Some(array![1, 2]));
assert_eq!(iter.next(), Some(array![3, 4]));
assert_eq!(iter.next(), None);
let mut iter = dataloader.iter();
assert_eq!(iter.next(), Some(array![1, 2]));
assert_eq!(iter.next(), Some(array![3, 4]));
assert_eq!(iter.next(), None);
}
#[test]
fn one_dimension_basic_string() {
let dataset = vec![String::from("a"), String::from("b")];
let dataloader = DataLoader::builder(dataset).build();
let mut iter = dataloader.iter();
assert_eq!(iter.next(), Some(vec![String::from("a")]));
assert_eq!(iter.next(), Some(vec![String::from("b")]));
assert_eq!(iter.next(), None);
}
#[test]
fn collate() {
let dataset = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let dataloader = DataLoader::builder(dataset)
.batch_size(2)
.collate_fn(NoOpCollate)
.build();
let mut iter = dataloader.iter();
assert_eq!(iter.next(), Some(vec![1, 2]));
assert_eq!(iter.next(), Some(vec![3, 4]));
assert_eq!(iter.next(), Some(vec![5, 6]));
assert_eq!(iter.next(), Some(vec![7, 8]));
assert_eq!(iter.next(), Some(vec![9, 10]));
assert_eq!(iter.next(), None);
}
struct TestDataLoader<S: Sampler> {
loader: DataLoader<NdarrayDataset<f64, f64, Ix4, Ix1>, S>,
data: Array4<f64>,
labels: Array1<f64>,
dataset: NdarrayDataset<f64, f64, Ix4, Ix1>,
}
enum TestDataLoaderData {
Sequential(TestDataLoader<SequentialSampler>),
Random(TestDataLoader<RandomSampler>),
}
fn get_loader_with_dummy_data(batch_size: usize, shuffle: bool) -> TestDataLoaderData {
let normal: Normal<f64> = Normal::new(0.0, 1.0).unwrap();
let data = Array::random((100, 2, 3, 5), normal);
let labels = Array::random(100, Uniform::<f64>::new(0., 50.));
let dataset = NdarrayDataset {
ndarrays: (data.clone(), labels.clone()),
};
if shuffle {
let loader = DataLoader::builder(dataset.clone())
.batch_size(batch_size)
.shuffle()
.build();
TestDataLoaderData::Random(TestDataLoader {
loader,
data,
labels,
dataset,
})
} else {
let loader = DataLoader::builder(dataset.clone())
.batch_size(batch_size)
.build();
TestDataLoaderData::Sequential(TestDataLoader {
loader,
data,
labels,
dataset,
})
}
}
#[test]
fn sequential_non_batch() {
let batch_size = 1;
let test_dataloader_data = tests::get_loader_with_dummy_data(batch_size, false);
let test_data;
if let TestDataLoaderData::Sequential(test_dataloader_data) = test_dataloader_data {
test_data = test_dataloader_data;
} else {
panic!("Expected a sequential loader")
}
let mut current_idx = 0;
for (idx, (sample, target)) in test_data.loader.iter().enumerate() {
assert_eq!(
sample,
test_data
.data
.slice_axis(Axis(0), Slice::from(idx..idx + batch_size))
);
assert_eq!(
target,
test_data
.labels
.slice_axis(Axis(0), Slice::from(idx..idx + batch_size))
);
current_idx = idx;
}
assert_eq!(current_idx, test_data.dataset.len() - 1);
}
#[test]
fn sequential_batch() {
let batch_size = 2;
let test_dataloader_data = tests::get_loader_with_dummy_data(2, false);
let test_data;
if let TestDataLoaderData::Sequential(test_dataloader_data) = test_dataloader_data {
test_data = test_dataloader_data;
} else {
panic!("Expected a sequential loader")
}
let mut current_i = 0;
for (i, (sample, target)) in test_data.loader.iter().enumerate() {
let idx = i * batch_size;
assert_eq!(
sample,
test_data
.data
.slice_axis(Axis(0), Slice::from(idx..idx + batch_size))
);
assert_eq!(
target,
test_data
.labels
.slice_axis(Axis(0), Slice::from(idx..idx + batch_size))
);
current_i = i;
}
assert_eq!(current_i, (test_data.dataset.len() - 1) / batch_size);
}
#[test]
fn shuffle_non_batch() {
let test_dataloader_data = tests::get_loader_with_dummy_data(1, true);
let test_data;
if let TestDataLoaderData::Random(test_dataloader_data) = test_dataloader_data {
test_data = test_dataloader_data;
} else {
panic!("Expected a random loader")
}
let mut found_data: HashMap<_, _> = (0..test_data.data.len())
.zip(vec![0; test_data.data.len()])
.collect();
let mut found_labels: HashMap<_, _> = (0..test_data.labels.len())
.zip(vec![0; test_data.labels.len()])
.collect();
let mut current_i = 0;
for (i, (sample, target)) in test_data.loader.iter().enumerate() {
current_i = i;
let mut current_data_point_idx = 0;
for (data_point_idx, data_point) in test_data.data.outer_iter().enumerate() {
current_data_point_idx = data_point_idx;
if data_point == sample.index_axis(Axis(0), 0) {
assert_eq!(found_data[&data_point_idx], 0);
*found_data.get_mut(&data_point_idx).unwrap() += 1;
break;
}
}
assert_eq!(
arr0(target[0]),
test_data.labels.index_axis(Axis(0), current_data_point_idx)
);
*found_labels.get_mut(¤t_data_point_idx).unwrap() += 1;
assert_eq!(found_data.values().sum::<usize>(), i + 1);
assert_eq!(found_labels.values().sum::<usize>(), i + 1);
}
assert_eq!(current_i, test_data.dataset.len() - 1);
}
#[test]
fn shuffle_batch() {
let batch_size = 2;
let test_dataloader_data = tests::get_loader_with_dummy_data(batch_size, true);
let test_data;
if let TestDataLoaderData::Random(test_dataloader_data) = test_dataloader_data {
test_data = test_dataloader_data;
} else {
panic!("Expected a random loader")
}
let mut found_data: HashMap<_, _> = (0..test_data.data.len())
.zip(vec![0; test_data.data.len()])
.collect();
let mut found_labels: HashMap<_, _> = (0..test_data.labels.len())
.zip(vec![0; test_data.labels.len()])
.collect();
let mut current_i = 0;
for (i, (batch_samples, batch_targets)) in test_data.loader.iter().enumerate() {
current_i = i;
for (sample, target) in batch_samples.outer_iter().zip(batch_targets) {
let mut current_data_point_idx = 0;
for (data_point_idx, data_point) in test_data.data.outer_iter().enumerate() {
current_data_point_idx = data_point_idx;
if data_point == sample {
assert_eq!(found_data[&data_point_idx], 0);
*found_data.get_mut(&data_point_idx).unwrap() += 1;
break;
}
}
assert_eq!(
arr0(target),
test_data.labels.index_axis(Axis(0), current_data_point_idx)
);
*found_labels.get_mut(¤t_data_point_idx).unwrap() += 1;
}
assert_eq!(found_data.values().sum::<usize>(), (i + 1) * batch_size);
assert_eq!(found_labels.values().sum::<usize>(), (i + 1) * batch_size);
}
assert_eq!(current_i, (test_data.dataset.len() - 1) / batch_size);
}
#[test]
fn vec_of_token() {
let dataset = vec![
(0, vec![1, 23, 4, 0]),
(1, vec![4, 0, 0, 0]),
(1, vec![8, 23, 12, 3]),
(0, vec![2, 45, 4, 0]),
];
let loader = DataLoader::builder(dataset).batch_size(2).build();
let mut iter = loader.iter();
assert_eq!(
iter.next(),
Some((
array![0, 1],
vec![array![1, 4], array![23, 0], array![4, 0], array![0, 0]]
))
);
}
}