ghostflow_data/
dataloader.rs1use ghostflow_core::Tensor;
4use crate::dataset::Dataset;
5use crate::sampler::{SequentialSampler, RandomSampler};
6use rayon::prelude::*;
7
8pub struct DataLoader<D: Dataset> {
10 dataset: D,
11 batch_size: usize,
12 shuffle: bool,
13 drop_last: bool,
14 num_workers: usize,
15}
16
17impl<D: Dataset> DataLoader<D> {
18 pub fn new(dataset: D, batch_size: usize) -> Self {
19 DataLoader {
20 dataset,
21 batch_size,
22 shuffle: false,
23 drop_last: false,
24 num_workers: 0,
25 }
26 }
27
28 pub fn shuffle(mut self, shuffle: bool) -> Self {
29 self.shuffle = shuffle;
30 self
31 }
32
33 pub fn drop_last(mut self, drop_last: bool) -> Self {
34 self.drop_last = drop_last;
35 self
36 }
37
38 pub fn num_workers(mut self, num_workers: usize) -> Self {
39 self.num_workers = num_workers;
40 self
41 }
42
43 pub fn len(&self) -> usize {
45 let n = self.dataset.len();
46 if self.drop_last {
47 n / self.batch_size
48 } else {
49 (n + self.batch_size - 1) / self.batch_size
50 }
51 }
52
53 pub fn is_empty(&self) -> bool {
54 self.len() == 0
55 }
56
57 pub fn iter(&self) -> DataLoaderIter<'_, D> {
59 let indices: Vec<usize> = if self.shuffle {
60 RandomSampler::new(self.dataset.len()).collect()
61 } else {
62 SequentialSampler::new(self.dataset.len()).collect()
63 };
64
65 DataLoaderIter {
66 loader: self,
67 indices,
68 current_batch: 0,
69 }
70 }
71}
72
73pub struct DataLoaderIter<'a, D: Dataset> {
75 loader: &'a DataLoader<D>,
76 indices: Vec<usize>,
77 current_batch: usize,
78}
79
80impl<'a, D: Dataset> Iterator for DataLoaderIter<'a, D> {
81 type Item = (Tensor, Tensor);
82
83 fn next(&mut self) -> Option<Self::Item> {
84 let start = self.current_batch * self.loader.batch_size;
85
86 if start >= self.indices.len() {
87 return None;
88 }
89
90 let end = (start + self.loader.batch_size).min(self.indices.len());
91
92 if self.loader.drop_last && end - start < self.loader.batch_size {
93 return None;
94 }
95
96 let batch_indices = &self.indices[start..end];
97 self.current_batch += 1;
98
99 let samples: Vec<(Tensor, Tensor)> = if self.loader.num_workers > 0 {
101 batch_indices
102 .par_iter()
103 .map(|&idx| self.loader.dataset.get(idx))
104 .collect()
105 } else {
106 batch_indices
107 .iter()
108 .map(|&idx| self.loader.dataset.get(idx))
109 .collect()
110 };
111
112 Some(collate_batch(samples))
114 }
115}
116
117fn collate_batch(samples: Vec<(Tensor, Tensor)>) -> (Tensor, Tensor) {
119 let batch_size = samples.len();
120
121 if batch_size == 0 {
122 return (Tensor::zeros(&[0]), Tensor::zeros(&[0]));
123 }
124
125 let data_shape = samples[0].0.dims().to_vec();
127 let target_shape = samples[0].1.dims().to_vec();
128 let first_data_numel = samples[0].0.numel();
129 let first_target_numel = samples[0].1.numel();
130
131 let mut data_vec: Vec<f32> = Vec::with_capacity(batch_size * first_data_numel);
133 let mut target_vec: Vec<f32> = Vec::with_capacity(batch_size * first_target_numel);
134
135 for (data, target) in samples {
136 data_vec.extend(data.data_f32());
137 target_vec.extend(target.data_f32());
138 }
139
140 let mut batch_data_shape = vec![batch_size];
142 batch_data_shape.extend(&data_shape);
143
144 let mut batch_target_shape = vec![batch_size];
145 batch_target_shape.extend(&target_shape);
146
147 (
148 Tensor::from_slice(&data_vec, &batch_data_shape).unwrap(),
149 Tensor::from_slice(&target_vec, &batch_target_shape).unwrap(),
150 )
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156 use crate::dataset::TensorDataset;
157
158 #[test]
159 fn test_dataloader() {
160 let data = Tensor::randn(&[100, 10]);
161 let targets = Tensor::randn(&[100, 1]);
162 let dataset = TensorDataset::new(data, targets);
163
164 let loader = DataLoader::new(dataset, 16);
165
166 let mut count = 0;
167 for (batch_data, _batch_target) in loader.iter() {
168 assert!(batch_data.dims()[0] <= 16);
169 count += 1;
170 }
171
172 assert_eq!(count, 7); }
174
175 #[test]
176 fn test_dataloader_shuffle() {
177 let data = Tensor::arange(0.0, 10.0, 1.0).reshape(&[10, 1]).unwrap();
178 let targets = Tensor::zeros(&[10, 1]);
179 let dataset = TensorDataset::new(data, targets);
180
181 let loader = DataLoader::new(dataset, 5).shuffle(true);
182
183 for _ in loader.iter() {}
185 }
186}