1use crate::collate::{Collate, stack_tensors};
18use crate::dataset::Dataset;
19use crate::sampler::{RandomSampler, Sampler, SequentialSampler};
20use axonml_core::Device;
21use axonml_tensor::Tensor;
22use rayon::prelude::*;
23use std::marker::PhantomData;
24use std::sync::mpsc;
25use std::thread;
26
27#[derive(Debug, Clone)]
33pub struct Batch {
34 pub data: Tensor<f32>,
36 pub targets: Tensor<f32>,
38 pub size: usize,
40}
41
42impl Batch {
43 #[must_use]
45 pub fn new(data: Tensor<f32>, targets: Tensor<f32>) -> Self {
46 let size = data.shape()[0];
47 Self {
48 data,
49 targets,
50 size,
51 }
52 }
53
54 #[must_use]
56 pub fn len(&self) -> usize {
57 self.size
58 }
59
60 #[must_use]
62 pub fn is_empty(&self) -> bool {
63 self.size == 0
64 }
65}
66
67pub struct DataLoader<D>
75where
76 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
77{
78 dataset: D,
80 batch_size: usize,
82 shuffle: bool,
84 drop_last: bool,
86 num_workers: usize,
88}
89
90impl<D> DataLoader<D>
91where
92 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
93{
94 pub fn new(dataset: D, batch_size: usize) -> Self {
96 Self {
97 dataset,
98 batch_size,
99 shuffle: false,
100 drop_last: false,
101 num_workers: 0,
102 }
103 }
104
105 pub fn shuffle(mut self, shuffle: bool) -> Self {
107 self.shuffle = shuffle;
108 self
109 }
110
111 pub fn drop_last(mut self, drop_last: bool) -> Self {
113 self.drop_last = drop_last;
114 self
115 }
116
117 pub fn num_workers(mut self, num_workers: usize) -> Self {
119 self.num_workers = num_workers;
120 self
121 }
122
123 pub fn batch_size(&self) -> usize {
125 self.batch_size
126 }
127
128 pub fn len(&self) -> usize {
130 let total = self.dataset.len();
131 if self.drop_last {
132 total / self.batch_size
133 } else {
134 total.div_ceil(self.batch_size)
135 }
136 }
137
138 pub fn is_empty(&self) -> bool {
140 self.dataset.is_empty()
141 }
142
143 pub fn dataset_len(&self) -> usize {
145 self.dataset.len()
146 }
147
148 pub fn iter(&self) -> DataLoaderIter<'_, D> {
150 let indices: Vec<usize> = if self.shuffle {
151 let sampler = RandomSampler::new(self.dataset.len());
152 sampler.iter().collect()
153 } else {
154 let sampler = SequentialSampler::new(self.dataset.len());
155 sampler.iter().collect()
156 };
157
158 DataLoaderIter {
159 dataset: &self.dataset,
160 indices,
161 batch_size: self.batch_size,
162 drop_last: self.drop_last,
163 position: 0,
164 num_workers: self.num_workers,
165 }
166 }
167}
168
169pub struct DataLoaderIter<'a, D>
175where
176 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
177{
178 dataset: &'a D,
179 indices: Vec<usize>,
180 batch_size: usize,
181 drop_last: bool,
182 position: usize,
183 num_workers: usize,
184}
185
186impl<D> Iterator for DataLoaderIter<'_, D>
187where
188 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
189{
190 type Item = Batch;
191
192 fn next(&mut self) -> Option<Self::Item> {
193 if self.position >= self.indices.len() {
194 return None;
195 }
196
197 let end = (self.position + self.batch_size).min(self.indices.len());
198 let batch_indices = &self.indices[self.position..end];
199
200 if batch_indices.len() < self.batch_size && self.drop_last {
202 return None;
203 }
204
205 let samples: Vec<(Tensor<f32>, Tensor<f32>)> = if self.num_workers > 0 {
207 batch_indices
209 .par_iter()
210 .filter_map(|&idx| self.dataset.get(idx))
211 .collect()
212 } else {
213 batch_indices
215 .iter()
216 .filter_map(|&idx| self.dataset.get(idx))
217 .collect()
218 };
219
220 if samples.is_empty() {
221 return None;
222 }
223
224 let data_samples: Vec<Tensor<f32>> = samples.iter().map(|(x, _)| x.clone()).collect();
226 let target_samples: Vec<Tensor<f32>> = samples.iter().map(|(_, y)| y.clone()).collect();
227
228 let data = stack_tensors(&data_samples);
230 let targets = stack_tensors(&target_samples);
231
232 self.position = end;
233
234 Some(Batch::new(data, targets))
235 }
236}
237
238impl<D> DataLoaderIter<'_, D>
239where
240 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
241{
242 #[must_use]
244 pub fn remaining(&self) -> usize {
245 let remaining_samples = self.indices.len().saturating_sub(self.position);
246 if self.drop_last {
247 remaining_samples / self.batch_size
248 } else {
249 remaining_samples.div_ceil(self.batch_size)
250 }
251 }
252}
253
254pub struct GpuPrefetchIter {
276 receiver: mpsc::Receiver<Batch>,
278 _worker: Option<thread::JoinHandle<()>>,
280}
281
282impl GpuPrefetchIter {
283 fn new(batches: Vec<Batch>, device: Device) -> Self {
288 let (tx, rx) = mpsc::sync_channel(1);
292
293 let worker = thread::spawn(move || {
294 for batch in batches {
295 let gpu_data = match batch.data.to_device(device) {
297 Ok(t) => t,
298 Err(_) => batch.data, };
300 let gpu_targets = match batch.targets.to_device(device) {
301 Ok(t) => t,
302 Err(_) => batch.targets,
303 };
304
305 let gpu_batch = Batch::new(gpu_data, gpu_targets);
306 if tx.send(gpu_batch).is_err() {
308 break;
309 }
310 }
311 });
312
313 Self {
314 receiver: rx,
315 _worker: Some(worker),
316 }
317 }
318}
319
320impl Iterator for GpuPrefetchIter {
321 type Item = Batch;
322
323 fn next(&mut self) -> Option<Self::Item> {
324 self.receiver.recv().ok()
325 }
326}
327
328impl<D> DataLoader<D>
329where
330 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
331{
332 pub fn prefetch_to_gpu(&self, device: Device) -> GpuPrefetchIter {
352 let batches: Vec<Batch> = self.iter().collect();
354 GpuPrefetchIter::new(batches, device)
355 }
356}
357
358pub struct GenericDataLoader<D, C, T>
364where
365 D: Dataset<Item = T>,
366 C: Collate<T>,
367 T: Send,
368{
369 dataset: D,
370 collate_fn: C,
371 batch_size: usize,
372 shuffle: bool,
373 drop_last: bool,
374 num_workers: usize,
375 _phantom: PhantomData<T>,
376}
377
378impl<D, C, T> GenericDataLoader<D, C, T>
379where
380 D: Dataset<Item = T>,
381 C: Collate<T>,
382 T: Send,
383{
384 pub fn new(dataset: D, collate_fn: C, batch_size: usize) -> Self {
386 Self {
387 dataset,
388 collate_fn,
389 batch_size,
390 shuffle: false,
391 drop_last: false,
392 num_workers: 0,
393 _phantom: PhantomData,
394 }
395 }
396
397 pub fn num_workers(mut self, num_workers: usize) -> Self {
399 self.num_workers = num_workers;
400 self
401 }
402
403 pub fn shuffle(mut self, shuffle: bool) -> Self {
405 self.shuffle = shuffle;
406 self
407 }
408
409 pub fn drop_last(mut self, drop_last: bool) -> Self {
411 self.drop_last = drop_last;
412 self
413 }
414
415 pub fn len(&self) -> usize {
417 let total = self.dataset.len();
418 if self.drop_last {
419 total / self.batch_size
420 } else {
421 total.div_ceil(self.batch_size)
422 }
423 }
424
425 pub fn is_empty(&self) -> bool {
427 self.dataset.is_empty()
428 }
429
430 #[allow(clippy::iter_not_returning_iterator)]
432 pub fn iter(&self) -> GenericDataLoaderIter<'_, D, C, T> {
433 let indices: Vec<usize> = if self.shuffle {
434 let sampler = RandomSampler::new(self.dataset.len());
435 sampler.iter().collect()
436 } else {
437 (0..self.dataset.len()).collect()
438 };
439
440 GenericDataLoaderIter {
441 dataset: &self.dataset,
442 collate_fn: &self.collate_fn,
443 indices,
444 batch_size: self.batch_size,
445 drop_last: self.drop_last,
446 position: 0,
447 num_workers: self.num_workers,
448 _phantom: PhantomData,
449 }
450 }
451}
452
453pub struct GenericDataLoaderIter<'a, D, C, T>
455where
456 D: Dataset<Item = T>,
457 C: Collate<T>,
458 T: Send,
459{
460 dataset: &'a D,
461 collate_fn: &'a C,
462 indices: Vec<usize>,
463 batch_size: usize,
464 drop_last: bool,
465 position: usize,
466 num_workers: usize,
467 _phantom: PhantomData<T>,
468}
469
470impl<D, C, T> Iterator for GenericDataLoaderIter<'_, D, C, T>
471where
472 D: Dataset<Item = T>,
473 C: Collate<T>,
474 T: Send + Sync,
475{
476 type Item = C::Output;
477
478 fn next(&mut self) -> Option<Self::Item> {
479 if self.position >= self.indices.len() {
480 return None;
481 }
482
483 let end = (self.position + self.batch_size).min(self.indices.len());
484 let batch_indices = &self.indices[self.position..end];
485
486 if batch_indices.len() < self.batch_size && self.drop_last {
487 return None;
488 }
489
490 let samples: Vec<T> = if self.num_workers > 0 {
492 batch_indices
493 .par_iter()
494 .filter_map(|&idx| self.dataset.get(idx))
495 .collect()
496 } else {
497 batch_indices
498 .iter()
499 .filter_map(|&idx| self.dataset.get(idx))
500 .collect()
501 };
502
503 if samples.is_empty() {
504 return None;
505 }
506
507 self.position = end;
508
509 Some(self.collate_fn.collate(samples))
510 }
511}
512
513#[cfg(test)]
518mod tests {
519 use super::*;
520 use crate::collate::DefaultCollate;
521 use crate::dataset::TensorDataset;
522
523 fn create_test_dataset(size: usize) -> TensorDataset {
524 let data: Vec<f32> = (0..size * 2).map(|i| i as f32).collect();
525 let targets: Vec<f32> = (0..size).map(|i| (i % 2) as f32).collect();
526
527 let x = Tensor::from_vec(data, &[size, 2]).unwrap();
528 let y = Tensor::from_vec(targets, &[size]).unwrap();
529
530 TensorDataset::new(x, y)
531 }
532
533 #[test]
534 fn test_dataloader_basic() {
535 let dataset = create_test_dataset(10);
536 let loader = DataLoader::new(dataset, 3);
537
538 assert_eq!(loader.batch_size(), 3);
539 assert_eq!(loader.len(), 4); let batches: Vec<Batch> = loader.iter().collect();
542 assert_eq!(batches.len(), 4);
543
544 assert_eq!(batches[0].len(), 3);
546 assert_eq!(batches[1].len(), 3);
547 assert_eq!(batches[2].len(), 3);
548 assert_eq!(batches[3].len(), 1);
549 }
550
551 #[test]
552 fn test_dataloader_drop_last() {
553 let dataset = create_test_dataset(10);
554 let loader = DataLoader::new(dataset, 3).drop_last(true);
555
556 assert_eq!(loader.len(), 3); let batches: Vec<Batch> = loader.iter().collect();
559 assert_eq!(batches.len(), 3);
560
561 for batch in &batches {
563 assert_eq!(batch.len(), 3);
564 }
565 }
566
567 #[test]
568 fn test_dataloader_shuffle() {
569 let dataset = create_test_dataset(100);
570 let loader = DataLoader::new(dataset, 10).shuffle(true);
571
572 let batch1: Vec<Batch> = loader.iter().take(1).collect();
574 let batch2: Vec<Batch> = loader.iter().take(1).collect();
575
576 assert!(!batch1.is_empty());
579 assert!(!batch2.is_empty());
580 }
581
582 #[test]
583 fn test_dataloader_exact_batches() {
584 let dataset = create_test_dataset(9);
585 let loader = DataLoader::new(dataset, 3);
586
587 let batches: Vec<Batch> = loader.iter().collect();
588 assert_eq!(batches.len(), 3);
589
590 for batch in &batches {
591 assert_eq!(batch.len(), 3);
592 }
593 }
594
595 #[test]
596 fn test_batch_struct() {
597 let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
598 let targets = Tensor::from_vec(vec![0.0, 1.0], &[2]).unwrap();
599
600 let batch = Batch::new(data, targets);
601 assert_eq!(batch.len(), 2);
602 assert!(!batch.is_empty());
603 }
604
605 #[test]
606 fn test_dataloader_empty() {
607 let x = Tensor::from_vec(vec![], &[0, 2]).unwrap();
608 let y = Tensor::from_vec(vec![], &[0]).unwrap();
609 let dataset = TensorDataset::new(x, y);
610 let loader = DataLoader::new(dataset, 3);
611
612 assert!(loader.is_empty());
613 let batches: Vec<Batch> = loader.iter().collect();
614 assert!(batches.is_empty());
615 }
616
617 #[test]
618 fn test_dataloader_single_item() {
619 let dataset = create_test_dataset(1);
620 let loader = DataLoader::new(dataset, 3);
621
622 let batches: Vec<Batch> = loader.iter().collect();
623 assert_eq!(batches.len(), 1);
624 assert_eq!(batches[0].len(), 1);
625 }
626
627 #[test]
628 fn test_dataloader_iteration_order() {
629 let dataset = create_test_dataset(6);
630 let loader = DataLoader::new(dataset, 2).shuffle(false);
631
632 let batches: Vec<Batch> = loader.iter().collect();
633
634 assert_eq!(batches[0].data.to_vec(), vec![0.0, 1.0, 2.0, 3.0]);
636 assert_eq!(batches[1].data.to_vec(), vec![4.0, 5.0, 6.0, 7.0]);
637 assert_eq!(batches[2].data.to_vec(), vec![8.0, 9.0, 10.0, 11.0]);
638 }
639
640 #[test]
641 fn test_generic_dataloader() {
642 let dataset = create_test_dataset(6);
643 let collate = DefaultCollate::new();
644 let loader = GenericDataLoader::new(dataset, collate, 2);
645
646 let batches: Vec<_> = loader.iter().collect();
647 assert_eq!(batches.len(), 3);
648 }
649
650 #[test]
651 fn test_dataloader_remaining() {
652 let dataset = create_test_dataset(10);
653 let loader = DataLoader::new(dataset, 3);
654
655 let mut iter = loader.iter();
656 assert_eq!(iter.remaining(), 4);
657
658 iter.next();
659 assert_eq!(iter.remaining(), 3);
660
661 iter.next();
662 assert_eq!(iter.remaining(), 2);
663 }
664
665 #[test]
666 fn test_parallel_dataloader() {
667 let dataset = create_test_dataset(100);
668 let loader = DataLoader::new(dataset, 10).num_workers(4);
669
670 let batches: Vec<Batch> = loader.iter().collect();
671 assert_eq!(batches.len(), 10);
672
673 let total_samples: usize = batches.iter().map(|b| b.len()).sum();
675 assert_eq!(total_samples, 100);
676 }
677
678 #[test]
679 fn test_parallel_vs_sequential_equivalence() {
680 let dataset_seq = create_test_dataset(50);
682 let dataset_par = create_test_dataset(50);
683
684 let loader_seq = DataLoader::new(dataset_seq, 5).num_workers(0);
686 let batches_seq: Vec<Batch> = loader_seq.iter().collect();
687
688 let loader_par = DataLoader::new(dataset_par, 5).num_workers(4);
690 let batches_par: Vec<Batch> = loader_par.iter().collect();
691
692 assert_eq!(batches_seq.len(), batches_par.len());
694
695 for i in 0..batches_seq.len() {
697 assert_eq!(batches_seq[i].data.to_vec(), batches_par[i].data.to_vec());
698 assert_eq!(
699 batches_seq[i].targets.to_vec(),
700 batches_par[i].targets.to_vec()
701 );
702 }
703 }
704
705 #[test]
706 fn test_parallel_dataloader_drop_last() {
707 let dataset = create_test_dataset(95);
708 let loader = DataLoader::new(dataset, 10).drop_last(true).num_workers(4);
709
710 let batches: Vec<Batch> = loader.iter().collect();
711 assert_eq!(batches.len(), 9); for batch in &batches {
714 assert_eq!(batch.len(), 10);
715 }
716 }
717
718 #[test]
719 fn test_parallel_generic_dataloader() {
720 let dataset = create_test_dataset(60);
721 let collate = DefaultCollate::new();
722 let loader = GenericDataLoader::new(dataset, collate, 10).num_workers(4);
723
724 let batches: Vec<_> = loader.iter().collect();
725 assert_eq!(batches.len(), 6);
726 }
727
728 #[test]
729 fn test_gpu_prefetch_cpu_fallback() {
730 use axonml_core::Device;
732
733 let dataset = create_test_dataset(10);
734 let loader = DataLoader::new(dataset, 3);
735
736 let batches: Vec<Batch> = loader.prefetch_to_gpu(Device::Cpu).collect();
738 assert_eq!(batches.len(), 4); assert_eq!(batches[0].len(), 3);
741 assert_eq!(batches[1].len(), 3);
742 assert_eq!(batches[2].len(), 3);
743 assert_eq!(batches[3].len(), 1);
744 }
745
746 #[test]
747 fn test_gpu_prefetch_data_integrity() {
748 use axonml_core::Device;
750
751 let dataset = create_test_dataset(6);
752 let loader = DataLoader::new(dataset, 2).shuffle(false);
753
754 let batches: Vec<Batch> = loader.prefetch_to_gpu(Device::Cpu).collect();
755
756 assert_eq!(batches[0].data.to_vec(), vec![0.0, 1.0, 2.0, 3.0]);
758 assert_eq!(batches[1].data.to_vec(), vec![4.0, 5.0, 6.0, 7.0]);
759 assert_eq!(batches[2].data.to_vec(), vec![8.0, 9.0, 10.0, 11.0]);
760 }
761
762 #[test]
763 fn test_gpu_prefetch_early_drop() {
764 use axonml_core::Device;
766
767 let dataset = create_test_dataset(100);
768 let loader = DataLoader::new(dataset, 10);
769
770 let mut iter = loader.prefetch_to_gpu(Device::Cpu);
771 let first = iter.next();
772 assert!(first.is_some());
773 assert_eq!(first.unwrap().len(), 10);
774
775 drop(iter);
777 }
778}