burn_core/data/dataloader/
builder.rs1use super::{
2 BatchDataLoader, BatchStrategy, DataLoader, FixBatchStrategy, MultiThreadDataLoader,
3 batcher::Batcher,
4};
5use burn_dataset::Dataset;
6use burn_tensor::backend::Backend;
7use rand::{SeedableRng, rngs::StdRng};
8use std::sync::Arc;
9
10pub struct DataLoaderBuilder<B: Backend, I, O> {
12 strategy: Option<Box<dyn BatchStrategy<I>>>,
13 batcher: Arc<dyn Batcher<B, I, O>>,
14 num_threads: Option<usize>,
15 shuffle: Option<u64>,
16 device: Option<B::Device>,
17}
18
19impl<B, I, O> DataLoaderBuilder<B, I, O>
20where
21 B: Backend,
22 I: Send + Sync + Clone + std::fmt::Debug + 'static,
23 O: Send + Clone + std::fmt::Debug + 'static,
24{
25 pub fn new<Bt>(batcher: Bt) -> Self
35 where
36 Bt: Batcher<B, I, O> + 'static,
37 {
38 Self {
39 batcher: Arc::new(batcher),
40 strategy: None,
41 num_threads: None,
42 shuffle: None,
43 device: None,
44 }
45 }
46
47 pub fn batch_size(mut self, batch_size: usize) -> Self {
59 self.strategy = Some(Box::new(FixBatchStrategy::new(batch_size)));
60 self
61 }
62
63 pub fn shuffle(mut self, seed: u64) -> Self {
75 self.shuffle = Some(seed);
76 self
77 }
78
79 pub fn num_workers(mut self, num_workers: usize) -> Self {
95 self.num_threads = Some(num_workers);
96 self
97 }
98
99 pub fn set_device(mut self, device: B::Device) -> Self {
109 self.device = Some(device);
110 self
111 }
112
113 pub fn build<D>(self, dataset: D) -> Arc<dyn DataLoader<B, O>>
123 where
124 D: Dataset<I> + 'static,
125 {
126 let dataset = Arc::new(dataset);
127
128 let device = self.device.unwrap_or_default();
129 let rng = self.shuffle.map(StdRng::seed_from_u64);
130 let strategy = match self.strategy {
131 Some(strategy) => strategy,
132 None => Box::new(FixBatchStrategy::new(1)),
133 };
134
135 if let Some(num_threads) = self.num_threads
136 && num_threads > 0
137 {
138 return Arc::new(MultiThreadDataLoader::new(
139 strategy,
140 dataset,
141 self.batcher,
142 num_threads,
143 device,
144 rng,
145 ));
146 }
147
148 Arc::new(BatchDataLoader::new(
149 strategy,
150 dataset,
151 self.batcher,
152 device,
153 rng,
154 ))
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161 use crate::TestBackend;
162 use crate::data::dataset::FakeDataset;
163 use burn_tensor::Device;
164
165 #[derive(new, Clone)]
166 struct TestBatcherDevice;
167
168 #[cfg(test)]
169 impl<I> Batcher<TestBackend, I, TestDevice> for TestBatcherDevice {
170 fn batch(&self, _items: Vec<I>, device: &TestDevice) -> TestDevice {
171 *device
172 }
173 }
174
175 type TestDevice = Device<TestBackend>;
176
177 #[test]
178 fn test_dataloader_no_workers() {
179 let default_device = TestDevice::default();
180 let dataloader = DataLoaderBuilder::new(TestBatcherDevice::new())
181 .batch_size(1)
182 .build(FakeDataset::<String>::new(9));
183
184 assert_eq!(dataloader.num_items(), 9);
185
186 for device in dataloader.iter() {
187 assert_eq!(device, default_device)
188 }
189 }
190
191 #[test]
192 fn test_dataloader_default_device() {
193 let default_device = TestDevice::default();
194 let dataloader = DataLoaderBuilder::new(TestBatcherDevice::new())
195 .batch_size(1)
196 .num_workers(1)
197 .build(FakeDataset::<String>::new(9));
198
199 assert_eq!(dataloader.num_items(), 9);
200
201 for device in dataloader.iter() {
202 assert_eq!(device, default_device)
203 }
204 }
205
206 #[test]
207 fn test_dataloader_slice_multi_device() {
208 let dataloader = DataLoaderBuilder::new(TestBatcherDevice::new())
209 .batch_size(1)
210 .num_workers(1)
211 .build(FakeDataset::<String>::new(11));
212
213 #[cfg(all(
214 test,
215 not(feature = "test-tch"),
216 not(feature = "test-wgpu"),
217 not(feature = "test-cuda")
218 ))]
219 let (device1, device2) = (burn_flex::FlexDevice, burn_flex::FlexDevice);
221
222 #[cfg(all(test, feature = "test-tch"))]
223 let (device1, device2) = (
224 burn_tch::LibTorchDevice::Cuda(0),
225 burn_tch::LibTorchDevice::Cuda(1),
226 );
227
228 #[cfg(all(test, feature = "test-wgpu"))]
229 let (device1, device2) = (
230 burn_wgpu::WgpuDevice::DiscreteGpu(0),
231 burn_wgpu::WgpuDevice::DiscreteGpu(1),
232 );
233
234 #[cfg(all(test, feature = "test-cuda"))]
235 let (device1, device2) = (burn_cuda::CudaDevice::new(0), burn_cuda::CudaDevice::new(1));
236
237 assert_eq!(dataloader.num_items(), 11);
238 let dataloader_1 = dataloader.slice(0, 5).to_device(&device1);
239 let dataloader_2 = dataloader.slice(5, 11).to_device(&device2);
240
241 assert_eq!(dataloader_1.num_items(), 5);
242 assert_eq!(dataloader_2.num_items(), 6);
243
244 let (mut iterator_1, mut iterator_2) = (dataloader_1.iter(), dataloader_2.iter());
245
246 for _ in 0..5 {
247 assert_eq!(iterator_1.next(), Some(device1));
248 assert_eq!(iterator_2.next(), Some(device2));
249 }
250
251 assert_eq!(iterator_1.next(), None);
252 assert_eq!(iterator_2.next(), Some(device2));
254 assert_eq!(iterator_2.next(), None);
255 }
256}