Skip to main content

burn_core/data/dataloader/
base.rs

1use burn_tensor::backend::Backend;
2
3pub use crate::data::dataset::{Dataset, DatasetIterator};
4use core::iter::Iterator;
5use std::sync::Arc;
6
7/// A progress struct that can be used to track the progress of a data loader.
8#[derive(new, Clone, Debug)]
9pub struct Progress {
10    /// The number of items that have been processed.
11    pub items_processed: usize,
12
13    /// The total number of items that need to be processed.
14    pub items_total: usize,
15}
16
17/// A data loader iterator that can be used to iterate over a data loader.
18pub trait DataLoaderIterator<O>: Iterator<Item = O> {
19    /// Returns the progress of the data loader.
20    fn progress(&self) -> Progress;
21}
22
23/// A data loader that can be used to iterate over a dataset.
24pub trait DataLoader<B: Backend, O>: Send + Sync {
25    /// Returns a boxed [iterator](DataLoaderIterator) to iterate over the data loader.
26    fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a>;
27
28    /// The number of items (not the number of batches nor the number of iterations),
29    /// corresponding to the items_total of the progress returned by the iterator.
30    fn num_items(&self) -> usize;
31
32    /// Move the data loader to the given device, ensuring the batches are assigned to the correct device.
33    fn to_device(&self, device: &B::Device) -> Arc<dyn DataLoader<B, O>>;
34
35    /// Returns a new data loader containing a subset of the data.
36    ///
37    /// The subset includes items from `start` (inclusive) to `end` (exclusive),
38    /// preserving the batch size and ordering of the original data loader.
39    ///
40    /// # Arguments
41    ///
42    /// * `start` - The starting index of the subset (inclusive).
43    /// * `end` - The ending index of the subset (exclusive).
44    ///
45    /// # Returns
46    ///
47    /// A boxed [`DataLoader`] instance containing only the specified range.
48    fn slice(&self, start: usize, end: usize) -> Arc<dyn DataLoader<B, O>>;
49}