Skip to main content

axonml_data/
dataloader.rs

1//! `DataLoader` - Batched Data Iteration
2//!
3//! Provides efficient batched iteration over datasets with optional
4//! shuffling and parallel data loading.
5//!
6//! @version 0.1.0
7//! @author `AutomataNexus` Development Team
8
9use crate::collate::{stack_tensors, Collate};
10use crate::dataset::Dataset;
11use crate::sampler::{RandomSampler, Sampler, SequentialSampler};
12use axonml_tensor::Tensor;
13use rayon::prelude::*;
14use std::marker::PhantomData;
15
16// =============================================================================
17// Batch Type
18// =============================================================================
19
20/// A batch of data from the `DataLoader`.
21#[derive(Debug, Clone)]
22pub struct Batch {
23    /// Batched input data.
24    pub data: Tensor<f32>,
25    /// Batched targets.
26    pub targets: Tensor<f32>,
27    /// Number of samples in this batch.
28    pub size: usize,
29}
30
31impl Batch {
32    /// Creates a new Batch.
33    #[must_use]
34    pub fn new(data: Tensor<f32>, targets: Tensor<f32>) -> Self {
35        let size = data.shape()[0];
36        Self {
37            data,
38            targets,
39            size,
40        }
41    }
42
43    /// Returns the batch size.
44    #[must_use]
45    pub fn len(&self) -> usize {
46        self.size
47    }
48
49    /// Returns true if the batch is empty.
50    #[must_use]
51    pub fn is_empty(&self) -> bool {
52        self.size == 0
53    }
54}
55
56// =============================================================================
57// DataLoader
58// =============================================================================
59
60/// `DataLoader` for batched iteration over datasets.
61///
62/// Provides configurable batching, shuffling, and iteration over datasets.
63pub struct DataLoader<D>
64where
65    D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
66{
67    /// The underlying dataset.
68    dataset: D,
69    /// Batch size.
70    batch_size: usize,
71    /// Whether to shuffle data each epoch.
72    shuffle: bool,
73    /// Whether to drop the last incomplete batch.
74    drop_last: bool,
75    /// Number of worker threads (for future parallel loading).
76    num_workers: usize,
77}
78
79impl<D> DataLoader<D>
80where
81    D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
82{
83    /// Creates a new `DataLoader` with the specified batch size.
84    pub fn new(dataset: D, batch_size: usize) -> Self {
85        Self {
86            dataset,
87            batch_size,
88            shuffle: false,
89            drop_last: false,
90            num_workers: 0,
91        }
92    }
93
94    /// Enables or disables shuffling.
95    pub fn shuffle(mut self, shuffle: bool) -> Self {
96        self.shuffle = shuffle;
97        self
98    }
99
100    /// Sets whether to drop the last incomplete batch.
101    pub fn drop_last(mut self, drop_last: bool) -> Self {
102        self.drop_last = drop_last;
103        self
104    }
105
106    /// Sets the number of worker threads for parallel data loading.
107    pub fn num_workers(mut self, num_workers: usize) -> Self {
108        self.num_workers = num_workers;
109        self
110    }
111
112    /// Returns the batch size.
113    pub fn batch_size(&self) -> usize {
114        self.batch_size
115    }
116
117    /// Returns the number of batches.
118    pub fn len(&self) -> usize {
119        let total = self.dataset.len();
120        if self.drop_last {
121            total / self.batch_size
122        } else {
123            total.div_ceil(self.batch_size)
124        }
125    }
126
127    /// Returns true if the `DataLoader` is empty.
128    pub fn is_empty(&self) -> bool {
129        self.dataset.is_empty()
130    }
131
132    /// Returns the dataset length.
133    pub fn dataset_len(&self) -> usize {
134        self.dataset.len()
135    }
136
137    /// Creates an iterator over batches.
138    pub fn iter(&self) -> DataLoaderIter<'_, D> {
139        let indices: Vec<usize> = if self.shuffle {
140            let sampler = RandomSampler::new(self.dataset.len());
141            sampler.iter().collect()
142        } else {
143            let sampler = SequentialSampler::new(self.dataset.len());
144            sampler.iter().collect()
145        };
146
147        DataLoaderIter {
148            dataset: &self.dataset,
149            indices,
150            batch_size: self.batch_size,
151            drop_last: self.drop_last,
152            position: 0,
153            num_workers: self.num_workers,
154        }
155    }
156}
157
158// =============================================================================
159// DataLoaderIter
160// =============================================================================
161
162/// Iterator over batches from a `DataLoader`.
163pub struct DataLoaderIter<'a, D>
164where
165    D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
166{
167    dataset: &'a D,
168    indices: Vec<usize>,
169    batch_size: usize,
170    drop_last: bool,
171    position: usize,
172    num_workers: usize,
173}
174
175impl<D> Iterator for DataLoaderIter<'_, D>
176where
177    D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
178{
179    type Item = Batch;
180
181    fn next(&mut self) -> Option<Self::Item> {
182        if self.position >= self.indices.len() {
183            return None;
184        }
185
186        let end = (self.position + self.batch_size).min(self.indices.len());
187        let batch_indices = &self.indices[self.position..end];
188
189        // Check if this is an incomplete batch
190        if batch_indices.len() < self.batch_size && self.drop_last {
191            return None;
192        }
193
194        // Collect samples for this batch (parallel when num_workers > 0)
195        let samples: Vec<(Tensor<f32>, Tensor<f32>)> = if self.num_workers > 0 {
196            // Parallel sample collection using rayon
197            batch_indices
198                .par_iter()
199                .filter_map(|&idx| self.dataset.get(idx))
200                .collect()
201        } else {
202            // Sequential fallback for num_workers = 0
203            batch_indices
204                .iter()
205                .filter_map(|&idx| self.dataset.get(idx))
206                .collect()
207        };
208
209        if samples.is_empty() {
210            return None;
211        }
212
213        // Separate data and targets for stacking
214        let data_samples: Vec<Tensor<f32>> = samples.iter().map(|(x, _)| x.clone()).collect();
215        let target_samples: Vec<Tensor<f32>> = samples.iter().map(|(_, y)| y.clone()).collect();
216
217        // Stack samples into batches
218        let data = stack_tensors(&data_samples);
219        let targets = stack_tensors(&target_samples);
220
221        self.position = end;
222
223        Some(Batch::new(data, targets))
224    }
225}
226
227impl<D> DataLoaderIter<'_, D>
228where
229    D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
230{
231    /// Returns the number of remaining batches.
232    #[must_use]
233    pub fn remaining(&self) -> usize {
234        let remaining_samples = self.indices.len().saturating_sub(self.position);
235        if self.drop_last {
236            remaining_samples / self.batch_size
237        } else {
238            remaining_samples.div_ceil(self.batch_size)
239        }
240    }
241}
242
243// =============================================================================
244// GenericDataLoader
245// =============================================================================
246
247/// A more flexible `DataLoader` that works with any Dataset and Collate function.
248pub struct GenericDataLoader<D, C, T>
249where
250    D: Dataset<Item = T>,
251    C: Collate<T>,
252    T: Send,
253{
254    dataset: D,
255    collate_fn: C,
256    batch_size: usize,
257    shuffle: bool,
258    drop_last: bool,
259    num_workers: usize,
260    _phantom: PhantomData<T>,
261}
262
263impl<D, C, T> GenericDataLoader<D, C, T>
264where
265    D: Dataset<Item = T>,
266    C: Collate<T>,
267    T: Send,
268{
269    /// Creates a new `GenericDataLoader`.
270    pub fn new(dataset: D, collate_fn: C, batch_size: usize) -> Self {
271        Self {
272            dataset,
273            collate_fn,
274            batch_size,
275            shuffle: false,
276            drop_last: false,
277            num_workers: 0,
278            _phantom: PhantomData,
279        }
280    }
281
282    /// Sets the number of worker threads for parallel data loading.
283    pub fn num_workers(mut self, num_workers: usize) -> Self {
284        self.num_workers = num_workers;
285        self
286    }
287
288    /// Enables or disables shuffling.
289    pub fn shuffle(mut self, shuffle: bool) -> Self {
290        self.shuffle = shuffle;
291        self
292    }
293
294    /// Sets whether to drop the last incomplete batch.
295    pub fn drop_last(mut self, drop_last: bool) -> Self {
296        self.drop_last = drop_last;
297        self
298    }
299
300    /// Returns the number of batches.
301    pub fn len(&self) -> usize {
302        let total = self.dataset.len();
303        if self.drop_last {
304            total / self.batch_size
305        } else {
306            total.div_ceil(self.batch_size)
307        }
308    }
309
310    /// Returns true if empty.
311    pub fn is_empty(&self) -> bool {
312        self.dataset.is_empty()
313    }
314
315    /// Creates an iterator over batches.
316    pub fn iter(&self) -> GenericDataLoaderIter<'_, D, C, T> {
317        let indices: Vec<usize> = if self.shuffle {
318            let sampler = RandomSampler::new(self.dataset.len());
319            sampler.iter().collect()
320        } else {
321            (0..self.dataset.len()).collect()
322        };
323
324        GenericDataLoaderIter {
325            dataset: &self.dataset,
326            collate_fn: &self.collate_fn,
327            indices,
328            batch_size: self.batch_size,
329            drop_last: self.drop_last,
330            position: 0,
331            num_workers: self.num_workers,
332            _phantom: PhantomData,
333        }
334    }
335}
336
337/// Iterator for `GenericDataLoader`.
338pub struct GenericDataLoaderIter<'a, D, C, T>
339where
340    D: Dataset<Item = T>,
341    C: Collate<T>,
342    T: Send,
343{
344    dataset: &'a D,
345    collate_fn: &'a C,
346    indices: Vec<usize>,
347    batch_size: usize,
348    drop_last: bool,
349    position: usize,
350    num_workers: usize,
351    _phantom: PhantomData<T>,
352}
353
354impl<D, C, T> Iterator for GenericDataLoaderIter<'_, D, C, T>
355where
356    D: Dataset<Item = T>,
357    C: Collate<T>,
358    T: Send + Sync,
359{
360    type Item = C::Output;
361
362    fn next(&mut self) -> Option<Self::Item> {
363        if self.position >= self.indices.len() {
364            return None;
365        }
366
367        let end = (self.position + self.batch_size).min(self.indices.len());
368        let batch_indices = &self.indices[self.position..end];
369
370        if batch_indices.len() < self.batch_size && self.drop_last {
371            return None;
372        }
373
374        // Collect samples (parallel when num_workers > 0)
375        let samples: Vec<T> = if self.num_workers > 0 {
376            batch_indices
377                .par_iter()
378                .filter_map(|&idx| self.dataset.get(idx))
379                .collect()
380        } else {
381            batch_indices
382                .iter()
383                .filter_map(|&idx| self.dataset.get(idx))
384                .collect()
385        };
386
387        if samples.is_empty() {
388            return None;
389        }
390
391        self.position = end;
392
393        Some(self.collate_fn.collate(samples))
394    }
395}
396
397// =============================================================================
398// Tests
399// =============================================================================
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404    use crate::collate::DefaultCollate;
405    use crate::dataset::TensorDataset;
406
407    fn create_test_dataset(size: usize) -> TensorDataset {
408        let data: Vec<f32> = (0..size * 2).map(|i| i as f32).collect();
409        let targets: Vec<f32> = (0..size).map(|i| (i % 2) as f32).collect();
410
411        let x = Tensor::from_vec(data, &[size, 2]).unwrap();
412        let y = Tensor::from_vec(targets, &[size]).unwrap();
413
414        TensorDataset::new(x, y)
415    }
416
417    #[test]
418    fn test_dataloader_basic() {
419        let dataset = create_test_dataset(10);
420        let loader = DataLoader::new(dataset, 3);
421
422        assert_eq!(loader.batch_size(), 3);
423        assert_eq!(loader.len(), 4); // ceil(10/3) = 4
424
425        let batches: Vec<Batch> = loader.iter().collect();
426        assert_eq!(batches.len(), 4);
427
428        // First 3 batches have size 3, last has size 1
429        assert_eq!(batches[0].len(), 3);
430        assert_eq!(batches[1].len(), 3);
431        assert_eq!(batches[2].len(), 3);
432        assert_eq!(batches[3].len(), 1);
433    }
434
435    #[test]
436    fn test_dataloader_drop_last() {
437        let dataset = create_test_dataset(10);
438        let loader = DataLoader::new(dataset, 3).drop_last(true);
439
440        assert_eq!(loader.len(), 3); // floor(10/3) = 3
441
442        let batches: Vec<Batch> = loader.iter().collect();
443        assert_eq!(batches.len(), 3);
444
445        // All batches have full size
446        for batch in &batches {
447            assert_eq!(batch.len(), 3);
448        }
449    }
450
451    #[test]
452    fn test_dataloader_shuffle() {
453        let dataset = create_test_dataset(100);
454        let loader = DataLoader::new(dataset, 10).shuffle(true);
455
456        // Run multiple iterations and collect first batch data
457        let batch1: Vec<Batch> = loader.iter().take(1).collect();
458        let batch2: Vec<Batch> = loader.iter().take(1).collect();
459
460        // Due to shuffling, batches should (usually) be different
461        // We can't guarantee this, but the loader should work
462        assert!(!batch1.is_empty());
463        assert!(!batch2.is_empty());
464    }
465
466    #[test]
467    fn test_dataloader_exact_batches() {
468        let dataset = create_test_dataset(9);
469        let loader = DataLoader::new(dataset, 3);
470
471        let batches: Vec<Batch> = loader.iter().collect();
472        assert_eq!(batches.len(), 3);
473
474        for batch in &batches {
475            assert_eq!(batch.len(), 3);
476        }
477    }
478
479    #[test]
480    fn test_batch_struct() {
481        let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
482        let targets = Tensor::from_vec(vec![0.0, 1.0], &[2]).unwrap();
483
484        let batch = Batch::new(data, targets);
485        assert_eq!(batch.len(), 2);
486        assert!(!batch.is_empty());
487    }
488
489    #[test]
490    fn test_dataloader_empty() {
491        let x = Tensor::from_vec(vec![], &[0, 2]).unwrap();
492        let y = Tensor::from_vec(vec![], &[0]).unwrap();
493        let dataset = TensorDataset::new(x, y);
494        let loader = DataLoader::new(dataset, 3);
495
496        assert!(loader.is_empty());
497        let batches: Vec<Batch> = loader.iter().collect();
498        assert!(batches.is_empty());
499    }
500
501    #[test]
502    fn test_dataloader_single_item() {
503        let dataset = create_test_dataset(1);
504        let loader = DataLoader::new(dataset, 3);
505
506        let batches: Vec<Batch> = loader.iter().collect();
507        assert_eq!(batches.len(), 1);
508        assert_eq!(batches[0].len(), 1);
509    }
510
511    #[test]
512    fn test_dataloader_iteration_order() {
513        let dataset = create_test_dataset(6);
514        let loader = DataLoader::new(dataset, 2).shuffle(false);
515
516        let batches: Vec<Batch> = loader.iter().collect();
517
518        // Without shuffle, data should be in order
519        assert_eq!(batches[0].data.to_vec(), vec![0.0, 1.0, 2.0, 3.0]);
520        assert_eq!(batches[1].data.to_vec(), vec![4.0, 5.0, 6.0, 7.0]);
521        assert_eq!(batches[2].data.to_vec(), vec![8.0, 9.0, 10.0, 11.0]);
522    }
523
524    #[test]
525    fn test_generic_dataloader() {
526        let dataset = create_test_dataset(6);
527        let collate = DefaultCollate::new();
528        let loader = GenericDataLoader::new(dataset, collate, 2);
529
530        let batches: Vec<_> = loader.iter().collect();
531        assert_eq!(batches.len(), 3);
532    }
533
534    #[test]
535    fn test_dataloader_remaining() {
536        let dataset = create_test_dataset(10);
537        let loader = DataLoader::new(dataset, 3);
538
539        let mut iter = loader.iter();
540        assert_eq!(iter.remaining(), 4);
541
542        iter.next();
543        assert_eq!(iter.remaining(), 3);
544
545        iter.next();
546        assert_eq!(iter.remaining(), 2);
547    }
548
549    #[test]
550    fn test_parallel_dataloader() {
551        let dataset = create_test_dataset(100);
552        let loader = DataLoader::new(dataset, 10).num_workers(4);
553
554        let batches: Vec<Batch> = loader.iter().collect();
555        assert_eq!(batches.len(), 10);
556
557        // Verify all samples are present
558        let total_samples: usize = batches.iter().map(|b| b.len()).sum();
559        assert_eq!(total_samples, 100);
560    }
561
562    #[test]
563    fn test_parallel_vs_sequential_equivalence() {
564        // Create two identical datasets
565        let dataset_seq = create_test_dataset(50);
566        let dataset_par = create_test_dataset(50);
567
568        // Sequential
569        let loader_seq = DataLoader::new(dataset_seq, 5).num_workers(0);
570        let batches_seq: Vec<Batch> = loader_seq.iter().collect();
571
572        // Parallel
573        let loader_par = DataLoader::new(dataset_par, 5).num_workers(4);
574        let batches_par: Vec<Batch> = loader_par.iter().collect();
575
576        // Same number of batches
577        assert_eq!(batches_seq.len(), batches_par.len());
578
579        // Same data (no shuffle, so order should be same)
580        for i in 0..batches_seq.len() {
581            assert_eq!(batches_seq[i].data.to_vec(), batches_par[i].data.to_vec());
582            assert_eq!(
583                batches_seq[i].targets.to_vec(),
584                batches_par[i].targets.to_vec()
585            );
586        }
587    }
588
589    #[test]
590    fn test_parallel_dataloader_drop_last() {
591        let dataset = create_test_dataset(95);
592        let loader = DataLoader::new(dataset, 10).drop_last(true).num_workers(4);
593
594        let batches: Vec<Batch> = loader.iter().collect();
595        assert_eq!(batches.len(), 9); // 95 / 10 = 9 full batches
596
597        for batch in &batches {
598            assert_eq!(batch.len(), 10);
599        }
600    }
601
602    #[test]
603    fn test_parallel_generic_dataloader() {
604        let dataset = create_test_dataset(60);
605        let collate = DefaultCollate::new();
606        let loader = GenericDataLoader::new(dataset, collate, 10).num_workers(4);
607
608        let batches: Vec<_> = loader.iter().collect();
609        assert_eq!(batches.len(), 6);
610    }
611}