shrew_data/
async_loader.rs1use std::collections::HashMap;
26use std::sync::mpsc;
27use std::sync::{Arc, Mutex};
28use std::thread;
29
30use rand::rngs::StdRng;
31use rand::seq::SliceRandom;
32use rand::{thread_rng, SeedableRng};
33
34use shrew_core::backend::Backend;
35use shrew_core::error::Error;
36use shrew_core::tensor::Tensor;
37use shrew_core::DType;
38
39use crate::dataset::{Dataset, Sample};
40use crate::transform::Transform;
41
42#[derive(Debug, Clone)]
46pub struct AsyncDataLoaderConfig {
47 pub batch_size: usize,
49 pub shuffle: bool,
51 pub drop_last: bool,
53 pub dtype: DType,
55 pub num_workers: usize,
59 pub prefetch_factor: usize,
62 pub seed: Option<u64>,
64}
65
66impl Default for AsyncDataLoaderConfig {
67 fn default() -> Self {
68 Self {
69 batch_size: 32,
70 shuffle: true,
71 drop_last: false,
72 dtype: DType::F32,
73 num_workers: 2,
74 prefetch_factor: 2,
75 seed: None,
76 }
77 }
78}
79
80impl AsyncDataLoaderConfig {
81 pub fn batch_size(mut self, bs: usize) -> Self {
82 self.batch_size = bs;
83 self
84 }
85 pub fn shuffle(mut self, s: bool) -> Self {
86 self.shuffle = s;
87 self
88 }
89 pub fn drop_last(mut self, d: bool) -> Self {
90 self.drop_last = d;
91 self
92 }
93 pub fn dtype(mut self, d: DType) -> Self {
94 self.dtype = d;
95 self
96 }
97 pub fn num_workers(mut self, n: usize) -> Self {
98 self.num_workers = n;
99 self
100 }
101 pub fn prefetch_factor(mut self, pf: usize) -> Self {
102 self.prefetch_factor = pf;
103 self
104 }
105 pub fn seed(mut self, s: u64) -> Self {
106 self.seed = Some(s);
107 self
108 }
109}
110
111pub type Batch<B> = HashMap<String, Tensor<B>>;
115
116pub struct AsyncDataLoader<B: Backend> {
131 dataset: Arc<dyn Dataset>,
132 config: AsyncDataLoaderConfig,
133 transforms: Vec<Arc<dyn Transform>>,
134 device: B::Device,
135 indices: Vec<usize>,
136}
137
138impl<B: Backend> AsyncDataLoader<B>
139where
140 B::Device: Clone + Send + Sync + 'static,
141{
142 pub fn new(
144 dataset: Arc<dyn Dataset>,
145 device: B::Device,
146 config: AsyncDataLoaderConfig,
147 ) -> Self {
148 let n = dataset.len();
149 let indices: Vec<usize> = (0..n).collect();
150 Self {
151 dataset,
152 config,
153 transforms: Vec::new(),
154 device,
155 indices,
156 }
157 }
158
159 pub fn with_transform(mut self, t: Arc<dyn Transform>) -> Self {
161 self.transforms.push(t);
162 self
163 }
164
165 pub fn num_batches(&self) -> usize {
167 if self.config.drop_last {
168 self.dataset.len() / self.config.batch_size
169 } else {
170 self.dataset.len().div_ceil(self.config.batch_size)
171 }
172 }
173
174 pub fn reshuffle(&mut self) {
176 if self.config.shuffle {
177 match self.config.seed {
178 Some(seed) => {
179 let mut rng = StdRng::seed_from_u64(seed);
180 self.indices.shuffle(&mut rng);
181 }
182 None => {
183 let mut rng = thread_rng();
184 self.indices.shuffle(&mut rng);
185 }
186 }
187 }
188 }
189
190 #[allow(clippy::type_complexity)]
197 pub fn iter_epoch(&mut self, input_name: &str, target_name: &str) -> PrefetchIterator<B> {
198 self.reshuffle();
199
200 let bs = self.config.batch_size;
201 let n = self.dataset.len();
202 let num_batches = self.num_batches();
203 let workers = self.config.num_workers.max(1);
204 let capacity = self.config.prefetch_factor * workers;
205
206 let mut batch_ranges: Vec<Vec<usize>> = Vec::with_capacity(num_batches);
208 for b in 0..num_batches {
209 let start = b * bs;
210 let end = (start + bs).min(n);
211 batch_ranges.push(self.indices[start..end].to_vec());
212 }
213
214 let (tx, rx) = mpsc::sync_channel::<Result<Batch<B>, Error>>(capacity);
215
216 let work_queue: Arc<Mutex<std::vec::IntoIter<(usize, Vec<usize>)>>> = Arc::new(Mutex::new(
218 batch_ranges
219 .into_iter()
220 .enumerate()
221 .collect::<Vec<_>>()
222 .into_iter(),
223 ));
224
225 let dtype = self.config.dtype;
227 let device = self.device.clone();
228 let transforms = self.transforms.clone();
229 let input_name = input_name.to_string();
230 let target_name = target_name.to_string();
231 let dataset = self.dataset.clone();
232
233 let mut handles = Vec::with_capacity(workers);
234 for _ in 0..workers {
235 let wq = work_queue.clone();
236 let tx = tx.clone();
237 let dev = device.clone();
238 let tfs = transforms.clone();
239 let in_name = input_name.clone();
240 let tgt_name = target_name.clone();
241 let ds = dataset.clone();
242
243 let handle = thread::spawn(move || {
244 let dataset: &dyn Dataset = &*ds;
245
246 loop {
247 let item = {
249 let mut q = wq.lock().unwrap();
250 q.next()
251 };
252 let (_batch_idx, sample_indices) = match item {
253 Some(x) => x,
254 None => break, };
256
257 let samples: Vec<Sample> = sample_indices
259 .iter()
260 .map(|&i| {
261 let mut s = dataset.get(i);
262 for t in &tfs {
263 s = t.apply(s);
264 }
265 s
266 })
267 .collect();
268
269 let result = collate_batch::<B>(&samples, dtype, &dev, &in_name, &tgt_name);
271
272 if tx.send(result).is_err() {
274 break;
275 }
276 }
277 });
278 handles.push(handle);
279 }
280
281 drop(tx);
283
284 PrefetchIterator {
285 rx,
286 handles: Some(handles),
287 remaining: num_batches,
288 }
289 }
290}
291
292pub struct PrefetchIterator<B: Backend> {
298 rx: mpsc::Receiver<Result<Batch<B>, Error>>,
299 handles: Option<Vec<thread::JoinHandle<()>>>,
300 remaining: usize,
301}
302
303impl<B: Backend> Iterator for PrefetchIterator<B> {
304 type Item = Result<Batch<B>, Error>;
305
306 fn next(&mut self) -> Option<Self::Item> {
307 if self.remaining == 0 {
308 return None;
309 }
310 match self.rx.recv() {
311 Ok(batch) => {
312 self.remaining -= 1;
313 Some(batch)
314 }
315 Err(_) => {
316 self.remaining = 0;
318 None
319 }
320 }
321 }
322
323 fn size_hint(&self) -> (usize, Option<usize>) {
324 (self.remaining, Some(self.remaining))
325 }
326}
327
328impl<B: Backend> ExactSizeIterator for PrefetchIterator<B> {}
329
330impl<B: Backend> Drop for PrefetchIterator<B> {
331 fn drop(&mut self) {
332 while self.rx.try_recv().is_ok() {}
334 if let Some(handles) = self.handles.take() {
336 for h in handles {
337 let _ = h.join();
338 }
339 }
340 }
341}
342
343fn collate_batch<B: Backend>(
347 samples: &[Sample],
348 dtype: DType,
349 device: &B::Device,
350 input_name: &str,
351 target_name: &str,
352) -> Result<Batch<B>, Error> {
353 let batch_size = samples.len();
354 if batch_size == 0 {
355 return Ok(HashMap::new());
356 }
357
358 let feat_shape = &samples[0].feature_shape;
360 let feat_len: usize = feat_shape.iter().product();
361 let mut features = Vec::with_capacity(batch_size * feat_len);
362 for s in samples {
363 features.extend_from_slice(&s.features);
364 }
365
366 let tgt_shape = &samples[0].target_shape;
368 let tgt_len: usize = tgt_shape.iter().product();
369 let mut targets = Vec::with_capacity(batch_size * tgt_len);
370 for s in samples {
371 targets.extend_from_slice(&s.target);
372 }
373
374 let mut f_shape = vec![batch_size];
376 f_shape.extend_from_slice(feat_shape);
377 let mut t_shape = vec![batch_size];
378 t_shape.extend_from_slice(tgt_shape);
379
380 let input_tensor = Tensor::<B>::from_f64_slice(&features, f_shape, dtype, device)?;
381 let target_tensor = Tensor::<B>::from_f64_slice(&targets, t_shape, dtype, device)?;
382
383 let mut batch = HashMap::with_capacity(2);
384 batch.insert(input_name.to_string(), input_tensor);
385 batch.insert(target_name.to_string(), target_tensor);
386 Ok(batch)
387}