Skip to main content

burn_core/data/dataloader/
builder.rs

1use super::{
2    BatchDataLoader, BatchStrategy, DataLoader, FixBatchStrategy, MultiThreadDataLoader,
3    batcher::Batcher,
4};
5use burn_dataset::Dataset;
6use burn_tensor::backend::Backend;
7use rand::{SeedableRng, rngs::StdRng};
8use std::sync::Arc;
9
10/// A builder for data loaders.
11pub struct DataLoaderBuilder<B: Backend, I, O> {
12    strategy: Option<Box<dyn BatchStrategy<I>>>,
13    batcher: Arc<dyn Batcher<B, I, O>>,
14    num_threads: Option<usize>,
15    shuffle: Option<u64>,
16    device: Option<B::Device>,
17}
18
19impl<B, I, O> DataLoaderBuilder<B, I, O>
20where
21    B: Backend,
22    I: Send + Sync + Clone + std::fmt::Debug + 'static,
23    O: Send + Clone + std::fmt::Debug + 'static,
24{
25    /// Creates a new data loader builder.
26    ///
27    /// # Arguments
28    ///
29    /// * `batcher` - The batcher.
30    ///
31    /// # Returns
32    ///
33    /// The data loader builder.
34    pub fn new<Bt>(batcher: Bt) -> Self
35    where
36        Bt: Batcher<B, I, O> + 'static,
37    {
38        Self {
39            batcher: Arc::new(batcher),
40            strategy: None,
41            num_threads: None,
42            shuffle: None,
43            device: None,
44        }
45    }
46
47    /// Sets the batch size to a fix number.
48    ///
49    /// The [fix batch strategy](FixBatchStrategy) will be used.
50    ///
51    /// # Arguments
52    ///
53    /// * `batch_size` - The batch size.
54    ///
55    /// # Returns
56    ///
57    /// The data loader builder.
58    pub fn batch_size(mut self, batch_size: usize) -> Self {
59        self.strategy = Some(Box::new(FixBatchStrategy::new(batch_size)));
60        self
61    }
62
63    /// Sets the seed for shuffling.
64    ///
65    /// Each time the dataloader starts a new iteration, the dataset will be shuffled.
66    ///
67    /// # Arguments
68    ///
69    /// * `seed` - The seed.
70    ///
71    /// # Returns
72    ///
73    /// The data loader builder.
74    pub fn shuffle(mut self, seed: u64) -> Self {
75        self.shuffle = Some(seed);
76        self
77    }
78
79    /// Sets the number of workers.
80    ///
81    /// - `Some(0)` or `None`: the dataloader will run without work threads.
82    /// - `Some(n); n > 0`: the dataloader will run with `n` background threads.
83    ///
84    /// A 1-worker threaded dataloader will run loads in a background thread,
85    /// while a 0-worker threaded dataloader will run loads in the main thread.
86    ///
87    /// # Arguments
88    ///
89    /// * `num_workers` - The number of workers.
90    ///
91    /// # Returns
92    ///
93    /// The data loader builder.
94    pub fn num_workers(mut self, num_workers: usize) -> Self {
95        self.num_threads = Some(num_workers);
96        self
97    }
98
99    /// Sets the data loader device.
100    ///
101    /// # Arguments
102    ///
103    /// * `device` - The device to use when loading a batch.
104    ///
105    /// # Returns
106    ///
107    /// The data loader builder.
108    pub fn set_device(mut self, device: B::Device) -> Self {
109        self.device = Some(device);
110        self
111    }
112
113    /// Builds the data loader.
114    ///
115    /// # Arguments
116    ///
117    /// * `dataset` - The dataset.
118    ///
119    /// # Returns
120    ///
121    /// The data loader.
122    pub fn build<D>(self, dataset: D) -> Arc<dyn DataLoader<B, O>>
123    where
124        D: Dataset<I> + 'static,
125    {
126        let dataset = Arc::new(dataset);
127
128        let device = self.device.unwrap_or_default();
129        let rng = self.shuffle.map(StdRng::seed_from_u64);
130        let strategy = match self.strategy {
131            Some(strategy) => strategy,
132            None => Box::new(FixBatchStrategy::new(1)),
133        };
134
135        if let Some(num_threads) = self.num_threads
136            && num_threads > 0
137        {
138            return Arc::new(MultiThreadDataLoader::new(
139                strategy,
140                dataset,
141                self.batcher,
142                num_threads,
143                device,
144                rng,
145            ));
146        }
147
148        Arc::new(BatchDataLoader::new(
149            strategy,
150            dataset,
151            self.batcher,
152            device,
153            rng,
154        ))
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161    use crate::TestBackend;
162    use crate::data::dataset::FakeDataset;
163    use burn_tensor::Device;
164
165    #[derive(new, Clone)]
166    struct TestBatcherDevice;
167
168    #[cfg(test)]
169    impl<I> Batcher<TestBackend, I, TestDevice> for TestBatcherDevice {
170        fn batch(&self, _items: Vec<I>, device: &TestDevice) -> TestDevice {
171            *device
172        }
173    }
174
175    type TestDevice = Device<TestBackend>;
176
177    #[test]
178    fn test_dataloader_no_workers() {
179        let default_device = TestDevice::default();
180        let dataloader = DataLoaderBuilder::new(TestBatcherDevice::new())
181            .batch_size(1)
182            .build(FakeDataset::<String>::new(9));
183
184        assert_eq!(dataloader.num_items(), 9);
185
186        for device in dataloader.iter() {
187            assert_eq!(device, default_device)
188        }
189    }
190
191    #[test]
192    fn test_dataloader_default_device() {
193        let default_device = TestDevice::default();
194        let dataloader = DataLoaderBuilder::new(TestBatcherDevice::new())
195            .batch_size(1)
196            .num_workers(1)
197            .build(FakeDataset::<String>::new(9));
198
199        assert_eq!(dataloader.num_items(), 9);
200
201        for device in dataloader.iter() {
202            assert_eq!(device, default_device)
203        }
204    }
205
206    #[test]
207    fn test_dataloader_slice_multi_device() {
208        let dataloader = DataLoaderBuilder::new(TestBatcherDevice::new())
209            .batch_size(1)
210            .num_workers(1)
211            .build(FakeDataset::<String>::new(11));
212
213        #[cfg(all(
214            test,
215            not(feature = "test-tch"),
216            not(feature = "test-wgpu"),
217            not(feature = "test-cuda")
218        ))]
219        // Only one device exists...
220        let (device1, device2) = (burn_flex::FlexDevice, burn_flex::FlexDevice);
221
222        #[cfg(all(test, feature = "test-tch"))]
223        let (device1, device2) = (
224            burn_tch::LibTorchDevice::Cuda(0),
225            burn_tch::LibTorchDevice::Cuda(1),
226        );
227
228        #[cfg(all(test, feature = "test-wgpu"))]
229        let (device1, device2) = (
230            burn_wgpu::WgpuDevice::DiscreteGpu(0),
231            burn_wgpu::WgpuDevice::DiscreteGpu(1),
232        );
233
234        #[cfg(all(test, feature = "test-cuda"))]
235        let (device1, device2) = (burn_cuda::CudaDevice::new(0), burn_cuda::CudaDevice::new(1));
236
237        assert_eq!(dataloader.num_items(), 11);
238        let dataloader_1 = dataloader.slice(0, 5).to_device(&device1);
239        let dataloader_2 = dataloader.slice(5, 11).to_device(&device2);
240
241        assert_eq!(dataloader_1.num_items(), 5);
242        assert_eq!(dataloader_2.num_items(), 6);
243
244        let (mut iterator_1, mut iterator_2) = (dataloader_1.iter(), dataloader_2.iter());
245
246        for _ in 0..5 {
247            assert_eq!(iterator_1.next(), Some(device1));
248            assert_eq!(iterator_2.next(), Some(device2));
249        }
250
251        assert_eq!(iterator_1.next(), None);
252        // For uneven split, the last dataloader (partial dataset) will have the remaining item
253        assert_eq!(iterator_2.next(), Some(device2));
254        assert_eq!(iterator_2.next(), None);
255    }
256}