1mod builder;
7use builder::Builder;
8use rand::{seq::SliceRandom, thread_rng};
9
10use crate::collate::{Collate, DefaultCollate};
11
12#[derive(Debug)]
15pub struct DataLoader<D, C> {
16 dataset: D,
18 batch_size: usize,
20 drop_last: bool,
23 collate_fn: C,
25 shuffle: bool,
27}
28
29impl<D> DataLoader<D, DefaultCollate>
30where
31 D: IntoIterator,
32 DefaultCollate: Collate<D::Item>,
33{
34 pub fn builder(dataset: D) -> Builder<D, DefaultCollate> {
36 Builder::new(dataset)
37 }
38}
39
40impl<D, C> IntoIterator for DataLoader<D, C>
44where
45 D: IntoIterator,
46 C: Collate<<D as IntoIterator>::Item>,
47{
48 type Item = C::Output;
50 type IntoIter = IntoIter<D::IntoIter, C>;
51
52 fn into_iter(self) -> Self::IntoIter {
53 IntoIter {
54 batch_size: self.batch_size,
55 dataset_iter: self.dataset.into_iter(),
56 drop_last: self.drop_last,
57 collate_fn: self.collate_fn,
58 shuffle: self.shuffle,
59 }
60 }
61}
62
63#[derive(Debug)]
65pub struct IntoIter<D, C> {
66 batch_size: usize,
67 dataset_iter: D,
68 drop_last: bool,
69 collate_fn: C,
70 shuffle: bool,
71}
72
73impl<D, C> Iterator for IntoIter<D, C>
74where
75 D: Iterator,
76 C: Collate<D::Item>,
77{
78 type Item = C::Output;
79 fn next(&mut self) -> Option<Self::Item> {
80 let mut batch = self
81 .dataset_iter
82 .by_ref()
83 .take(self.batch_size)
84 .collect::<Vec<_>>();
85
86 if batch.is_empty() {
87 return None;
88 }
89
90 if batch.len() == self.batch_size || (batch.len() != self.batch_size && !self.drop_last) {
91 if self.shuffle {
92 batch.shuffle(&mut thread_rng());
93 }
94 return Some(self.collate_fn.collate(batch));
95 }
96 None
97 }
98 fn size_hint(&self) -> (usize, Option<usize>) {
99 let (lower, _) = self.dataset_iter.size_hint();
100 let lower = if self.drop_last {
101 lower / self.batch_size
102 } else {
103 (lower + self.batch_size - 1) / self.batch_size
104 };
105 (lower, Some(lower))
106 }
107}
108
109impl<D, C> ExactSizeIterator for IntoIter<D, C>
110where
111 D: Iterator + ExactSizeIterator,
112 C: Collate<D::Item>,
113{
114}
115
116#[derive(Debug)]
118pub struct Iter<'dataset, D, C> {
119 batch_size: usize,
120 #[allow(clippy::struct_field_names)]
121 dataset_iter: D,
122 drop_last: bool,
123 collate_fn: &'dataset C,
124 shuffle: bool,
125}
126
127impl<'dataset, D, C> IntoIterator for &'dataset DataLoader<D, C>
128where
129 D: 'dataset,
130 &'dataset D: IntoIterator,
131 C: Collate<<&'dataset D as IntoIterator>::Item>,
132{
133 type Item = C::Output;
134 type IntoIter = Iter<'dataset, <&'dataset D as IntoIterator>::IntoIter, C>;
135
136 fn into_iter(self) -> Self::IntoIter {
137 Iter {
138 batch_size: self.batch_size,
139 dataset_iter: self.dataset.into_iter(),
140 drop_last: self.drop_last,
141 collate_fn: &self.collate_fn,
142 shuffle: self.shuffle,
143 }
144 }
145}
146
147impl<'dataset, D, C> DataLoader<D, C>
148where
149 D: 'dataset,
150 &'dataset D: IntoIterator,
151 C: Collate<<&'dataset D as IntoIterator>::Item>,
152{
153 pub fn iter(&'dataset self) -> Iter<'_, <&'dataset D as IntoIterator>::IntoIter, C> {
156 Iter {
157 batch_size: self.batch_size,
158 dataset_iter: self.dataset.into_iter(),
159 drop_last: self.drop_last,
160 collate_fn: &self.collate_fn,
161 shuffle: self.shuffle,
162 }
163 }
164}
165
166impl<'dataset, D, C> Iterator for Iter<'dataset, D, C>
167where
168 D: Iterator,
169 C: Collate<D::Item>,
170{
171 type Item = C::Output;
172 fn next(&mut self) -> Option<Self::Item> {
173 let mut batch = self
174 .dataset_iter
175 .by_ref()
176 .take(self.batch_size)
177 .collect::<Vec<_>>();
178
179 if batch.is_empty() {
180 return None;
181 }
182
183 if batch.len() == self.batch_size || (batch.len() != self.batch_size && !self.drop_last) {
184 if self.shuffle {
185 batch.shuffle(&mut thread_rng());
186 }
187 return Some(self.collate_fn.collate(batch));
188 }
189 None
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196
197 use crate::collate::NoOpCollate;
198 use ndarray::array;
199
200 #[test]
201 fn multiple_iteration() {
202 let dataset = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
203 let loader = DataLoader::builder(dataset).batch_size(2).build();
204
205 for sample in loader.iter() {
206 dbg!(sample);
207 }
208
209 for sample in &loader {
210 dbg!(sample);
211 }
212
213 let mut into_iter = loader.into_iter();
214 assert_eq!(into_iter.next(), Some(array![0, 1]));
215 assert_eq!(into_iter.next(), Some(array![2, 3]));
216 assert_eq!(into_iter.next(), Some(array![4, 5]));
217 assert_eq!(into_iter.next(), Some(array![6, 7]));
218 assert_eq!(into_iter.next(), Some(array![8, 9]));
219 assert_eq!(into_iter.next(), Some(array![10]));
220 assert_eq!(into_iter.next(), None);
221 }
222
223 #[test]
224 fn drop_last() {
225 let dataset = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
226 let loader = DataLoader::builder(dataset)
227 .batch_size(2)
228 .drop_last()
229 .build();
230
231 let mut into_iter = loader.into_iter();
232 assert_eq!(into_iter.next(), Some(array![0, 1]));
233 assert_eq!(into_iter.next(), Some(array![2, 3]));
234 assert_eq!(into_iter.next(), Some(array![4, 5]));
235 assert_eq!(into_iter.next(), Some(array![6, 7]));
236 assert_eq!(into_iter.next(), Some(array![8, 9]));
237 assert_eq!(into_iter.next(), None);
238 }
239
240 #[test]
241 fn custom_collate() {
242 let dataset = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
243 let loader = DataLoader::builder(dataset)
244 .batch_size(2)
245 .collate_fn(NoOpCollate)
246 .build();
247
248 let mut into_iter = loader.into_iter();
249 assert_eq!(into_iter.next(), Some(vec![0, 1]));
250 assert_eq!(into_iter.next(), Some(vec![2, 3]));
251 assert_eq!(into_iter.next(), Some(vec![4, 5]));
252 assert_eq!(into_iter.next(), Some(vec![6, 7]));
253 assert_eq!(into_iter.next(), Some(vec![8, 9]));
254 assert_eq!(into_iter.next(), Some(vec![10]));
255 assert_eq!(into_iter.next(), None);
256 }
257
258 #[test]
259 fn vec_of_token() {
260 let dataset = vec![
261 (0, vec![1, 23, 4, 0]),
262 (1, vec![4, 0, 0, 0]),
263 (1, vec![8, 23, 12, 3]),
264 (0, vec![2, 45, 4, 0]),
265 ];
266
267 let loader = DataLoader::builder(dataset).batch_size(2).build();
268
269 for el in &loader {
270 dbg!(el);
271 }
272
273 let mut iter = loader.iter();
274
275 assert_eq!(
276 iter.next(),
277 Some((
278 array![0, 1],
279 vec![array![1, 4], array![23, 0], array![4, 0], array![0, 0]]
280 ))
281 );
282 }
283
284 #[test]
285 fn len() {
286 let dataset = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
287 let loader = DataLoader::builder(dataset)
288 .batch_size(2)
289 .drop_last()
290 .build();
291
292 let into_iter = loader.into_iter();
293 assert_eq!(into_iter.len(), 5);
294
295 let dataset = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
296 let loader = DataLoader::builder(dataset).batch_size(2).build();
297
298 let mut into_iter = loader.into_iter();
299 assert_eq!(into_iter.len(), 6);
300 into_iter.next();
301 assert_eq!(into_iter.len(), 5);
302 }
303}