ai_dataloader/iterable/
dataloader.rs

1//! # Iterable `DataLoader`
2//!
3//! Iterable `DataLoader` is a `DataLoader` that operate on an iterable dataset.
4//! An iterable dataset is just a type that implement `IntoIterator`.
5
6mod builder;
7use builder::Builder;
8use rand::{seq::SliceRandom, thread_rng};
9
10use crate::collate::{Collate, DefaultCollate};
11
12/// For iterable dataset, the `datalaoder` will yield until the underlying iterator is `None`.
13/// As the iteration over the dataset can be done multiple time, depending if the underlying dataset iterator consume the dataset or not.
14#[derive(Debug)]
15pub struct DataLoader<D, C> {
16    /// The dataset we will iterate over.
17    dataset: D,
18    /// The number of sample a batch will contain.
19    batch_size: usize,
20    /// If `true`, the sampler will drop the last batch if
21    /// its size were less than `batch_size`.
22    drop_last: bool,
23    /// Collate function.
24    collate_fn: C,
25    /// If `true` the sample in the batch will be shuffled
26    shuffle: bool,
27}
28
29impl<D> DataLoader<D, DefaultCollate>
30where
31    D: IntoIterator,
32    DefaultCollate: Collate<D::Item>,
33{
34    /// return a [`DataLoader`] builder.
35    pub fn builder(dataset: D) -> Builder<D, DefaultCollate> {
36        Builder::new(dataset)
37    }
38}
39
40// we want to use dataloader in for loop
41// A dataset is something we can turn into an iterator.
42// We make a an iterator that consume this iterator and yield only batches of it.
43impl<D, C> IntoIterator for DataLoader<D, C>
44where
45    D: IntoIterator,
46    C: Collate<<D as IntoIterator>::Item>,
47{
48    // We yield batch of dataset element (which can be transformed by the collate function).
49    type Item = C::Output;
50    type IntoIter = IntoIter<D::IntoIter, C>;
51
52    fn into_iter(self) -> Self::IntoIter {
53        IntoIter {
54            batch_size: self.batch_size,
55            dataset_iter: self.dataset.into_iter(),
56            drop_last: self.drop_last,
57            collate_fn: self.collate_fn,
58            shuffle: self.shuffle,
59        }
60    }
61}
62
63/// Iterator returned by `into_iter` function.
64#[derive(Debug)]
65pub struct IntoIter<D, C> {
66    batch_size: usize,
67    dataset_iter: D,
68    drop_last: bool,
69    collate_fn: C,
70    shuffle: bool,
71}
72
73impl<D, C> Iterator for IntoIter<D, C>
74where
75    D: Iterator,
76    C: Collate<D::Item>,
77{
78    type Item = C::Output;
79    fn next(&mut self) -> Option<Self::Item> {
80        let mut batch = self
81            .dataset_iter
82            .by_ref()
83            .take(self.batch_size)
84            .collect::<Vec<_>>();
85
86        if batch.is_empty() {
87            return None;
88        }
89
90        if batch.len() == self.batch_size || (batch.len() != self.batch_size && !self.drop_last) {
91            if self.shuffle {
92                batch.shuffle(&mut thread_rng());
93            }
94            return Some(self.collate_fn.collate(batch));
95        }
96        None
97    }
98    fn size_hint(&self) -> (usize, Option<usize>) {
99        let (lower, _) = self.dataset_iter.size_hint();
100        let lower = if self.drop_last {
101            lower / self.batch_size
102        } else {
103            (lower + self.batch_size - 1) / self.batch_size
104        };
105        (lower, Some(lower))
106    }
107}
108
109impl<D, C> ExactSizeIterator for IntoIter<D, C>
110where
111    D: Iterator + ExactSizeIterator,
112    C: Collate<D::Item>,
113{
114}
115
116/// Iterator returned by `iter` function.
117#[derive(Debug)]
118pub struct Iter<'dataset, D, C> {
119    batch_size: usize,
120    #[allow(clippy::struct_field_names)]
121    dataset_iter: D,
122    drop_last: bool,
123    collate_fn: &'dataset C,
124    shuffle: bool,
125}
126
127impl<'dataset, D, C> IntoIterator for &'dataset DataLoader<D, C>
128where
129    D: 'dataset,
130    &'dataset D: IntoIterator,
131    C: Collate<<&'dataset D as IntoIterator>::Item>,
132{
133    type Item = C::Output;
134    type IntoIter = Iter<'dataset, <&'dataset D as IntoIterator>::IntoIter, C>;
135
136    fn into_iter(self) -> Self::IntoIter {
137        Iter {
138            batch_size: self.batch_size,
139            dataset_iter: self.dataset.into_iter(),
140            drop_last: self.drop_last,
141            collate_fn: &self.collate_fn,
142            shuffle: self.shuffle,
143        }
144    }
145}
146
147impl<'dataset, D, C> DataLoader<D, C>
148where
149    D: 'dataset,
150    &'dataset D: IntoIterator,
151    C: Collate<<&'dataset D as IntoIterator>::Item>,
152{
153    /// Iterate over the dataloader without consuming the underlying dataset.
154    /// As it make no sens to collate reference into a tensor, by default element are copied.
155    pub fn iter(&'dataset self) -> Iter<'_, <&'dataset D as IntoIterator>::IntoIter, C> {
156        Iter {
157            batch_size: self.batch_size,
158            dataset_iter: self.dataset.into_iter(),
159            drop_last: self.drop_last,
160            collate_fn: &self.collate_fn,
161            shuffle: self.shuffle,
162        }
163    }
164}
165
166impl<'dataset, D, C> Iterator for Iter<'dataset, D, C>
167where
168    D: Iterator,
169    C: Collate<D::Item>,
170{
171    type Item = C::Output;
172    fn next(&mut self) -> Option<Self::Item> {
173        let mut batch = self
174            .dataset_iter
175            .by_ref()
176            .take(self.batch_size)
177            .collect::<Vec<_>>();
178
179        if batch.is_empty() {
180            return None;
181        }
182
183        if batch.len() == self.batch_size || (batch.len() != self.batch_size && !self.drop_last) {
184            if self.shuffle {
185                batch.shuffle(&mut thread_rng());
186            }
187            return Some(self.collate_fn.collate(batch));
188        }
189        None
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196
197    use crate::collate::NoOpCollate;
198    use ndarray::array;
199
200    #[test]
201    fn multiple_iteration() {
202        let dataset = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
203        let loader = DataLoader::builder(dataset).batch_size(2).build();
204
205        for sample in loader.iter() {
206            dbg!(sample);
207        }
208
209        for sample in &loader {
210            dbg!(sample);
211        }
212
213        let mut into_iter = loader.into_iter();
214        assert_eq!(into_iter.next(), Some(array![0, 1]));
215        assert_eq!(into_iter.next(), Some(array![2, 3]));
216        assert_eq!(into_iter.next(), Some(array![4, 5]));
217        assert_eq!(into_iter.next(), Some(array![6, 7]));
218        assert_eq!(into_iter.next(), Some(array![8, 9]));
219        assert_eq!(into_iter.next(), Some(array![10]));
220        assert_eq!(into_iter.next(), None);
221    }
222
223    #[test]
224    fn drop_last() {
225        let dataset = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
226        let loader = DataLoader::builder(dataset)
227            .batch_size(2)
228            .drop_last()
229            .build();
230
231        let mut into_iter = loader.into_iter();
232        assert_eq!(into_iter.next(), Some(array![0, 1]));
233        assert_eq!(into_iter.next(), Some(array![2, 3]));
234        assert_eq!(into_iter.next(), Some(array![4, 5]));
235        assert_eq!(into_iter.next(), Some(array![6, 7]));
236        assert_eq!(into_iter.next(), Some(array![8, 9]));
237        assert_eq!(into_iter.next(), None);
238    }
239
240    #[test]
241    fn custom_collate() {
242        let dataset = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
243        let loader = DataLoader::builder(dataset)
244            .batch_size(2)
245            .collate_fn(NoOpCollate)
246            .build();
247
248        let mut into_iter = loader.into_iter();
249        assert_eq!(into_iter.next(), Some(vec![0, 1]));
250        assert_eq!(into_iter.next(), Some(vec![2, 3]));
251        assert_eq!(into_iter.next(), Some(vec![4, 5]));
252        assert_eq!(into_iter.next(), Some(vec![6, 7]));
253        assert_eq!(into_iter.next(), Some(vec![8, 9]));
254        assert_eq!(into_iter.next(), Some(vec![10]));
255        assert_eq!(into_iter.next(), None);
256    }
257
258    #[test]
259    fn vec_of_token() {
260        let dataset = vec![
261            (0, vec![1, 23, 4, 0]),
262            (1, vec![4, 0, 0, 0]),
263            (1, vec![8, 23, 12, 3]),
264            (0, vec![2, 45, 4, 0]),
265        ];
266
267        let loader = DataLoader::builder(dataset).batch_size(2).build();
268
269        for el in &loader {
270            dbg!(el);
271        }
272
273        let mut iter = loader.iter();
274
275        assert_eq!(
276            iter.next(),
277            Some((
278                array![0, 1],
279                vec![array![1, 4], array![23, 0], array![4, 0], array![0, 0]]
280            ))
281        );
282    }
283
284    #[test]
285    fn len() {
286        let dataset = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
287        let loader = DataLoader::builder(dataset)
288            .batch_size(2)
289            .drop_last()
290            .build();
291
292        let into_iter = loader.into_iter();
293        assert_eq!(into_iter.len(), 5);
294
295        let dataset = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
296        let loader = DataLoader::builder(dataset).batch_size(2).build();
297
298        let mut into_iter = loader.into_iter();
299        assert_eq!(into_iter.len(), 6);
300        into_iter.next();
301        assert_eq!(into_iter.len(), 5);
302    }
303}