burn_core/data/dataloader/
split.rs1use std::sync::Arc;
2
3use burn_tensor::backend::Backend;
4
5use super::DataLoader;
6
7pub fn split_dataloader<B: Backend, O>(
9 dataloader: Arc<dyn DataLoader<B, O>>,
10 devices: &[B::Device],
11) -> Vec<Arc<dyn DataLoader<B, O>>> {
12 let num_splits = devices.len();
13 if num_splits > 1 {
14 let num_items = dataloader.num_items();
15 let mut dataloaders = Vec::with_capacity(num_splits);
16
17 let mut start = 0;
18 let step = num_items / num_splits;
19 for (i, device) in devices.iter().enumerate() {
20 let end = if i == (num_splits - 1) {
21 num_items
22 } else {
23 start + step
24 };
25 let dataloader = dataloader.slice(start, end).to_device(device);
26 dataloaders.push(dataloader);
27 start = end;
28 }
29 dataloaders
30 } else {
31 vec![dataloader]
32 }
33}
34
35#[cfg(test)]
36mod tests {
37 use burn_tensor::Device;
38 use std::collections::HashSet;
39
40 use super::*;
41 use crate::TestBackend;
42 use crate::data::dataloader::batcher::Batcher;
43 use crate::data::dataloader::{BatchDataLoader, FixBatchStrategy};
44 use crate::data::dataset::FakeDataset;
45
46 #[test]
47 fn test_split_batch_dataloader() {
48 type TestDevice = Device<TestBackend>;
49
50 #[derive(new, Clone)]
51 pub struct TestBatcher;
52
53 #[cfg(test)]
54 impl<I> Batcher<TestBackend, I, (Vec<I>, TestDevice)> for TestBatcher {
55 fn batch(&self, items: Vec<I>, device: &TestDevice) -> (Vec<I>, TestDevice) {
56 (items, *device)
57 }
58 }
59
60 let batcher = Arc::new(TestBatcher::new());
61 let dataset = Arc::new(FakeDataset::<String>::new(11));
62
63 #[allow(clippy::arc_with_non_send_sync)]
64 let dataloader = Arc::new(BatchDataLoader::new(
65 Box::new(FixBatchStrategy::new(5)),
66 dataset.clone(),
67 batcher,
68 Default::default(),
69 None,
70 ));
71
72 #[cfg(all(
73 test,
74 not(feature = "test-tch"),
75 not(feature = "test-wgpu"),
76 not(feature = "test-cuda")
77 ))]
78 let (device1, device2) = (burn_flex::FlexDevice, burn_flex::FlexDevice);
80
81 #[cfg(all(test, feature = "test-tch"))]
82 let (device1, device2) = (
83 burn_tch::LibTorchDevice::Cuda(0),
84 burn_tch::LibTorchDevice::Cuda(1),
85 );
86
87 #[cfg(all(test, feature = "test-wgpu"))]
88 let (device1, device2) = (
89 burn_wgpu::WgpuDevice::DiscreteGpu(0),
90 burn_wgpu::WgpuDevice::DiscreteGpu(1),
91 );
92
93 #[cfg(all(test, feature = "test-cuda"))]
94 let (device1, device2) = (burn_cuda::CudaDevice::new(0), burn_cuda::CudaDevice::new(1));
95
96 let dataloaders = split_dataloader(dataloader.clone(), &[device1, device2]);
97
98 assert_eq!(dataloaders.len(), 2);
99
100 let [dataloader_1, dataloader_2] = match dataloaders.try_into() {
101 Ok(arr) => arr,
102 Err(_) => unreachable!(),
103 };
104 assert_eq!(dataloader_1.num_items(), 5);
105 assert_eq!(dataloader_2.num_items(), 6);
106
107 let mut items_dataloader = HashSet::new();
108 let mut items_dataloader_split = HashSet::new();
109
110 for (items, _device) in dataloader.iter() {
111 for item in items {
112 items_dataloader.insert(item);
113 }
114 }
115
116 for (items, device) in dataloader_1.iter() {
117 assert_eq!(device, device1);
118 for item in items {
119 items_dataloader_split.insert(item);
120 }
121 }
122
123 for (items, device) in dataloader_2.iter() {
124 assert_eq!(device, device2);
125 for item in items {
126 items_dataloader_split.insert(item);
127 }
128 }
129
130 assert_eq!(items_dataloader, items_dataloader_split);
131 }
132}