axonml_data/
dataloader.rs

1//! `DataLoader` - Batched Data Iteration
2//!
3//! Provides efficient batched iteration over datasets with optional
4//! shuffling and parallel data loading.
5//!
6//! @version 0.1.0
7//! @author `AutomataNexus` Development Team
8
9use crate::collate::{stack_tensors, Collate};
10use crate::dataset::Dataset;
11use crate::sampler::{RandomSampler, Sampler, SequentialSampler};
12use axonml_tensor::Tensor;
13use std::marker::PhantomData;
14
15// =============================================================================
16// Batch Type
17// =============================================================================
18
19/// A batch of data from the `DataLoader`.
20#[derive(Debug, Clone)]
21pub struct Batch {
22    /// Batched input data.
23    pub data: Tensor<f32>,
24    /// Batched targets.
25    pub targets: Tensor<f32>,
26    /// Number of samples in this batch.
27    pub size: usize,
28}
29
30impl Batch {
31    /// Creates a new Batch.
32    #[must_use] pub fn new(data: Tensor<f32>, targets: Tensor<f32>) -> Self {
33        let size = data.shape()[0];
34        Self {
35            data,
36            targets,
37            size,
38        }
39    }
40
41    /// Returns the batch size.
42    #[must_use] pub fn len(&self) -> usize {
43        self.size
44    }
45
46    /// Returns true if the batch is empty.
47    #[must_use] pub fn is_empty(&self) -> bool {
48        self.size == 0
49    }
50}
51
52// =============================================================================
53// DataLoader
54// =============================================================================
55
56/// `DataLoader` for batched iteration over datasets.
57///
58/// Provides configurable batching, shuffling, and iteration over datasets.
59pub struct DataLoader<D>
60where
61    D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
62{
63    /// The underlying dataset.
64    dataset: D,
65    /// Batch size.
66    batch_size: usize,
67    /// Whether to shuffle data each epoch.
68    shuffle: bool,
69    /// Whether to drop the last incomplete batch.
70    drop_last: bool,
71    /// Number of worker threads (for future parallel loading).
72    num_workers: usize,
73}
74
75impl<D> DataLoader<D>
76where
77    D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
78{
79    /// Creates a new `DataLoader` with the specified batch size.
80    pub fn new(dataset: D, batch_size: usize) -> Self {
81        Self {
82            dataset,
83            batch_size,
84            shuffle: false,
85            drop_last: false,
86            num_workers: 0,
87        }
88    }
89
90    /// Enables or disables shuffling.
91    pub fn shuffle(mut self, shuffle: bool) -> Self {
92        self.shuffle = shuffle;
93        self
94    }
95
96    /// Sets whether to drop the last incomplete batch.
97    pub fn drop_last(mut self, drop_last: bool) -> Self {
98        self.drop_last = drop_last;
99        self
100    }
101
102    /// Sets the number of worker threads for parallel data loading.
103    pub fn num_workers(mut self, num_workers: usize) -> Self {
104        self.num_workers = num_workers;
105        self
106    }
107
108    /// Returns the batch size.
109    pub fn batch_size(&self) -> usize {
110        self.batch_size
111    }
112
113    /// Returns the number of batches.
114    pub fn len(&self) -> usize {
115        let total = self.dataset.len();
116        if self.drop_last {
117            total / self.batch_size
118        } else {
119            total.div_ceil(self.batch_size)
120        }
121    }
122
123    /// Returns true if the `DataLoader` is empty.
124    pub fn is_empty(&self) -> bool {
125        self.dataset.is_empty()
126    }
127
128    /// Returns the dataset length.
129    pub fn dataset_len(&self) -> usize {
130        self.dataset.len()
131    }
132
133    /// Creates an iterator over batches.
134    pub fn iter(&self) -> DataLoaderIter<'_, D> {
135        let indices: Vec<usize> = if self.shuffle {
136            let sampler = RandomSampler::new(self.dataset.len());
137            sampler.iter().collect()
138        } else {
139            let sampler = SequentialSampler::new(self.dataset.len());
140            sampler.iter().collect()
141        };
142
143        DataLoaderIter {
144            dataset: &self.dataset,
145            indices,
146            batch_size: self.batch_size,
147            drop_last: self.drop_last,
148            position: 0,
149        }
150    }
151}
152
153// =============================================================================
154// DataLoaderIter
155// =============================================================================
156
157/// Iterator over batches from a `DataLoader`.
158pub struct DataLoaderIter<'a, D>
159where
160    D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
161{
162    dataset: &'a D,
163    indices: Vec<usize>,
164    batch_size: usize,
165    drop_last: bool,
166    position: usize,
167}
168
169impl<D> Iterator for DataLoaderIter<'_, D>
170where
171    D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
172{
173    type Item = Batch;
174
175    fn next(&mut self) -> Option<Self::Item> {
176        if self.position >= self.indices.len() {
177            return None;
178        }
179
180        let end = (self.position + self.batch_size).min(self.indices.len());
181        let batch_indices = &self.indices[self.position..end];
182
183        // Check if this is an incomplete batch
184        if batch_indices.len() < self.batch_size && self.drop_last {
185            return None;
186        }
187
188        // Collect samples for this batch
189        let mut data_samples = Vec::with_capacity(batch_indices.len());
190        let mut target_samples = Vec::with_capacity(batch_indices.len());
191
192        for &idx in batch_indices {
193            if let Some((x, y)) = self.dataset.get(idx) {
194                data_samples.push(x);
195                target_samples.push(y);
196            }
197        }
198
199        if data_samples.is_empty() {
200            return None;
201        }
202
203        // Stack samples into batches
204        let data = stack_tensors(&data_samples);
205        let targets = stack_tensors(&target_samples);
206
207        self.position = end;
208
209        Some(Batch::new(data, targets))
210    }
211}
212
213impl<D> DataLoaderIter<'_, D>
214where
215    D: Dataset<Item = (Tensor<f32>, Tensor<f32>)>,
216{
217    /// Returns the number of remaining batches.
218    #[must_use] pub fn remaining(&self) -> usize {
219        let remaining_samples = self.indices.len().saturating_sub(self.position);
220        if self.drop_last {
221            remaining_samples / self.batch_size
222        } else {
223            remaining_samples.div_ceil(self.batch_size)
224        }
225    }
226}
227
228// =============================================================================
229// GenericDataLoader
230// =============================================================================
231
232/// A more flexible `DataLoader` that works with any Dataset and Collate function.
233pub struct GenericDataLoader<D, C, T>
234where
235    D: Dataset<Item = T>,
236    C: Collate<T>,
237    T: Send,
238{
239    dataset: D,
240    collate_fn: C,
241    batch_size: usize,
242    shuffle: bool,
243    drop_last: bool,
244    _phantom: PhantomData<T>,
245}
246
247impl<D, C, T> GenericDataLoader<D, C, T>
248where
249    D: Dataset<Item = T>,
250    C: Collate<T>,
251    T: Send,
252{
253    /// Creates a new `GenericDataLoader`.
254    pub fn new(dataset: D, collate_fn: C, batch_size: usize) -> Self {
255        Self {
256            dataset,
257            collate_fn,
258            batch_size,
259            shuffle: false,
260            drop_last: false,
261            _phantom: PhantomData,
262        }
263    }
264
265    /// Enables or disables shuffling.
266    pub fn shuffle(mut self, shuffle: bool) -> Self {
267        self.shuffle = shuffle;
268        self
269    }
270
271    /// Sets whether to drop the last incomplete batch.
272    pub fn drop_last(mut self, drop_last: bool) -> Self {
273        self.drop_last = drop_last;
274        self
275    }
276
277    /// Returns the number of batches.
278    pub fn len(&self) -> usize {
279        let total = self.dataset.len();
280        if self.drop_last {
281            total / self.batch_size
282        } else {
283            total.div_ceil(self.batch_size)
284        }
285    }
286
287    /// Returns true if empty.
288    pub fn is_empty(&self) -> bool {
289        self.dataset.is_empty()
290    }
291
292    /// Creates an iterator over batches.
293    pub fn iter(&self) -> GenericDataLoaderIter<'_, D, C, T> {
294        let indices: Vec<usize> = if self.shuffle {
295            let sampler = RandomSampler::new(self.dataset.len());
296            sampler.iter().collect()
297        } else {
298            (0..self.dataset.len()).collect()
299        };
300
301        GenericDataLoaderIter {
302            dataset: &self.dataset,
303            collate_fn: &self.collate_fn,
304            indices,
305            batch_size: self.batch_size,
306            drop_last: self.drop_last,
307            position: 0,
308            _phantom: PhantomData,
309        }
310    }
311}
312
313/// Iterator for `GenericDataLoader`.
314pub struct GenericDataLoaderIter<'a, D, C, T>
315where
316    D: Dataset<Item = T>,
317    C: Collate<T>,
318    T: Send,
319{
320    dataset: &'a D,
321    collate_fn: &'a C,
322    indices: Vec<usize>,
323    batch_size: usize,
324    drop_last: bool,
325    position: usize,
326    _phantom: PhantomData<T>,
327}
328
329impl<D, C, T> Iterator for GenericDataLoaderIter<'_, D, C, T>
330where
331    D: Dataset<Item = T>,
332    C: Collate<T>,
333    T: Send,
334{
335    type Item = C::Output;
336
337    fn next(&mut self) -> Option<Self::Item> {
338        if self.position >= self.indices.len() {
339            return None;
340        }
341
342        let end = (self.position + self.batch_size).min(self.indices.len());
343        let batch_indices = &self.indices[self.position..end];
344
345        if batch_indices.len() < self.batch_size && self.drop_last {
346            return None;
347        }
348
349        // Collect samples
350        let samples: Vec<T> = batch_indices
351            .iter()
352            .filter_map(|&idx| self.dataset.get(idx))
353            .collect();
354
355        if samples.is_empty() {
356            return None;
357        }
358
359        self.position = end;
360
361        Some(self.collate_fn.collate(samples))
362    }
363}
364
365// =============================================================================
366// Tests
367// =============================================================================
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372    use crate::collate::DefaultCollate;
373    use crate::dataset::TensorDataset;
374
375    fn create_test_dataset(size: usize) -> TensorDataset {
376        let data: Vec<f32> = (0..size * 2).map(|i| i as f32).collect();
377        let targets: Vec<f32> = (0..size).map(|i| (i % 2) as f32).collect();
378
379        let x = Tensor::from_vec(data, &[size, 2]).unwrap();
380        let y = Tensor::from_vec(targets, &[size]).unwrap();
381
382        TensorDataset::new(x, y)
383    }
384
385    #[test]
386    fn test_dataloader_basic() {
387        let dataset = create_test_dataset(10);
388        let loader = DataLoader::new(dataset, 3);
389
390        assert_eq!(loader.batch_size(), 3);
391        assert_eq!(loader.len(), 4); // ceil(10/3) = 4
392
393        let batches: Vec<Batch> = loader.iter().collect();
394        assert_eq!(batches.len(), 4);
395
396        // First 3 batches have size 3, last has size 1
397        assert_eq!(batches[0].len(), 3);
398        assert_eq!(batches[1].len(), 3);
399        assert_eq!(batches[2].len(), 3);
400        assert_eq!(batches[3].len(), 1);
401    }
402
403    #[test]
404    fn test_dataloader_drop_last() {
405        let dataset = create_test_dataset(10);
406        let loader = DataLoader::new(dataset, 3).drop_last(true);
407
408        assert_eq!(loader.len(), 3); // floor(10/3) = 3
409
410        let batches: Vec<Batch> = loader.iter().collect();
411        assert_eq!(batches.len(), 3);
412
413        // All batches have full size
414        for batch in &batches {
415            assert_eq!(batch.len(), 3);
416        }
417    }
418
419    #[test]
420    fn test_dataloader_shuffle() {
421        let dataset = create_test_dataset(100);
422        let loader = DataLoader::new(dataset, 10).shuffle(true);
423
424        // Run multiple iterations and collect first batch data
425        let batch1: Vec<Batch> = loader.iter().take(1).collect();
426        let batch2: Vec<Batch> = loader.iter().take(1).collect();
427
428        // Due to shuffling, batches should (usually) be different
429        // We can't guarantee this, but the loader should work
430        assert!(!batch1.is_empty());
431        assert!(!batch2.is_empty());
432    }
433
434    #[test]
435    fn test_dataloader_exact_batches() {
436        let dataset = create_test_dataset(9);
437        let loader = DataLoader::new(dataset, 3);
438
439        let batches: Vec<Batch> = loader.iter().collect();
440        assert_eq!(batches.len(), 3);
441
442        for batch in &batches {
443            assert_eq!(batch.len(), 3);
444        }
445    }
446
447    #[test]
448    fn test_batch_struct() {
449        let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
450        let targets = Tensor::from_vec(vec![0.0, 1.0], &[2]).unwrap();
451
452        let batch = Batch::new(data, targets);
453        assert_eq!(batch.len(), 2);
454        assert!(!batch.is_empty());
455    }
456
457    #[test]
458    fn test_dataloader_empty() {
459        let x = Tensor::from_vec(vec![], &[0, 2]).unwrap();
460        let y = Tensor::from_vec(vec![], &[0]).unwrap();
461        let dataset = TensorDataset::new(x, y);
462        let loader = DataLoader::new(dataset, 3);
463
464        assert!(loader.is_empty());
465        let batches: Vec<Batch> = loader.iter().collect();
466        assert!(batches.is_empty());
467    }
468
469    #[test]
470    fn test_dataloader_single_item() {
471        let dataset = create_test_dataset(1);
472        let loader = DataLoader::new(dataset, 3);
473
474        let batches: Vec<Batch> = loader.iter().collect();
475        assert_eq!(batches.len(), 1);
476        assert_eq!(batches[0].len(), 1);
477    }
478
479    #[test]
480    fn test_dataloader_iteration_order() {
481        let dataset = create_test_dataset(6);
482        let loader = DataLoader::new(dataset, 2).shuffle(false);
483
484        let batches: Vec<Batch> = loader.iter().collect();
485
486        // Without shuffle, data should be in order
487        assert_eq!(batches[0].data.to_vec(), vec![0.0, 1.0, 2.0, 3.0]);
488        assert_eq!(batches[1].data.to_vec(), vec![4.0, 5.0, 6.0, 7.0]);
489        assert_eq!(batches[2].data.to_vec(), vec![8.0, 9.0, 10.0, 11.0]);
490    }
491
492    #[test]
493    fn test_generic_dataloader() {
494        let dataset = create_test_dataset(6);
495        let collate = DefaultCollate::new();
496        let loader = GenericDataLoader::new(dataset, collate, 2);
497
498        let batches: Vec<_> = loader.iter().collect();
499        assert_eq!(batches.len(), 3);
500    }
501
502    #[test]
503    fn test_dataloader_remaining() {
504        let dataset = create_test_dataset(10);
505        let loader = DataLoader::new(dataset, 3);
506
507        let mut iter = loader.iter();
508        assert_eq!(iter.remaining(), 4);
509
510        iter.next();
511        assert_eq!(iter.remaining(), 3);
512
513        iter.next();
514        assert_eq!(iter.remaining(), 2);
515    }
516}