1use burn_backend::{
2 DType, Shape, TensorData,
3 backend::{Backend, DeviceId, DeviceOps, ExecutionError},
4 try_read_sync,
5};
6use burn_ir::{BackendIr, OperationIr, TensorHandle, TensorId, TensorIr};
7use burn_std::future::DynFut;
8
9use crate::{
10 ByteBridge, DirectChannel, MultiBackendBridge, RouterTensor, Runner, RunnerChannel,
11 RunnerClient,
12};
13
14macro_rules! impl_multi_backend_types {
16 ($module_name:ident, $DefaultBackend:ident, $($OtherBackend:ident),+) => {
18 pub mod $module_name {
27 use super::*;
28
29 pub enum Handle<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+> {
32 #[allow(missing_docs)]
33 $DefaultBackend($DefaultBackend::Handle),
34 $(
35 #[allow(missing_docs)]
36 $OtherBackend($OtherBackend::Handle),
37 )+
38 }
39
40 #[derive(Clone, Debug)]
43 pub enum MultiDevice<$DefaultBackend: Backend, $($OtherBackend: Backend),+> {
44 #[allow(missing_docs)]
45 $DefaultBackend($DefaultBackend::Device),
46 $(
47 #[allow(missing_docs)]
48 $OtherBackend($OtherBackend::Device),
49 )+
50 }
51 impl<$DefaultBackend: Backend, $($OtherBackend: Backend),+> PartialEq for MultiDevice<$DefaultBackend, $($OtherBackend),+> {
52 fn eq(&self, other: &Self) -> bool {
53 match (self, other) {
54 (Self::$DefaultBackend(lhs), Self::$DefaultBackend(rhs)) => lhs == rhs,
55 $(
56 (Self::$OtherBackend(lhs), Self::$OtherBackend(rhs)) => lhs == rhs,
57 )+
58 _ => false,
59 }
60 }
61 }
62
63 impl<$DefaultBackend: Backend, $($OtherBackend: Backend),+> Default for MultiDevice<$DefaultBackend, $($OtherBackend),+> {
65 fn default() -> Self {
66 Self::$DefaultBackend($DefaultBackend::Device::default())
67 }
68 }
69
70 impl<$DefaultBackend: Backend, $($OtherBackend: Backend),+> burn_std::device::Device for MultiDevice<$DefaultBackend, $($OtherBackend),+> {
71 fn from_id(_device_id: DeviceId) -> Self {
72 Default::default()
74 }
75
76 fn to_id(&self) -> DeviceId {
77 match self {
78 Self::$DefaultBackend(device) => device.id(),
79 $(
80 Self::$OtherBackend(device) => device.id(),
81 )+
82 }
83 }
84
85 fn device_count(_type_id: u16) -> usize {
86 1
87 }
88 }
89
90 impl<$DefaultBackend: Backend, $($OtherBackend: Backend),+> DeviceOps for MultiDevice<$DefaultBackend, $($OtherBackend),+> {}
91
92 #[derive(Clone)]
94 pub enum MultiRunnerClient<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+> {
95 #[allow(missing_docs)]
96 $DefaultBackend(Runner<$DefaultBackend>),
97 $(
98 #[allow(missing_docs)]
99 $OtherBackend(Runner<$OtherBackend>),
100 )+
101 }
102
103 impl<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+> RunnerClient for MultiRunnerClient<$DefaultBackend, $($OtherBackend),+>
104 {
105 type Device = MultiDevice<$DefaultBackend, $($OtherBackend),+>;
106
107 fn register_op(&self, op: OperationIr) {
108 match self {
109 Self::$DefaultBackend(runner) => runner.register_op(op),
110 $(
111 Self::$OtherBackend(runner) => runner.register_op(op),
112 )+
113 }
114 }
115
116 fn read_tensor_async(&self, tensor: TensorIr) -> DynFut<Result<TensorData, ExecutionError>> {
117 match self {
118 Self::$DefaultBackend(runner) => runner.read_tensor_async(tensor),
119 $(
120 Self::$OtherBackend(runner) => runner.read_tensor_async(tensor),
121 )+
122 }
123 }
124
125 fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self> {
126 match self {
127 Self::$DefaultBackend(runner) => {
128 let desc = runner.register_tensor_data_desc(data);
129 RouterTensor::new(desc.id, desc.shape, desc.dtype, self.clone())
130 }
131 $(
132 Self::$OtherBackend(runner) => {
133 let desc = runner.register_tensor_data_desc(data);
134 RouterTensor::new(desc.id, desc.shape, desc.dtype, self.clone())
135 }
136 )+
137 }
138 }
139
140 fn device(&self) -> Self::Device {
141 match self {
142 Self::$DefaultBackend(runner) => MultiDevice::$DefaultBackend(runner.device()),
143 $(
144 Self::$OtherBackend(runner) => MultiDevice::$OtherBackend(runner.device()),
145 )+
146 }
147 }
148
149 fn sync(&self) -> Result<(), ExecutionError> {
150 match self {
151 Self::$DefaultBackend(runner) => runner.sync(),
152 $(
153 Self::$OtherBackend(runner) => runner.sync(),
154 )+
155 }
156 }
157
158 fn seed(&self, seed: u64) {
159 match self {
160 Self::$DefaultBackend(runner) => runner.seed(seed),
161 $(
162 Self::$OtherBackend(runner) => runner.seed(seed),
163 )+
164 }
165 }
166
167 fn create_empty_handle(&self) -> TensorId {
168 match self {
169 Self::$DefaultBackend(runner) => runner.create_empty_handle(),
170 $(
171 Self::$OtherBackend(runner) => runner.create_empty_handle(),
172 )+
173 }
174 }
175
176 fn supports_dtype(&self, dtype: burn_std::DType) -> bool {
177 match self {
178 Self::$DefaultBackend(runner) => runner.supports_dtype(dtype),
179 $(
180 Self::$OtherBackend(runner) => runner.supports_dtype(dtype),
181 )+
182 }
183 }
184 }
185
186 impl<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+, Br> RunnerChannel for DirectChannel<($DefaultBackend, $($OtherBackend),+), Br>
187 where
188 Br: MultiBackendBridge<TensorHandle = Handle<$DefaultBackend, $($OtherBackend),+>, Device = MultiDevice<$DefaultBackend, $($OtherBackend),+>>,
189 {
190 type Device = Br::Device;
191
192 type Bridge = Br;
193
194 type FloatElem = $DefaultBackend::FloatElem;
195 type IntElem = $DefaultBackend::IntElem;
196 type BoolElem = $DefaultBackend::BoolElem;
197
198 type Client = MultiRunnerClient<$DefaultBackend, $($OtherBackend),+>;
199
200 fn init_client(device: &Self::Device) -> Self::Client {
201 match device {
202 MultiDevice::$DefaultBackend(device) => MultiRunnerClient::$DefaultBackend(Runner::new(device.clone())),
203 $(
204 MultiDevice::$OtherBackend(device) => MultiRunnerClient::$OtherBackend(Runner::new(device.clone())),
205 )+
206 }
207 }
208
209 fn get_tensor_handle(
210 tensor: &TensorIr,
211 client: &Self::Client,
212 ) -> <Self::Bridge as MultiBackendBridge>::TensorHandle {
213 match client {
214 MultiRunnerClient::$DefaultBackend(runner) => Handle::$DefaultBackend(runner.get_tensor_handle(tensor)),
215 $(
216 MultiRunnerClient::$OtherBackend(runner) => Handle::$OtherBackend(runner.get_tensor_handle(tensor)),
217 )+
218 }
219 }
220
221 fn register_tensor(
222 client: &Self::Client,
223 handle: <Self::Bridge as MultiBackendBridge>::TensorHandle,
224 shape: Shape,
225 dtype: DType,
226 ) -> RouterTensor<Self::Client> {
227 match client {
228 MultiRunnerClient::$DefaultBackend(runner) => match handle {
229 Handle::$DefaultBackend(handle) => runner.register_tensor(handle, shape, dtype, client.clone()),
230 _ => unreachable!("Can't register tensor handle for another backend."),
231 },
232 $(
233 MultiRunnerClient::$OtherBackend(runner) => match handle {
234 Handle::$OtherBackend(handle) => runner.register_tensor(handle, shape, dtype, client.clone()),
235 _ => unreachable!("Can't register tensor handle for another backend."),
236 },
237 )+
238 }
239 }
240
241 fn name(_device: &Self::Device) -> String {
242 let mut name = format!("{}", $DefaultBackend::name(&<$DefaultBackend::Device as Default>::default()));
243 $(
244 name.push_str(&format!(", {}", $OtherBackend::name(&<$OtherBackend::Device as Default>::default())));
245 )+
246 format!("direct<({})>", name)
247 }
248 }
249
250 impl<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+> MultiBackendBridge for ByteBridge<($DefaultBackend, $($OtherBackend),+)> {
251 type TensorHandle = Handle<$DefaultBackend, $($OtherBackend),+>;
252 type Device = MultiDevice<$DefaultBackend, $($OtherBackend),+>;
253
254 fn change_backend_float(
255 tensor: Self::TensorHandle,
256 shape: Shape,
257 target_device: &Self::Device,
258 ) -> Self::TensorHandle {
259 multi_backend_match!(shape, (tensor, target_device) : $DefaultBackend, $($OtherBackend),+)
260 }
261
262 fn change_backend_int(
263 tensor: Self::TensorHandle,
264 shape: Shape,
265 target_device: &Self::Device,
266 ) -> Self::TensorHandle {
267 multi_backend_match!(shape, (tensor, target_device) : $DefaultBackend, $($OtherBackend),+)
268 }
269
270 fn change_backend_bool(
271 tensor: Self::TensorHandle,
272 shape: Shape,
273 target_device: &Self::Device,
274 ) -> Self::TensorHandle {
275 multi_backend_match!(shape, (tensor, target_device) : $DefaultBackend, $($OtherBackend),+)
276 }
277
278 }
279 }
280 };
281}
282
283macro_rules! bridge {
284 ($Backend:ident, $handle:expr, $device:expr, $shape:expr) => {{
285 let tensor = $Backend::float_tensor(TensorHandle {
287 handle: $handle,
288 shape: $shape,
289 });
290 let tensor = $Backend::float_to_device(tensor, $device);
291 let handle = $Backend::float_tensor_handle(tensor);
292 Handle::$Backend(handle)
293 }};
294 ($BackendA:ident, $BackendB:ident, $handle:expr, $device:expr, $shape:expr) => {{
295 let tensor = $BackendA::float_tensor(TensorHandle { handle: $handle, shape: $shape });
297 let data = try_read_sync($BackendA::float_into_data(tensor)).unwrap().expect(
298 "Failed to read tensor data synchronously. This can happen on platforms that don't support blocking futures like WASM."
299 );
300 let tensor = $BackendB::float_from_data(data, $device);
301 let handle = $BackendB::float_tensor_handle(tensor);
302 Handle::$BackendB(handle)
303 }};
304}
305
306macro_rules! multi_backend_match {
307 ($shape:expr, ($handle:expr, $device:expr) : $DefaultBackend:ident, $($OtherBackend:ident),+) => {
308 multi_backend_match! (
309 @step
310 $shape,
311 ($handle, $device);
312 {
313 (Handle::$DefaultBackend(handle), MultiDevice::$DefaultBackend(device)) => bridge!($DefaultBackend, handle, device, $shape),
314 $(
315 (Handle::$DefaultBackend(handle), MultiDevice::$OtherBackend(device)) => bridge!($DefaultBackend, $OtherBackend, handle, device, $shape),
316 (Handle::$OtherBackend(handle), MultiDevice::$DefaultBackend(device)) => bridge!($OtherBackend, $DefaultBackend, handle, device, $shape),
317 (Handle::$OtherBackend(handle), MultiDevice::$OtherBackend(device)) => bridge!($OtherBackend, handle, device, $shape),
318 )+
319 };
320 $($OtherBackend),+
321 )
322 };
323
324 (@step
325 $shape:expr,
326 $pats:tt;
327 { $($arms:tt)* };
328 $BackendA:ident,
329 $($OtherBackend:ident),+
330 ) => {
331 multi_backend_match! (
332 @step
333 $shape,
334 $pats;
335 {
336 $($arms)*
337 $(
338 (Handle::$BackendA(handle), MultiDevice::$OtherBackend(device)) => bridge!($BackendA, $OtherBackend, handle, device, $shape),
339 (Handle::$OtherBackend(handle), MultiDevice::$BackendA(device)) => bridge!($OtherBackend, $BackendA, handle, device, $shape),
340 )*
341 };
342 $($OtherBackend),*
343 )
344 };
345
346 (@step
347 $shape:expr,
348 ($handle:expr, $device:expr);
349 { $($arms:tt)* };
350 $($BackendA:ident)?
351 ) => {
352 match ($handle, $device) {
353 $($arms)*
354 }
355 };
356}
357
358impl_multi_backend_types!(duo, B1, B2);
360impl_multi_backend_types!(trio, B1, B2, B3);
361impl_multi_backend_types!(quad, B1, B2, B3, B4);
362
363#[cfg(not(target_os = "windows"))] #[cfg(test)]
365mod tests {
366 use burn_tensor::{Tensor, backend::Backend};
367
368 use super::*;
369 use crate::tests::{TestBackend, TestBackend1, TestBackend2};
370
371 #[test]
372 fn should_support_dual_byte_bridge() {
373 let device1 = duo::MultiDevice::B1(<TestBackend1 as Backend>::Device::default());
374 let device2 = duo::MultiDevice::B2(<TestBackend2 as Backend>::Device::default());
375 let tensor1 = Tensor::<TestBackend, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device1);
376 let tensor2 = Tensor::<TestBackend, 1>::from_floats([5.0, 6.0, 7.0, 8.0], &device2);
377
378 let tensor1_2 = tensor1.clone().to_device(&device2);
379 tensor1.into_data().assert_eq(&tensor1_2.into_data(), true);
380
381 let tensor2_1 = tensor2.clone().to_device(&device1);
382 tensor2.into_data().assert_eq(&tensor2_1.into_data(), true);
383 }
384}