1use crate::collate::{Collate, stack_tensors};
19use crate::dataset::Dataset;
20use crate::sampler::{RandomSampler, Sampler, SequentialSampler};
21use axonml_core::Device;
22use axonml_tensor::Tensor;
23use rayon::prelude::*;
24use std::marker::PhantomData;
25use std::sync::mpsc;
26use std::thread;
27
28#[derive(Debug, Clone)]
34pub struct Batch {
35 pub data: Tensor<f32>,
37 pub targets: Tensor<f32>,
39 pub size: usize,
41}
42
43impl Batch {
44 #[must_use]
46 pub fn new(data: Tensor<f32>, targets: Tensor<f32>) -> Self {
47 let size = data.shape()[0];
48 Self {
49 data,
50 targets,
51 size,
52 }
53 }
54
55 #[must_use]
57 pub fn len(&self) -> usize {
58 self.size
59 }
60
61 #[must_use]
63 pub fn is_empty(&self) -> bool {
64 self.size == 0
65 }
66}
67
68pub struct DataLoader<D>
76where
77 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
78{
79 dataset: D,
81 batch_size: usize,
83 shuffle: bool,
85 drop_last: bool,
87 num_workers: usize,
89}
90
91impl<D> DataLoader<D>
92where
93 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
94{
95 pub fn new(dataset: D, batch_size: usize) -> Self {
97 Self {
98 dataset,
99 batch_size,
100 shuffle: false,
101 drop_last: false,
102 num_workers: 0,
103 }
104 }
105
106 pub fn shuffle(mut self, shuffle: bool) -> Self {
108 self.shuffle = shuffle;
109 self
110 }
111
112 pub fn drop_last(mut self, drop_last: bool) -> Self {
114 self.drop_last = drop_last;
115 self
116 }
117
118 pub fn num_workers(mut self, num_workers: usize) -> Self {
120 self.num_workers = num_workers;
121 self
122 }
123
124 pub fn batch_size(&self) -> usize {
126 self.batch_size
127 }
128
129 pub fn len(&self) -> usize {
131 let total = self.dataset.len();
132 if self.drop_last {
133 total / self.batch_size
134 } else {
135 total.div_ceil(self.batch_size)
136 }
137 }
138
139 pub fn is_empty(&self) -> bool {
141 self.dataset.is_empty()
142 }
143
144 pub fn dataset_len(&self) -> usize {
146 self.dataset.len()
147 }
148
149 pub fn iter(&self) -> DataLoaderIter<'_, D> {
151 let indices: Vec<usize> = if self.shuffle {
152 let sampler = RandomSampler::new(self.dataset.len());
153 sampler.iter().collect()
154 } else {
155 let sampler = SequentialSampler::new(self.dataset.len());
156 sampler.iter().collect()
157 };
158
159 DataLoaderIter {
160 dataset: &self.dataset,
161 indices,
162 batch_size: self.batch_size,
163 drop_last: self.drop_last,
164 position: 0,
165 num_workers: self.num_workers,
166 }
167 }
168}
169
170pub struct DataLoaderIter<'a, D>
176where
177 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
178{
179 dataset: &'a D,
180 indices: Vec<usize>,
181 batch_size: usize,
182 drop_last: bool,
183 position: usize,
184 num_workers: usize,
185}
186
187impl<D> Iterator for DataLoaderIter<'_, D>
188where
189 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
190{
191 type Item = Batch;
192
193 fn next(&mut self) -> Option<Self::Item> {
194 if self.position >= self.indices.len() {
195 return None;
196 }
197
198 let end = (self.position + self.batch_size).min(self.indices.len());
199 let batch_indices = &self.indices[self.position..end];
200
201 if batch_indices.len() < self.batch_size && self.drop_last {
203 return None;
204 }
205
206 let samples: Vec<(Tensor<f32>, Tensor<f32>)> = if self.num_workers > 0 {
208 batch_indices
210 .par_iter()
211 .filter_map(|&idx| self.dataset.get(idx))
212 .collect()
213 } else {
214 batch_indices
216 .iter()
217 .filter_map(|&idx| self.dataset.get(idx))
218 .collect()
219 };
220
221 if samples.is_empty() {
222 return None;
223 }
224
225 let data_samples: Vec<Tensor<f32>> = samples.iter().map(|(x, _)| x.clone()).collect();
227 let target_samples: Vec<Tensor<f32>> = samples.iter().map(|(_, y)| y.clone()).collect();
228
229 let data = stack_tensors(&data_samples);
231 let targets = stack_tensors(&target_samples);
232
233 self.position = end;
234
235 Some(Batch::new(data, targets))
236 }
237}
238
239impl<D> DataLoaderIter<'_, D>
240where
241 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
242{
243 #[must_use]
245 pub fn remaining(&self) -> usize {
246 let remaining_samples = self.indices.len().saturating_sub(self.position);
247 if self.drop_last {
248 remaining_samples / self.batch_size
249 } else {
250 remaining_samples.div_ceil(self.batch_size)
251 }
252 }
253}
254
255pub struct GpuPrefetchIter {
277 receiver: mpsc::Receiver<Batch>,
279 _worker: Option<thread::JoinHandle<()>>,
281}
282
283impl GpuPrefetchIter {
284 fn new_streaming<D>(
290 dataset: D,
291 indices: Vec<usize>,
292 batch_size: usize,
293 drop_last: bool,
294 num_workers: usize,
295 device: Device,
296 ) -> Self
297 where
298 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)> + 'static,
299 {
300 let (tx, rx) = mpsc::sync_channel(2);
302
303 let worker = thread::spawn(move || {
304 let mut position = 0;
305 while position < indices.len() {
306 let end = (position + batch_size).min(indices.len());
307 let batch_indices = &indices[position..end];
308
309 if batch_indices.len() < batch_size && drop_last {
310 break;
311 }
312
313 let samples: Vec<(Tensor<f32>, Tensor<f32>)> = if num_workers > 0 {
315 batch_indices
316 .par_iter()
317 .filter_map(|&idx| dataset.get(idx))
318 .collect()
319 } else {
320 batch_indices
321 .iter()
322 .filter_map(|&idx| dataset.get(idx))
323 .collect()
324 };
325
326 if samples.is_empty() {
327 break;
328 }
329
330 let data_samples: Vec<Tensor<f32>> =
331 samples.iter().map(|(x, _)| x.clone()).collect();
332 let target_samples: Vec<Tensor<f32>> =
333 samples.iter().map(|(_, y)| y.clone()).collect();
334
335 let data = stack_tensors(&data_samples);
336 let targets = stack_tensors(&target_samples);
337
338 let gpu_data = match data.to_device(device) {
340 Ok(t) => t,
341 Err(_) => data,
342 };
343 let gpu_targets = match targets.to_device(device) {
344 Ok(t) => t,
345 Err(_) => targets,
346 };
347
348 if tx.send(Batch::new(gpu_data, gpu_targets)).is_err() {
349 break;
350 }
351
352 position = end;
353 }
354 });
355
356 Self {
357 receiver: rx,
358 _worker: Some(worker),
359 }
360 }
361}
362
363impl Iterator for GpuPrefetchIter {
364 type Item = Batch;
365
366 fn next(&mut self) -> Option<Self::Item> {
367 self.receiver.recv().ok()
368 }
369}
370
371impl<D> DataLoader<D>
372where
373 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
374{
375 pub fn prefetch_to_gpu(&self, device: Device) -> GpuPrefetchIter
393 where
394 D: Clone + 'static,
395 {
396 let indices: Vec<usize> = if self.shuffle {
397 let sampler = RandomSampler::new(self.dataset.len());
398 sampler.iter().collect()
399 } else {
400 (0..self.dataset.len()).collect()
401 };
402
403 GpuPrefetchIter::new_streaming(
404 self.dataset.clone(),
405 indices,
406 self.batch_size,
407 self.drop_last,
408 self.num_workers,
409 device,
410 )
411 }
412}
413
414pub struct GenericDataLoader<D, C, T>
420where
421 D: Dataset<Item = T>,
422 C: Collate<T>,
423 T: Send,
424{
425 dataset: D,
426 collate_fn: C,
427 batch_size: usize,
428 shuffle: bool,
429 drop_last: bool,
430 num_workers: usize,
431 _phantom: PhantomData<T>,
432}
433
434impl<D, C, T> GenericDataLoader<D, C, T>
435where
436 D: Dataset<Item = T>,
437 C: Collate<T>,
438 T: Send,
439{
440 pub fn new(dataset: D, collate_fn: C, batch_size: usize) -> Self {
442 Self {
443 dataset,
444 collate_fn,
445 batch_size,
446 shuffle: false,
447 drop_last: false,
448 num_workers: 0,
449 _phantom: PhantomData,
450 }
451 }
452
453 pub fn num_workers(mut self, num_workers: usize) -> Self {
455 self.num_workers = num_workers;
456 self
457 }
458
459 pub fn shuffle(mut self, shuffle: bool) -> Self {
461 self.shuffle = shuffle;
462 self
463 }
464
465 pub fn drop_last(mut self, drop_last: bool) -> Self {
467 self.drop_last = drop_last;
468 self
469 }
470
471 pub fn len(&self) -> usize {
473 let total = self.dataset.len();
474 if self.drop_last {
475 total / self.batch_size
476 } else {
477 total.div_ceil(self.batch_size)
478 }
479 }
480
481 pub fn is_empty(&self) -> bool {
483 self.dataset.is_empty()
484 }
485
486 #[allow(clippy::iter_not_returning_iterator)]
488 pub fn iter(&self) -> GenericDataLoaderIter<'_, D, C, T> {
489 let indices: Vec<usize> = if self.shuffle {
490 let sampler = RandomSampler::new(self.dataset.len());
491 sampler.iter().collect()
492 } else {
493 (0..self.dataset.len()).collect()
494 };
495
496 GenericDataLoaderIter {
497 dataset: &self.dataset,
498 collate_fn: &self.collate_fn,
499 indices,
500 batch_size: self.batch_size,
501 drop_last: self.drop_last,
502 position: 0,
503 num_workers: self.num_workers,
504 _phantom: PhantomData,
505 }
506 }
507}
508
509pub struct GenericDataLoaderIter<'a, D, C, T>
511where
512 D: Dataset<Item = T>,
513 C: Collate<T>,
514 T: Send,
515{
516 dataset: &'a D,
517 collate_fn: &'a C,
518 indices: Vec<usize>,
519 batch_size: usize,
520 drop_last: bool,
521 position: usize,
522 num_workers: usize,
523 _phantom: PhantomData<T>,
524}
525
526impl<D, C, T> Iterator for GenericDataLoaderIter<'_, D, C, T>
527where
528 D: Dataset<Item = T>,
529 C: Collate<T>,
530 T: Send + Sync,
531{
532 type Item = C::Output;
533
534 fn next(&mut self) -> Option<Self::Item> {
535 if self.position >= self.indices.len() {
536 return None;
537 }
538
539 let end = (self.position + self.batch_size).min(self.indices.len());
540 let batch_indices = &self.indices[self.position..end];
541
542 if batch_indices.len() < self.batch_size && self.drop_last {
543 return None;
544 }
545
546 let samples: Vec<T> = if self.num_workers > 0 {
548 batch_indices
549 .par_iter()
550 .filter_map(|&idx| self.dataset.get(idx))
551 .collect()
552 } else {
553 batch_indices
554 .iter()
555 .filter_map(|&idx| self.dataset.get(idx))
556 .collect()
557 };
558
559 if samples.is_empty() {
560 return None;
561 }
562
563 self.position = end;
564
565 Some(self.collate_fn.collate(samples))
566 }
567}
568
569#[cfg(test)]
574mod tests {
575 use super::*;
576 use crate::collate::DefaultCollate;
577 use crate::dataset::TensorDataset;
578
579 fn create_test_dataset(size: usize) -> TensorDataset {
580 let data: Vec<f32> = (0..size * 2).map(|i| i as f32).collect();
581 let targets: Vec<f32> = (0..size).map(|i| (i % 2) as f32).collect();
582
583 let x = Tensor::from_vec(data, &[size, 2]).unwrap();
584 let y = Tensor::from_vec(targets, &[size]).unwrap();
585
586 TensorDataset::new(x, y)
587 }
588
589 #[test]
590 fn test_dataloader_basic() {
591 let dataset = create_test_dataset(10);
592 let loader = DataLoader::new(dataset, 3);
593
594 assert_eq!(loader.batch_size(), 3);
595 assert_eq!(loader.len(), 4); let batches: Vec<Batch> = loader.iter().collect();
598 assert_eq!(batches.len(), 4);
599
600 assert_eq!(batches[0].len(), 3);
602 assert_eq!(batches[1].len(), 3);
603 assert_eq!(batches[2].len(), 3);
604 assert_eq!(batches[3].len(), 1);
605 }
606
607 #[test]
608 fn test_dataloader_drop_last() {
609 let dataset = create_test_dataset(10);
610 let loader = DataLoader::new(dataset, 3).drop_last(true);
611
612 assert_eq!(loader.len(), 3); let batches: Vec<Batch> = loader.iter().collect();
615 assert_eq!(batches.len(), 3);
616
617 for batch in &batches {
619 assert_eq!(batch.len(), 3);
620 }
621 }
622
623 #[test]
624 fn test_dataloader_shuffle() {
625 let dataset = create_test_dataset(100);
626 let loader = DataLoader::new(dataset, 10).shuffle(true);
627
628 let batch1: Vec<Batch> = loader.iter().take(1).collect();
630 let batch2: Vec<Batch> = loader.iter().take(1).collect();
631
632 assert!(!batch1.is_empty());
635 assert!(!batch2.is_empty());
636 }
637
638 #[test]
639 fn test_dataloader_exact_batches() {
640 let dataset = create_test_dataset(9);
641 let loader = DataLoader::new(dataset, 3);
642
643 let batches: Vec<Batch> = loader.iter().collect();
644 assert_eq!(batches.len(), 3);
645
646 for batch in &batches {
647 assert_eq!(batch.len(), 3);
648 }
649 }
650
651 #[test]
652 fn test_batch_struct() {
653 let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
654 let targets = Tensor::from_vec(vec![0.0, 1.0], &[2]).unwrap();
655
656 let batch = Batch::new(data, targets);
657 assert_eq!(batch.len(), 2);
658 assert!(!batch.is_empty());
659 }
660
661 #[test]
662 fn test_dataloader_empty() {
663 let x = Tensor::from_vec(vec![], &[0, 2]).unwrap();
664 let y = Tensor::from_vec(vec![], &[0]).unwrap();
665 let dataset = TensorDataset::new(x, y);
666 let loader = DataLoader::new(dataset, 3);
667
668 assert!(loader.is_empty());
669 let batches: Vec<Batch> = loader.iter().collect();
670 assert!(batches.is_empty());
671 }
672
673 #[test]
674 fn test_dataloader_single_item() {
675 let dataset = create_test_dataset(1);
676 let loader = DataLoader::new(dataset, 3);
677
678 let batches: Vec<Batch> = loader.iter().collect();
679 assert_eq!(batches.len(), 1);
680 assert_eq!(batches[0].len(), 1);
681 }
682
683 #[test]
684 fn test_dataloader_iteration_order() {
685 let dataset = create_test_dataset(6);
686 let loader = DataLoader::new(dataset, 2).shuffle(false);
687
688 let batches: Vec<Batch> = loader.iter().collect();
689
690 assert_eq!(batches[0].data.to_vec(), vec![0.0, 1.0, 2.0, 3.0]);
692 assert_eq!(batches[1].data.to_vec(), vec![4.0, 5.0, 6.0, 7.0]);
693 assert_eq!(batches[2].data.to_vec(), vec![8.0, 9.0, 10.0, 11.0]);
694 }
695
696 #[test]
697 fn test_generic_dataloader() {
698 let dataset = create_test_dataset(6);
699 let collate = DefaultCollate::new();
700 let loader = GenericDataLoader::new(dataset, collate, 2);
701
702 let batches: Vec<_> = loader.iter().collect();
703 assert_eq!(batches.len(), 3);
704 }
705
706 #[test]
707 fn test_dataloader_remaining() {
708 let dataset = create_test_dataset(10);
709 let loader = DataLoader::new(dataset, 3);
710
711 let mut iter = loader.iter();
712 assert_eq!(iter.remaining(), 4);
713
714 iter.next();
715 assert_eq!(iter.remaining(), 3);
716
717 iter.next();
718 assert_eq!(iter.remaining(), 2);
719 }
720
721 #[test]
722 fn test_parallel_dataloader() {
723 let dataset = create_test_dataset(100);
724 let loader = DataLoader::new(dataset, 10).num_workers(4);
725
726 let batches: Vec<Batch> = loader.iter().collect();
727 assert_eq!(batches.len(), 10);
728
729 let total_samples: usize = batches.iter().map(|b| b.len()).sum();
731 assert_eq!(total_samples, 100);
732 }
733
734 #[test]
735 fn test_parallel_vs_sequential_equivalence() {
736 let dataset_seq = create_test_dataset(50);
738 let dataset_par = create_test_dataset(50);
739
740 let loader_seq = DataLoader::new(dataset_seq, 5).num_workers(0);
742 let batches_seq: Vec<Batch> = loader_seq.iter().collect();
743
744 let loader_par = DataLoader::new(dataset_par, 5).num_workers(4);
746 let batches_par: Vec<Batch> = loader_par.iter().collect();
747
748 assert_eq!(batches_seq.len(), batches_par.len());
750
751 for i in 0..batches_seq.len() {
753 assert_eq!(batches_seq[i].data.to_vec(), batches_par[i].data.to_vec());
754 assert_eq!(
755 batches_seq[i].targets.to_vec(),
756 batches_par[i].targets.to_vec()
757 );
758 }
759 }
760
761 #[test]
762 fn test_parallel_dataloader_drop_last() {
763 let dataset = create_test_dataset(95);
764 let loader = DataLoader::new(dataset, 10).drop_last(true).num_workers(4);
765
766 let batches: Vec<Batch> = loader.iter().collect();
767 assert_eq!(batches.len(), 9); for batch in &batches {
770 assert_eq!(batch.len(), 10);
771 }
772 }
773
774 #[test]
775 fn test_parallel_generic_dataloader() {
776 let dataset = create_test_dataset(60);
777 let collate = DefaultCollate::new();
778 let loader = GenericDataLoader::new(dataset, collate, 10).num_workers(4);
779
780 let batches: Vec<_> = loader.iter().collect();
781 assert_eq!(batches.len(), 6);
782 }
783
784 #[test]
785 fn test_gpu_prefetch_cpu_fallback() {
786 use axonml_core::Device;
788
789 let dataset = create_test_dataset(10);
790 let loader = DataLoader::new(dataset, 3);
791
792 let batches: Vec<Batch> = loader.prefetch_to_gpu(Device::Cpu).collect();
794 assert_eq!(batches.len(), 4); assert_eq!(batches[0].len(), 3);
797 assert_eq!(batches[1].len(), 3);
798 assert_eq!(batches[2].len(), 3);
799 assert_eq!(batches[3].len(), 1);
800 }
801
802 #[test]
803 fn test_gpu_prefetch_data_integrity() {
804 use axonml_core::Device;
806
807 let dataset = create_test_dataset(6);
808 let loader = DataLoader::new(dataset, 2).shuffle(false);
809
810 let batches: Vec<Batch> = loader.prefetch_to_gpu(Device::Cpu).collect();
811
812 assert_eq!(batches[0].data.to_vec(), vec![0.0, 1.0, 2.0, 3.0]);
814 assert_eq!(batches[1].data.to_vec(), vec![4.0, 5.0, 6.0, 7.0]);
815 assert_eq!(batches[2].data.to_vec(), vec![8.0, 9.0, 10.0, 11.0]);
816 }
817
818 #[test]
819 fn test_gpu_prefetch_early_drop() {
820 use axonml_core::Device;
822
823 let dataset = create_test_dataset(100);
824 let loader = DataLoader::new(dataset, 10);
825
826 let mut iter = loader.prefetch_to_gpu(Device::Cpu);
827 let first = iter.next();
828 assert!(first.is_some());
829 assert_eq!(first.unwrap().len(), 10);
830
831 drop(iter);
833 }
834}