Skip to main content

axonml_data/
dataloader.rs

1//! `DataLoader` - Batched Data Iteration
2//!
3//! # File
4//! `crates/axonml-data/src/dataloader.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use crate::collate::{Collate, stack_tensors};
18use crate::dataset::Dataset;
19use crate::sampler::{RandomSampler, Sampler, SequentialSampler};
20use axonml_core::Device;
21use axonml_tensor::Tensor;
22use rayon::prelude::*;
23use std::marker::PhantomData;
24use std::sync::mpsc;
25use std::thread;
26
27// =============================================================================
28// Batch Type
29// =============================================================================
30
31/// A batch of data from the `DataLoader`.
32#[derive(Debug, Clone)]
33pub struct Batch {
34    /// Batched input data.
35    pub data: Tensor<f32>,
36    /// Batched targets.
37    pub targets: Tensor<f32>,
38    /// Number of samples in this batch.
39    pub size: usize,
40}
41
42impl Batch {
43    /// Creates a new Batch.
44    #[must_use]
45    pub fn new(data: Tensor<f32>, targets: Tensor<f32>) -> Self {
46        let size = data.shape()[0];
47        Self {
48            data,
49            targets,
50            size,
51        }
52    }
53
54    /// Returns the batch size.
55    #[must_use]
56    pub fn len(&self) -> usize {
57        self.size
58    }
59
60    /// Returns true if the batch is empty.
61    #[must_use]
62    pub fn is_empty(&self) -> bool {
63        self.size == 0
64    }
65}
66
67// =============================================================================
68// DataLoader
69// =============================================================================
70
71/// `DataLoader` for batched iteration over datasets.
72///
73/// Provides configurable batching, shuffling, and iteration over datasets.
74pub struct DataLoader<D>
75where
76    D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
77{
78    /// The underlying dataset.
79    dataset: D,
80    /// Batch size.
81    batch_size: usize,
82    /// Whether to shuffle data each epoch.
83    shuffle: bool,
84    /// Whether to drop the last incomplete batch.
85    drop_last: bool,
86    /// Number of worker threads (for future parallel loading).
87    num_workers: usize,
88}
89
90impl<D> DataLoader<D>
91where
92    D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
93{
94    /// Creates a new `DataLoader` with the specified batch size.
95    pub fn new(dataset: D, batch_size: usize) -> Self {
96        Self {
97            dataset,
98            batch_size,
99            shuffle: false,
100            drop_last: false,
101            num_workers: 0,
102        }
103    }
104
105    /// Enables or disables shuffling.
106    pub fn shuffle(mut self, shuffle: bool) -> Self {
107        self.shuffle = shuffle;
108        self
109    }
110
111    /// Sets whether to drop the last incomplete batch.
112    pub fn drop_last(mut self, drop_last: bool) -> Self {
113        self.drop_last = drop_last;
114        self
115    }
116
117    /// Sets the number of worker threads for parallel data loading.
118    pub fn num_workers(mut self, num_workers: usize) -> Self {
119        self.num_workers = num_workers;
120        self
121    }
122
123    /// Returns the batch size.
124    pub fn batch_size(&self) -> usize {
125        self.batch_size
126    }
127
128    /// Returns the number of batches.
129    pub fn len(&self) -> usize {
130        let total = self.dataset.len();
131        if self.drop_last {
132            total / self.batch_size
133        } else {
134            total.div_ceil(self.batch_size)
135        }
136    }
137
138    /// Returns true if the `DataLoader` is empty.
139    pub fn is_empty(&self) -> bool {
140        self.dataset.is_empty()
141    }
142
143    /// Returns the dataset length.
144    pub fn dataset_len(&self) -> usize {
145        self.dataset.len()
146    }
147
148    /// Creates an iterator over batches.
149    pub fn iter(&self) -> DataLoaderIter<'_, D> {
150        let indices: Vec<usize> = if self.shuffle {
151            let sampler = RandomSampler::new(self.dataset.len());
152            sampler.iter().collect()
153        } else {
154            let sampler = SequentialSampler::new(self.dataset.len());
155            sampler.iter().collect()
156        };
157
158        DataLoaderIter {
159            dataset: &self.dataset,
160            indices,
161            batch_size: self.batch_size,
162            drop_last: self.drop_last,
163            position: 0,
164            num_workers: self.num_workers,
165        }
166    }
167}
168
169// =============================================================================
170// DataLoaderIter
171// =============================================================================
172
173/// Iterator over batches from a `DataLoader`.
174pub struct DataLoaderIter<'a, D>
175where
176    D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
177{
178    dataset: &'a D,
179    indices: Vec<usize>,
180    batch_size: usize,
181    drop_last: bool,
182    position: usize,
183    num_workers: usize,
184}
185
186impl<D> Iterator for DataLoaderIter<'_, D>
187where
188    D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
189{
190    type Item = Batch;
191
192    fn next(&mut self) -> Option<Self::Item> {
193        if self.position >= self.indices.len() {
194            return None;
195        }
196
197        let end = (self.position + self.batch_size).min(self.indices.len());
198        let batch_indices = &self.indices[self.position..end];
199
200        // Check if this is an incomplete batch
201        if batch_indices.len() < self.batch_size && self.drop_last {
202            return None;
203        }
204
205        // Collect samples for this batch (parallel when num_workers > 0)
206        let samples: Vec<(Tensor<f32>, Tensor<f32>)> = if self.num_workers > 0 {
207            // Parallel sample collection using rayon
208            batch_indices
209                .par_iter()
210                .filter_map(|&idx| self.dataset.get(idx))
211                .collect()
212        } else {
213            // Sequential fallback for num_workers = 0
214            batch_indices
215                .iter()
216                .filter_map(|&idx| self.dataset.get(idx))
217                .collect()
218        };
219
220        if samples.is_empty() {
221            return None;
222        }
223
224        // Separate data and targets for stacking
225        let data_samples: Vec<Tensor<f32>> = samples.iter().map(|(x, _)| x.clone()).collect();
226        let target_samples: Vec<Tensor<f32>> = samples.iter().map(|(_, y)| y.clone()).collect();
227
228        // Stack samples into batches
229        let data = stack_tensors(&data_samples);
230        let targets = stack_tensors(&target_samples);
231
232        self.position = end;
233
234        Some(Batch::new(data, targets))
235    }
236}
237
238impl<D> DataLoaderIter<'_, D>
239where
240    D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
241{
242    /// Returns the number of remaining batches.
243    #[must_use]
244    pub fn remaining(&self) -> usize {
245        let remaining_samples = self.indices.len().saturating_sub(self.position);
246        if self.drop_last {
247            remaining_samples / self.batch_size
248        } else {
249            remaining_samples.div_ceil(self.batch_size)
250        }
251    }
252}
253
254// =============================================================================
255// GPU Prefetch Iterator
256// =============================================================================
257
258/// A wrapper iterator that prefetches batches onto a GPU device in a background
259/// thread, overlapping CPU data loading with GPU computation.
260///
261/// When the training loop calls `next()`, it receives a batch that is already
262/// resident on the target GPU device. Meanwhile, the background thread is
263/// loading and transferring the next batch.
264///
265/// # Usage
266/// ```ignore
267/// let loader = DataLoader::new(dataset, 64).shuffle(true).num_workers(4);
268/// let device = Device::Cuda(0);
269///
270/// for batch in loader.prefetch_to_gpu(device) {
271///     // batch.data and batch.targets are already on the GPU
272///     let output = model.forward(&batch.data);
273/// }
274/// ```
275pub struct GpuPrefetchIter {
276    /// Receiver for pre-transferred GPU batches.
277    receiver: mpsc::Receiver<Batch>,
278    /// Handle to the background prefetch thread (joined on drop).
279    _worker: Option<thread::JoinHandle<()>>,
280}
281
282impl GpuPrefetchIter {
283    /// Creates a GPU prefetch iterator that streams batches lazily.
284    ///
285    /// Spawns a background thread that produces batches one at a time from the
286    /// dataset, transfers them to `device`, and sends through a bounded channel.
287    /// Only 2 batches are buffered at any time — no eager materialization.
288    fn new_streaming<D>(
289        dataset: D,
290        indices: Vec<usize>,
291        batch_size: usize,
292        drop_last: bool,
293        num_workers: usize,
294        device: Device,
295    ) -> Self
296    where
297        D: Dataset<Item = (Tensor<f32>, Tensor<f32>)> + 'static,
298    {
299        // Bounded channel: at most 2 batches buffered (current + next).
300        let (tx, rx) = mpsc::sync_channel(2);
301
302        let worker = thread::spawn(move || {
303            let mut position = 0;
304            while position < indices.len() {
305                let end = (position + batch_size).min(indices.len());
306                let batch_indices = &indices[position..end];
307
308                if batch_indices.len() < batch_size && drop_last {
309                    break;
310                }
311
312                // Collect samples for this batch
313                let samples: Vec<(Tensor<f32>, Tensor<f32>)> = if num_workers > 0 {
314                    batch_indices
315                        .par_iter()
316                        .filter_map(|&idx| dataset.get(idx))
317                        .collect()
318                } else {
319                    batch_indices
320                        .iter()
321                        .filter_map(|&idx| dataset.get(idx))
322                        .collect()
323                };
324
325                if samples.is_empty() {
326                    break;
327                }
328
329                let data_samples: Vec<Tensor<f32>> =
330                    samples.iter().map(|(x, _)| x.clone()).collect();
331                let target_samples: Vec<Tensor<f32>> =
332                    samples.iter().map(|(_, y)| y.clone()).collect();
333
334                let data = stack_tensors(&data_samples);
335                let targets = stack_tensors(&target_samples);
336
337                // Transfer to GPU
338                let gpu_data = match data.to_device(device) {
339                    Ok(t) => t,
340                    Err(_) => data,
341                };
342                let gpu_targets = match targets.to_device(device) {
343                    Ok(t) => t,
344                    Err(_) => targets,
345                };
346
347                if tx.send(Batch::new(gpu_data, gpu_targets)).is_err() {
348                    break;
349                }
350
351                position = end;
352            }
353        });
354
355        Self {
356            receiver: rx,
357            _worker: Some(worker),
358        }
359    }
360}
361
362impl Iterator for GpuPrefetchIter {
363    type Item = Batch;
364
365    fn next(&mut self) -> Option<Self::Item> {
366        self.receiver.recv().ok()
367    }
368}
369
370impl<D> DataLoader<D>
371where
372    D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
373{
374    /// Returns an iterator that prefetches batches onto the target GPU device
375    /// in a background thread.
376    ///
377    /// Batches are produced lazily — only 2 are buffered at any time, avoiding
378    /// the O(dataset_size) memory spike of eager collection. The background
379    /// thread overlaps CPU data loading with GPU computation.
380    ///
381    /// # Arguments
382    /// * `device` - Target GPU device (e.g., `Device::Cuda(0)`)
383    ///
384    /// # Example
385    /// ```ignore
386    /// let loader = DataLoader::new(dataset, 64).shuffle(true);
387    /// for batch in loader.prefetch_to_gpu(Device::Cuda(0)) {
388    ///     // batch.data and batch.targets are already on GPU
389    /// }
390    /// ```
391    pub fn prefetch_to_gpu(&self, device: Device) -> GpuPrefetchIter
392    where
393        D: Clone + 'static,
394    {
395        let indices: Vec<usize> = if self.shuffle {
396            let sampler = RandomSampler::new(self.dataset.len());
397            sampler.iter().collect()
398        } else {
399            (0..self.dataset.len()).collect()
400        };
401
402        GpuPrefetchIter::new_streaming(
403            self.dataset.clone(),
404            indices,
405            self.batch_size,
406            self.drop_last,
407            self.num_workers,
408            device,
409        )
410    }
411}
412
413// =============================================================================
414// GenericDataLoader
415// =============================================================================
416
417/// A more flexible `DataLoader` that works with any Dataset and Collate function.
418pub struct GenericDataLoader<D, C, T>
419where
420    D: Dataset<Item = T>,
421    C: Collate<T>,
422    T: Send,
423{
424    dataset: D,
425    collate_fn: C,
426    batch_size: usize,
427    shuffle: bool,
428    drop_last: bool,
429    num_workers: usize,
430    _phantom: PhantomData<T>,
431}
432
433impl<D, C, T> GenericDataLoader<D, C, T>
434where
435    D: Dataset<Item = T>,
436    C: Collate<T>,
437    T: Send,
438{
439    /// Creates a new `GenericDataLoader`.
440    pub fn new(dataset: D, collate_fn: C, batch_size: usize) -> Self {
441        Self {
442            dataset,
443            collate_fn,
444            batch_size,
445            shuffle: false,
446            drop_last: false,
447            num_workers: 0,
448            _phantom: PhantomData,
449        }
450    }
451
452    /// Sets the number of worker threads for parallel data loading.
453    pub fn num_workers(mut self, num_workers: usize) -> Self {
454        self.num_workers = num_workers;
455        self
456    }
457
458    /// Enables or disables shuffling.
459    pub fn shuffle(mut self, shuffle: bool) -> Self {
460        self.shuffle = shuffle;
461        self
462    }
463
464    /// Sets whether to drop the last incomplete batch.
465    pub fn drop_last(mut self, drop_last: bool) -> Self {
466        self.drop_last = drop_last;
467        self
468    }
469
470    /// Returns the number of batches.
471    pub fn len(&self) -> usize {
472        let total = self.dataset.len();
473        if self.drop_last {
474            total / self.batch_size
475        } else {
476            total.div_ceil(self.batch_size)
477        }
478    }
479
480    /// Returns true if empty.
481    pub fn is_empty(&self) -> bool {
482        self.dataset.is_empty()
483    }
484
485    /// Creates an iterator over batches.
486    #[allow(clippy::iter_not_returning_iterator)]
487    pub fn iter(&self) -> GenericDataLoaderIter<'_, D, C, T> {
488        let indices: Vec<usize> = if self.shuffle {
489            let sampler = RandomSampler::new(self.dataset.len());
490            sampler.iter().collect()
491        } else {
492            (0..self.dataset.len()).collect()
493        };
494
495        GenericDataLoaderIter {
496            dataset: &self.dataset,
497            collate_fn: &self.collate_fn,
498            indices,
499            batch_size: self.batch_size,
500            drop_last: self.drop_last,
501            position: 0,
502            num_workers: self.num_workers,
503            _phantom: PhantomData,
504        }
505    }
506}
507
508/// Iterator for `GenericDataLoader`.
509pub struct GenericDataLoaderIter<'a, D, C, T>
510where
511    D: Dataset<Item = T>,
512    C: Collate<T>,
513    T: Send,
514{
515    dataset: &'a D,
516    collate_fn: &'a C,
517    indices: Vec<usize>,
518    batch_size: usize,
519    drop_last: bool,
520    position: usize,
521    num_workers: usize,
522    _phantom: PhantomData<T>,
523}
524
525impl<D, C, T> Iterator for GenericDataLoaderIter<'_, D, C, T>
526where
527    D: Dataset<Item = T>,
528    C: Collate<T>,
529    T: Send + Sync,
530{
531    type Item = C::Output;
532
533    fn next(&mut self) -> Option<Self::Item> {
534        if self.position >= self.indices.len() {
535            return None;
536        }
537
538        let end = (self.position + self.batch_size).min(self.indices.len());
539        let batch_indices = &self.indices[self.position..end];
540
541        if batch_indices.len() < self.batch_size && self.drop_last {
542            return None;
543        }
544
545        // Collect samples (parallel when num_workers > 0)
546        let samples: Vec<T> = if self.num_workers > 0 {
547            batch_indices
548                .par_iter()
549                .filter_map(|&idx| self.dataset.get(idx))
550                .collect()
551        } else {
552            batch_indices
553                .iter()
554                .filter_map(|&idx| self.dataset.get(idx))
555                .collect()
556        };
557
558        if samples.is_empty() {
559            return None;
560        }
561
562        self.position = end;
563
564        Some(self.collate_fn.collate(samples))
565    }
566}
567
568// =============================================================================
569// Tests
570// =============================================================================
571
572#[cfg(test)]
573mod tests {
574    use super::*;
575    use crate::collate::DefaultCollate;
576    use crate::dataset::TensorDataset;
577
578    fn create_test_dataset(size: usize) -> TensorDataset {
579        let data: Vec<f32> = (0..size * 2).map(|i| i as f32).collect();
580        let targets: Vec<f32> = (0..size).map(|i| (i % 2) as f32).collect();
581
582        let x = Tensor::from_vec(data, &[size, 2]).unwrap();
583        let y = Tensor::from_vec(targets, &[size]).unwrap();
584
585        TensorDataset::new(x, y)
586    }
587
588    #[test]
589    fn test_dataloader_basic() {
590        let dataset = create_test_dataset(10);
591        let loader = DataLoader::new(dataset, 3);
592
593        assert_eq!(loader.batch_size(), 3);
594        assert_eq!(loader.len(), 4); // ceil(10/3) = 4
595
596        let batches: Vec<Batch> = loader.iter().collect();
597        assert_eq!(batches.len(), 4);
598
599        // First 3 batches have size 3, last has size 1
600        assert_eq!(batches[0].len(), 3);
601        assert_eq!(batches[1].len(), 3);
602        assert_eq!(batches[2].len(), 3);
603        assert_eq!(batches[3].len(), 1);
604    }
605
606    #[test]
607    fn test_dataloader_drop_last() {
608        let dataset = create_test_dataset(10);
609        let loader = DataLoader::new(dataset, 3).drop_last(true);
610
611        assert_eq!(loader.len(), 3); // floor(10/3) = 3
612
613        let batches: Vec<Batch> = loader.iter().collect();
614        assert_eq!(batches.len(), 3);
615
616        // All batches have full size
617        for batch in &batches {
618            assert_eq!(batch.len(), 3);
619        }
620    }
621
622    #[test]
623    fn test_dataloader_shuffle() {
624        let dataset = create_test_dataset(100);
625        let loader = DataLoader::new(dataset, 10).shuffle(true);
626
627        // Run multiple iterations and collect first batch data
628        let batch1: Vec<Batch> = loader.iter().take(1).collect();
629        let batch2: Vec<Batch> = loader.iter().take(1).collect();
630
631        // Due to shuffling, batches should (usually) be different
632        // We can't guarantee this, but the loader should work
633        assert!(!batch1.is_empty());
634        assert!(!batch2.is_empty());
635    }
636
637    #[test]
638    fn test_dataloader_exact_batches() {
639        let dataset = create_test_dataset(9);
640        let loader = DataLoader::new(dataset, 3);
641
642        let batches: Vec<Batch> = loader.iter().collect();
643        assert_eq!(batches.len(), 3);
644
645        for batch in &batches {
646            assert_eq!(batch.len(), 3);
647        }
648    }
649
650    #[test]
651    fn test_batch_struct() {
652        let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
653        let targets = Tensor::from_vec(vec![0.0, 1.0], &[2]).unwrap();
654
655        let batch = Batch::new(data, targets);
656        assert_eq!(batch.len(), 2);
657        assert!(!batch.is_empty());
658    }
659
660    #[test]
661    fn test_dataloader_empty() {
662        let x = Tensor::from_vec(vec![], &[0, 2]).unwrap();
663        let y = Tensor::from_vec(vec![], &[0]).unwrap();
664        let dataset = TensorDataset::new(x, y);
665        let loader = DataLoader::new(dataset, 3);
666
667        assert!(loader.is_empty());
668        let batches: Vec<Batch> = loader.iter().collect();
669        assert!(batches.is_empty());
670    }
671
672    #[test]
673    fn test_dataloader_single_item() {
674        let dataset = create_test_dataset(1);
675        let loader = DataLoader::new(dataset, 3);
676
677        let batches: Vec<Batch> = loader.iter().collect();
678        assert_eq!(batches.len(), 1);
679        assert_eq!(batches[0].len(), 1);
680    }
681
682    #[test]
683    fn test_dataloader_iteration_order() {
684        let dataset = create_test_dataset(6);
685        let loader = DataLoader::new(dataset, 2).shuffle(false);
686
687        let batches: Vec<Batch> = loader.iter().collect();
688
689        // Without shuffle, data should be in order
690        assert_eq!(batches[0].data.to_vec(), vec![0.0, 1.0, 2.0, 3.0]);
691        assert_eq!(batches[1].data.to_vec(), vec![4.0, 5.0, 6.0, 7.0]);
692        assert_eq!(batches[2].data.to_vec(), vec![8.0, 9.0, 10.0, 11.0]);
693    }
694
695    #[test]
696    fn test_generic_dataloader() {
697        let dataset = create_test_dataset(6);
698        let collate = DefaultCollate::new();
699        let loader = GenericDataLoader::new(dataset, collate, 2);
700
701        let batches: Vec<_> = loader.iter().collect();
702        assert_eq!(batches.len(), 3);
703    }
704
705    #[test]
706    fn test_dataloader_remaining() {
707        let dataset = create_test_dataset(10);
708        let loader = DataLoader::new(dataset, 3);
709
710        let mut iter = loader.iter();
711        assert_eq!(iter.remaining(), 4);
712
713        iter.next();
714        assert_eq!(iter.remaining(), 3);
715
716        iter.next();
717        assert_eq!(iter.remaining(), 2);
718    }
719
720    #[test]
721    fn test_parallel_dataloader() {
722        let dataset = create_test_dataset(100);
723        let loader = DataLoader::new(dataset, 10).num_workers(4);
724
725        let batches: Vec<Batch> = loader.iter().collect();
726        assert_eq!(batches.len(), 10);
727
728        // Verify all samples are present
729        let total_samples: usize = batches.iter().map(|b| b.len()).sum();
730        assert_eq!(total_samples, 100);
731    }
732
733    #[test]
734    fn test_parallel_vs_sequential_equivalence() {
735        // Create two identical datasets
736        let dataset_seq = create_test_dataset(50);
737        let dataset_par = create_test_dataset(50);
738
739        // Sequential
740        let loader_seq = DataLoader::new(dataset_seq, 5).num_workers(0);
741        let batches_seq: Vec<Batch> = loader_seq.iter().collect();
742
743        // Parallel
744        let loader_par = DataLoader::new(dataset_par, 5).num_workers(4);
745        let batches_par: Vec<Batch> = loader_par.iter().collect();
746
747        // Same number of batches
748        assert_eq!(batches_seq.len(), batches_par.len());
749
750        // Same data (no shuffle, so order should be same)
751        for i in 0..batches_seq.len() {
752            assert_eq!(batches_seq[i].data.to_vec(), batches_par[i].data.to_vec());
753            assert_eq!(
754                batches_seq[i].targets.to_vec(),
755                batches_par[i].targets.to_vec()
756            );
757        }
758    }
759
760    #[test]
761    fn test_parallel_dataloader_drop_last() {
762        let dataset = create_test_dataset(95);
763        let loader = DataLoader::new(dataset, 10).drop_last(true).num_workers(4);
764
765        let batches: Vec<Batch> = loader.iter().collect();
766        assert_eq!(batches.len(), 9); // 95 / 10 = 9 full batches
767
768        for batch in &batches {
769            assert_eq!(batch.len(), 10);
770        }
771    }
772
773    #[test]
774    fn test_parallel_generic_dataloader() {
775        let dataset = create_test_dataset(60);
776        let collate = DefaultCollate::new();
777        let loader = GenericDataLoader::new(dataset, collate, 10).num_workers(4);
778
779        let batches: Vec<_> = loader.iter().collect();
780        assert_eq!(batches.len(), 6);
781    }
782
783    #[test]
784    fn test_gpu_prefetch_cpu_fallback() {
785        // Test that prefetch_to_gpu works on CPU device (no-op transfer)
786        use axonml_core::Device;
787
788        let dataset = create_test_dataset(10);
789        let loader = DataLoader::new(dataset, 3);
790
791        // prefetch_to_gpu with CPU device should act as a pass-through
792        let batches: Vec<Batch> = loader.prefetch_to_gpu(Device::Cpu).collect();
793        assert_eq!(batches.len(), 4); // ceil(10/3) = 4
794
795        assert_eq!(batches[0].len(), 3);
796        assert_eq!(batches[1].len(), 3);
797        assert_eq!(batches[2].len(), 3);
798        assert_eq!(batches[3].len(), 1);
799    }
800
801    #[test]
802    fn test_gpu_prefetch_data_integrity() {
803        // Verify that data remains correct through the prefetch pipeline
804        use axonml_core::Device;
805
806        let dataset = create_test_dataset(6);
807        let loader = DataLoader::new(dataset, 2).shuffle(false);
808
809        let batches: Vec<Batch> = loader.prefetch_to_gpu(Device::Cpu).collect();
810
811        // Without shuffle, data should be in order (same as regular iter)
812        assert_eq!(batches[0].data.to_vec(), vec![0.0, 1.0, 2.0, 3.0]);
813        assert_eq!(batches[1].data.to_vec(), vec![4.0, 5.0, 6.0, 7.0]);
814        assert_eq!(batches[2].data.to_vec(), vec![8.0, 9.0, 10.0, 11.0]);
815    }
816
817    #[test]
818    fn test_gpu_prefetch_early_drop() {
819        // Test that dropping the iterator early doesn't leak or deadlock
820        use axonml_core::Device;
821
822        let dataset = create_test_dataset(100);
823        let loader = DataLoader::new(dataset, 10);
824
825        let mut iter = loader.prefetch_to_gpu(Device::Cpu);
826        let first = iter.next();
827        assert!(first.is_some());
828        assert_eq!(first.unwrap().len(), 10);
829
830        // Drop the iterator early - should not hang
831        drop(iter);
832    }
833}