Skip to main content

burn_core/data/dataloader/
split.rs

1use std::sync::Arc;
2
3use burn_tensor::backend::Backend;
4
5use super::DataLoader;
6
7/// Splits a dataloader into multiple partial dataloaders (one per device).
8pub fn split_dataloader<B: Backend, O>(
9    dataloader: Arc<dyn DataLoader<B, O>>,
10    devices: &[B::Device],
11) -> Vec<Arc<dyn DataLoader<B, O>>> {
12    let num_splits = devices.len();
13    if num_splits > 1 {
14        let num_items = dataloader.num_items();
15        let mut dataloaders = Vec::with_capacity(num_splits);
16
17        let mut start = 0;
18        let step = num_items / num_splits;
19        for (i, device) in devices.iter().enumerate() {
20            let end = if i == (num_splits - 1) {
21                num_items
22            } else {
23                start + step
24            };
25            let dataloader = dataloader.slice(start, end).to_device(device);
26            dataloaders.push(dataloader);
27            start = end;
28        }
29        dataloaders
30    } else {
31        vec![dataloader]
32    }
33}
34
35#[cfg(test)]
36mod tests {
37    use std::collections::HashSet;
38
39    use super::*;
40    use crate::TestBackend;
41    use crate::data::dataloader::batcher::Batcher;
42    use crate::data::dataloader::{BatchDataLoader, FixBatchStrategy};
43    use crate::data::dataset::FakeDataset;
44
45    #[test]
46    fn test_split_batch_dataloader() {
47        type TestDevice = <TestBackend as Backend>::Device;
48
49        #[derive(new, Clone)]
50        pub struct TestBatcher;
51
52        #[cfg(test)]
53        impl<I> Batcher<TestBackend, I, (Vec<I>, TestDevice)> for TestBatcher {
54            fn batch(&self, items: Vec<I>, device: &TestDevice) -> (Vec<I>, TestDevice) {
55                (items, *device)
56            }
57        }
58
59        let batcher = Arc::new(TestBatcher::new());
60        let dataset = Arc::new(FakeDataset::<String>::new(11));
61
62        #[allow(clippy::arc_with_non_send_sync)]
63        let dataloader = Arc::new(BatchDataLoader::new(
64            Box::new(FixBatchStrategy::new(5)),
65            dataset.clone(),
66            batcher,
67            Default::default(),
68            None,
69        ));
70
71        #[cfg(all(
72            test,
73            not(feature = "test-tch"),
74            not(feature = "test-wgpu"),
75            not(feature = "test-cuda")
76        ))]
77        // Only one device exists...
78        let (device1, device2) = (
79            burn_ndarray::NdArrayDevice::Cpu,
80            burn_ndarray::NdArrayDevice::Cpu,
81        );
82
83        #[cfg(all(test, feature = "test-tch"))]
84        let (device1, device2) = (
85            burn_tch::LibTorchDevice::Cuda(0),
86            burn_tch::LibTorchDevice::Cuda(1),
87        );
88
89        #[cfg(all(test, feature = "test-wgpu"))]
90        let (device1, device2) = (
91            burn_wgpu::WgpuDevice::DiscreteGpu(0),
92            burn_wgpu::WgpuDevice::DiscreteGpu(1),
93        );
94
95        #[cfg(all(test, feature = "test-cuda"))]
96        let (device1, device2) = (burn_cuda::CudaDevice::new(0), burn_cuda::CudaDevice::new(1));
97
98        let dataloaders = split_dataloader(dataloader.clone(), &[device1, device2]);
99
100        assert_eq!(dataloaders.len(), 2);
101
102        let [dataloader_1, dataloader_2] = match dataloaders.try_into() {
103            Ok(arr) => arr,
104            Err(_) => unreachable!(),
105        };
106        assert_eq!(dataloader_1.num_items(), 5);
107        assert_eq!(dataloader_2.num_items(), 6);
108
109        let mut items_dataloader = HashSet::new();
110        let mut items_dataloader_split = HashSet::new();
111
112        for (items, _device) in dataloader.iter() {
113            for item in items {
114                items_dataloader.insert(item);
115            }
116        }
117
118        for (items, device) in dataloader_1.iter() {
119            assert_eq!(device, device1);
120            for item in items {
121                items_dataloader_split.insert(item);
122            }
123        }
124
125        for (items, device) in dataloader_2.iter() {
126            assert_eq!(device, device2);
127            for item in items {
128                items_dataloader_split.insert(item);
129            }
130        }
131
132        assert_eq!(items_dataloader, items_dataloader_split);
133    }
134}