Skip to main content

axonml_data/
dataloader.rs

1//! `DataLoader` - Batched Data Iteration
2//!
3//! Provides efficient batched iteration over datasets with optional
4//! shuffling and parallel data loading.
5//!
6//! @version 0.1.0
7//! @author `AutomataNexus` Development Team
8
9use crate::collate::{stack_tensors, Collate};
10use crate::dataset::Dataset;
11use crate::sampler::{RandomSampler, Sampler, SequentialSampler};
12use axonml_tensor::Tensor;
13use rayon::prelude::*;
14use std::marker::PhantomData;
15
16// =============================================================================
17// Batch Type
18// =============================================================================
19
20/// A batch of data from the `DataLoader`.
21#[derive(Debug, Clone)]
22pub struct Batch {
23    /// Batched input data.
24    pub data: Tensor<f32>,
25    /// Batched targets.
26    pub targets: Tensor<f32>,
27    /// Number of samples in this batch.
28    pub size: usize,
29}
30
31impl Batch {
32    /// Creates a new Batch.
33    #[must_use] pub fn new(data: Tensor<f32>, targets: Tensor<f32>) -> Self {
34        let size = data.shape()[0];
35        Self {
36            data,
37            targets,
38            size,
39        }
40    }
41
42    /// Returns the batch size.
43    #[must_use] pub fn len(&self) -> usize {
44        self.size
45    }
46
47    /// Returns true if the batch is empty.
48    #[must_use] pub fn is_empty(&self) -> bool {
49        self.size == 0
50    }
51}
52
53// =============================================================================
54// DataLoader
55// =============================================================================
56
57/// `DataLoader` for batched iteration over datasets.
58///
59/// Provides configurable batching, shuffling, and iteration over datasets.
60pub struct DataLoader<D>
61where
62    D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
63{
64    /// The underlying dataset.
65    dataset: D,
66    /// Batch size.
67    batch_size: usize,
68    /// Whether to shuffle data each epoch.
69    shuffle: bool,
70    /// Whether to drop the last incomplete batch.
71    drop_last: bool,
72    /// Number of worker threads (for future parallel loading).
73    num_workers: usize,
74}
75
76impl<D> DataLoader<D>
77where
78    D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
79{
80    /// Creates a new `DataLoader` with the specified batch size.
81    pub fn new(dataset: D, batch_size: usize) -> Self {
82        Self {
83            dataset,
84            batch_size,
85            shuffle: false,
86            drop_last: false,
87            num_workers: 0,
88        }
89    }
90
91    /// Enables or disables shuffling.
92    pub fn shuffle(mut self, shuffle: bool) -> Self {
93        self.shuffle = shuffle;
94        self
95    }
96
97    /// Sets whether to drop the last incomplete batch.
98    pub fn drop_last(mut self, drop_last: bool) -> Self {
99        self.drop_last = drop_last;
100        self
101    }
102
103    /// Sets the number of worker threads for parallel data loading.
104    pub fn num_workers(mut self, num_workers: usize) -> Self {
105        self.num_workers = num_workers;
106        self
107    }
108
109    /// Returns the batch size.
110    pub fn batch_size(&self) -> usize {
111        self.batch_size
112    }
113
114    /// Returns the number of batches.
115    pub fn len(&self) -> usize {
116        let total = self.dataset.len();
117        if self.drop_last {
118            total / self.batch_size
119        } else {
120            total.div_ceil(self.batch_size)
121        }
122    }
123
124    /// Returns true if the `DataLoader` is empty.
125    pub fn is_empty(&self) -> bool {
126        self.dataset.is_empty()
127    }
128
129    /// Returns the dataset length.
130    pub fn dataset_len(&self) -> usize {
131        self.dataset.len()
132    }
133
134    /// Creates an iterator over batches.
135    pub fn iter(&self) -> DataLoaderIter<'_, D> {
136        let indices: Vec<usize> = if self.shuffle {
137            let sampler = RandomSampler::new(self.dataset.len());
138            sampler.iter().collect()
139        } else {
140            let sampler = SequentialSampler::new(self.dataset.len());
141            sampler.iter().collect()
142        };
143
144        DataLoaderIter {
145            dataset: &self.dataset,
146            indices,
147            batch_size: self.batch_size,
148            drop_last: self.drop_last,
149            position: 0,
150            num_workers: self.num_workers,
151        }
152    }
153}
154
155// =============================================================================
156// DataLoaderIter
157// =============================================================================
158
159/// Iterator over batches from a `DataLoader`.
160pub struct DataLoaderIter<'a, D>
161where
162    D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
163{
164    dataset: &'a D,
165    indices: Vec<usize>,
166    batch_size: usize,
167    drop_last: bool,
168    position: usize,
169    num_workers: usize,
170}
171
172impl<D> Iterator for DataLoaderIter<'_, D>
173where
174    D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
175{
176    type Item = Batch;
177
178    fn next(&mut self) -> Option<Self::Item> {
179        if self.position >= self.indices.len() {
180            return None;
181        }
182
183        let end = (self.position + self.batch_size).min(self.indices.len());
184        let batch_indices = &self.indices[self.position..end];
185
186        // Check if this is an incomplete batch
187        if batch_indices.len() < self.batch_size && self.drop_last {
188            return None;
189        }
190
191        // Collect samples for this batch (parallel when num_workers > 0)
192        let samples: Vec<(Tensor<f32>, Tensor<f32>)> = if self.num_workers > 0 {
193            // Parallel sample collection using rayon
194            batch_indices
195                .par_iter()
196                .filter_map(|&idx| self.dataset.get(idx))
197                .collect()
198        } else {
199            // Sequential fallback for num_workers = 0
200            batch_indices
201                .iter()
202                .filter_map(|&idx| self.dataset.get(idx))
203                .collect()
204        };
205
206        if samples.is_empty() {
207            return None;
208        }
209
210        // Separate data and targets for stacking
211        let data_samples: Vec<Tensor<f32>> = samples.iter().map(|(x, _)| x.clone()).collect();
212        let target_samples: Vec<Tensor<f32>> = samples.iter().map(|(_, y)| y.clone()).collect();
213
214        // Stack samples into batches
215        let data = stack_tensors(&data_samples);
216        let targets = stack_tensors(&target_samples);
217
218        self.position = end;
219
220        Some(Batch::new(data, targets))
221    }
222}
223
224impl<D> DataLoaderIter<'_, D>
225where
226    D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
227{
228    /// Returns the number of remaining batches.
229    #[must_use] pub fn remaining(&self) -> usize {
230        let remaining_samples = self.indices.len().saturating_sub(self.position);
231        if self.drop_last {
232            remaining_samples / self.batch_size
233        } else {
234            remaining_samples.div_ceil(self.batch_size)
235        }
236    }
237}
238
239// =============================================================================
240// GenericDataLoader
241// =============================================================================
242
243/// A more flexible `DataLoader` that works with any Dataset and Collate function.
244pub struct GenericDataLoader<D, C, T>
245where
246    D: Dataset<Item = T>,
247    C: Collate<T>,
248    T: Send,
249{
250    dataset: D,
251    collate_fn: C,
252    batch_size: usize,
253    shuffle: bool,
254    drop_last: bool,
255    num_workers: usize,
256    _phantom: PhantomData<T>,
257}
258
259impl<D, C, T> GenericDataLoader<D, C, T>
260where
261    D: Dataset<Item = T>,
262    C: Collate<T>,
263    T: Send,
264{
265    /// Creates a new `GenericDataLoader`.
266    pub fn new(dataset: D, collate_fn: C, batch_size: usize) -> Self {
267        Self {
268            dataset,
269            collate_fn,
270            batch_size,
271            shuffle: false,
272            drop_last: false,
273            num_workers: 0,
274            _phantom: PhantomData,
275        }
276    }
277
278    /// Sets the number of worker threads for parallel data loading.
279    pub fn num_workers(mut self, num_workers: usize) -> Self {
280        self.num_workers = num_workers;
281        self
282    }
283
284    /// Enables or disables shuffling.
285    pub fn shuffle(mut self, shuffle: bool) -> Self {
286        self.shuffle = shuffle;
287        self
288    }
289
290    /// Sets whether to drop the last incomplete batch.
291    pub fn drop_last(mut self, drop_last: bool) -> Self {
292        self.drop_last = drop_last;
293        self
294    }
295
296    /// Returns the number of batches.
297    pub fn len(&self) -> usize {
298        let total = self.dataset.len();
299        if self.drop_last {
300            total / self.batch_size
301        } else {
302            total.div_ceil(self.batch_size)
303        }
304    }
305
306    /// Returns true if empty.
307    pub fn is_empty(&self) -> bool {
308        self.dataset.is_empty()
309    }
310
311    /// Creates an iterator over batches.
312    pub fn iter(&self) -> GenericDataLoaderIter<'_, D, C, T> {
313        let indices: Vec<usize> = if self.shuffle {
314            let sampler = RandomSampler::new(self.dataset.len());
315            sampler.iter().collect()
316        } else {
317            (0..self.dataset.len()).collect()
318        };
319
320        GenericDataLoaderIter {
321            dataset: &self.dataset,
322            collate_fn: &self.collate_fn,
323            indices,
324            batch_size: self.batch_size,
325            drop_last: self.drop_last,
326            position: 0,
327            num_workers: self.num_workers,
328            _phantom: PhantomData,
329        }
330    }
331}
332
333/// Iterator for `GenericDataLoader`.
334pub struct GenericDataLoaderIter<'a, D, C, T>
335where
336    D: Dataset<Item = T>,
337    C: Collate<T>,
338    T: Send,
339{
340    dataset: &'a D,
341    collate_fn: &'a C,
342    indices: Vec<usize>,
343    batch_size: usize,
344    drop_last: bool,
345    position: usize,
346    num_workers: usize,
347    _phantom: PhantomData<T>,
348}
349
350impl<D, C, T> Iterator for GenericDataLoaderIter<'_, D, C, T>
351where
352    D: Dataset<Item = T>,
353    C: Collate<T>,
354    T: Send + Sync,
355{
356    type Item = C::Output;
357
358    fn next(&mut self) -> Option<Self::Item> {
359        if self.position >= self.indices.len() {
360            return None;
361        }
362
363        let end = (self.position + self.batch_size).min(self.indices.len());
364        let batch_indices = &self.indices[self.position..end];
365
366        if batch_indices.len() < self.batch_size && self.drop_last {
367            return None;
368        }
369
370        // Collect samples (parallel when num_workers > 0)
371        let samples: Vec<T> = if self.num_workers > 0 {
372            batch_indices
373                .par_iter()
374                .filter_map(|&idx| self.dataset.get(idx))
375                .collect()
376        } else {
377            batch_indices
378                .iter()
379                .filter_map(|&idx| self.dataset.get(idx))
380                .collect()
381        };
382
383        if samples.is_empty() {
384            return None;
385        }
386
387        self.position = end;
388
389        Some(self.collate_fn.collate(samples))
390    }
391}
392
393// =============================================================================
394// Tests
395// =============================================================================
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400    use crate::collate::DefaultCollate;
401    use crate::dataset::TensorDataset;
402
403    fn create_test_dataset(size: usize) -> TensorDataset {
404        let data: Vec<f32> = (0..size * 2).map(|i| i as f32).collect();
405        let targets: Vec<f32> = (0..size).map(|i| (i % 2) as f32).collect();
406
407        let x = Tensor::from_vec(data, &[size, 2]).unwrap();
408        let y = Tensor::from_vec(targets, &[size]).unwrap();
409
410        TensorDataset::new(x, y)
411    }
412
413    #[test]
414    fn test_dataloader_basic() {
415        let dataset = create_test_dataset(10);
416        let loader = DataLoader::new(dataset, 3);
417
418        assert_eq!(loader.batch_size(), 3);
419        assert_eq!(loader.len(), 4); // ceil(10/3) = 4
420
421        let batches: Vec<Batch> = loader.iter().collect();
422        assert_eq!(batches.len(), 4);
423
424        // First 3 batches have size 3, last has size 1
425        assert_eq!(batches[0].len(), 3);
426        assert_eq!(batches[1].len(), 3);
427        assert_eq!(batches[2].len(), 3);
428        assert_eq!(batches[3].len(), 1);
429    }
430
431    #[test]
432    fn test_dataloader_drop_last() {
433        let dataset = create_test_dataset(10);
434        let loader = DataLoader::new(dataset, 3).drop_last(true);
435
436        assert_eq!(loader.len(), 3); // floor(10/3) = 3
437
438        let batches: Vec<Batch> = loader.iter().collect();
439        assert_eq!(batches.len(), 3);
440
441        // All batches have full size
442        for batch in &batches {
443            assert_eq!(batch.len(), 3);
444        }
445    }
446
447    #[test]
448    fn test_dataloader_shuffle() {
449        let dataset = create_test_dataset(100);
450        let loader = DataLoader::new(dataset, 10).shuffle(true);
451
452        // Run multiple iterations and collect first batch data
453        let batch1: Vec<Batch> = loader.iter().take(1).collect();
454        let batch2: Vec<Batch> = loader.iter().take(1).collect();
455
456        // Due to shuffling, batches should (usually) be different
457        // We can't guarantee this, but the loader should work
458        assert!(!batch1.is_empty());
459        assert!(!batch2.is_empty());
460    }
461
462    #[test]
463    fn test_dataloader_exact_batches() {
464        let dataset = create_test_dataset(9);
465        let loader = DataLoader::new(dataset, 3);
466
467        let batches: Vec<Batch> = loader.iter().collect();
468        assert_eq!(batches.len(), 3);
469
470        for batch in &batches {
471            assert_eq!(batch.len(), 3);
472        }
473    }
474
475    #[test]
476    fn test_batch_struct() {
477        let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
478        let targets = Tensor::from_vec(vec![0.0, 1.0], &[2]).unwrap();
479
480        let batch = Batch::new(data, targets);
481        assert_eq!(batch.len(), 2);
482        assert!(!batch.is_empty());
483    }
484
485    #[test]
486    fn test_dataloader_empty() {
487        let x = Tensor::from_vec(vec![], &[0, 2]).unwrap();
488        let y = Tensor::from_vec(vec![], &[0]).unwrap();
489        let dataset = TensorDataset::new(x, y);
490        let loader = DataLoader::new(dataset, 3);
491
492        assert!(loader.is_empty());
493        let batches: Vec<Batch> = loader.iter().collect();
494        assert!(batches.is_empty());
495    }
496
497    #[test]
498    fn test_dataloader_single_item() {
499        let dataset = create_test_dataset(1);
500        let loader = DataLoader::new(dataset, 3);
501
502        let batches: Vec<Batch> = loader.iter().collect();
503        assert_eq!(batches.len(), 1);
504        assert_eq!(batches[0].len(), 1);
505    }
506
507    #[test]
508    fn test_dataloader_iteration_order() {
509        let dataset = create_test_dataset(6);
510        let loader = DataLoader::new(dataset, 2).shuffle(false);
511
512        let batches: Vec<Batch> = loader.iter().collect();
513
514        // Without shuffle, data should be in order
515        assert_eq!(batches[0].data.to_vec(), vec![0.0, 1.0, 2.0, 3.0]);
516        assert_eq!(batches[1].data.to_vec(), vec![4.0, 5.0, 6.0, 7.0]);
517        assert_eq!(batches[2].data.to_vec(), vec![8.0, 9.0, 10.0, 11.0]);
518    }
519
520    #[test]
521    fn test_generic_dataloader() {
522        let dataset = create_test_dataset(6);
523        let collate = DefaultCollate::new();
524        let loader = GenericDataLoader::new(dataset, collate, 2);
525
526        let batches: Vec<_> = loader.iter().collect();
527        assert_eq!(batches.len(), 3);
528    }
529
530    #[test]
531    fn test_dataloader_remaining() {
532        let dataset = create_test_dataset(10);
533        let loader = DataLoader::new(dataset, 3);
534
535        let mut iter = loader.iter();
536        assert_eq!(iter.remaining(), 4);
537
538        iter.next();
539        assert_eq!(iter.remaining(), 3);
540
541        iter.next();
542        assert_eq!(iter.remaining(), 2);
543    }
544
545    #[test]
546    fn test_parallel_dataloader() {
547        let dataset = create_test_dataset(100);
548        let loader = DataLoader::new(dataset, 10).num_workers(4);
549
550        let batches: Vec<Batch> = loader.iter().collect();
551        assert_eq!(batches.len(), 10);
552
553        // Verify all samples are present
554        let total_samples: usize = batches.iter().map(|b| b.len()).sum();
555        assert_eq!(total_samples, 100);
556    }
557
558    #[test]
559    fn test_parallel_vs_sequential_equivalence() {
560        // Create two identical datasets
561        let dataset_seq = create_test_dataset(50);
562        let dataset_par = create_test_dataset(50);
563
564        // Sequential
565        let loader_seq = DataLoader::new(dataset_seq, 5).num_workers(0);
566        let batches_seq: Vec<Batch> = loader_seq.iter().collect();
567
568        // Parallel
569        let loader_par = DataLoader::new(dataset_par, 5).num_workers(4);
570        let batches_par: Vec<Batch> = loader_par.iter().collect();
571
572        // Same number of batches
573        assert_eq!(batches_seq.len(), batches_par.len());
574
575        // Same data (no shuffle, so order should be same)
576        for i in 0..batches_seq.len() {
577            assert_eq!(batches_seq[i].data.to_vec(), batches_par[i].data.to_vec());
578            assert_eq!(batches_seq[i].targets.to_vec(), batches_par[i].targets.to_vec());
579        }
580    }
581
582    #[test]
583    fn test_parallel_dataloader_drop_last() {
584        let dataset = create_test_dataset(95);
585        let loader = DataLoader::new(dataset, 10)
586            .drop_last(true)
587            .num_workers(4);
588
589        let batches: Vec<Batch> = loader.iter().collect();
590        assert_eq!(batches.len(), 9); // 95 / 10 = 9 full batches
591
592        for batch in &batches {
593            assert_eq!(batch.len(), 10);
594        }
595    }
596
597    #[test]
598    fn test_parallel_generic_dataloader() {
599        let dataset = create_test_dataset(60);
600        let collate = DefaultCollate::new();
601        let loader = GenericDataLoader::new(dataset, collate, 10).num_workers(4);
602
603        let batches: Vec<_> = loader.iter().collect();
604        assert_eq!(batches.len(), 6);
605    }
606}