ai_dataloader/indexable/
dataloader.rs

1//! Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset.
2
3use super::fetch::{Fetcher, MapDatasetFetcher};
4use crate::{
5    collate::{Collate, DefaultCollate},
6    sampler::{BatchIterator, BatchSampler, Sampler, SequentialSampler},
7    Dataset, Len,
8};
9
10mod builder;
11use builder::Builder;
12
13/// Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset.
14///
15///
16/// ```rust
17/// use ai_dataloader::indexable::DataLoader;
18///
19/// let loader = DataLoader::builder(vec![(0, "hola"), (1, "hello"), (2, "hallo"), (3, "bonjour")]).batch_size(2).shuffle().build();
20///
21/// for (label, text) in &loader {
22///     println!("Label {label:?}");
23///     println!("Text {text:?}");
24/// }
25/// ```
26///
27#[derive(Debug, Clone, PartialEq, PartialOrd, Hash, Eq, Ord)]
28pub struct DataLoader<D, S = SequentialSampler, C = DefaultCollate> {
29    /// Dataset from which to load the data.
30    dataset: D,
31    /// Return a batch of indices at a time.
32    batch_sampler: BatchSampler<S>,
33    /// Collate function.
34    collate_fn: C,
35}
36
37impl<D> DataLoader<D, SequentialSampler, DefaultCollate>
38where
39    D: Dataset,
40    DefaultCollate: Collate<D::Sample>,
41{
42    /// Helper to return a [`DataLoader`] builder.
43    pub fn builder(dataset: D) -> Builder<D, SequentialSampler, DefaultCollate> {
44        Builder::new(dataset)
45    }
46}
47
48impl<D, S, C> DataLoader<D, S, C>
49where
50    D: Dataset + Sync,
51    S: Sampler,
52    C: Collate<D::Sample>,
53    D::Sample: Send,
54{
55    /// Return not owning iterator over the dataloader.
56    pub fn iter(&self) -> SingleProcessDataLoaderIter<'_, D, S, C> {
57        SingleProcessDataLoaderIter::new(self)
58    }
59}
60
61impl<D, S, C> Len for DataLoader<D, S, C>
62where
63    D: Dataset,
64    S: Sampler,
65    C: Collate<D::Sample>,
66{
67    /// Return the number of batch that contain the dataloader.
68    fn len(&self) -> usize {
69        self.batch_sampler.len()
70    }
71}
72
73/// Iterate over the dataloader with a single thread.
74#[derive(Debug)]
75pub struct SingleProcessDataLoaderIter<'dataset, D, S = SequentialSampler, C = DefaultCollate>
76where
77    D: Dataset + Sync,
78    S: Sampler,
79    C: Collate<D::Sample>,
80{
81    /// The batch iterator of this iterator.
82    sampler_iter: BatchIterator<S::IntoIter>,
83    /// Number of sample yielded.
84    num_yielded: u64,
85    /// Used to fetch the data from the dataset.
86    data_fetcher: MapDatasetFetcher<'dataset, D, C>,
87}
88
89impl<'dataset, D, S, C> SingleProcessDataLoaderIter<'dataset, D, S, C>
90where
91    D: Dataset + Sync,
92    S: Sampler,
93    C: Collate<D::Sample>,
94    D::Sample: Send,
95{
96    fn new(loader: &DataLoader<D, S, C>) -> SingleProcessDataLoaderIter<'_, D, S, C> {
97        SingleProcessDataLoaderIter {
98            sampler_iter: loader.batch_sampler.iter(),
99            num_yielded: 0,
100            data_fetcher: MapDatasetFetcher {
101                dataset: &loader.dataset,
102                collate_fn: &loader.collate_fn,
103            },
104        }
105    }
106    fn next_index(&mut self) -> Option<Vec<usize>> {
107        self.sampler_iter.next()
108    }
109    fn next_data(&mut self) -> Option<C::Output> {
110        let index = self.next_index();
111        if let Some(index) = index {
112            let data = self.data_fetcher.fetch(index);
113            return Some(data);
114        }
115        None
116    }
117}
118
119impl<'dataset, D, S, C> Iterator for SingleProcessDataLoaderIter<'dataset, D, S, C>
120where
121    D: Dataset + Sync,
122    S: Sampler,
123    C: Collate<D::Sample>,
124    D::Sample: Send,
125{
126    type Item = C::Output;
127    fn next(&mut self) -> Option<Self::Item> {
128        let data = self.next_data();
129
130        if let Some(data) = data {
131            self.num_yielded += 1;
132            return Some(data);
133        }
134        None
135    }
136    fn size_hint(&self) -> (usize, Option<usize>) {
137        let (lower, upper) = self.sampler_iter.size_hint();
138        (lower, upper)
139    }
140}
141
142impl<'dataset, D, S, C> IntoIterator for &'dataset DataLoader<D, S, C>
143where
144    D: Dataset + Sync,
145    S: Sampler,
146    C: Collate<D::Sample>,
147    D::Sample: Send,
148{
149    type Item = C::Output;
150    type IntoIter = SingleProcessDataLoaderIter<'dataset, D, S, C>;
151
152    fn into_iter(self) -> Self::IntoIter {
153        self.iter()
154    }
155}
156
157impl<'dataset, D, S, C> ExactSizeIterator for SingleProcessDataLoaderIter<'dataset, D, S, C>
158where
159    D: Dataset + Sync,
160    S: Sampler,
161    S::IntoIter: ExactSizeIterator,
162    C: Collate<D::Sample>,
163    D::Sample: Send,
164{
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use crate::collate::NoOpCollate;
171    use crate::sampler::RandomSampler;
172    use crate::sampler::SequentialSampler;
173    use crate::Len;
174    use crate::NdarrayDataset;
175    use ndarray::{arr0, array, Array, Array1, Array4, Axis, Ix1, Ix4, Slice};
176    use ndarray_rand::rand_distr::{Normal, Uniform};
177    use ndarray_rand::RandomExt;
178    use std::collections::HashMap;
179
180    #[test]
181    fn len() {
182        let dataset = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
183        let dataloader = DataLoader::builder(dataset)
184            .batch_size(2)
185            .drop_last()
186            .build();
187        assert_eq!(dataloader.len(), dataloader.batch_sampler.len());
188        assert_eq!(dataloader.len(), 5);
189        let mut iter = dataloader.iter();
190        assert_eq!(iter.len(), 5);
191        iter.next();
192        assert_eq!(iter.len(), 4);
193    }
194
195    #[test]
196    fn one_dimension_basic() {
197        let dataset = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
198        let dataloader = DataLoader::builder(dataset).batch_size(2).build();
199
200        let mut iter = dataloader.iter();
201        assert_eq!(iter.next(), Some(array![1, 2]));
202        assert_eq!(iter.next(), Some(array![3, 4]));
203        assert_eq!(iter.next(), Some(array![5, 6]));
204        assert_eq!(iter.next(), Some(array![7, 8]));
205        assert_eq!(iter.next(), Some(array![9, 10]));
206        assert_eq!(iter.next(), None);
207    }
208
209    #[test]
210    fn two_iteration() {
211        let dataset = vec![1, 2, 3, 4];
212        let dataloader = DataLoader::builder(dataset).batch_size(2).build();
213
214        let mut iter = dataloader.iter();
215        assert_eq!(iter.next(), Some(array![1, 2]));
216        assert_eq!(iter.next(), Some(array![3, 4]));
217        assert_eq!(iter.next(), None);
218        let mut iter = dataloader.iter();
219        assert_eq!(iter.next(), Some(array![1, 2]));
220        assert_eq!(iter.next(), Some(array![3, 4]));
221        assert_eq!(iter.next(), None);
222    }
223
224    #[test]
225    fn one_dimension_basic_string() {
226        let dataset = vec![String::from("a"), String::from("b")];
227        let dataloader = DataLoader::builder(dataset).build();
228
229        let mut iter = dataloader.iter();
230        assert_eq!(iter.next(), Some(vec![String::from("a")]));
231        assert_eq!(iter.next(), Some(vec![String::from("b")]));
232        assert_eq!(iter.next(), None);
233    }
234    #[test]
235    fn collate() {
236        let dataset = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
237
238        let dataloader = DataLoader::builder(dataset)
239            .batch_size(2)
240            .collate_fn(NoOpCollate)
241            .build();
242
243        let mut iter = dataloader.iter();
244
245        assert_eq!(iter.next(), Some(vec![1, 2]));
246        assert_eq!(iter.next(), Some(vec![3, 4]));
247        assert_eq!(iter.next(), Some(vec![5, 6]));
248        assert_eq!(iter.next(), Some(vec![7, 8]));
249        assert_eq!(iter.next(), Some(vec![9, 10]));
250        assert_eq!(iter.next(), None);
251    }
252    struct TestDataLoader<S: Sampler> {
253        loader: DataLoader<NdarrayDataset<f64, f64, Ix4, Ix1>, S>,
254        data: Array4<f64>,
255        labels: Array1<f64>,
256        dataset: NdarrayDataset<f64, f64, Ix4, Ix1>,
257    }
258    enum TestDataLoaderData {
259        Sequential(TestDataLoader<SequentialSampler>),
260        Random(TestDataLoader<RandomSampler>),
261    }
262    fn get_loader_with_dummy_data(batch_size: usize, shuffle: bool) -> TestDataLoaderData {
263        // We use a normal distribution for the random numbers
264        let normal: Normal<f64> = Normal::new(0.0, 1.0).unwrap();
265        // We create a 4-dimensional array populated with random value
266        let data = Array::random((100, 2, 3, 5), normal);
267        // We create a 1-dimensional array populated with random value
268        let labels = Array::random(100, Uniform::<f64>::new(0., 50.));
269        // Basic Test dataset
270        let dataset = NdarrayDataset {
271            ndarrays: (data.clone(), labels.clone()),
272        };
273
274        if shuffle {
275            let loader = DataLoader::builder(dataset.clone())
276                .batch_size(batch_size)
277                .shuffle()
278                .build();
279
280            TestDataLoaderData::Random(TestDataLoader {
281                loader,
282                data,
283                labels,
284                dataset,
285            })
286        } else {
287            let loader = DataLoader::builder(dataset.clone())
288                .batch_size(batch_size)
289                .build();
290
291            TestDataLoaderData::Sequential(TestDataLoader {
292                loader,
293                data,
294                labels,
295                dataset,
296            })
297        }
298    }
299
300    #[test]
301    fn sequential_non_batch() {
302        let batch_size = 1;
303        let test_dataloader_data = tests::get_loader_with_dummy_data(batch_size, false);
304        let test_data;
305        if let TestDataLoaderData::Sequential(test_dataloader_data) = test_dataloader_data {
306            test_data = test_dataloader_data;
307        } else {
308            panic!("Expected a sequential loader")
309        }
310        let mut current_idx = 0;
311
312        for (idx, (sample, target)) in test_data.loader.iter().enumerate() {
313            assert_eq!(
314                sample,
315                test_data
316                    .data
317                    .slice_axis(Axis(0), Slice::from(idx..idx + batch_size))
318            );
319            assert_eq!(
320                target,
321                test_data
322                    .labels
323                    .slice_axis(Axis(0), Slice::from(idx..idx + batch_size))
324            );
325            current_idx = idx;
326        }
327        assert_eq!(current_idx, test_data.dataset.len() - 1);
328    }
329
330    #[test]
331    fn sequential_batch() {
332        let batch_size = 2;
333        let test_dataloader_data = tests::get_loader_with_dummy_data(2, false);
334        let test_data;
335        if let TestDataLoaderData::Sequential(test_dataloader_data) = test_dataloader_data {
336            test_data = test_dataloader_data;
337        } else {
338            panic!("Expected a sequential loader")
339        }
340
341        let mut current_i = 0;
342
343        for (i, (sample, target)) in test_data.loader.iter().enumerate() {
344            let idx = i * batch_size;
345            assert_eq!(
346                sample,
347                test_data
348                    .data
349                    .slice_axis(Axis(0), Slice::from(idx..idx + batch_size))
350            );
351            assert_eq!(
352                target,
353                test_data
354                    .labels
355                    .slice_axis(Axis(0), Slice::from(idx..idx + batch_size))
356            );
357            current_i = i;
358        }
359        assert_eq!(current_i, (test_data.dataset.len() - 1) / batch_size);
360    }
361
362    #[test]
363    fn shuffle_non_batch() {
364        let test_dataloader_data = tests::get_loader_with_dummy_data(1, true);
365        let test_data;
366        if let TestDataLoaderData::Random(test_dataloader_data) = test_dataloader_data {
367            test_data = test_dataloader_data;
368        } else {
369            panic!("Expected a random loader")
370        }
371        // 2 maps to keep track on what we have iterated.
372        let mut found_data: HashMap<_, _> = (0..test_data.data.len())
373            .zip(vec![0; test_data.data.len()])
374            .collect();
375        let mut found_labels: HashMap<_, _> = (0..test_data.labels.len())
376            .zip(vec![0; test_data.labels.len()])
377            .collect();
378        let mut current_i = 0;
379        for (i, (sample, target)) in test_data.loader.iter().enumerate() {
380            current_i = i;
381            let mut current_data_point_idx = 0;
382            // We iterate over the original data, finding the data corresponding to the one the dataloader just yield us
383            for (data_point_idx, data_point) in test_data.data.outer_iter().enumerate() {
384                current_data_point_idx = data_point_idx;
385                // We need to take the inner of the sample (It's not automatically done like in python)
386                if data_point == sample.index_axis(Axis(0), 0) {
387                    assert_eq!(found_data[&data_point_idx], 0);
388                    *found_data.get_mut(&data_point_idx).unwrap() += 1;
389                    break;
390                }
391            }
392
393            assert_eq!(
394                arr0(target[0]),
395                test_data.labels.index_axis(Axis(0), current_data_point_idx)
396            );
397            *found_labels.get_mut(&current_data_point_idx).unwrap() += 1;
398            assert_eq!(found_data.values().sum::<usize>(), i + 1);
399            assert_eq!(found_labels.values().sum::<usize>(), i + 1);
400        }
401        assert_eq!(current_i, test_data.dataset.len() - 1);
402    }
403
404    #[test]
405    fn shuffle_batch() {
406        let batch_size = 2;
407        let test_dataloader_data = tests::get_loader_with_dummy_data(batch_size, true);
408        let test_data;
409        if let TestDataLoaderData::Random(test_dataloader_data) = test_dataloader_data {
410            test_data = test_dataloader_data;
411        } else {
412            panic!("Expected a random loader")
413        }
414        let mut found_data: HashMap<_, _> = (0..test_data.data.len())
415            .zip(vec![0; test_data.data.len()])
416            .collect();
417        let mut found_labels: HashMap<_, _> = (0..test_data.labels.len())
418            .zip(vec![0; test_data.labels.len()])
419            .collect();
420        let mut current_i = 0;
421        for (i, (batch_samples, batch_targets)) in test_data.loader.iter().enumerate() {
422            current_i = i;
423            for (sample, target) in batch_samples.outer_iter().zip(batch_targets) {
424                let mut current_data_point_idx = 0;
425                for (data_point_idx, data_point) in test_data.data.outer_iter().enumerate() {
426                    current_data_point_idx = data_point_idx;
427                    if data_point == sample {
428                        assert_eq!(found_data[&data_point_idx], 0);
429                        *found_data.get_mut(&data_point_idx).unwrap() += 1;
430                        break;
431                    }
432                }
433                assert_eq!(
434                    arr0(target),
435                    test_data.labels.index_axis(Axis(0), current_data_point_idx)
436                );
437                *found_labels.get_mut(&current_data_point_idx).unwrap() += 1;
438            }
439            assert_eq!(found_data.values().sum::<usize>(), (i + 1) * batch_size);
440            assert_eq!(found_labels.values().sum::<usize>(), (i + 1) * batch_size);
441        }
442        assert_eq!(current_i, (test_data.dataset.len() - 1) / batch_size);
443    }
444
445    #[test]
446    fn vec_of_token() {
447        let dataset = vec![
448            (0, vec![1, 23, 4, 0]),
449            (1, vec![4, 0, 0, 0]),
450            (1, vec![8, 23, 12, 3]),
451            (0, vec![2, 45, 4, 0]),
452        ];
453
454        let loader = DataLoader::builder(dataset).batch_size(2).build();
455
456        let mut iter = loader.iter();
457
458        assert_eq!(
459            iter.next(),
460            Some((
461                array![0, 1],
462                vec![array![1, 4], array![23, 0], array![4, 0], array![0, 0]]
463            ))
464        );
465    }
466}