ghostflow_data/
dataloader.rs

1//! DataLoader for batching and iterating over datasets
2
3use ghostflow_core::Tensor;
4use crate::dataset::Dataset;
5use crate::sampler::{SequentialSampler, RandomSampler};
6use rayon::prelude::*;
7
8/// DataLoader for efficient batch loading
9pub struct DataLoader<D: Dataset> {
10    dataset: D,
11    batch_size: usize,
12    shuffle: bool,
13    drop_last: bool,
14    num_workers: usize,
15}
16
17impl<D: Dataset> DataLoader<D> {
18    pub fn new(dataset: D, batch_size: usize) -> Self {
19        DataLoader {
20            dataset,
21            batch_size,
22            shuffle: false,
23            drop_last: false,
24            num_workers: 0,
25        }
26    }
27
28    pub fn shuffle(mut self, shuffle: bool) -> Self {
29        self.shuffle = shuffle;
30        self
31    }
32
33    pub fn drop_last(mut self, drop_last: bool) -> Self {
34        self.drop_last = drop_last;
35        self
36    }
37
38    pub fn num_workers(mut self, num_workers: usize) -> Self {
39        self.num_workers = num_workers;
40        self
41    }
42
43    /// Get number of batches
44    pub fn len(&self) -> usize {
45        let n = self.dataset.len();
46        if self.drop_last {
47            n / self.batch_size
48        } else {
49            (n + self.batch_size - 1) / self.batch_size
50        }
51    }
52
53    pub fn is_empty(&self) -> bool {
54        self.len() == 0
55    }
56
57    /// Create an iterator over batches
58    pub fn iter(&self) -> DataLoaderIter<'_, D> {
59        let indices: Vec<usize> = if self.shuffle {
60            RandomSampler::new(self.dataset.len()).collect()
61        } else {
62            SequentialSampler::new(self.dataset.len()).collect()
63        };
64
65        DataLoaderIter {
66            loader: self,
67            indices,
68            current_batch: 0,
69        }
70    }
71}
72
73/// Iterator over DataLoader batches
74pub struct DataLoaderIter<'a, D: Dataset> {
75    loader: &'a DataLoader<D>,
76    indices: Vec<usize>,
77    current_batch: usize,
78}
79
80impl<'a, D: Dataset> Iterator for DataLoaderIter<'a, D> {
81    type Item = (Tensor, Tensor);
82
83    fn next(&mut self) -> Option<Self::Item> {
84        let start = self.current_batch * self.loader.batch_size;
85        
86        if start >= self.indices.len() {
87            return None;
88        }
89
90        let end = (start + self.loader.batch_size).min(self.indices.len());
91        
92        if self.loader.drop_last && end - start < self.loader.batch_size {
93            return None;
94        }
95
96        let batch_indices = &self.indices[start..end];
97        self.current_batch += 1;
98
99        // Collect batch samples
100        let samples: Vec<(Tensor, Tensor)> = if self.loader.num_workers > 0 {
101            batch_indices
102                .par_iter()
103                .map(|&idx| self.loader.dataset.get(idx))
104                .collect()
105        } else {
106            batch_indices
107                .iter()
108                .map(|&idx| self.loader.dataset.get(idx))
109                .collect()
110        };
111
112        // Stack into batch tensors
113        Some(collate_batch(samples))
114    }
115}
116
117/// Collate samples into a batch
118fn collate_batch(samples: Vec<(Tensor, Tensor)>) -> (Tensor, Tensor) {
119    let batch_size = samples.len();
120    
121    if batch_size == 0 {
122        return (Tensor::zeros(&[0]), Tensor::zeros(&[0]));
123    }
124
125    // Get shapes from first sample
126    let data_shape = samples[0].0.dims().to_vec();
127    let target_shape = samples[0].1.dims().to_vec();
128    let first_data_numel = samples[0].0.numel();
129    let first_target_numel = samples[0].1.numel();
130
131    // Collect all data
132    let mut data_vec: Vec<f32> = Vec::with_capacity(batch_size * first_data_numel);
133    let mut target_vec: Vec<f32> = Vec::with_capacity(batch_size * first_target_numel);
134
135    for (data, target) in samples {
136        data_vec.extend(data.data_f32());
137        target_vec.extend(target.data_f32());
138    }
139
140    // Create batch shapes
141    let mut batch_data_shape = vec![batch_size];
142    batch_data_shape.extend(&data_shape);
143
144    let mut batch_target_shape = vec![batch_size];
145    batch_target_shape.extend(&target_shape);
146
147    (
148        Tensor::from_slice(&data_vec, &batch_data_shape).unwrap(),
149        Tensor::from_slice(&target_vec, &batch_target_shape).unwrap(),
150    )
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156    use crate::dataset::TensorDataset;
157
158    #[test]
159    fn test_dataloader() {
160        let data = Tensor::randn(&[100, 10]);
161        let targets = Tensor::randn(&[100, 1]);
162        let dataset = TensorDataset::new(data, targets);
163        
164        let loader = DataLoader::new(dataset, 16);
165        
166        let mut count = 0;
167        for (batch_data, _batch_target) in loader.iter() {
168            assert!(batch_data.dims()[0] <= 16);
169            count += 1;
170        }
171        
172        assert_eq!(count, 7); // ceil(100/16) = 7
173    }
174
175    #[test]
176    fn test_dataloader_shuffle() {
177        let data = Tensor::arange(0.0, 10.0, 1.0).reshape(&[10, 1]).unwrap();
178        let targets = Tensor::zeros(&[10, 1]);
179        let dataset = TensorDataset::new(data, targets);
180        
181        let loader = DataLoader::new(dataset, 5).shuffle(true);
182        
183        // Just verify it runs without error
184        for _ in loader.iter() {}
185    }
186}