1use crate::collate::{stack_tensors, Collate};
10use crate::dataset::Dataset;
11use crate::sampler::{RandomSampler, Sampler, SequentialSampler};
12use axonml_tensor::Tensor;
13use rayon::prelude::*;
14use std::marker::PhantomData;
15
16#[derive(Debug, Clone)]
22pub struct Batch {
23 pub data: Tensor<f32>,
25 pub targets: Tensor<f32>,
27 pub size: usize,
29}
30
31impl Batch {
32 #[must_use] pub fn new(data: Tensor<f32>, targets: Tensor<f32>) -> Self {
34 let size = data.shape()[0];
35 Self {
36 data,
37 targets,
38 size,
39 }
40 }
41
42 #[must_use] pub fn len(&self) -> usize {
44 self.size
45 }
46
47 #[must_use] pub fn is_empty(&self) -> bool {
49 self.size == 0
50 }
51}
52
53pub struct DataLoader<D>
61where
62 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
63{
64 dataset: D,
66 batch_size: usize,
68 shuffle: bool,
70 drop_last: bool,
72 num_workers: usize,
74}
75
76impl<D> DataLoader<D>
77where
78 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
79{
80 pub fn new(dataset: D, batch_size: usize) -> Self {
82 Self {
83 dataset,
84 batch_size,
85 shuffle: false,
86 drop_last: false,
87 num_workers: 0,
88 }
89 }
90
91 pub fn shuffle(mut self, shuffle: bool) -> Self {
93 self.shuffle = shuffle;
94 self
95 }
96
97 pub fn drop_last(mut self, drop_last: bool) -> Self {
99 self.drop_last = drop_last;
100 self
101 }
102
103 pub fn num_workers(mut self, num_workers: usize) -> Self {
105 self.num_workers = num_workers;
106 self
107 }
108
109 pub fn batch_size(&self) -> usize {
111 self.batch_size
112 }
113
114 pub fn len(&self) -> usize {
116 let total = self.dataset.len();
117 if self.drop_last {
118 total / self.batch_size
119 } else {
120 total.div_ceil(self.batch_size)
121 }
122 }
123
124 pub fn is_empty(&self) -> bool {
126 self.dataset.is_empty()
127 }
128
129 pub fn dataset_len(&self) -> usize {
131 self.dataset.len()
132 }
133
134 pub fn iter(&self) -> DataLoaderIter<'_, D> {
136 let indices: Vec<usize> = if self.shuffle {
137 let sampler = RandomSampler::new(self.dataset.len());
138 sampler.iter().collect()
139 } else {
140 let sampler = SequentialSampler::new(self.dataset.len());
141 sampler.iter().collect()
142 };
143
144 DataLoaderIter {
145 dataset: &self.dataset,
146 indices,
147 batch_size: self.batch_size,
148 drop_last: self.drop_last,
149 position: 0,
150 num_workers: self.num_workers,
151 }
152 }
153}
154
155pub struct DataLoaderIter<'a, D>
161where
162 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
163{
164 dataset: &'a D,
165 indices: Vec<usize>,
166 batch_size: usize,
167 drop_last: bool,
168 position: usize,
169 num_workers: usize,
170}
171
172impl<D> Iterator for DataLoaderIter<'_, D>
173where
174 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
175{
176 type Item = Batch;
177
178 fn next(&mut self) -> Option<Self::Item> {
179 if self.position >= self.indices.len() {
180 return None;
181 }
182
183 let end = (self.position + self.batch_size).min(self.indices.len());
184 let batch_indices = &self.indices[self.position..end];
185
186 if batch_indices.len() < self.batch_size && self.drop_last {
188 return None;
189 }
190
191 let samples: Vec<(Tensor<f32>, Tensor<f32>)> = if self.num_workers > 0 {
193 batch_indices
195 .par_iter()
196 .filter_map(|&idx| self.dataset.get(idx))
197 .collect()
198 } else {
199 batch_indices
201 .iter()
202 .filter_map(|&idx| self.dataset.get(idx))
203 .collect()
204 };
205
206 if samples.is_empty() {
207 return None;
208 }
209
210 let data_samples: Vec<Tensor<f32>> = samples.iter().map(|(x, _)| x.clone()).collect();
212 let target_samples: Vec<Tensor<f32>> = samples.iter().map(|(_, y)| y.clone()).collect();
213
214 let data = stack_tensors(&data_samples);
216 let targets = stack_tensors(&target_samples);
217
218 self.position = end;
219
220 Some(Batch::new(data, targets))
221 }
222}
223
224impl<D> DataLoaderIter<'_, D>
225where
226 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
227{
228 #[must_use] pub fn remaining(&self) -> usize {
230 let remaining_samples = self.indices.len().saturating_sub(self.position);
231 if self.drop_last {
232 remaining_samples / self.batch_size
233 } else {
234 remaining_samples.div_ceil(self.batch_size)
235 }
236 }
237}
238
239pub struct GenericDataLoader<D, C, T>
245where
246 D: Dataset<Item = T>,
247 C: Collate<T>,
248 T: Send,
249{
250 dataset: D,
251 collate_fn: C,
252 batch_size: usize,
253 shuffle: bool,
254 drop_last: bool,
255 num_workers: usize,
256 _phantom: PhantomData<T>,
257}
258
259impl<D, C, T> GenericDataLoader<D, C, T>
260where
261 D: Dataset<Item = T>,
262 C: Collate<T>,
263 T: Send,
264{
265 pub fn new(dataset: D, collate_fn: C, batch_size: usize) -> Self {
267 Self {
268 dataset,
269 collate_fn,
270 batch_size,
271 shuffle: false,
272 drop_last: false,
273 num_workers: 0,
274 _phantom: PhantomData,
275 }
276 }
277
278 pub fn num_workers(mut self, num_workers: usize) -> Self {
280 self.num_workers = num_workers;
281 self
282 }
283
284 pub fn shuffle(mut self, shuffle: bool) -> Self {
286 self.shuffle = shuffle;
287 self
288 }
289
290 pub fn drop_last(mut self, drop_last: bool) -> Self {
292 self.drop_last = drop_last;
293 self
294 }
295
296 pub fn len(&self) -> usize {
298 let total = self.dataset.len();
299 if self.drop_last {
300 total / self.batch_size
301 } else {
302 total.div_ceil(self.batch_size)
303 }
304 }
305
306 pub fn is_empty(&self) -> bool {
308 self.dataset.is_empty()
309 }
310
311 pub fn iter(&self) -> GenericDataLoaderIter<'_, D, C, T> {
313 let indices: Vec<usize> = if self.shuffle {
314 let sampler = RandomSampler::new(self.dataset.len());
315 sampler.iter().collect()
316 } else {
317 (0..self.dataset.len()).collect()
318 };
319
320 GenericDataLoaderIter {
321 dataset: &self.dataset,
322 collate_fn: &self.collate_fn,
323 indices,
324 batch_size: self.batch_size,
325 drop_last: self.drop_last,
326 position: 0,
327 num_workers: self.num_workers,
328 _phantom: PhantomData,
329 }
330 }
331}
332
333pub struct GenericDataLoaderIter<'a, D, C, T>
335where
336 D: Dataset<Item = T>,
337 C: Collate<T>,
338 T: Send,
339{
340 dataset: &'a D,
341 collate_fn: &'a C,
342 indices: Vec<usize>,
343 batch_size: usize,
344 drop_last: bool,
345 position: usize,
346 num_workers: usize,
347 _phantom: PhantomData<T>,
348}
349
350impl<D, C, T> Iterator for GenericDataLoaderIter<'_, D, C, T>
351where
352 D: Dataset<Item = T>,
353 C: Collate<T>,
354 T: Send + Sync,
355{
356 type Item = C::Output;
357
358 fn next(&mut self) -> Option<Self::Item> {
359 if self.position >= self.indices.len() {
360 return None;
361 }
362
363 let end = (self.position + self.batch_size).min(self.indices.len());
364 let batch_indices = &self.indices[self.position..end];
365
366 if batch_indices.len() < self.batch_size && self.drop_last {
367 return None;
368 }
369
370 let samples: Vec<T> = if self.num_workers > 0 {
372 batch_indices
373 .par_iter()
374 .filter_map(|&idx| self.dataset.get(idx))
375 .collect()
376 } else {
377 batch_indices
378 .iter()
379 .filter_map(|&idx| self.dataset.get(idx))
380 .collect()
381 };
382
383 if samples.is_empty() {
384 return None;
385 }
386
387 self.position = end;
388
389 Some(self.collate_fn.collate(samples))
390 }
391}
392
393#[cfg(test)]
398mod tests {
399 use super::*;
400 use crate::collate::DefaultCollate;
401 use crate::dataset::TensorDataset;
402
403 fn create_test_dataset(size: usize) -> TensorDataset {
404 let data: Vec<f32> = (0..size * 2).map(|i| i as f32).collect();
405 let targets: Vec<f32> = (0..size).map(|i| (i % 2) as f32).collect();
406
407 let x = Tensor::from_vec(data, &[size, 2]).unwrap();
408 let y = Tensor::from_vec(targets, &[size]).unwrap();
409
410 TensorDataset::new(x, y)
411 }
412
413 #[test]
414 fn test_dataloader_basic() {
415 let dataset = create_test_dataset(10);
416 let loader = DataLoader::new(dataset, 3);
417
418 assert_eq!(loader.batch_size(), 3);
419 assert_eq!(loader.len(), 4); let batches: Vec<Batch> = loader.iter().collect();
422 assert_eq!(batches.len(), 4);
423
424 assert_eq!(batches[0].len(), 3);
426 assert_eq!(batches[1].len(), 3);
427 assert_eq!(batches[2].len(), 3);
428 assert_eq!(batches[3].len(), 1);
429 }
430
431 #[test]
432 fn test_dataloader_drop_last() {
433 let dataset = create_test_dataset(10);
434 let loader = DataLoader::new(dataset, 3).drop_last(true);
435
436 assert_eq!(loader.len(), 3); let batches: Vec<Batch> = loader.iter().collect();
439 assert_eq!(batches.len(), 3);
440
441 for batch in &batches {
443 assert_eq!(batch.len(), 3);
444 }
445 }
446
447 #[test]
448 fn test_dataloader_shuffle() {
449 let dataset = create_test_dataset(100);
450 let loader = DataLoader::new(dataset, 10).shuffle(true);
451
452 let batch1: Vec<Batch> = loader.iter().take(1).collect();
454 let batch2: Vec<Batch> = loader.iter().take(1).collect();
455
456 assert!(!batch1.is_empty());
459 assert!(!batch2.is_empty());
460 }
461
462 #[test]
463 fn test_dataloader_exact_batches() {
464 let dataset = create_test_dataset(9);
465 let loader = DataLoader::new(dataset, 3);
466
467 let batches: Vec<Batch> = loader.iter().collect();
468 assert_eq!(batches.len(), 3);
469
470 for batch in &batches {
471 assert_eq!(batch.len(), 3);
472 }
473 }
474
475 #[test]
476 fn test_batch_struct() {
477 let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
478 let targets = Tensor::from_vec(vec![0.0, 1.0], &[2]).unwrap();
479
480 let batch = Batch::new(data, targets);
481 assert_eq!(batch.len(), 2);
482 assert!(!batch.is_empty());
483 }
484
485 #[test]
486 fn test_dataloader_empty() {
487 let x = Tensor::from_vec(vec![], &[0, 2]).unwrap();
488 let y = Tensor::from_vec(vec![], &[0]).unwrap();
489 let dataset = TensorDataset::new(x, y);
490 let loader = DataLoader::new(dataset, 3);
491
492 assert!(loader.is_empty());
493 let batches: Vec<Batch> = loader.iter().collect();
494 assert!(batches.is_empty());
495 }
496
497 #[test]
498 fn test_dataloader_single_item() {
499 let dataset = create_test_dataset(1);
500 let loader = DataLoader::new(dataset, 3);
501
502 let batches: Vec<Batch> = loader.iter().collect();
503 assert_eq!(batches.len(), 1);
504 assert_eq!(batches[0].len(), 1);
505 }
506
507 #[test]
508 fn test_dataloader_iteration_order() {
509 let dataset = create_test_dataset(6);
510 let loader = DataLoader::new(dataset, 2).shuffle(false);
511
512 let batches: Vec<Batch> = loader.iter().collect();
513
514 assert_eq!(batches[0].data.to_vec(), vec![0.0, 1.0, 2.0, 3.0]);
516 assert_eq!(batches[1].data.to_vec(), vec![4.0, 5.0, 6.0, 7.0]);
517 assert_eq!(batches[2].data.to_vec(), vec![8.0, 9.0, 10.0, 11.0]);
518 }
519
520 #[test]
521 fn test_generic_dataloader() {
522 let dataset = create_test_dataset(6);
523 let collate = DefaultCollate::new();
524 let loader = GenericDataLoader::new(dataset, collate, 2);
525
526 let batches: Vec<_> = loader.iter().collect();
527 assert_eq!(batches.len(), 3);
528 }
529
530 #[test]
531 fn test_dataloader_remaining() {
532 let dataset = create_test_dataset(10);
533 let loader = DataLoader::new(dataset, 3);
534
535 let mut iter = loader.iter();
536 assert_eq!(iter.remaining(), 4);
537
538 iter.next();
539 assert_eq!(iter.remaining(), 3);
540
541 iter.next();
542 assert_eq!(iter.remaining(), 2);
543 }
544
545 #[test]
546 fn test_parallel_dataloader() {
547 let dataset = create_test_dataset(100);
548 let loader = DataLoader::new(dataset, 10).num_workers(4);
549
550 let batches: Vec<Batch> = loader.iter().collect();
551 assert_eq!(batches.len(), 10);
552
553 let total_samples: usize = batches.iter().map(|b| b.len()).sum();
555 assert_eq!(total_samples, 100);
556 }
557
558 #[test]
559 fn test_parallel_vs_sequential_equivalence() {
560 let dataset_seq = create_test_dataset(50);
562 let dataset_par = create_test_dataset(50);
563
564 let loader_seq = DataLoader::new(dataset_seq, 5).num_workers(0);
566 let batches_seq: Vec<Batch> = loader_seq.iter().collect();
567
568 let loader_par = DataLoader::new(dataset_par, 5).num_workers(4);
570 let batches_par: Vec<Batch> = loader_par.iter().collect();
571
572 assert_eq!(batches_seq.len(), batches_par.len());
574
575 for i in 0..batches_seq.len() {
577 assert_eq!(batches_seq[i].data.to_vec(), batches_par[i].data.to_vec());
578 assert_eq!(batches_seq[i].targets.to_vec(), batches_par[i].targets.to_vec());
579 }
580 }
581
582 #[test]
583 fn test_parallel_dataloader_drop_last() {
584 let dataset = create_test_dataset(95);
585 let loader = DataLoader::new(dataset, 10)
586 .drop_last(true)
587 .num_workers(4);
588
589 let batches: Vec<Batch> = loader.iter().collect();
590 assert_eq!(batches.len(), 9); for batch in &batches {
593 assert_eq!(batch.len(), 10);
594 }
595 }
596
597 #[test]
598 fn test_parallel_generic_dataloader() {
599 let dataset = create_test_dataset(60);
600 let collate = DefaultCollate::new();
601 let loader = GenericDataLoader::new(dataset, collate, 10).num_workers(4);
602
603 let batches: Vec<_> = loader.iter().collect();
604 assert_eq!(batches.len(), 6);
605 }
606}