Skip to main content

burn_core/data/dataloader/
builder.rs

1use super::{
2    BatchDataLoader, BatchStrategy, DataLoader, FixBatchStrategy, MultiThreadDataLoader,
3    batcher::Batcher,
4};
5use burn_dataset::Dataset;
6use burn_tensor::backend::Backend;
7use rand::{SeedableRng, rngs::StdRng};
8use std::sync::Arc;
9
10/// A builder for data loaders.
11pub struct DataLoaderBuilder<B: Backend, I, O> {
12    strategy: Option<Box<dyn BatchStrategy<I>>>,
13    batcher: Arc<dyn Batcher<B, I, O>>,
14    num_threads: Option<usize>,
15    shuffle: Option<u64>,
16    device: Option<B::Device>,
17}
18
19impl<B, I, O> DataLoaderBuilder<B, I, O>
20where
21    B: Backend,
22    I: Send + Sync + Clone + std::fmt::Debug + 'static,
23    O: Send + Clone + std::fmt::Debug + 'static,
24{
25    /// Creates a new data loader builder.
26    ///
27    /// # Arguments
28    ///
29    /// * `batcher` - The batcher.
30    ///
31    /// # Returns
32    ///
33    /// The data loader builder.
34    pub fn new<Bt>(batcher: Bt) -> Self
35    where
36        Bt: Batcher<B, I, O> + 'static,
37    {
38        Self {
39            batcher: Arc::new(batcher),
40            strategy: None,
41            num_threads: None,
42            shuffle: None,
43            device: None,
44        }
45    }
46
47    /// Sets the batch size to a fix number.
48    ///
49    /// The [fix batch strategy](FixBatchStrategy) will be used.
50    ///
51    /// # Arguments
52    ///
53    /// * `batch_size` - The batch size.
54    ///
55    /// # Returns
56    ///
57    /// The data loader builder.
58    pub fn batch_size(mut self, batch_size: usize) -> Self {
59        self.strategy = Some(Box::new(FixBatchStrategy::new(batch_size)));
60        self
61    }
62
63    /// Sets the seed for shuffling.
64    ///
65    /// Each time the dataloader starts a new iteration, the dataset will be shuffled.
66    ///
67    /// # Arguments
68    ///
69    /// * `seed` - The seed.
70    ///
71    /// # Returns
72    ///
73    /// The data loader builder.
74    pub fn shuffle(mut self, seed: u64) -> Self {
75        self.shuffle = Some(seed);
76        self
77    }
78
79    /// Sets the number of workers.
80    ///
81    /// - `Some(0)` or `None`: the dataloader will run without work threads.
82    /// - `Some(n); n > 0`: the dataloader will run with `n` background threads.
83    ///
84    /// A 1-worker threaded dataloader will run loads in a background thread,
85    /// while a 0-worker threaded dataloader will run loads in the main thread.
86    ///
87    /// # Arguments
88    ///
89    /// * `num_workers` - The number of workers.
90    ///
91    /// # Returns
92    ///
93    /// The data loader builder.
94    pub fn num_workers(mut self, num_workers: usize) -> Self {
95        self.num_threads = Some(num_workers);
96        self
97    }
98
99    /// Sets the data loader device.
100    ///
101    /// # Arguments
102    ///
103    /// * `device` - The device to use when loading a batch.
104    ///
105    /// # Returns
106    ///
107    /// The data loader builder.
108    pub fn set_device(mut self, device: B::Device) -> Self {
109        self.device = Some(device);
110        self
111    }
112
113    /// Builds the data loader.
114    ///
115    /// # Arguments
116    ///
117    /// * `dataset` - The dataset.
118    ///
119    /// # Returns
120    ///
121    /// The data loader.
122    pub fn build<D>(self, dataset: D) -> Arc<dyn DataLoader<B, O>>
123    where
124        D: Dataset<I> + 'static,
125    {
126        let dataset = Arc::new(dataset);
127
128        let device = self.device.unwrap_or_default();
129        let rng = self.shuffle.map(StdRng::seed_from_u64);
130        let strategy = match self.strategy {
131            Some(strategy) => strategy,
132            None => Box::new(FixBatchStrategy::new(1)),
133        };
134
135        if let Some(num_threads) = self.num_threads
136            && num_threads > 0
137        {
138            return Arc::new(MultiThreadDataLoader::new(
139                strategy,
140                dataset,
141                self.batcher,
142                num_threads,
143                device,
144                rng,
145            ));
146        }
147
148        Arc::new(BatchDataLoader::new(
149            strategy,
150            dataset,
151            self.batcher,
152            device,
153            rng,
154        ))
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161    use crate::TestBackend;
162    use crate::data::dataset::FakeDataset;
163
164    #[derive(new, Clone)]
165    struct TestBatcherDevice;
166
167    #[cfg(test)]
168    impl<I> Batcher<TestBackend, I, TestDevice> for TestBatcherDevice {
169        fn batch(&self, _items: Vec<I>, device: &TestDevice) -> TestDevice {
170            *device
171        }
172    }
173
174    type TestDevice = <TestBackend as Backend>::Device;
175
176    #[test]
177    fn test_dataloader_no_workers() {
178        type TestDevice = <TestBackend as Backend>::Device;
179
180        let default_device = TestDevice::default();
181        let dataloader = DataLoaderBuilder::new(TestBatcherDevice::new())
182            .batch_size(1)
183            .build(FakeDataset::<String>::new(9));
184
185        assert_eq!(dataloader.num_items(), 9);
186
187        for device in dataloader.iter() {
188            assert_eq!(device, default_device)
189        }
190    }
191
192    #[test]
193    fn test_dataloader_default_device() {
194        let default_device = TestDevice::default();
195        let dataloader = DataLoaderBuilder::new(TestBatcherDevice::new())
196            .batch_size(1)
197            .num_workers(1)
198            .build(FakeDataset::<String>::new(9));
199
200        assert_eq!(dataloader.num_items(), 9);
201
202        for device in dataloader.iter() {
203            assert_eq!(device, default_device)
204        }
205    }
206
207    #[test]
208    fn test_dataloader_slice_multi_device() {
209        let dataloader = DataLoaderBuilder::new(TestBatcherDevice::new())
210            .batch_size(1)
211            .num_workers(1)
212            .build(FakeDataset::<String>::new(11));
213
214        #[cfg(all(
215            test,
216            not(feature = "test-tch"),
217            not(feature = "test-wgpu"),
218            not(feature = "test-cuda")
219        ))]
220        // Only one device exists...
221        let (device1, device2) = (
222            burn_ndarray::NdArrayDevice::Cpu,
223            burn_ndarray::NdArrayDevice::Cpu,
224        );
225
226        #[cfg(all(test, feature = "test-tch"))]
227        let (device1, device2) = (
228            burn_tch::LibTorchDevice::Cuda(0),
229            burn_tch::LibTorchDevice::Cuda(1),
230        );
231
232        #[cfg(all(test, feature = "test-wgpu"))]
233        let (device1, device2) = (
234            burn_wgpu::WgpuDevice::DiscreteGpu(0),
235            burn_wgpu::WgpuDevice::DiscreteGpu(1),
236        );
237
238        #[cfg(all(test, feature = "test-cuda"))]
239        let (device1, device2) = (burn_cuda::CudaDevice::new(0), burn_cuda::CudaDevice::new(1));
240
241        assert_eq!(dataloader.num_items(), 11);
242        let dataloader_1 = dataloader.slice(0, 5).to_device(&device1);
243        let dataloader_2 = dataloader.slice(5, 11).to_device(&device2);
244
245        assert_eq!(dataloader_1.num_items(), 5);
246        assert_eq!(dataloader_2.num_items(), 6);
247
248        let (mut iterator_1, mut iterator_2) = (dataloader_1.iter(), dataloader_2.iter());
249
250        for _ in 0..5 {
251            assert_eq!(iterator_1.next(), Some(device1));
252            assert_eq!(iterator_2.next(), Some(device2));
253        }
254
255        assert_eq!(iterator_1.next(), None);
256        // For uneven split, the last dataloader (partial dataset) will have the remaining item
257        assert_eq!(iterator_2.next(), Some(device2));
258        assert_eq!(iterator_2.next(), None);
259    }
260}