Skip to main content

shrew_data/
async_loader.rs

1// AsyncDataLoader — Prefetching data loader with background workers
2//
3// Spawns a pool of background threads that pre-load and transform batches ahead
4// of the consumer.  The consumer pulls ready batches from a channel, overlapping
5// data loading with GPU computation.
6//
7// Usage:
8//
9//   let loader = AsyncDataLoader::<CpuBackend>::new(
10//       &dataset,
11//       CpuDevice,
12//       AsyncDataLoaderConfig::default()
13//           .batch_size(64)
14//           .prefetch_factor(2)
15//           .num_workers(4),
16//   );
17//
18//   for epoch in 0..num_epochs {
19//       for batch in loader.iter_epoch("input", "target") {
20//           let batch = batch?;
21//           // train on batch ...
22//       }
23//   }
24
25use std::collections::HashMap;
26use std::sync::mpsc;
27use std::sync::{Arc, Mutex};
28use std::thread;
29
30use rand::rngs::StdRng;
31use rand::seq::SliceRandom;
32use rand::{thread_rng, SeedableRng};
33
34use shrew_core::backend::Backend;
35use shrew_core::error::Error;
36use shrew_core::tensor::Tensor;
37use shrew_core::DType;
38
39use crate::dataset::{Dataset, Sample};
40use crate::transform::Transform;
41
42// Configuration
43
44/// Configuration for the async prefetching data loader.
45#[derive(Debug, Clone)]
46pub struct AsyncDataLoaderConfig {
47    /// Number of samples per batch.
48    pub batch_size: usize,
49    /// Whether to shuffle indices each epoch.
50    pub shuffle: bool,
51    /// Whether to drop the last incomplete batch.
52    pub drop_last: bool,
53    /// DType for the created tensors.
54    pub dtype: DType,
55    /// Number of background workers (threads) for loading + transforming.
56    /// 0 = no background threads (falls back to sync loading, with prefetch
57    /// still happening on a single background thread).
58    pub num_workers: usize,
59    /// How many batches to pre-load ahead of the consumer.
60    /// Total buffered batches = prefetch_factor * max(num_workers, 1).
61    pub prefetch_factor: usize,
62    /// Optional random seed for reproducible shuffling.
63    pub seed: Option<u64>,
64}
65
66impl Default for AsyncDataLoaderConfig {
67    fn default() -> Self {
68        Self {
69            batch_size: 32,
70            shuffle: true,
71            drop_last: false,
72            dtype: DType::F32,
73            num_workers: 2,
74            prefetch_factor: 2,
75            seed: None,
76        }
77    }
78}
79
80impl AsyncDataLoaderConfig {
81    pub fn batch_size(mut self, bs: usize) -> Self {
82        self.batch_size = bs;
83        self
84    }
85    pub fn shuffle(mut self, s: bool) -> Self {
86        self.shuffle = s;
87        self
88    }
89    pub fn drop_last(mut self, d: bool) -> Self {
90        self.drop_last = d;
91        self
92    }
93    pub fn dtype(mut self, d: DType) -> Self {
94        self.dtype = d;
95        self
96    }
97    pub fn num_workers(mut self, n: usize) -> Self {
98        self.num_workers = n;
99        self
100    }
101    pub fn prefetch_factor(mut self, pf: usize) -> Self {
102        self.prefetch_factor = pf;
103        self
104    }
105    pub fn seed(mut self, s: u64) -> Self {
106        self.seed = Some(s);
107        self
108    }
109}
110
111// Batch type alias
112
113/// A single batch: maps names (e.g. "input", "target") to tensors.
114pub type Batch<B> = HashMap<String, Tensor<B>>;
115
116// AsyncDataLoader
117
118/// A data loader that prefetches batches on background threads.
119///
120/// On each call to [`iter_epoch`](AsyncDataLoader::iter_epoch), the loader:
121/// 1. Optionally reshuffles indices.
122/// 2. Spawns worker threads that load, transform, and collate batches.
123/// 3. Returns an iterator that pulls ready batches from a bounded channel.
124///
125/// The channel capacity is `prefetch_factor * max(num_workers, 1)`, so at most
126/// that many batches are materialised in memory at any time.
127///
128/// The dataset is held via `Arc<dyn Dataset>` so it can be safely shared with
129/// background worker threads.
130pub struct AsyncDataLoader<B: Backend> {
131    dataset: Arc<dyn Dataset>,
132    config: AsyncDataLoaderConfig,
133    transforms: Vec<Arc<dyn Transform>>,
134    device: B::Device,
135    indices: Vec<usize>,
136}
137
138impl<B: Backend> AsyncDataLoader<B>
139where
140    B::Device: Clone + Send + Sync + 'static,
141{
142    /// Create a new async data loader.
143    pub fn new(
144        dataset: Arc<dyn Dataset>,
145        device: B::Device,
146        config: AsyncDataLoaderConfig,
147    ) -> Self {
148        let n = dataset.len();
149        let indices: Vec<usize> = (0..n).collect();
150        Self {
151            dataset,
152            config,
153            transforms: Vec::new(),
154            device,
155            indices,
156        }
157    }
158
159    /// Add a transform.
160    pub fn with_transform(mut self, t: Arc<dyn Transform>) -> Self {
161        self.transforms.push(t);
162        self
163    }
164
165    /// Number of batches per epoch.
166    pub fn num_batches(&self) -> usize {
167        if self.config.drop_last {
168            self.dataset.len() / self.config.batch_size
169        } else {
170            self.dataset.len().div_ceil(self.config.batch_size)
171        }
172    }
173
174    /// Reshuffle indices.
175    pub fn reshuffle(&mut self) {
176        if self.config.shuffle {
177            match self.config.seed {
178                Some(seed) => {
179                    let mut rng = StdRng::seed_from_u64(seed);
180                    self.indices.shuffle(&mut rng);
181                }
182                None => {
183                    let mut rng = thread_rng();
184                    self.indices.shuffle(&mut rng);
185                }
186            }
187        }
188    }
189
190    /// Iterate over one epoch of prefetched batches.
191    ///
192    /// Spawns background workers that load batches into a bounded channel.
193    /// The returned iterator yields `Result<Batch<B>>` — one per batch.
194    ///
195    /// The background workers are joined when the iterator is dropped.
196    #[allow(clippy::type_complexity)]
197    pub fn iter_epoch(&mut self, input_name: &str, target_name: &str) -> PrefetchIterator<B> {
198        self.reshuffle();
199
200        let bs = self.config.batch_size;
201        let n = self.dataset.len();
202        let num_batches = self.num_batches();
203        let workers = self.config.num_workers.max(1);
204        let capacity = self.config.prefetch_factor * workers;
205
206        // Build the list of batch index ranges
207        let mut batch_ranges: Vec<Vec<usize>> = Vec::with_capacity(num_batches);
208        for b in 0..num_batches {
209            let start = b * bs;
210            let end = (start + bs).min(n);
211            batch_ranges.push(self.indices[start..end].to_vec());
212        }
213
214        let (tx, rx) = mpsc::sync_channel::<Result<Batch<B>, Error>>(capacity);
215
216        // Shared work queue: each worker pops a batch index to process
217        let work_queue: Arc<Mutex<std::vec::IntoIter<(usize, Vec<usize>)>>> = Arc::new(Mutex::new(
218            batch_ranges
219                .into_iter()
220                .enumerate()
221                .collect::<Vec<_>>()
222                .into_iter(),
223        ));
224
225        // Snapshot all the info workers need
226        let dtype = self.config.dtype;
227        let device = self.device.clone();
228        let transforms = self.transforms.clone();
229        let input_name = input_name.to_string();
230        let target_name = target_name.to_string();
231        let dataset = self.dataset.clone();
232
233        let mut handles = Vec::with_capacity(workers);
234        for _ in 0..workers {
235            let wq = work_queue.clone();
236            let tx = tx.clone();
237            let dev = device.clone();
238            let tfs = transforms.clone();
239            let in_name = input_name.clone();
240            let tgt_name = target_name.clone();
241            let ds = dataset.clone();
242
243            let handle = thread::spawn(move || {
244                let dataset: &dyn Dataset = &*ds;
245
246                loop {
247                    // Pop next batch from the shared queue
248                    let item = {
249                        let mut q = wq.lock().unwrap();
250                        q.next()
251                    };
252                    let (_batch_idx, sample_indices) = match item {
253                        Some(x) => x,
254                        None => break, // no more work
255                    };
256
257                    // Fetch and transform samples
258                    let samples: Vec<Sample> = sample_indices
259                        .iter()
260                        .map(|&i| {
261                            let mut s = dataset.get(i);
262                            for t in &tfs {
263                                s = t.apply(s);
264                            }
265                            s
266                        })
267                        .collect();
268
269                    // Collate into a batch of tensors
270                    let result = collate_batch::<B>(&samples, dtype, &dev, &in_name, &tgt_name);
271
272                    // Send to consumer — if receiver is dropped, stop
273                    if tx.send(result).is_err() {
274                        break;
275                    }
276                }
277            });
278            handles.push(handle);
279        }
280
281        // Drop the original sender so the channel closes when all workers finish
282        drop(tx);
283
284        PrefetchIterator {
285            rx,
286            handles: Some(handles),
287            remaining: num_batches,
288        }
289    }
290}
291
292// PrefetchIterator
293
294/// An iterator that yields prefetched batches from background workers.
295///
296/// Workers are joined when the iterator is fully consumed or dropped.
297pub struct PrefetchIterator<B: Backend> {
298    rx: mpsc::Receiver<Result<Batch<B>, Error>>,
299    handles: Option<Vec<thread::JoinHandle<()>>>,
300    remaining: usize,
301}
302
303impl<B: Backend> Iterator for PrefetchIterator<B> {
304    type Item = Result<Batch<B>, Error>;
305
306    fn next(&mut self) -> Option<Self::Item> {
307        if self.remaining == 0 {
308            return None;
309        }
310        match self.rx.recv() {
311            Ok(batch) => {
312                self.remaining -= 1;
313                Some(batch)
314            }
315            Err(_) => {
316                // Channel closed — workers done (possibly early)
317                self.remaining = 0;
318                None
319            }
320        }
321    }
322
323    fn size_hint(&self) -> (usize, Option<usize>) {
324        (self.remaining, Some(self.remaining))
325    }
326}
327
328impl<B: Backend> ExactSizeIterator for PrefetchIterator<B> {}
329
330impl<B: Backend> Drop for PrefetchIterator<B> {
331    fn drop(&mut self) {
332        // Drain the channel to unblock workers
333        while self.rx.try_recv().is_ok() {}
334        // Join all worker threads
335        if let Some(handles) = self.handles.take() {
336            for h in handles {
337                let _ = h.join();
338            }
339        }
340    }
341}
342
343// Collation helper
344
345/// Collate a slice of samples into a batch of tensors.
346fn collate_batch<B: Backend>(
347    samples: &[Sample],
348    dtype: DType,
349    device: &B::Device,
350    input_name: &str,
351    target_name: &str,
352) -> Result<Batch<B>, Error> {
353    let batch_size = samples.len();
354    if batch_size == 0 {
355        return Ok(HashMap::new());
356    }
357
358    // Flatten features
359    let feat_shape = &samples[0].feature_shape;
360    let feat_len: usize = feat_shape.iter().product();
361    let mut features = Vec::with_capacity(batch_size * feat_len);
362    for s in samples {
363        features.extend_from_slice(&s.features);
364    }
365
366    // Flatten targets
367    let tgt_shape = &samples[0].target_shape;
368    let tgt_len: usize = tgt_shape.iter().product();
369    let mut targets = Vec::with_capacity(batch_size * tgt_len);
370    for s in samples {
371        targets.extend_from_slice(&s.target);
372    }
373
374    // Build shapes: [batch_size, ...feature_shape] and [batch_size, ...target_shape]
375    let mut f_shape = vec![batch_size];
376    f_shape.extend_from_slice(feat_shape);
377    let mut t_shape = vec![batch_size];
378    t_shape.extend_from_slice(tgt_shape);
379
380    let input_tensor = Tensor::<B>::from_f64_slice(&features, f_shape, dtype, device)?;
381    let target_tensor = Tensor::<B>::from_f64_slice(&targets, t_shape, dtype, device)?;
382
383    let mut batch = HashMap::with_capacity(2);
384    batch.insert(input_name.to_string(), input_tensor);
385    batch.insert(target_name.to_string(), target_tensor);
386    Ok(batch)
387}