Skip to main content

axonml_data/
dataloader.rs

1//! `DataLoader` - Batched Data Iteration
2//!
3//! # File
4//! `crates/axonml-data/src/dataloader.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr. — AutomataNexus LLC
8//! ORCID: 0009-0005-2158-7060
9//!
10//! # Updated
11//! April 14, 2026 11:15 PM EST
12//!
13//! # Disclaimer
14//! Use at own risk. This software is provided "as is", without warranty of any
15//! kind, express or implied. The author and AutomataNexus shall not be held
16//! liable for any damages arising from the use of this software.
17
18use crate::collate::{Collate, stack_tensors};
19use crate::dataset::Dataset;
20use crate::sampler::{RandomSampler, Sampler, SequentialSampler};
21use axonml_core::Device;
22use axonml_tensor::Tensor;
23use rayon::prelude::*;
24use std::marker::PhantomData;
25use std::sync::mpsc;
26use std::thread;
27
28// =============================================================================
29// Batch Type
30// =============================================================================
31
32/// A batch of data from the `DataLoader`.
33#[derive(Debug, Clone)]
34pub struct Batch {
35    /// Batched input data.
36    pub data: Tensor<f32>,
37    /// Batched targets.
38    pub targets: Tensor<f32>,
39    /// Number of samples in this batch.
40    pub size: usize,
41}
42
43impl Batch {
44    /// Creates a new Batch.
45    #[must_use]
46    pub fn new(data: Tensor<f32>, targets: Tensor<f32>) -> Self {
47        let size = data.shape()[0];
48        Self {
49            data,
50            targets,
51            size,
52        }
53    }
54
55    /// Returns the batch size.
56    #[must_use]
57    pub fn len(&self) -> usize {
58        self.size
59    }
60
61    /// Returns true if the batch is empty.
62    #[must_use]
63    pub fn is_empty(&self) -> bool {
64        self.size == 0
65    }
66}
67
68// =============================================================================
69// DataLoader
70// =============================================================================
71
72/// `DataLoader` for batched iteration over datasets.
73///
74/// Provides configurable batching, shuffling, and iteration over datasets.
75pub struct DataLoader<D>
76where
77    D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
78{
79    /// The underlying dataset.
80    dataset: D,
81    /// Batch size.
82    batch_size: usize,
83    /// Whether to shuffle data each epoch.
84    shuffle: bool,
85    /// Whether to drop the last incomplete batch.
86    drop_last: bool,
87    /// Number of worker threads (for future parallel loading).
88    num_workers: usize,
89}
90
91impl<D> DataLoader<D>
92where
93    D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
94{
95    /// Creates a new `DataLoader` with the specified batch size.
96    pub fn new(dataset: D, batch_size: usize) -> Self {
97        Self {
98            dataset,
99            batch_size,
100            shuffle: false,
101            drop_last: false,
102            num_workers: 0,
103        }
104    }
105
106    /// Enables or disables shuffling.
107    pub fn shuffle(mut self, shuffle: bool) -> Self {
108        self.shuffle = shuffle;
109        self
110    }
111
112    /// Sets whether to drop the last incomplete batch.
113    pub fn drop_last(mut self, drop_last: bool) -> Self {
114        self.drop_last = drop_last;
115        self
116    }
117
118    /// Sets the number of worker threads for parallel data loading.
119    pub fn num_workers(mut self, num_workers: usize) -> Self {
120        self.num_workers = num_workers;
121        self
122    }
123
124    /// Returns the batch size.
125    pub fn batch_size(&self) -> usize {
126        self.batch_size
127    }
128
129    /// Returns the number of batches.
130    pub fn len(&self) -> usize {
131        let total = self.dataset.len();
132        if self.drop_last {
133            total / self.batch_size
134        } else {
135            total.div_ceil(self.batch_size)
136        }
137    }
138
139    /// Returns true if the `DataLoader` is empty.
140    pub fn is_empty(&self) -> bool {
141        self.dataset.is_empty()
142    }
143
144    /// Returns the dataset length.
145    pub fn dataset_len(&self) -> usize {
146        self.dataset.len()
147    }
148
149    /// Creates an iterator over batches.
150    pub fn iter(&self) -> DataLoaderIter<'_, D> {
151        let indices: Vec<usize> = if self.shuffle {
152            let sampler = RandomSampler::new(self.dataset.len());
153            sampler.iter().collect()
154        } else {
155            let sampler = SequentialSampler::new(self.dataset.len());
156            sampler.iter().collect()
157        };
158
159        DataLoaderIter {
160            dataset: &self.dataset,
161            indices,
162            batch_size: self.batch_size,
163            drop_last: self.drop_last,
164            position: 0,
165            num_workers: self.num_workers,
166        }
167    }
168}
169
170// =============================================================================
171// DataLoaderIter
172// =============================================================================
173
174/// Iterator over batches from a `DataLoader`.
175pub struct DataLoaderIter<'a, D>
176where
177    D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
178{
179    dataset: &'a D,
180    indices: Vec<usize>,
181    batch_size: usize,
182    drop_last: bool,
183    position: usize,
184    num_workers: usize,
185}
186
187impl<D> Iterator for DataLoaderIter<'_, D>
188where
189    D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
190{
191    type Item = Batch;
192
193    fn next(&mut self) -> Option<Self::Item> {
194        if self.position >= self.indices.len() {
195            return None;
196        }
197
198        let end = (self.position + self.batch_size).min(self.indices.len());
199        let batch_indices = &self.indices[self.position..end];
200
201        // Check if this is an incomplete batch
202        if batch_indices.len() < self.batch_size && self.drop_last {
203            return None;
204        }
205
206        // Collect samples for this batch (parallel when num_workers > 0)
207        let samples: Vec<(Tensor<f32>, Tensor<f32>)> = if self.num_workers > 0 {
208            // Parallel sample collection using rayon
209            batch_indices
210                .par_iter()
211                .filter_map(|&idx| self.dataset.get(idx))
212                .collect()
213        } else {
214            // Sequential fallback for num_workers = 0
215            batch_indices
216                .iter()
217                .filter_map(|&idx| self.dataset.get(idx))
218                .collect()
219        };
220
221        if samples.is_empty() {
222            return None;
223        }
224
225        // Separate data and targets for stacking
226        let data_samples: Vec<Tensor<f32>> = samples.iter().map(|(x, _)| x.clone()).collect();
227        let target_samples: Vec<Tensor<f32>> = samples.iter().map(|(_, y)| y.clone()).collect();
228
229        // Stack samples into batches
230        let data = stack_tensors(&data_samples);
231        let targets = stack_tensors(&target_samples);
232
233        self.position = end;
234
235        Some(Batch::new(data, targets))
236    }
237}
238
239impl<D> DataLoaderIter<'_, D>
240where
241    D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
242{
243    /// Returns the number of remaining batches.
244    #[must_use]
245    pub fn remaining(&self) -> usize {
246        let remaining_samples = self.indices.len().saturating_sub(self.position);
247        if self.drop_last {
248            remaining_samples / self.batch_size
249        } else {
250            remaining_samples.div_ceil(self.batch_size)
251        }
252    }
253}
254
255// =============================================================================
256// GPU Prefetch Iterator
257// =============================================================================
258
259/// A wrapper iterator that prefetches batches onto a GPU device in a background
260/// thread, overlapping CPU data loading with GPU computation.
261///
262/// When the training loop calls `next()`, it receives a batch that is already
263/// resident on the target GPU device. Meanwhile, the background thread is
264/// loading and transferring the next batch.
265///
266/// # Usage
267/// ```ignore
268/// let loader = DataLoader::new(dataset, 64).shuffle(true).num_workers(4);
269/// let device = Device::Cuda(0);
270///
271/// for batch in loader.prefetch_to_gpu(device) {
272///     // batch.data and batch.targets are already on the GPU
273///     let output = model.forward(&batch.data);
274/// }
275/// ```
276pub struct GpuPrefetchIter {
277    /// Receiver for pre-transferred GPU batches.
278    receiver: mpsc::Receiver<Batch>,
279    /// Handle to the background prefetch thread (joined on drop).
280    _worker: Option<thread::JoinHandle<()>>,
281}
282
283impl GpuPrefetchIter {
284    /// Creates a GPU prefetch iterator that streams batches lazily.
285    ///
286    /// Spawns a background thread that produces batches one at a time from the
287    /// dataset, transfers them to `device`, and sends through a bounded channel.
288    /// Only 2 batches are buffered at any time — no eager materialization.
289    fn new_streaming<D>(
290        dataset: D,
291        indices: Vec<usize>,
292        batch_size: usize,
293        drop_last: bool,
294        num_workers: usize,
295        device: Device,
296    ) -> Self
297    where
298        D: Dataset<Item = (Tensor<f32>, Tensor<f32>)> + 'static,
299    {
300        // Bounded channel: at most 2 batches buffered (current + next).
301        let (tx, rx) = mpsc::sync_channel(2);
302
303        let worker = thread::spawn(move || {
304            let mut position = 0;
305            while position < indices.len() {
306                let end = (position + batch_size).min(indices.len());
307                let batch_indices = &indices[position..end];
308
309                if batch_indices.len() < batch_size && drop_last {
310                    break;
311                }
312
313                // Collect samples for this batch
314                let samples: Vec<(Tensor<f32>, Tensor<f32>)> = if num_workers > 0 {
315                    batch_indices
316                        .par_iter()
317                        .filter_map(|&idx| dataset.get(idx))
318                        .collect()
319                } else {
320                    batch_indices
321                        .iter()
322                        .filter_map(|&idx| dataset.get(idx))
323                        .collect()
324                };
325
326                if samples.is_empty() {
327                    break;
328                }
329
330                let data_samples: Vec<Tensor<f32>> =
331                    samples.iter().map(|(x, _)| x.clone()).collect();
332                let target_samples: Vec<Tensor<f32>> =
333                    samples.iter().map(|(_, y)| y.clone()).collect();
334
335                let data = stack_tensors(&data_samples);
336                let targets = stack_tensors(&target_samples);
337
338                // Transfer to GPU
339                let gpu_data = match data.to_device(device) {
340                    Ok(t) => t,
341                    Err(_) => data,
342                };
343                let gpu_targets = match targets.to_device(device) {
344                    Ok(t) => t,
345                    Err(_) => targets,
346                };
347
348                if tx.send(Batch::new(gpu_data, gpu_targets)).is_err() {
349                    break;
350                }
351
352                position = end;
353            }
354        });
355
356        Self {
357            receiver: rx,
358            _worker: Some(worker),
359        }
360    }
361}
362
363impl Iterator for GpuPrefetchIter {
364    type Item = Batch;
365
366    fn next(&mut self) -> Option<Self::Item> {
367        self.receiver.recv().ok()
368    }
369}
370
371impl<D> DataLoader<D>
372where
373    D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
374{
375    /// Returns an iterator that prefetches batches onto the target GPU device
376    /// in a background thread.
377    ///
378    /// Batches are produced lazily — only 2 are buffered at any time, avoiding
379    /// the O(dataset_size) memory spike of eager collection. The background
380    /// thread overlaps CPU data loading with GPU computation.
381    ///
382    /// # Arguments
383    /// * `device` - Target GPU device (e.g., `Device::Cuda(0)`)
384    ///
385    /// # Example
386    /// ```ignore
387    /// let loader = DataLoader::new(dataset, 64).shuffle(true);
388    /// for batch in loader.prefetch_to_gpu(Device::Cuda(0)) {
389    ///     // batch.data and batch.targets are already on GPU
390    /// }
391    /// ```
392    pub fn prefetch_to_gpu(&self, device: Device) -> GpuPrefetchIter
393    where
394        D: Clone + 'static,
395    {
396        let indices: Vec<usize> = if self.shuffle {
397            let sampler = RandomSampler::new(self.dataset.len());
398            sampler.iter().collect()
399        } else {
400            (0..self.dataset.len()).collect()
401        };
402
403        GpuPrefetchIter::new_streaming(
404            self.dataset.clone(),
405            indices,
406            self.batch_size,
407            self.drop_last,
408            self.num_workers,
409            device,
410        )
411    }
412}
413
414// =============================================================================
415// GenericDataLoader
416// =============================================================================
417
418/// A more flexible `DataLoader` that works with any Dataset and Collate function.
419pub struct GenericDataLoader<D, C, T>
420where
421    D: Dataset<Item = T>,
422    C: Collate<T>,
423    T: Send,
424{
425    dataset: D,
426    collate_fn: C,
427    batch_size: usize,
428    shuffle: bool,
429    drop_last: bool,
430    num_workers: usize,
431    _phantom: PhantomData<T>,
432}
433
434impl<D, C, T> GenericDataLoader<D, C, T>
435where
436    D: Dataset<Item = T>,
437    C: Collate<T>,
438    T: Send,
439{
440    /// Creates a new `GenericDataLoader`.
441    pub fn new(dataset: D, collate_fn: C, batch_size: usize) -> Self {
442        Self {
443            dataset,
444            collate_fn,
445            batch_size,
446            shuffle: false,
447            drop_last: false,
448            num_workers: 0,
449            _phantom: PhantomData,
450        }
451    }
452
453    /// Sets the number of worker threads for parallel data loading.
454    pub fn num_workers(mut self, num_workers: usize) -> Self {
455        self.num_workers = num_workers;
456        self
457    }
458
459    /// Enables or disables shuffling.
460    pub fn shuffle(mut self, shuffle: bool) -> Self {
461        self.shuffle = shuffle;
462        self
463    }
464
465    /// Sets whether to drop the last incomplete batch.
466    pub fn drop_last(mut self, drop_last: bool) -> Self {
467        self.drop_last = drop_last;
468        self
469    }
470
471    /// Returns the number of batches.
472    pub fn len(&self) -> usize {
473        let total = self.dataset.len();
474        if self.drop_last {
475            total / self.batch_size
476        } else {
477            total.div_ceil(self.batch_size)
478        }
479    }
480
481    /// Returns true if empty.
482    pub fn is_empty(&self) -> bool {
483        self.dataset.is_empty()
484    }
485
486    /// Creates an iterator over batches.
487    #[allow(clippy::iter_not_returning_iterator)]
488    pub fn iter(&self) -> GenericDataLoaderIter<'_, D, C, T> {
489        let indices: Vec<usize> = if self.shuffle {
490            let sampler = RandomSampler::new(self.dataset.len());
491            sampler.iter().collect()
492        } else {
493            (0..self.dataset.len()).collect()
494        };
495
496        GenericDataLoaderIter {
497            dataset: &self.dataset,
498            collate_fn: &self.collate_fn,
499            indices,
500            batch_size: self.batch_size,
501            drop_last: self.drop_last,
502            position: 0,
503            num_workers: self.num_workers,
504            _phantom: PhantomData,
505        }
506    }
507}
508
509/// Iterator for `GenericDataLoader`.
510pub struct GenericDataLoaderIter<'a, D, C, T>
511where
512    D: Dataset<Item = T>,
513    C: Collate<T>,
514    T: Send,
515{
516    dataset: &'a D,
517    collate_fn: &'a C,
518    indices: Vec<usize>,
519    batch_size: usize,
520    drop_last: bool,
521    position: usize,
522    num_workers: usize,
523    _phantom: PhantomData<T>,
524}
525
526impl<D, C, T> Iterator for GenericDataLoaderIter<'_, D, C, T>
527where
528    D: Dataset<Item = T>,
529    C: Collate<T>,
530    T: Send + Sync,
531{
532    type Item = C::Output;
533
534    fn next(&mut self) -> Option<Self::Item> {
535        if self.position >= self.indices.len() {
536            return None;
537        }
538
539        let end = (self.position + self.batch_size).min(self.indices.len());
540        let batch_indices = &self.indices[self.position..end];
541
542        if batch_indices.len() < self.batch_size && self.drop_last {
543            return None;
544        }
545
546        // Collect samples (parallel when num_workers > 0)
547        let samples: Vec<T> = if self.num_workers > 0 {
548            batch_indices
549                .par_iter()
550                .filter_map(|&idx| self.dataset.get(idx))
551                .collect()
552        } else {
553            batch_indices
554                .iter()
555                .filter_map(|&idx| self.dataset.get(idx))
556                .collect()
557        };
558
559        if samples.is_empty() {
560            return None;
561        }
562
563        self.position = end;
564
565        Some(self.collate_fn.collate(samples))
566    }
567}
568
569// =============================================================================
570// Tests
571// =============================================================================
572
573#[cfg(test)]
574mod tests {
575    use super::*;
576    use crate::collate::DefaultCollate;
577    use crate::dataset::TensorDataset;
578
579    fn create_test_dataset(size: usize) -> TensorDataset {
580        let data: Vec<f32> = (0..size * 2).map(|i| i as f32).collect();
581        let targets: Vec<f32> = (0..size).map(|i| (i % 2) as f32).collect();
582
583        let x = Tensor::from_vec(data, &[size, 2]).unwrap();
584        let y = Tensor::from_vec(targets, &[size]).unwrap();
585
586        TensorDataset::new(x, y)
587    }
588
589    #[test]
590    fn test_dataloader_basic() {
591        let dataset = create_test_dataset(10);
592        let loader = DataLoader::new(dataset, 3);
593
594        assert_eq!(loader.batch_size(), 3);
595        assert_eq!(loader.len(), 4); // ceil(10/3) = 4
596
597        let batches: Vec<Batch> = loader.iter().collect();
598        assert_eq!(batches.len(), 4);
599
600        // First 3 batches have size 3, last has size 1
601        assert_eq!(batches[0].len(), 3);
602        assert_eq!(batches[1].len(), 3);
603        assert_eq!(batches[2].len(), 3);
604        assert_eq!(batches[3].len(), 1);
605    }
606
607    #[test]
608    fn test_dataloader_drop_last() {
609        let dataset = create_test_dataset(10);
610        let loader = DataLoader::new(dataset, 3).drop_last(true);
611
612        assert_eq!(loader.len(), 3); // floor(10/3) = 3
613
614        let batches: Vec<Batch> = loader.iter().collect();
615        assert_eq!(batches.len(), 3);
616
617        // All batches have full size
618        for batch in &batches {
619            assert_eq!(batch.len(), 3);
620        }
621    }
622
623    #[test]
624    fn test_dataloader_shuffle() {
625        let dataset = create_test_dataset(100);
626        let loader = DataLoader::new(dataset, 10).shuffle(true);
627
628        // Run multiple iterations and collect first batch data
629        let batch1: Vec<Batch> = loader.iter().take(1).collect();
630        let batch2: Vec<Batch> = loader.iter().take(1).collect();
631
632        // Due to shuffling, batches should (usually) be different
633        // We can't guarantee this, but the loader should work
634        assert!(!batch1.is_empty());
635        assert!(!batch2.is_empty());
636    }
637
638    #[test]
639    fn test_dataloader_exact_batches() {
640        let dataset = create_test_dataset(9);
641        let loader = DataLoader::new(dataset, 3);
642
643        let batches: Vec<Batch> = loader.iter().collect();
644        assert_eq!(batches.len(), 3);
645
646        for batch in &batches {
647            assert_eq!(batch.len(), 3);
648        }
649    }
650
651    #[test]
652    fn test_batch_struct() {
653        let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
654        let targets = Tensor::from_vec(vec![0.0, 1.0], &[2]).unwrap();
655
656        let batch = Batch::new(data, targets);
657        assert_eq!(batch.len(), 2);
658        assert!(!batch.is_empty());
659    }
660
661    #[test]
662    fn test_dataloader_empty() {
663        let x = Tensor::from_vec(vec![], &[0, 2]).unwrap();
664        let y = Tensor::from_vec(vec![], &[0]).unwrap();
665        let dataset = TensorDataset::new(x, y);
666        let loader = DataLoader::new(dataset, 3);
667
668        assert!(loader.is_empty());
669        let batches: Vec<Batch> = loader.iter().collect();
670        assert!(batches.is_empty());
671    }
672
673    #[test]
674    fn test_dataloader_single_item() {
675        let dataset = create_test_dataset(1);
676        let loader = DataLoader::new(dataset, 3);
677
678        let batches: Vec<Batch> = loader.iter().collect();
679        assert_eq!(batches.len(), 1);
680        assert_eq!(batches[0].len(), 1);
681    }
682
683    #[test]
684    fn test_dataloader_iteration_order() {
685        let dataset = create_test_dataset(6);
686        let loader = DataLoader::new(dataset, 2).shuffle(false);
687
688        let batches: Vec<Batch> = loader.iter().collect();
689
690        // Without shuffle, data should be in order
691        assert_eq!(batches[0].data.to_vec(), vec![0.0, 1.0, 2.0, 3.0]);
692        assert_eq!(batches[1].data.to_vec(), vec![4.0, 5.0, 6.0, 7.0]);
693        assert_eq!(batches[2].data.to_vec(), vec![8.0, 9.0, 10.0, 11.0]);
694    }
695
696    #[test]
697    fn test_generic_dataloader() {
698        let dataset = create_test_dataset(6);
699        let collate = DefaultCollate::new();
700        let loader = GenericDataLoader::new(dataset, collate, 2);
701
702        let batches: Vec<_> = loader.iter().collect();
703        assert_eq!(batches.len(), 3);
704    }
705
706    #[test]
707    fn test_dataloader_remaining() {
708        let dataset = create_test_dataset(10);
709        let loader = DataLoader::new(dataset, 3);
710
711        let mut iter = loader.iter();
712        assert_eq!(iter.remaining(), 4);
713
714        iter.next();
715        assert_eq!(iter.remaining(), 3);
716
717        iter.next();
718        assert_eq!(iter.remaining(), 2);
719    }
720
721    #[test]
722    fn test_parallel_dataloader() {
723        let dataset = create_test_dataset(100);
724        let loader = DataLoader::new(dataset, 10).num_workers(4);
725
726        let batches: Vec<Batch> = loader.iter().collect();
727        assert_eq!(batches.len(), 10);
728
729        // Verify all samples are present
730        let total_samples: usize = batches.iter().map(|b| b.len()).sum();
731        assert_eq!(total_samples, 100);
732    }
733
734    #[test]
735    fn test_parallel_vs_sequential_equivalence() {
736        // Create two identical datasets
737        let dataset_seq = create_test_dataset(50);
738        let dataset_par = create_test_dataset(50);
739
740        // Sequential
741        let loader_seq = DataLoader::new(dataset_seq, 5).num_workers(0);
742        let batches_seq: Vec<Batch> = loader_seq.iter().collect();
743
744        // Parallel
745        let loader_par = DataLoader::new(dataset_par, 5).num_workers(4);
746        let batches_par: Vec<Batch> = loader_par.iter().collect();
747
748        // Same number of batches
749        assert_eq!(batches_seq.len(), batches_par.len());
750
751        // Same data (no shuffle, so order should be same)
752        for i in 0..batches_seq.len() {
753            assert_eq!(batches_seq[i].data.to_vec(), batches_par[i].data.to_vec());
754            assert_eq!(
755                batches_seq[i].targets.to_vec(),
756                batches_par[i].targets.to_vec()
757            );
758        }
759    }
760
761    #[test]
762    fn test_parallel_dataloader_drop_last() {
763        let dataset = create_test_dataset(95);
764        let loader = DataLoader::new(dataset, 10).drop_last(true).num_workers(4);
765
766        let batches: Vec<Batch> = loader.iter().collect();
767        assert_eq!(batches.len(), 9); // 95 / 10 = 9 full batches
768
769        for batch in &batches {
770            assert_eq!(batch.len(), 10);
771        }
772    }
773
774    #[test]
775    fn test_parallel_generic_dataloader() {
776        let dataset = create_test_dataset(60);
777        let collate = DefaultCollate::new();
778        let loader = GenericDataLoader::new(dataset, collate, 10).num_workers(4);
779
780        let batches: Vec<_> = loader.iter().collect();
781        assert_eq!(batches.len(), 6);
782    }
783
784    #[test]
785    fn test_gpu_prefetch_cpu_fallback() {
786        // Test that prefetch_to_gpu works on CPU device (no-op transfer)
787        use axonml_core::Device;
788
789        let dataset = create_test_dataset(10);
790        let loader = DataLoader::new(dataset, 3);
791
792        // prefetch_to_gpu with CPU device should act as a pass-through
793        let batches: Vec<Batch> = loader.prefetch_to_gpu(Device::Cpu).collect();
794        assert_eq!(batches.len(), 4); // ceil(10/3) = 4
795
796        assert_eq!(batches[0].len(), 3);
797        assert_eq!(batches[1].len(), 3);
798        assert_eq!(batches[2].len(), 3);
799        assert_eq!(batches[3].len(), 1);
800    }
801
802    #[test]
803    fn test_gpu_prefetch_data_integrity() {
804        // Verify that data remains correct through the prefetch pipeline
805        use axonml_core::Device;
806
807        let dataset = create_test_dataset(6);
808        let loader = DataLoader::new(dataset, 2).shuffle(false);
809
810        let batches: Vec<Batch> = loader.prefetch_to_gpu(Device::Cpu).collect();
811
812        // Without shuffle, data should be in order (same as regular iter)
813        assert_eq!(batches[0].data.to_vec(), vec![0.0, 1.0, 2.0, 3.0]);
814        assert_eq!(batches[1].data.to_vec(), vec![4.0, 5.0, 6.0, 7.0]);
815        assert_eq!(batches[2].data.to_vec(), vec![8.0, 9.0, 10.0, 11.0]);
816    }
817
818    #[test]
819    fn test_gpu_prefetch_early_drop() {
820        // Test that dropping the iterator early doesn't leak or deadlock
821        use axonml_core::Device;
822
823        let dataset = create_test_dataset(100);
824        let loader = DataLoader::new(dataset, 10);
825
826        let mut iter = loader.prefetch_to_gpu(Device::Cpu);
827        let first = iter.next();
828        assert!(first.is_some());
829        assert_eq!(first.unwrap().len(), 10);
830
831        // Drop the iterator early - should not hang
832        drop(iter);
833    }
834}