burn_core/data/dataloader/
builder.rs

1use super::{
2    BatchDataLoader, BatchStrategy, DataLoader, FixBatchStrategy, MultiThreadDataLoader,
3    batcher::Batcher,
4};
5use burn_dataset::Dataset;
6use burn_tensor::backend::Backend;
7use rand::{SeedableRng, rngs::StdRng};
8use std::sync::Arc;
9
10/// A builder for data loaders.
11pub struct DataLoaderBuilder<B: Backend, I, O> {
12    strategy: Option<Box<dyn BatchStrategy<I>>>,
13    batcher: Arc<dyn Batcher<B, I, O>>,
14    num_threads: Option<usize>,
15    shuffle: Option<u64>,
16    device: Option<B::Device>,
17}
18
19impl<B, I, O> DataLoaderBuilder<B, I, O>
20where
21    B: Backend,
22    I: Send + Sync + Clone + std::fmt::Debug + 'static,
23    O: Send + Clone + std::fmt::Debug + 'static,
24{
25    /// Creates a new data loader builder.
26    ///
27    /// # Arguments
28    ///
29    /// * `batcher` - The batcher.
30    ///
31    /// # Returns
32    ///
33    /// The data loader builder.
34    pub fn new<Bt>(batcher: Bt) -> Self
35    where
36        Bt: Batcher<B, I, O> + 'static,
37    {
38        Self {
39            batcher: Arc::new(batcher),
40            strategy: None,
41            num_threads: None,
42            shuffle: None,
43            device: None,
44        }
45    }
46
47    /// Sets the batch size to a fix number.The [fix batch strategy](FixBatchStrategy)
48    /// will be used.
49    ///
50    /// # Arguments
51    ///
52    /// * `batch_size` - The batch size.
53    ///
54    /// # Returns
55    ///
56    /// The data loader builder.
57    pub fn batch_size(mut self, batch_size: usize) -> Self {
58        self.strategy = Some(Box::new(FixBatchStrategy::new(batch_size)));
59        self
60    }
61
62    /// Sets the seed for shuffling.
63    ///
64    /// Each time the dataloader starts a new iteration, the dataset will be shuffled.
65    ///
66    /// # Arguments
67    ///
68    /// * `seed` - The seed.
69    ///
70    /// # Returns
71    ///
72    /// The data loader builder.
73    pub fn shuffle(mut self, seed: u64) -> Self {
74        self.shuffle = Some(seed);
75        self
76    }
77
78    /// Sets the number of workers.
79    ///
80    /// # Arguments
81    ///
82    /// * `num_workers` - The number of workers.
83    ///
84    /// # Returns
85    ///
86    /// The data loader builder.
87    pub fn num_workers(mut self, num_workers: usize) -> Self {
88        self.num_threads = Some(num_workers);
89        self
90    }
91
92    /// Sets the data loader device.
93    ///
94    /// # Arguments
95    ///
96    /// * `device` - The device to use when loading a batch.
97    ///
98    /// # Returns
99    ///
100    /// The data loader builder.
101    pub fn set_device(mut self, device: B::Device) -> Self {
102        self.device = Some(device);
103        self
104    }
105
106    /// Builds the data loader.
107    ///
108    /// # Arguments
109    ///
110    /// * `dataset` - The dataset.
111    ///
112    /// # Returns
113    ///
114    /// The data loader.
115    pub fn build<D>(self, dataset: D) -> Arc<dyn DataLoader<B, O>>
116    where
117        D: Dataset<I> + 'static,
118    {
119        let dataset = Arc::new(dataset);
120
121        let device = self.device.unwrap_or_default();
122        let rng = self.shuffle.map(StdRng::seed_from_u64);
123        let strategy = match self.strategy {
124            Some(strategy) => strategy,
125            None => Box::new(FixBatchStrategy::new(1)),
126        };
127        if let Some(num_threads) = self.num_threads {
128            return Arc::new(MultiThreadDataLoader::new(
129                strategy,
130                dataset,
131                self.batcher,
132                num_threads,
133                device,
134                rng,
135            ));
136        }
137
138        Arc::new(BatchDataLoader::new(
139            strategy,
140            dataset,
141            self.batcher,
142            device,
143            rng,
144        ))
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use crate::data::dataset::FakeDataset;
152    use crate::{TestBackend, data::dataloader::batcher::Batcher};
153
154    #[test]
155    fn test_dataloader_default_device() {
156        type TestDevice = <TestBackend as Backend>::Device;
157
158        #[derive(new, Clone)]
159        pub struct TestBatcher;
160
161        #[cfg(test)]
162        impl<I> Batcher<TestBackend, I, TestDevice> for TestBatcher {
163            fn batch(&self, _items: Vec<I>, device: &TestDevice) -> TestDevice {
164                *device
165            }
166        }
167
168        let default_device = TestDevice::default();
169        let dataloader = DataLoaderBuilder::new(TestBatcher::new())
170            .batch_size(1)
171            .num_workers(1)
172            .build(FakeDataset::<String>::new(9));
173
174        assert_eq!(dataloader.num_items(), 9);
175
176        for device in dataloader.iter() {
177            assert_eq!(device, default_device)
178        }
179    }
180
181    #[test]
182    fn test_dataloader_slice_multi_device() {
183        type TestDevice = <TestBackend as Backend>::Device;
184
185        #[derive(new, Clone)]
186        pub struct TestBatcher;
187
188        #[cfg(test)]
189        impl<I> Batcher<TestBackend, I, TestDevice> for TestBatcher {
190            fn batch(&self, _items: Vec<I>, device: &TestDevice) -> TestDevice {
191                *device
192            }
193        }
194
195        let dataloader = DataLoaderBuilder::new(TestBatcher::new())
196            .batch_size(1)
197            .num_workers(1)
198            .build(FakeDataset::<String>::new(11));
199
200        #[cfg(all(
201            test,
202            not(feature = "test-tch"),
203            not(feature = "test-wgpu"),
204            not(feature = "test-cuda")
205        ))]
206        // Only one device exists...
207        let (device1, device2) = (
208            burn_ndarray::NdArrayDevice::Cpu,
209            burn_ndarray::NdArrayDevice::Cpu,
210        );
211
212        #[cfg(all(test, feature = "test-tch"))]
213        let (device1, device2) = (
214            burn_tch::LibTorchDevice::Cuda(0),
215            burn_tch::LibTorchDevice::Cuda(1),
216        );
217
218        #[cfg(all(test, feature = "test-wgpu"))]
219        let (device1, device2) = (
220            burn_wgpu::WgpuDevice::DiscreteGpu(0),
221            burn_wgpu::WgpuDevice::DiscreteGpu(1),
222        );
223
224        #[cfg(all(test, feature = "test-cuda"))]
225        let (device1, device2) = (burn_cuda::CudaDevice::new(0), burn_cuda::CudaDevice::new(1));
226
227        assert_eq!(dataloader.num_items(), 11);
228        let dataloader_1 = dataloader.slice(0, 5).to_device(&device1);
229        let dataloader_2 = dataloader.slice(5, 11).to_device(&device2);
230
231        assert_eq!(dataloader_1.num_items(), 5);
232        assert_eq!(dataloader_2.num_items(), 6);
233
234        let (mut iterator_1, mut iterator_2) = (dataloader_1.iter(), dataloader_2.iter());
235
236        for _ in 0..5 {
237            assert_eq!(iterator_1.next(), Some(device1));
238            assert_eq!(iterator_2.next(), Some(device2));
239        }
240
241        assert_eq!(iterator_1.next(), None);
242        // For uneven split, the last dataloader (partial dataset) will have the remaining item
243        assert_eq!(iterator_2.next(), Some(device2));
244        assert_eq!(iterator_2.next(), None);
245    }
246}