burn_core/data/dataloader/
split.rs1use std::sync::Arc;
2
3use burn_tensor::backend::Backend;
4
5use super::DataLoader;
6
7pub fn split_dataloader<B: Backend, O>(
9 dataloader: Arc<dyn DataLoader<B, O>>,
10 devices: &[B::Device],
11) -> Vec<Arc<dyn DataLoader<B, O>>> {
12 let num_splits = devices.len();
13 if num_splits > 1 {
14 let num_items = dataloader.num_items();
15 let mut dataloaders = Vec::with_capacity(num_splits);
16
17 let mut start = 0;
18 let step = num_items / num_splits;
19 for (i, device) in devices.iter().enumerate() {
20 let end = if i == (num_splits - 1) {
21 num_items
22 } else {
23 start + step
24 };
25 let dataloader = dataloader.slice(start, end).to_device(device);
26 dataloaders.push(dataloader);
27 start = end;
28 }
29 dataloaders
30 } else {
31 vec![dataloader]
32 }
33}
34
35#[cfg(test)]
36mod tests {
37 use std::collections::HashSet;
38
39 use super::*;
40 use crate::TestBackend;
41 use crate::data::dataloader::batcher::Batcher;
42 use crate::data::dataloader::{BatchDataLoader, FixBatchStrategy};
43 use crate::data::dataset::FakeDataset;
44
45 #[test]
46 fn test_split_batch_dataloader() {
47 type TestDevice = <TestBackend as Backend>::Device;
48
49 #[derive(new, Clone)]
50 pub struct TestBatcher;
51
52 #[cfg(test)]
53 impl<I> Batcher<TestBackend, I, (Vec<I>, TestDevice)> for TestBatcher {
54 fn batch(&self, items: Vec<I>, device: &TestDevice) -> (Vec<I>, TestDevice) {
55 (items, *device)
56 }
57 }
58
59 let batcher = Arc::new(TestBatcher::new());
60 let dataset = Arc::new(FakeDataset::<String>::new(11));
61
62 #[allow(clippy::arc_with_non_send_sync)]
63 let dataloader = Arc::new(BatchDataLoader::new(
64 Box::new(FixBatchStrategy::new(5)),
65 dataset.clone(),
66 batcher,
67 Default::default(),
68 None,
69 ));
70
71 #[cfg(all(
72 test,
73 not(feature = "test-tch"),
74 not(feature = "test-wgpu"),
75 not(feature = "test-cuda")
76 ))]
77 let (device1, device2) = (
79 burn_ndarray::NdArrayDevice::Cpu,
80 burn_ndarray::NdArrayDevice::Cpu,
81 );
82
83 #[cfg(all(test, feature = "test-tch"))]
84 let (device1, device2) = (
85 burn_tch::LibTorchDevice::Cuda(0),
86 burn_tch::LibTorchDevice::Cuda(1),
87 );
88
89 #[cfg(all(test, feature = "test-wgpu"))]
90 let (device1, device2) = (
91 burn_wgpu::WgpuDevice::DiscreteGpu(0),
92 burn_wgpu::WgpuDevice::DiscreteGpu(1),
93 );
94
95 #[cfg(all(test, feature = "test-cuda"))]
96 let (device1, device2) = (burn_cuda::CudaDevice::new(0), burn_cuda::CudaDevice::new(1));
97
98 let dataloaders = split_dataloader(dataloader.clone(), &[device1, device2]);
99
100 assert_eq!(dataloaders.len(), 2);
101
102 let [dataloader_1, dataloader_2] = match dataloaders.try_into() {
103 Ok(arr) => arr,
104 Err(_) => unreachable!(),
105 };
106 assert_eq!(dataloader_1.num_items(), 5);
107 assert_eq!(dataloader_2.num_items(), 6);
108
109 let mut items_dataloader = HashSet::new();
110 let mut items_dataloader_split = HashSet::new();
111
112 for (items, _device) in dataloader.iter() {
113 for item in items {
114 items_dataloader.insert(item);
115 }
116 }
117
118 for (items, device) in dataloader_1.iter() {
119 assert_eq!(device, device1);
120 for item in items {
121 items_dataloader_split.insert(item);
122 }
123 }
124
125 for (items, device) in dataloader_2.iter() {
126 assert_eq!(device, device2);
127 for item in items {
128 items_dataloader_split.insert(item);
129 }
130 }
131
132 assert_eq!(items_dataloader, items_dataloader_split);
133 }
134}