burn_core/data/dataloader/
builder.rs1use super::{
2 BatchDataLoader, BatchStrategy, DataLoader, FixBatchStrategy, MultiThreadDataLoader,
3 batcher::Batcher,
4};
5use burn_dataset::Dataset;
6use burn_tensor::backend::Backend;
7use rand::{SeedableRng, rngs::StdRng};
8use std::sync::Arc;
9
10pub struct DataLoaderBuilder<B: Backend, I, O> {
12 strategy: Option<Box<dyn BatchStrategy<I>>>,
13 batcher: Arc<dyn Batcher<B, I, O>>,
14 num_threads: Option<usize>,
15 shuffle: Option<u64>,
16 device: Option<B::Device>,
17}
18
19impl<B, I, O> DataLoaderBuilder<B, I, O>
20where
21 B: Backend,
22 I: Send + Sync + Clone + std::fmt::Debug + 'static,
23 O: Send + Clone + std::fmt::Debug + 'static,
24{
25 pub fn new<Bt>(batcher: Bt) -> Self
35 where
36 Bt: Batcher<B, I, O> + 'static,
37 {
38 Self {
39 batcher: Arc::new(batcher),
40 strategy: None,
41 num_threads: None,
42 shuffle: None,
43 device: None,
44 }
45 }
46
47 pub fn batch_size(mut self, batch_size: usize) -> Self {
58 self.strategy = Some(Box::new(FixBatchStrategy::new(batch_size)));
59 self
60 }
61
62 pub fn shuffle(mut self, seed: u64) -> Self {
74 self.shuffle = Some(seed);
75 self
76 }
77
78 pub fn num_workers(mut self, num_workers: usize) -> Self {
88 self.num_threads = Some(num_workers);
89 self
90 }
91
92 pub fn set_device(mut self, device: B::Device) -> Self {
102 self.device = Some(device);
103 self
104 }
105
106 pub fn build<D>(self, dataset: D) -> Arc<dyn DataLoader<B, O>>
116 where
117 D: Dataset<I> + 'static,
118 {
119 let dataset = Arc::new(dataset);
120
121 let device = self.device.unwrap_or_default();
122 let rng = self.shuffle.map(StdRng::seed_from_u64);
123 let strategy = match self.strategy {
124 Some(strategy) => strategy,
125 None => Box::new(FixBatchStrategy::new(1)),
126 };
127 if let Some(num_threads) = self.num_threads {
128 return Arc::new(MultiThreadDataLoader::new(
129 strategy,
130 dataset,
131 self.batcher,
132 num_threads,
133 device,
134 rng,
135 ));
136 }
137
138 Arc::new(BatchDataLoader::new(
139 strategy,
140 dataset,
141 self.batcher,
142 device,
143 rng,
144 ))
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use crate::data::dataset::FakeDataset;
152 use crate::{TestBackend, data::dataloader::batcher::Batcher};
153
154 #[test]
155 fn test_dataloader_default_device() {
156 type TestDevice = <TestBackend as Backend>::Device;
157
158 #[derive(new, Clone)]
159 pub struct TestBatcher;
160
161 #[cfg(test)]
162 impl<I> Batcher<TestBackend, I, TestDevice> for TestBatcher {
163 fn batch(&self, _items: Vec<I>, device: &TestDevice) -> TestDevice {
164 *device
165 }
166 }
167
168 let default_device = TestDevice::default();
169 let dataloader = DataLoaderBuilder::new(TestBatcher::new())
170 .batch_size(1)
171 .num_workers(1)
172 .build(FakeDataset::<String>::new(9));
173
174 assert_eq!(dataloader.num_items(), 9);
175
176 for device in dataloader.iter() {
177 assert_eq!(device, default_device)
178 }
179 }
180
181 #[test]
182 fn test_dataloader_slice_multi_device() {
183 type TestDevice = <TestBackend as Backend>::Device;
184
185 #[derive(new, Clone)]
186 pub struct TestBatcher;
187
188 #[cfg(test)]
189 impl<I> Batcher<TestBackend, I, TestDevice> for TestBatcher {
190 fn batch(&self, _items: Vec<I>, device: &TestDevice) -> TestDevice {
191 *device
192 }
193 }
194
195 let dataloader = DataLoaderBuilder::new(TestBatcher::new())
196 .batch_size(1)
197 .num_workers(1)
198 .build(FakeDataset::<String>::new(11));
199
200 #[cfg(all(
201 test,
202 not(feature = "test-tch"),
203 not(feature = "test-wgpu"),
204 not(feature = "test-cuda")
205 ))]
206 let (device1, device2) = (
208 burn_ndarray::NdArrayDevice::Cpu,
209 burn_ndarray::NdArrayDevice::Cpu,
210 );
211
212 #[cfg(all(test, feature = "test-tch"))]
213 let (device1, device2) = (
214 burn_tch::LibTorchDevice::Cuda(0),
215 burn_tch::LibTorchDevice::Cuda(1),
216 );
217
218 #[cfg(all(test, feature = "test-wgpu"))]
219 let (device1, device2) = (
220 burn_wgpu::WgpuDevice::DiscreteGpu(0),
221 burn_wgpu::WgpuDevice::DiscreteGpu(1),
222 );
223
224 #[cfg(all(test, feature = "test-cuda"))]
225 let (device1, device2) = (burn_cuda::CudaDevice::new(0), burn_cuda::CudaDevice::new(1));
226
227 assert_eq!(dataloader.num_items(), 11);
228 let dataloader_1 = dataloader.slice(0, 5).to_device(&device1);
229 let dataloader_2 = dataloader.slice(5, 11).to_device(&device2);
230
231 assert_eq!(dataloader_1.num_items(), 5);
232 assert_eq!(dataloader_2.num_items(), 6);
233
234 let (mut iterator_1, mut iterator_2) = (dataloader_1.iter(), dataloader_2.iter());
235
236 for _ in 0..5 {
237 assert_eq!(iterator_1.next(), Some(device1));
238 assert_eq!(iterator_2.next(), Some(device2));
239 }
240
241 assert_eq!(iterator_1.next(), None);
242 assert_eq!(iterator_2.next(), Some(device2));
244 assert_eq!(iterator_2.next(), None);
245 }
246}