1use crate::collate::{Collate, stack_tensors};
18use crate::dataset::Dataset;
19use crate::sampler::{RandomSampler, Sampler, SequentialSampler};
20use axonml_core::Device;
21use axonml_tensor::Tensor;
22use rayon::prelude::*;
23use std::marker::PhantomData;
24use std::sync::mpsc;
25use std::thread;
26
27#[derive(Debug, Clone)]
33pub struct Batch {
34 pub data: Tensor<f32>,
36 pub targets: Tensor<f32>,
38 pub size: usize,
40}
41
42impl Batch {
43 #[must_use]
45 pub fn new(data: Tensor<f32>, targets: Tensor<f32>) -> Self {
46 let size = data.shape()[0];
47 Self {
48 data,
49 targets,
50 size,
51 }
52 }
53
54 #[must_use]
56 pub fn len(&self) -> usize {
57 self.size
58 }
59
60 #[must_use]
62 pub fn is_empty(&self) -> bool {
63 self.size == 0
64 }
65}
66
67pub struct DataLoader<D>
75where
76 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
77{
78 dataset: D,
80 batch_size: usize,
82 shuffle: bool,
84 drop_last: bool,
86 num_workers: usize,
88}
89
90impl<D> DataLoader<D>
91where
92 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
93{
94 pub fn new(dataset: D, batch_size: usize) -> Self {
96 Self {
97 dataset,
98 batch_size,
99 shuffle: false,
100 drop_last: false,
101 num_workers: 0,
102 }
103 }
104
105 pub fn shuffle(mut self, shuffle: bool) -> Self {
107 self.shuffle = shuffle;
108 self
109 }
110
111 pub fn drop_last(mut self, drop_last: bool) -> Self {
113 self.drop_last = drop_last;
114 self
115 }
116
117 pub fn num_workers(mut self, num_workers: usize) -> Self {
119 self.num_workers = num_workers;
120 self
121 }
122
123 pub fn batch_size(&self) -> usize {
125 self.batch_size
126 }
127
128 pub fn len(&self) -> usize {
130 let total = self.dataset.len();
131 if self.drop_last {
132 total / self.batch_size
133 } else {
134 total.div_ceil(self.batch_size)
135 }
136 }
137
138 pub fn is_empty(&self) -> bool {
140 self.dataset.is_empty()
141 }
142
143 pub fn dataset_len(&self) -> usize {
145 self.dataset.len()
146 }
147
148 pub fn iter(&self) -> DataLoaderIter<'_, D> {
150 let indices: Vec<usize> = if self.shuffle {
151 let sampler = RandomSampler::new(self.dataset.len());
152 sampler.iter().collect()
153 } else {
154 let sampler = SequentialSampler::new(self.dataset.len());
155 sampler.iter().collect()
156 };
157
158 DataLoaderIter {
159 dataset: &self.dataset,
160 indices,
161 batch_size: self.batch_size,
162 drop_last: self.drop_last,
163 position: 0,
164 num_workers: self.num_workers,
165 }
166 }
167}
168
169pub struct DataLoaderIter<'a, D>
175where
176 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
177{
178 dataset: &'a D,
179 indices: Vec<usize>,
180 batch_size: usize,
181 drop_last: bool,
182 position: usize,
183 num_workers: usize,
184}
185
186impl<D> Iterator for DataLoaderIter<'_, D>
187where
188 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
189{
190 type Item = Batch;
191
192 fn next(&mut self) -> Option<Self::Item> {
193 if self.position >= self.indices.len() {
194 return None;
195 }
196
197 let end = (self.position + self.batch_size).min(self.indices.len());
198 let batch_indices = &self.indices[self.position..end];
199
200 if batch_indices.len() < self.batch_size && self.drop_last {
202 return None;
203 }
204
205 let samples: Vec<(Tensor<f32>, Tensor<f32>)> = if self.num_workers > 0 {
207 batch_indices
209 .par_iter()
210 .filter_map(|&idx| self.dataset.get(idx))
211 .collect()
212 } else {
213 batch_indices
215 .iter()
216 .filter_map(|&idx| self.dataset.get(idx))
217 .collect()
218 };
219
220 if samples.is_empty() {
221 return None;
222 }
223
224 let data_samples: Vec<Tensor<f32>> = samples.iter().map(|(x, _)| x.clone()).collect();
226 let target_samples: Vec<Tensor<f32>> = samples.iter().map(|(_, y)| y.clone()).collect();
227
228 let data = stack_tensors(&data_samples);
230 let targets = stack_tensors(&target_samples);
231
232 self.position = end;
233
234 Some(Batch::new(data, targets))
235 }
236}
237
238impl<D> DataLoaderIter<'_, D>
239where
240 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
241{
242 #[must_use]
244 pub fn remaining(&self) -> usize {
245 let remaining_samples = self.indices.len().saturating_sub(self.position);
246 if self.drop_last {
247 remaining_samples / self.batch_size
248 } else {
249 remaining_samples.div_ceil(self.batch_size)
250 }
251 }
252}
253
254pub struct GpuPrefetchIter {
276 receiver: mpsc::Receiver<Batch>,
278 _worker: Option<thread::JoinHandle<()>>,
280}
281
282impl GpuPrefetchIter {
283 fn new_streaming<D>(
289 dataset: D,
290 indices: Vec<usize>,
291 batch_size: usize,
292 drop_last: bool,
293 num_workers: usize,
294 device: Device,
295 ) -> Self
296 where
297 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)> + 'static,
298 {
299 let (tx, rx) = mpsc::sync_channel(2);
301
302 let worker = thread::spawn(move || {
303 let mut position = 0;
304 while position < indices.len() {
305 let end = (position + batch_size).min(indices.len());
306 let batch_indices = &indices[position..end];
307
308 if batch_indices.len() < batch_size && drop_last {
309 break;
310 }
311
312 let samples: Vec<(Tensor<f32>, Tensor<f32>)> = if num_workers > 0 {
314 batch_indices
315 .par_iter()
316 .filter_map(|&idx| dataset.get(idx))
317 .collect()
318 } else {
319 batch_indices
320 .iter()
321 .filter_map(|&idx| dataset.get(idx))
322 .collect()
323 };
324
325 if samples.is_empty() {
326 break;
327 }
328
329 let data_samples: Vec<Tensor<f32>> =
330 samples.iter().map(|(x, _)| x.clone()).collect();
331 let target_samples: Vec<Tensor<f32>> =
332 samples.iter().map(|(_, y)| y.clone()).collect();
333
334 let data = stack_tensors(&data_samples);
335 let targets = stack_tensors(&target_samples);
336
337 let gpu_data = match data.to_device(device) {
339 Ok(t) => t,
340 Err(_) => data,
341 };
342 let gpu_targets = match targets.to_device(device) {
343 Ok(t) => t,
344 Err(_) => targets,
345 };
346
347 if tx.send(Batch::new(gpu_data, gpu_targets)).is_err() {
348 break;
349 }
350
351 position = end;
352 }
353 });
354
355 Self {
356 receiver: rx,
357 _worker: Some(worker),
358 }
359 }
360}
361
362impl Iterator for GpuPrefetchIter {
363 type Item = Batch;
364
365 fn next(&mut self) -> Option<Self::Item> {
366 self.receiver.recv().ok()
367 }
368}
369
370impl<D> DataLoader<D>
371where
372 D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
373{
374 pub fn prefetch_to_gpu(&self, device: Device) -> GpuPrefetchIter
392 where
393 D: Clone + 'static,
394 {
395 let indices: Vec<usize> = if self.shuffle {
396 let sampler = RandomSampler::new(self.dataset.len());
397 sampler.iter().collect()
398 } else {
399 (0..self.dataset.len()).collect()
400 };
401
402 GpuPrefetchIter::new_streaming(
403 self.dataset.clone(),
404 indices,
405 self.batch_size,
406 self.drop_last,
407 self.num_workers,
408 device,
409 )
410 }
411}
412
413pub struct GenericDataLoader<D, C, T>
419where
420 D: Dataset<Item = T>,
421 C: Collate<T>,
422 T: Send,
423{
424 dataset: D,
425 collate_fn: C,
426 batch_size: usize,
427 shuffle: bool,
428 drop_last: bool,
429 num_workers: usize,
430 _phantom: PhantomData<T>,
431}
432
433impl<D, C, T> GenericDataLoader<D, C, T>
434where
435 D: Dataset<Item = T>,
436 C: Collate<T>,
437 T: Send,
438{
439 pub fn new(dataset: D, collate_fn: C, batch_size: usize) -> Self {
441 Self {
442 dataset,
443 collate_fn,
444 batch_size,
445 shuffle: false,
446 drop_last: false,
447 num_workers: 0,
448 _phantom: PhantomData,
449 }
450 }
451
452 pub fn num_workers(mut self, num_workers: usize) -> Self {
454 self.num_workers = num_workers;
455 self
456 }
457
458 pub fn shuffle(mut self, shuffle: bool) -> Self {
460 self.shuffle = shuffle;
461 self
462 }
463
464 pub fn drop_last(mut self, drop_last: bool) -> Self {
466 self.drop_last = drop_last;
467 self
468 }
469
470 pub fn len(&self) -> usize {
472 let total = self.dataset.len();
473 if self.drop_last {
474 total / self.batch_size
475 } else {
476 total.div_ceil(self.batch_size)
477 }
478 }
479
480 pub fn is_empty(&self) -> bool {
482 self.dataset.is_empty()
483 }
484
485 #[allow(clippy::iter_not_returning_iterator)]
487 pub fn iter(&self) -> GenericDataLoaderIter<'_, D, C, T> {
488 let indices: Vec<usize> = if self.shuffle {
489 let sampler = RandomSampler::new(self.dataset.len());
490 sampler.iter().collect()
491 } else {
492 (0..self.dataset.len()).collect()
493 };
494
495 GenericDataLoaderIter {
496 dataset: &self.dataset,
497 collate_fn: &self.collate_fn,
498 indices,
499 batch_size: self.batch_size,
500 drop_last: self.drop_last,
501 position: 0,
502 num_workers: self.num_workers,
503 _phantom: PhantomData,
504 }
505 }
506}
507
508pub struct GenericDataLoaderIter<'a, D, C, T>
510where
511 D: Dataset<Item = T>,
512 C: Collate<T>,
513 T: Send,
514{
515 dataset: &'a D,
516 collate_fn: &'a C,
517 indices: Vec<usize>,
518 batch_size: usize,
519 drop_last: bool,
520 position: usize,
521 num_workers: usize,
522 _phantom: PhantomData<T>,
523}
524
525impl<D, C, T> Iterator for GenericDataLoaderIter<'_, D, C, T>
526where
527 D: Dataset<Item = T>,
528 C: Collate<T>,
529 T: Send + Sync,
530{
531 type Item = C::Output;
532
533 fn next(&mut self) -> Option<Self::Item> {
534 if self.position >= self.indices.len() {
535 return None;
536 }
537
538 let end = (self.position + self.batch_size).min(self.indices.len());
539 let batch_indices = &self.indices[self.position..end];
540
541 if batch_indices.len() < self.batch_size && self.drop_last {
542 return None;
543 }
544
545 let samples: Vec<T> = if self.num_workers > 0 {
547 batch_indices
548 .par_iter()
549 .filter_map(|&idx| self.dataset.get(idx))
550 .collect()
551 } else {
552 batch_indices
553 .iter()
554 .filter_map(|&idx| self.dataset.get(idx))
555 .collect()
556 };
557
558 if samples.is_empty() {
559 return None;
560 }
561
562 self.position = end;
563
564 Some(self.collate_fn.collate(samples))
565 }
566}
567
568#[cfg(test)]
573mod tests {
574 use super::*;
575 use crate::collate::DefaultCollate;
576 use crate::dataset::TensorDataset;
577
578 fn create_test_dataset(size: usize) -> TensorDataset {
579 let data: Vec<f32> = (0..size * 2).map(|i| i as f32).collect();
580 let targets: Vec<f32> = (0..size).map(|i| (i % 2) as f32).collect();
581
582 let x = Tensor::from_vec(data, &[size, 2]).unwrap();
583 let y = Tensor::from_vec(targets, &[size]).unwrap();
584
585 TensorDataset::new(x, y)
586 }
587
588 #[test]
589 fn test_dataloader_basic() {
590 let dataset = create_test_dataset(10);
591 let loader = DataLoader::new(dataset, 3);
592
593 assert_eq!(loader.batch_size(), 3);
594 assert_eq!(loader.len(), 4); let batches: Vec<Batch> = loader.iter().collect();
597 assert_eq!(batches.len(), 4);
598
599 assert_eq!(batches[0].len(), 3);
601 assert_eq!(batches[1].len(), 3);
602 assert_eq!(batches[2].len(), 3);
603 assert_eq!(batches[3].len(), 1);
604 }
605
606 #[test]
607 fn test_dataloader_drop_last() {
608 let dataset = create_test_dataset(10);
609 let loader = DataLoader::new(dataset, 3).drop_last(true);
610
611 assert_eq!(loader.len(), 3); let batches: Vec<Batch> = loader.iter().collect();
614 assert_eq!(batches.len(), 3);
615
616 for batch in &batches {
618 assert_eq!(batch.len(), 3);
619 }
620 }
621
622 #[test]
623 fn test_dataloader_shuffle() {
624 let dataset = create_test_dataset(100);
625 let loader = DataLoader::new(dataset, 10).shuffle(true);
626
627 let batch1: Vec<Batch> = loader.iter().take(1).collect();
629 let batch2: Vec<Batch> = loader.iter().take(1).collect();
630
631 assert!(!batch1.is_empty());
634 assert!(!batch2.is_empty());
635 }
636
637 #[test]
638 fn test_dataloader_exact_batches() {
639 let dataset = create_test_dataset(9);
640 let loader = DataLoader::new(dataset, 3);
641
642 let batches: Vec<Batch> = loader.iter().collect();
643 assert_eq!(batches.len(), 3);
644
645 for batch in &batches {
646 assert_eq!(batch.len(), 3);
647 }
648 }
649
650 #[test]
651 fn test_batch_struct() {
652 let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
653 let targets = Tensor::from_vec(vec![0.0, 1.0], &[2]).unwrap();
654
655 let batch = Batch::new(data, targets);
656 assert_eq!(batch.len(), 2);
657 assert!(!batch.is_empty());
658 }
659
660 #[test]
661 fn test_dataloader_empty() {
662 let x = Tensor::from_vec(vec![], &[0, 2]).unwrap();
663 let y = Tensor::from_vec(vec![], &[0]).unwrap();
664 let dataset = TensorDataset::new(x, y);
665 let loader = DataLoader::new(dataset, 3);
666
667 assert!(loader.is_empty());
668 let batches: Vec<Batch> = loader.iter().collect();
669 assert!(batches.is_empty());
670 }
671
672 #[test]
673 fn test_dataloader_single_item() {
674 let dataset = create_test_dataset(1);
675 let loader = DataLoader::new(dataset, 3);
676
677 let batches: Vec<Batch> = loader.iter().collect();
678 assert_eq!(batches.len(), 1);
679 assert_eq!(batches[0].len(), 1);
680 }
681
682 #[test]
683 fn test_dataloader_iteration_order() {
684 let dataset = create_test_dataset(6);
685 let loader = DataLoader::new(dataset, 2).shuffle(false);
686
687 let batches: Vec<Batch> = loader.iter().collect();
688
689 assert_eq!(batches[0].data.to_vec(), vec![0.0, 1.0, 2.0, 3.0]);
691 assert_eq!(batches[1].data.to_vec(), vec![4.0, 5.0, 6.0, 7.0]);
692 assert_eq!(batches[2].data.to_vec(), vec![8.0, 9.0, 10.0, 11.0]);
693 }
694
695 #[test]
696 fn test_generic_dataloader() {
697 let dataset = create_test_dataset(6);
698 let collate = DefaultCollate::new();
699 let loader = GenericDataLoader::new(dataset, collate, 2);
700
701 let batches: Vec<_> = loader.iter().collect();
702 assert_eq!(batches.len(), 3);
703 }
704
705 #[test]
706 fn test_dataloader_remaining() {
707 let dataset = create_test_dataset(10);
708 let loader = DataLoader::new(dataset, 3);
709
710 let mut iter = loader.iter();
711 assert_eq!(iter.remaining(), 4);
712
713 iter.next();
714 assert_eq!(iter.remaining(), 3);
715
716 iter.next();
717 assert_eq!(iter.remaining(), 2);
718 }
719
720 #[test]
721 fn test_parallel_dataloader() {
722 let dataset = create_test_dataset(100);
723 let loader = DataLoader::new(dataset, 10).num_workers(4);
724
725 let batches: Vec<Batch> = loader.iter().collect();
726 assert_eq!(batches.len(), 10);
727
728 let total_samples: usize = batches.iter().map(|b| b.len()).sum();
730 assert_eq!(total_samples, 100);
731 }
732
733 #[test]
734 fn test_parallel_vs_sequential_equivalence() {
735 let dataset_seq = create_test_dataset(50);
737 let dataset_par = create_test_dataset(50);
738
739 let loader_seq = DataLoader::new(dataset_seq, 5).num_workers(0);
741 let batches_seq: Vec<Batch> = loader_seq.iter().collect();
742
743 let loader_par = DataLoader::new(dataset_par, 5).num_workers(4);
745 let batches_par: Vec<Batch> = loader_par.iter().collect();
746
747 assert_eq!(batches_seq.len(), batches_par.len());
749
750 for i in 0..batches_seq.len() {
752 assert_eq!(batches_seq[i].data.to_vec(), batches_par[i].data.to_vec());
753 assert_eq!(
754 batches_seq[i].targets.to_vec(),
755 batches_par[i].targets.to_vec()
756 );
757 }
758 }
759
760 #[test]
761 fn test_parallel_dataloader_drop_last() {
762 let dataset = create_test_dataset(95);
763 let loader = DataLoader::new(dataset, 10).drop_last(true).num_workers(4);
764
765 let batches: Vec<Batch> = loader.iter().collect();
766 assert_eq!(batches.len(), 9); for batch in &batches {
769 assert_eq!(batch.len(), 10);
770 }
771 }
772
773 #[test]
774 fn test_parallel_generic_dataloader() {
775 let dataset = create_test_dataset(60);
776 let collate = DefaultCollate::new();
777 let loader = GenericDataLoader::new(dataset, collate, 10).num_workers(4);
778
779 let batches: Vec<_> = loader.iter().collect();
780 assert_eq!(batches.len(), 6);
781 }
782
783 #[test]
784 fn test_gpu_prefetch_cpu_fallback() {
785 use axonml_core::Device;
787
788 let dataset = create_test_dataset(10);
789 let loader = DataLoader::new(dataset, 3);
790
791 let batches: Vec<Batch> = loader.prefetch_to_gpu(Device::Cpu).collect();
793 assert_eq!(batches.len(), 4); assert_eq!(batches[0].len(), 3);
796 assert_eq!(batches[1].len(), 3);
797 assert_eq!(batches[2].len(), 3);
798 assert_eq!(batches[3].len(), 1);
799 }
800
801 #[test]
802 fn test_gpu_prefetch_data_integrity() {
803 use axonml_core::Device;
805
806 let dataset = create_test_dataset(6);
807 let loader = DataLoader::new(dataset, 2).shuffle(false);
808
809 let batches: Vec<Batch> = loader.prefetch_to_gpu(Device::Cpu).collect();
810
811 assert_eq!(batches[0].data.to_vec(), vec![0.0, 1.0, 2.0, 3.0]);
813 assert_eq!(batches[1].data.to_vec(), vec![4.0, 5.0, 6.0, 7.0]);
814 assert_eq!(batches[2].data.to_vec(), vec![8.0, 9.0, 10.0, 11.0]);
815 }
816
817 #[test]
818 fn test_gpu_prefetch_early_drop() {
819 use axonml_core::Device;
821
822 let dataset = create_test_dataset(100);
823 let loader = DataLoader::new(dataset, 10);
824
825 let mut iter = loader.prefetch_to_gpu(Device::Cpu);
826 let first = iter.next();
827 assert!(first.is_some());
828 assert_eq!(first.unwrap().len(), 10);
829
830 drop(iter);
832 }
833}