Skip to main content

burn_core/data/dataloader/
split.rs

1use std::sync::Arc;
2
3use burn_tensor::backend::Backend;
4
5use super::DataLoader;
6
7/// Splits a dataloader into multiple partial dataloaders (one per device).
8pub fn split_dataloader<B: Backend, O>(
9    dataloader: Arc<dyn DataLoader<B, O>>,
10    devices: &[B::Device],
11) -> Vec<Arc<dyn DataLoader<B, O>>> {
12    let num_splits = devices.len();
13    if num_splits > 1 {
14        let num_items = dataloader.num_items();
15        let mut dataloaders = Vec::with_capacity(num_splits);
16
17        let mut start = 0;
18        let step = num_items / num_splits;
19        for (i, device) in devices.iter().enumerate() {
20            let end = if i == (num_splits - 1) {
21                num_items
22            } else {
23                start + step
24            };
25            let dataloader = dataloader.slice(start, end).to_device(device);
26            dataloaders.push(dataloader);
27            start = end;
28        }
29        dataloaders
30    } else {
31        vec![dataloader]
32    }
33}
34
35#[cfg(test)]
36mod tests {
37    use burn_tensor::Device;
38    use std::collections::HashSet;
39
40    use super::*;
41    use crate::TestBackend;
42    use crate::data::dataloader::batcher::Batcher;
43    use crate::data::dataloader::{BatchDataLoader, FixBatchStrategy};
44    use crate::data::dataset::FakeDataset;
45
46    #[test]
47    fn test_split_batch_dataloader() {
48        type TestDevice = Device<TestBackend>;
49
50        #[derive(new, Clone)]
51        pub struct TestBatcher;
52
53        #[cfg(test)]
54        impl<I> Batcher<TestBackend, I, (Vec<I>, TestDevice)> for TestBatcher {
55            fn batch(&self, items: Vec<I>, device: &TestDevice) -> (Vec<I>, TestDevice) {
56                (items, *device)
57            }
58        }
59
60        let batcher = Arc::new(TestBatcher::new());
61        let dataset = Arc::new(FakeDataset::<String>::new(11));
62
63        #[allow(clippy::arc_with_non_send_sync)]
64        let dataloader = Arc::new(BatchDataLoader::new(
65            Box::new(FixBatchStrategy::new(5)),
66            dataset.clone(),
67            batcher,
68            Default::default(),
69            None,
70        ));
71
72        #[cfg(all(
73            test,
74            not(feature = "test-tch"),
75            not(feature = "test-wgpu"),
76            not(feature = "test-cuda")
77        ))]
78        // Only one device exists...
79        let (device1, device2) = (burn_flex::FlexDevice, burn_flex::FlexDevice);
80
81        #[cfg(all(test, feature = "test-tch"))]
82        let (device1, device2) = (
83            burn_tch::LibTorchDevice::Cuda(0),
84            burn_tch::LibTorchDevice::Cuda(1),
85        );
86
87        #[cfg(all(test, feature = "test-wgpu"))]
88        let (device1, device2) = (
89            burn_wgpu::WgpuDevice::DiscreteGpu(0),
90            burn_wgpu::WgpuDevice::DiscreteGpu(1),
91        );
92
93        #[cfg(all(test, feature = "test-cuda"))]
94        let (device1, device2) = (burn_cuda::CudaDevice::new(0), burn_cuda::CudaDevice::new(1));
95
96        let dataloaders = split_dataloader(dataloader.clone(), &[device1, device2]);
97
98        assert_eq!(dataloaders.len(), 2);
99
100        let [dataloader_1, dataloader_2] = match dataloaders.try_into() {
101            Ok(arr) => arr,
102            Err(_) => unreachable!(),
103        };
104        assert_eq!(dataloader_1.num_items(), 5);
105        assert_eq!(dataloader_2.num_items(), 6);
106
107        let mut items_dataloader = HashSet::new();
108        let mut items_dataloader_split = HashSet::new();
109
110        for (items, _device) in dataloader.iter() {
111            for item in items {
112                items_dataloader.insert(item);
113            }
114        }
115
116        for (items, device) in dataloader_1.iter() {
117            assert_eq!(device, device1);
118            for item in items {
119                items_dataloader_split.insert(item);
120            }
121        }
122
123        for (items, device) in dataloader_2.iter() {
124            assert_eq!(device, device2);
125            for item in items {
126                items_dataloader_split.insert(item);
127            }
128        }
129
130        assert_eq!(items_dataloader, items_dataloader_split);
131    }
132}