burn_core/data/dataloader/
builder.rs1use super::{
2 BatchDataLoader, BatchStrategy, DataLoader, FixBatchStrategy, MultiThreadDataLoader,
3 batcher::Batcher,
4};
5use burn_dataset::Dataset;
6use burn_tensor::backend::Backend;
7use rand::{SeedableRng, rngs::StdRng};
8use std::sync::Arc;
9
10pub struct DataLoaderBuilder<B: Backend, I, O> {
12 strategy: Option<Box<dyn BatchStrategy<I>>>,
13 batcher: Arc<dyn Batcher<B, I, O>>,
14 num_threads: Option<usize>,
15 shuffle: Option<u64>,
16 device: Option<B::Device>,
17}
18
19impl<B, I, O> DataLoaderBuilder<B, I, O>
20where
21 B: Backend,
22 I: Send + Sync + Clone + std::fmt::Debug + 'static,
23 O: Send + Clone + std::fmt::Debug + 'static,
24{
25 pub fn new<Bt>(batcher: Bt) -> Self
35 where
36 Bt: Batcher<B, I, O> + 'static,
37 {
38 Self {
39 batcher: Arc::new(batcher),
40 strategy: None,
41 num_threads: None,
42 shuffle: None,
43 device: None,
44 }
45 }
46
47 pub fn batch_size(mut self, batch_size: usize) -> Self {
59 self.strategy = Some(Box::new(FixBatchStrategy::new(batch_size)));
60 self
61 }
62
63 pub fn shuffle(mut self, seed: u64) -> Self {
75 self.shuffle = Some(seed);
76 self
77 }
78
79 pub fn num_workers(mut self, num_workers: usize) -> Self {
95 self.num_threads = Some(num_workers);
96 self
97 }
98
99 pub fn set_device(mut self, device: B::Device) -> Self {
109 self.device = Some(device);
110 self
111 }
112
113 pub fn build<D>(self, dataset: D) -> Arc<dyn DataLoader<B, O>>
123 where
124 D: Dataset<I> + 'static,
125 {
126 let dataset = Arc::new(dataset);
127
128 let device = self.device.unwrap_or_default();
129 let rng = self.shuffle.map(StdRng::seed_from_u64);
130 let strategy = match self.strategy {
131 Some(strategy) => strategy,
132 None => Box::new(FixBatchStrategy::new(1)),
133 };
134
135 if let Some(num_threads) = self.num_threads
136 && num_threads > 0
137 {
138 return Arc::new(MultiThreadDataLoader::new(
139 strategy,
140 dataset,
141 self.batcher,
142 num_threads,
143 device,
144 rng,
145 ));
146 }
147
148 Arc::new(BatchDataLoader::new(
149 strategy,
150 dataset,
151 self.batcher,
152 device,
153 rng,
154 ))
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161 use crate::TestBackend;
162 use crate::data::dataset::FakeDataset;
163
164 #[derive(new, Clone)]
165 struct TestBatcherDevice;
166
167 #[cfg(test)]
168 impl<I> Batcher<TestBackend, I, TestDevice> for TestBatcherDevice {
169 fn batch(&self, _items: Vec<I>, device: &TestDevice) -> TestDevice {
170 *device
171 }
172 }
173
174 type TestDevice = <TestBackend as Backend>::Device;
175
176 #[test]
177 fn test_dataloader_no_workers() {
178 type TestDevice = <TestBackend as Backend>::Device;
179
180 let default_device = TestDevice::default();
181 let dataloader = DataLoaderBuilder::new(TestBatcherDevice::new())
182 .batch_size(1)
183 .build(FakeDataset::<String>::new(9));
184
185 assert_eq!(dataloader.num_items(), 9);
186
187 for device in dataloader.iter() {
188 assert_eq!(device, default_device)
189 }
190 }
191
192 #[test]
193 fn test_dataloader_default_device() {
194 let default_device = TestDevice::default();
195 let dataloader = DataLoaderBuilder::new(TestBatcherDevice::new())
196 .batch_size(1)
197 .num_workers(1)
198 .build(FakeDataset::<String>::new(9));
199
200 assert_eq!(dataloader.num_items(), 9);
201
202 for device in dataloader.iter() {
203 assert_eq!(device, default_device)
204 }
205 }
206
207 #[test]
208 fn test_dataloader_slice_multi_device() {
209 let dataloader = DataLoaderBuilder::new(TestBatcherDevice::new())
210 .batch_size(1)
211 .num_workers(1)
212 .build(FakeDataset::<String>::new(11));
213
214 #[cfg(all(
215 test,
216 not(feature = "test-tch"),
217 not(feature = "test-wgpu"),
218 not(feature = "test-cuda")
219 ))]
220 let (device1, device2) = (
222 burn_ndarray::NdArrayDevice::Cpu,
223 burn_ndarray::NdArrayDevice::Cpu,
224 );
225
226 #[cfg(all(test, feature = "test-tch"))]
227 let (device1, device2) = (
228 burn_tch::LibTorchDevice::Cuda(0),
229 burn_tch::LibTorchDevice::Cuda(1),
230 );
231
232 #[cfg(all(test, feature = "test-wgpu"))]
233 let (device1, device2) = (
234 burn_wgpu::WgpuDevice::DiscreteGpu(0),
235 burn_wgpu::WgpuDevice::DiscreteGpu(1),
236 );
237
238 #[cfg(all(test, feature = "test-cuda"))]
239 let (device1, device2) = (burn_cuda::CudaDevice::new(0), burn_cuda::CudaDevice::new(1));
240
241 assert_eq!(dataloader.num_items(), 11);
242 let dataloader_1 = dataloader.slice(0, 5).to_device(&device1);
243 let dataloader_2 = dataloader.slice(5, 11).to_device(&device2);
244
245 assert_eq!(dataloader_1.num_items(), 5);
246 assert_eq!(dataloader_2.num_items(), 6);
247
248 let (mut iterator_1, mut iterator_2) = (dataloader_1.iter(), dataloader_2.iter());
249
250 for _ in 0..5 {
251 assert_eq!(iterator_1.next(), Some(device1));
252 assert_eq!(iterator_2.next(), Some(device2));
253 }
254
255 assert_eq!(iterator_1.next(), None);
256 assert_eq!(iterator_2.next(), Some(device2));
258 assert_eq!(iterator_2.next(), None);
259 }
260}