1use super::fetch::{Fetcher, MapDatasetFetcher};
4use crate::{
5 collate::{Collate, DefaultCollate},
6 sampler::{BatchIterator, BatchSampler, Sampler, SequentialSampler},
7 Dataset, Len,
8};
9
10mod builder;
11use builder::Builder;
12
13#[derive(Debug, Clone, PartialEq, PartialOrd, Hash, Eq, Ord)]
28pub struct DataLoader<D, S = SequentialSampler, C = DefaultCollate> {
29 dataset: D,
31 batch_sampler: BatchSampler<S>,
33 collate_fn: C,
35}
36
37impl<D> DataLoader<D, SequentialSampler, DefaultCollate>
38where
39 D: Dataset,
40 DefaultCollate: Collate<D::Sample>,
41{
42 pub fn builder(dataset: D) -> Builder<D, SequentialSampler, DefaultCollate> {
44 Builder::new(dataset)
45 }
46}
47
48impl<D, S, C> DataLoader<D, S, C>
49where
50 D: Dataset + Sync,
51 S: Sampler,
52 C: Collate<D::Sample>,
53 D::Sample: Send,
54{
55 pub fn iter(&self) -> SingleProcessDataLoaderIter<'_, D, S, C> {
57 SingleProcessDataLoaderIter::new(self)
58 }
59}
60
61impl<D, S, C> Len for DataLoader<D, S, C>
62where
63 D: Dataset,
64 S: Sampler,
65 C: Collate<D::Sample>,
66{
67 fn len(&self) -> usize {
69 self.batch_sampler.len()
70 }
71}
72
73#[derive(Debug)]
75pub struct SingleProcessDataLoaderIter<'dataset, D, S = SequentialSampler, C = DefaultCollate>
76where
77 D: Dataset + Sync,
78 S: Sampler,
79 C: Collate<D::Sample>,
80{
81 sampler_iter: BatchIterator<S::IntoIter>,
83 num_yielded: u64,
85 data_fetcher: MapDatasetFetcher<'dataset, D, C>,
87}
88
89impl<'dataset, D, S, C> SingleProcessDataLoaderIter<'dataset, D, S, C>
90where
91 D: Dataset + Sync,
92 S: Sampler,
93 C: Collate<D::Sample>,
94 D::Sample: Send,
95{
96 fn new(loader: &DataLoader<D, S, C>) -> SingleProcessDataLoaderIter<'_, D, S, C> {
97 SingleProcessDataLoaderIter {
98 sampler_iter: loader.batch_sampler.iter(),
99 num_yielded: 0,
100 data_fetcher: MapDatasetFetcher {
101 dataset: &loader.dataset,
102 collate_fn: &loader.collate_fn,
103 },
104 }
105 }
106 fn next_index(&mut self) -> Option<Vec<usize>> {
107 self.sampler_iter.next()
108 }
109 fn next_data(&mut self) -> Option<C::Output> {
110 let index = self.next_index();
111 if let Some(index) = index {
112 let data = self.data_fetcher.fetch(index);
113 return Some(data);
114 }
115 None
116 }
117}
118
119impl<'dataset, D, S, C> Iterator for SingleProcessDataLoaderIter<'dataset, D, S, C>
120where
121 D: Dataset + Sync,
122 S: Sampler,
123 C: Collate<D::Sample>,
124 D::Sample: Send,
125{
126 type Item = C::Output;
127 fn next(&mut self) -> Option<Self::Item> {
128 let data = self.next_data();
129
130 if let Some(data) = data {
131 self.num_yielded += 1;
132 return Some(data);
133 }
134 None
135 }
136 fn size_hint(&self) -> (usize, Option<usize>) {
137 let (lower, upper) = self.sampler_iter.size_hint();
138 (lower, upper)
139 }
140}
141
142impl<'dataset, D, S, C> IntoIterator for &'dataset DataLoader<D, S, C>
143where
144 D: Dataset + Sync,
145 S: Sampler,
146 C: Collate<D::Sample>,
147 D::Sample: Send,
148{
149 type Item = C::Output;
150 type IntoIter = SingleProcessDataLoaderIter<'dataset, D, S, C>;
151
152 fn into_iter(self) -> Self::IntoIter {
153 self.iter()
154 }
155}
156
157impl<'dataset, D, S, C> ExactSizeIterator for SingleProcessDataLoaderIter<'dataset, D, S, C>
158where
159 D: Dataset + Sync,
160 S: Sampler,
161 S::IntoIter: ExactSizeIterator,
162 C: Collate<D::Sample>,
163 D::Sample: Send,
164{
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170 use crate::collate::NoOpCollate;
171 use crate::sampler::RandomSampler;
172 use crate::sampler::SequentialSampler;
173 use crate::Len;
174 use crate::NdarrayDataset;
175 use ndarray::{arr0, array, Array, Array1, Array4, Axis, Ix1, Ix4, Slice};
176 use ndarray_rand::rand_distr::{Normal, Uniform};
177 use ndarray_rand::RandomExt;
178 use std::collections::HashMap;
179
180 #[test]
181 fn len() {
182 let dataset = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
183 let dataloader = DataLoader::builder(dataset)
184 .batch_size(2)
185 .drop_last()
186 .build();
187 assert_eq!(dataloader.len(), dataloader.batch_sampler.len());
188 assert_eq!(dataloader.len(), 5);
189 let mut iter = dataloader.iter();
190 assert_eq!(iter.len(), 5);
191 iter.next();
192 assert_eq!(iter.len(), 4);
193 }
194
195 #[test]
196 fn one_dimension_basic() {
197 let dataset = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
198 let dataloader = DataLoader::builder(dataset).batch_size(2).build();
199
200 let mut iter = dataloader.iter();
201 assert_eq!(iter.next(), Some(array![1, 2]));
202 assert_eq!(iter.next(), Some(array![3, 4]));
203 assert_eq!(iter.next(), Some(array![5, 6]));
204 assert_eq!(iter.next(), Some(array![7, 8]));
205 assert_eq!(iter.next(), Some(array![9, 10]));
206 assert_eq!(iter.next(), None);
207 }
208
209 #[test]
210 fn two_iteration() {
211 let dataset = vec![1, 2, 3, 4];
212 let dataloader = DataLoader::builder(dataset).batch_size(2).build();
213
214 let mut iter = dataloader.iter();
215 assert_eq!(iter.next(), Some(array![1, 2]));
216 assert_eq!(iter.next(), Some(array![3, 4]));
217 assert_eq!(iter.next(), None);
218 let mut iter = dataloader.iter();
219 assert_eq!(iter.next(), Some(array![1, 2]));
220 assert_eq!(iter.next(), Some(array![3, 4]));
221 assert_eq!(iter.next(), None);
222 }
223
224 #[test]
225 fn one_dimension_basic_string() {
226 let dataset = vec![String::from("a"), String::from("b")];
227 let dataloader = DataLoader::builder(dataset).build();
228
229 let mut iter = dataloader.iter();
230 assert_eq!(iter.next(), Some(vec![String::from("a")]));
231 assert_eq!(iter.next(), Some(vec![String::from("b")]));
232 assert_eq!(iter.next(), None);
233 }
234 #[test]
235 fn collate() {
236 let dataset = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
237
238 let dataloader = DataLoader::builder(dataset)
239 .batch_size(2)
240 .collate_fn(NoOpCollate)
241 .build();
242
243 let mut iter = dataloader.iter();
244
245 assert_eq!(iter.next(), Some(vec![1, 2]));
246 assert_eq!(iter.next(), Some(vec![3, 4]));
247 assert_eq!(iter.next(), Some(vec![5, 6]));
248 assert_eq!(iter.next(), Some(vec![7, 8]));
249 assert_eq!(iter.next(), Some(vec![9, 10]));
250 assert_eq!(iter.next(), None);
251 }
252 struct TestDataLoader<S: Sampler> {
253 loader: DataLoader<NdarrayDataset<f64, f64, Ix4, Ix1>, S>,
254 data: Array4<f64>,
255 labels: Array1<f64>,
256 dataset: NdarrayDataset<f64, f64, Ix4, Ix1>,
257 }
258 enum TestDataLoaderData {
259 Sequential(TestDataLoader<SequentialSampler>),
260 Random(TestDataLoader<RandomSampler>),
261 }
262 fn get_loader_with_dummy_data(batch_size: usize, shuffle: bool) -> TestDataLoaderData {
263 let normal: Normal<f64> = Normal::new(0.0, 1.0).unwrap();
265 let data = Array::random((100, 2, 3, 5), normal);
267 let labels = Array::random(100, Uniform::<f64>::new(0., 50.));
269 let dataset = NdarrayDataset {
271 ndarrays: (data.clone(), labels.clone()),
272 };
273
274 if shuffle {
275 let loader = DataLoader::builder(dataset.clone())
276 .batch_size(batch_size)
277 .shuffle()
278 .build();
279
280 TestDataLoaderData::Random(TestDataLoader {
281 loader,
282 data,
283 labels,
284 dataset,
285 })
286 } else {
287 let loader = DataLoader::builder(dataset.clone())
288 .batch_size(batch_size)
289 .build();
290
291 TestDataLoaderData::Sequential(TestDataLoader {
292 loader,
293 data,
294 labels,
295 dataset,
296 })
297 }
298 }
299
300 #[test]
301 fn sequential_non_batch() {
302 let batch_size = 1;
303 let test_dataloader_data = tests::get_loader_with_dummy_data(batch_size, false);
304 let test_data;
305 if let TestDataLoaderData::Sequential(test_dataloader_data) = test_dataloader_data {
306 test_data = test_dataloader_data;
307 } else {
308 panic!("Expected a sequential loader")
309 }
310 let mut current_idx = 0;
311
312 for (idx, (sample, target)) in test_data.loader.iter().enumerate() {
313 assert_eq!(
314 sample,
315 test_data
316 .data
317 .slice_axis(Axis(0), Slice::from(idx..idx + batch_size))
318 );
319 assert_eq!(
320 target,
321 test_data
322 .labels
323 .slice_axis(Axis(0), Slice::from(idx..idx + batch_size))
324 );
325 current_idx = idx;
326 }
327 assert_eq!(current_idx, test_data.dataset.len() - 1);
328 }
329
330 #[test]
331 fn sequential_batch() {
332 let batch_size = 2;
333 let test_dataloader_data = tests::get_loader_with_dummy_data(2, false);
334 let test_data;
335 if let TestDataLoaderData::Sequential(test_dataloader_data) = test_dataloader_data {
336 test_data = test_dataloader_data;
337 } else {
338 panic!("Expected a sequential loader")
339 }
340
341 let mut current_i = 0;
342
343 for (i, (sample, target)) in test_data.loader.iter().enumerate() {
344 let idx = i * batch_size;
345 assert_eq!(
346 sample,
347 test_data
348 .data
349 .slice_axis(Axis(0), Slice::from(idx..idx + batch_size))
350 );
351 assert_eq!(
352 target,
353 test_data
354 .labels
355 .slice_axis(Axis(0), Slice::from(idx..idx + batch_size))
356 );
357 current_i = i;
358 }
359 assert_eq!(current_i, (test_data.dataset.len() - 1) / batch_size);
360 }
361
362 #[test]
363 fn shuffle_non_batch() {
364 let test_dataloader_data = tests::get_loader_with_dummy_data(1, true);
365 let test_data;
366 if let TestDataLoaderData::Random(test_dataloader_data) = test_dataloader_data {
367 test_data = test_dataloader_data;
368 } else {
369 panic!("Expected a random loader")
370 }
371 let mut found_data: HashMap<_, _> = (0..test_data.data.len())
373 .zip(vec![0; test_data.data.len()])
374 .collect();
375 let mut found_labels: HashMap<_, _> = (0..test_data.labels.len())
376 .zip(vec![0; test_data.labels.len()])
377 .collect();
378 let mut current_i = 0;
379 for (i, (sample, target)) in test_data.loader.iter().enumerate() {
380 current_i = i;
381 let mut current_data_point_idx = 0;
382 for (data_point_idx, data_point) in test_data.data.outer_iter().enumerate() {
384 current_data_point_idx = data_point_idx;
385 if data_point == sample.index_axis(Axis(0), 0) {
387 assert_eq!(found_data[&data_point_idx], 0);
388 *found_data.get_mut(&data_point_idx).unwrap() += 1;
389 break;
390 }
391 }
392
393 assert_eq!(
394 arr0(target[0]),
395 test_data.labels.index_axis(Axis(0), current_data_point_idx)
396 );
397 *found_labels.get_mut(¤t_data_point_idx).unwrap() += 1;
398 assert_eq!(found_data.values().sum::<usize>(), i + 1);
399 assert_eq!(found_labels.values().sum::<usize>(), i + 1);
400 }
401 assert_eq!(current_i, test_data.dataset.len() - 1);
402 }
403
404 #[test]
405 fn shuffle_batch() {
406 let batch_size = 2;
407 let test_dataloader_data = tests::get_loader_with_dummy_data(batch_size, true);
408 let test_data;
409 if let TestDataLoaderData::Random(test_dataloader_data) = test_dataloader_data {
410 test_data = test_dataloader_data;
411 } else {
412 panic!("Expected a random loader")
413 }
414 let mut found_data: HashMap<_, _> = (0..test_data.data.len())
415 .zip(vec![0; test_data.data.len()])
416 .collect();
417 let mut found_labels: HashMap<_, _> = (0..test_data.labels.len())
418 .zip(vec![0; test_data.labels.len()])
419 .collect();
420 let mut current_i = 0;
421 for (i, (batch_samples, batch_targets)) in test_data.loader.iter().enumerate() {
422 current_i = i;
423 for (sample, target) in batch_samples.outer_iter().zip(batch_targets) {
424 let mut current_data_point_idx = 0;
425 for (data_point_idx, data_point) in test_data.data.outer_iter().enumerate() {
426 current_data_point_idx = data_point_idx;
427 if data_point == sample {
428 assert_eq!(found_data[&data_point_idx], 0);
429 *found_data.get_mut(&data_point_idx).unwrap() += 1;
430 break;
431 }
432 }
433 assert_eq!(
434 arr0(target),
435 test_data.labels.index_axis(Axis(0), current_data_point_idx)
436 );
437 *found_labels.get_mut(¤t_data_point_idx).unwrap() += 1;
438 }
439 assert_eq!(found_data.values().sum::<usize>(), (i + 1) * batch_size);
440 assert_eq!(found_labels.values().sum::<usize>(), (i + 1) * batch_size);
441 }
442 assert_eq!(current_i, (test_data.dataset.len() - 1) / batch_size);
443 }
444
445 #[test]
446 fn vec_of_token() {
447 let dataset = vec![
448 (0, vec![1, 23, 4, 0]),
449 (1, vec![4, 0, 0, 0]),
450 (1, vec![8, 23, 12, 3]),
451 (0, vec![2, 45, 4, 0]),
452 ];
453
454 let loader = DataLoader::builder(dataset).batch_size(2).build();
455
456 let mut iter = loader.iter();
457
458 assert_eq!(
459 iter.next(),
460 Some((
461 array![0, 1],
462 vec![array![1, 4], array![23, 0], array![4, 0], array![0, 0]]
463 ))
464 );
465 }
466}