Skip to main content

burn_router/
types.rs

1use burn_backend::{
2    DType, Shape, TensorData,
3    backend::{Backend, DeviceId, DeviceOps, ExecutionError},
4    try_read_sync,
5};
6use burn_ir::{BackendIr, OperationIr, TensorHandle, TensorId, TensorIr};
7use burn_std::future::DynFut;
8
9use crate::{
10    ByteBridge, DirectChannel, MultiBackendBridge, RouterTensor, Runner, RunnerChannel,
11    RunnerClient,
12};
13
14/// Implement multi backend types, with enums having one variant per backend.
15macro_rules! impl_multi_backend_types {
16    // Match the default backend and at least one other backend, with rest being optional
17    ($module_name:ident, $DefaultBackend:ident, $($OtherBackend:ident),+) => {
18        /// Module containing the essential types for multi-backend operations.
19        ///
20        /// - `Handle`: the type used to point to a tensor (defined for all backends).
21        /// - `MultiRunnerClient`: a client for multiple runners (each responsible to execute tensor operations on a given backend).
22        /// - `DirectChannel`: a local channel with direct connection to the backend runner clients.
23        /// - `ByteBridge`: a simple multi-backend bridge that transfers tensors via the underlying [tensor data](burn_backend::TensorData).
24        ///
25        /// Each enum type is defined with backend identifiers as variant names (e.g., `B1` and `B2` for dual backends).
26        pub mod $module_name {
27            use super::*;
28
29            /// The type that can be used to point to a tensor of any kind.
30            /// Each backend has its own variant.
31            pub enum Handle<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+> {
32                #[allow(missing_docs)]
33                $DefaultBackend($DefaultBackend::Handle),
34                $(
35                    #[allow(missing_docs)]
36                    $OtherBackend($OtherBackend::Handle),
37                )+
38            }
39
40            /// The device type used by a backend.
41            /// Each backend has its own variant.
42            #[derive(Clone, Debug)]
43            pub enum MultiDevice<$DefaultBackend: Backend, $($OtherBackend: Backend),+> {
44                #[allow(missing_docs)]
45                $DefaultBackend($DefaultBackend::Device),
46                $(
47                    #[allow(missing_docs)]
48                    $OtherBackend($OtherBackend::Device),
49                )+
50            }
51            impl<$DefaultBackend: Backend, $($OtherBackend: Backend),+> PartialEq for MultiDevice<$DefaultBackend, $($OtherBackend),+> {
52                fn eq(&self, other: &Self) -> bool {
53                    match (self, other) {
54                        (Self::$DefaultBackend(lhs), Self::$DefaultBackend(rhs)) => lhs == rhs,
55                        $(
56                            (Self::$OtherBackend(lhs), Self::$OtherBackend(rhs)) => lhs == rhs,
57                        )+
58                        _ => false,
59                    }
60                }
61            }
62
63            // Default implementation always returns the first backend's device
64            impl<$DefaultBackend: Backend, $($OtherBackend: Backend),+> Default for MultiDevice<$DefaultBackend, $($OtherBackend),+> {
65                fn default() -> Self {
66                    Self::$DefaultBackend($DefaultBackend::Device::default())
67                }
68            }
69
70            impl<$DefaultBackend: Backend, $($OtherBackend: Backend),+> burn_std::device::Device for MultiDevice<$DefaultBackend, $($OtherBackend),+> {
71                fn from_id(_device_id: DeviceId) -> Self {
72                    // TODO: Should be fix with the new router backend.
73                    Default::default()
74                }
75
76                fn to_id(&self) -> DeviceId {
77                    match self {
78                        Self::$DefaultBackend(device) => device.id(),
79                        $(
80                            Self::$OtherBackend(device) => device.id(),
81                        )+
82                    }
83                }
84
85                fn device_count(_type_id: u16) -> usize {
86                    1
87                }
88            }
89
90            impl<$DefaultBackend: Backend, $($OtherBackend: Backend),+> DeviceOps for MultiDevice<$DefaultBackend, $($OtherBackend),+> {}
91
92            /// A local client with multiple runners (each responsible to execute tensor operations on a given backend).
93            #[derive(Clone)]
94            pub enum MultiRunnerClient<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+> {
95                #[allow(missing_docs)]
96                $DefaultBackend(Runner<$DefaultBackend>),
97                $(
98                    #[allow(missing_docs)]
99                    $OtherBackend(Runner<$OtherBackend>),
100                )+
101            }
102
103            impl<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+> RunnerClient for MultiRunnerClient<$DefaultBackend, $($OtherBackend),+>
104            {
105               type Device = MultiDevice<$DefaultBackend, $($OtherBackend),+>;
106
107                fn register_op(&self, op: OperationIr) {
108                    match self {
109                        Self::$DefaultBackend(runner) => runner.register_op(op),
110                        $(
111                            Self::$OtherBackend(runner) => runner.register_op(op),
112                        )+
113                    }
114                }
115
116                fn read_tensor_async(&self, tensor: TensorIr) -> DynFut<Result<TensorData, ExecutionError>> {
117                    match self {
118                        Self::$DefaultBackend(runner) => runner.read_tensor_async(tensor),
119                        $(
120                            Self::$OtherBackend(runner) => runner.read_tensor_async(tensor),
121                        )+
122                    }
123                }
124
125                fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self> {
126                    match self {
127                        Self::$DefaultBackend(runner) => {
128                            let desc = runner.register_tensor_data_desc(data);
129                            RouterTensor::new(desc.id, desc.shape, desc.dtype, self.clone())
130                        }
131                        $(
132                            Self::$OtherBackend(runner) => {
133                                let desc = runner.register_tensor_data_desc(data);
134                                RouterTensor::new(desc.id, desc.shape, desc.dtype, self.clone())
135                            }
136                        )+
137                    }
138                }
139
140                fn device(&self) -> Self::Device {
141                    match self {
142                        Self::$DefaultBackend(runner) => MultiDevice::$DefaultBackend(runner.device()),
143                        $(
144                            Self::$OtherBackend(runner) => MultiDevice::$OtherBackend(runner.device()),
145                        )+
146                    }
147                }
148
149                fn sync(&self) -> Result<(), ExecutionError> {
150                    match self {
151                        Self::$DefaultBackend(runner) => runner.sync(),
152                        $(
153                            Self::$OtherBackend(runner) => runner.sync(),
154                        )+
155                    }
156                }
157
158                fn seed(&self, seed: u64) {
159                    match self {
160                        Self::$DefaultBackend(runner) => runner.seed(seed),
161                        $(
162                            Self::$OtherBackend(runner) => runner.seed(seed),
163                        )+
164                    }
165                }
166
167                fn create_empty_handle(&self) -> TensorId {
168                    match self {
169                        Self::$DefaultBackend(runner) => runner.create_empty_handle(),
170                        $(
171                            Self::$OtherBackend(runner) => runner.create_empty_handle(),
172                        )+
173                    }
174                }
175
176                fn supports_dtype(&self, dtype: burn_std::DType) -> bool {
177                    match self {
178                        Self::$DefaultBackend(runner) => runner.supports_dtype(dtype),
179                        $(
180                            Self::$OtherBackend(runner) => runner.supports_dtype(dtype),
181                        )+
182                    }
183                }
184            }
185
186            impl<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+, Br> RunnerChannel for DirectChannel<($DefaultBackend, $($OtherBackend),+), Br>
187            where
188                Br: MultiBackendBridge<TensorHandle = Handle<$DefaultBackend, $($OtherBackend),+>, Device = MultiDevice<$DefaultBackend, $($OtherBackend),+>>,
189            {
190                type Device = Br::Device;
191
192                type Bridge = Br;
193
194                type FloatElem = $DefaultBackend::FloatElem;
195                type IntElem = $DefaultBackend::IntElem;
196                type BoolElem = $DefaultBackend::BoolElem;
197
198                type Client = MultiRunnerClient<$DefaultBackend, $($OtherBackend),+>;
199
200                fn init_client(device: &Self::Device) -> Self::Client {
201                    match device {
202                        MultiDevice::$DefaultBackend(device) => MultiRunnerClient::$DefaultBackend(Runner::new(device.clone())),
203                        $(
204                            MultiDevice::$OtherBackend(device) => MultiRunnerClient::$OtherBackend(Runner::new(device.clone())),
205                        )+
206                    }
207                }
208
209                fn get_tensor_handle(
210                    tensor: &TensorIr,
211                    client: &Self::Client,
212                ) -> <Self::Bridge as MultiBackendBridge>::TensorHandle {
213                    match client {
214                        MultiRunnerClient::$DefaultBackend(runner) => Handle::$DefaultBackend(runner.get_tensor_handle(tensor)),
215                        $(
216                            MultiRunnerClient::$OtherBackend(runner) => Handle::$OtherBackend(runner.get_tensor_handle(tensor)),
217                        )+
218                    }
219                }
220
221                fn register_tensor(
222                    client: &Self::Client,
223                    handle: <Self::Bridge as MultiBackendBridge>::TensorHandle,
224                    shape: Shape,
225                    dtype: DType,
226                ) -> RouterTensor<Self::Client> {
227                    match client {
228                        MultiRunnerClient::$DefaultBackend(runner) => match handle {
229                            Handle::$DefaultBackend(handle) => runner.register_tensor(handle, shape, dtype, client.clone()),
230                            _ => unreachable!("Can't register tensor handle for another backend."),
231                        },
232                        $(
233                            MultiRunnerClient::$OtherBackend(runner) =>  match handle {
234                                Handle::$OtherBackend(handle) => runner.register_tensor(handle, shape, dtype, client.clone()),
235                                _ => unreachable!("Can't register tensor handle for another backend."),
236                            },
237                        )+
238                    }
239                }
240
241                fn name(_device: &Self::Device) -> String {
242                    let mut name = format!("{}", $DefaultBackend::name(&<$DefaultBackend::Device as Default>::default()));
243                    $(
244                        name.push_str(&format!(", {}", $OtherBackend::name(&<$OtherBackend::Device as Default>::default())));
245                    )+
246                    format!("direct<({})>", name)
247                }
248            }
249
250            impl<$DefaultBackend: BackendIr, $($OtherBackend: BackendIr),+> MultiBackendBridge for ByteBridge<($DefaultBackend, $($OtherBackend),+)> {
251                type TensorHandle = Handle<$DefaultBackend, $($OtherBackend),+>;
252                type Device = MultiDevice<$DefaultBackend, $($OtherBackend),+>;
253
254                fn change_backend_float(
255                    tensor: Self::TensorHandle,
256                    shape: Shape,
257                    target_device: &Self::Device,
258                ) -> Self::TensorHandle {
259                    multi_backend_match!(shape, (tensor, target_device) : $DefaultBackend, $($OtherBackend),+)
260                }
261
262                fn change_backend_int(
263                    tensor: Self::TensorHandle,
264                    shape: Shape,
265                    target_device: &Self::Device,
266                ) -> Self::TensorHandle {
267                    multi_backend_match!(shape, (tensor, target_device) : $DefaultBackend, $($OtherBackend),+)
268                }
269
270                fn change_backend_bool(
271                    tensor: Self::TensorHandle,
272                    shape: Shape,
273                    target_device: &Self::Device,
274                ) -> Self::TensorHandle {
275                    multi_backend_match!(shape, (tensor, target_device) : $DefaultBackend, $($OtherBackend),+)
276                }
277
278            }
279        }
280    };
281}
282
283macro_rules! bridge {
284    ($Backend:ident, $handle:expr, $device:expr, $shape:expr) => {{
285        // Bridge for the same backend
286        let tensor = $Backend::float_tensor(TensorHandle {
287            handle: $handle,
288            shape: $shape,
289        });
290        let tensor = $Backend::float_to_device(tensor, $device);
291        let handle = $Backend::float_tensor_handle(tensor);
292        Handle::$Backend(handle)
293    }};
294    ($BackendA:ident, $BackendB:ident, $handle:expr, $device:expr, $shape:expr) => {{
295        // Byte bridge between two backends
296        let tensor = $BackendA::float_tensor(TensorHandle { handle: $handle, shape: $shape });
297        let data = try_read_sync($BackendA::float_into_data(tensor)).unwrap().expect(
298            "Failed to read tensor data synchronously. This can happen on platforms that don't support blocking futures like WASM."
299        );
300        let tensor = $BackendB::float_from_data(data, $device);
301        let handle = $BackendB::float_tensor_handle(tensor);
302        Handle::$BackendB(handle)
303    }};
304}
305
306macro_rules! multi_backend_match {
307    ($shape:expr, ($handle:expr, $device:expr) : $DefaultBackend:ident, $($OtherBackend:ident),+) => {
308        multi_backend_match! (
309            @step
310            $shape,
311            ($handle, $device);
312            {
313                (Handle::$DefaultBackend(handle), MultiDevice::$DefaultBackend(device)) => bridge!($DefaultBackend, handle, device, $shape),
314                $(
315                    (Handle::$DefaultBackend(handle), MultiDevice::$OtherBackend(device)) => bridge!($DefaultBackend, $OtherBackend, handle, device, $shape),
316                    (Handle::$OtherBackend(handle), MultiDevice::$DefaultBackend(device)) => bridge!($OtherBackend, $DefaultBackend, handle, device, $shape),
317                    (Handle::$OtherBackend(handle), MultiDevice::$OtherBackend(device)) => bridge!($OtherBackend, handle, device, $shape),
318                )+
319            };
320            $($OtherBackend),+
321        )
322    };
323
324    (@step
325        $shape:expr,
326        $pats:tt;
327        { $($arms:tt)* };
328        $BackendA:ident,
329        $($OtherBackend:ident),+
330    ) => {
331        multi_backend_match! (
332            @step
333            $shape,
334            $pats;
335            {
336                $($arms)*
337                $(
338                    (Handle::$BackendA(handle), MultiDevice::$OtherBackend(device)) => bridge!($BackendA, $OtherBackend, handle, device, $shape),
339                    (Handle::$OtherBackend(handle), MultiDevice::$BackendA(device)) => bridge!($OtherBackend, $BackendA, handle, device, $shape),
340                )*
341            };
342            $($OtherBackend),*
343        )
344    };
345
346    (@step
347        $shape:expr,
348        ($handle:expr, $device:expr);
349        { $($arms:tt)* };
350        $($BackendA:ident)?
351    ) => {
352        match ($handle, $device) {
353            $($arms)*
354        }
355    };
356}
357
358// Implement multi-backend types and byte bridge for up to 4 backends
359impl_multi_backend_types!(duo, B1, B2);
360impl_multi_backend_types!(trio, B1, B2, B3);
361impl_multi_backend_types!(quad, B1, B2, B3, B4);
362
363#[cfg(not(target_os = "windows"))] // cannot find a wgpu adapter on windows CI
364#[cfg(test)]
365mod tests {
366    use burn_tensor::{Tensor, backend::Backend};
367
368    use super::*;
369    use crate::tests::{TestBackend, TestBackend1, TestBackend2};
370
371    #[test]
372    fn should_support_dual_byte_bridge() {
373        let device1 = duo::MultiDevice::B1(<TestBackend1 as Backend>::Device::default());
374        let device2 = duo::MultiDevice::B2(<TestBackend2 as Backend>::Device::default());
375        let tensor1 = Tensor::<TestBackend, 1>::from_floats([1.0, 2.0, 3.0, 4.0], &device1);
376        let tensor2 = Tensor::<TestBackend, 1>::from_floats([5.0, 6.0, 7.0, 8.0], &device2);
377
378        let tensor1_2 = tensor1.clone().to_device(&device2);
379        tensor1.into_data().assert_eq(&tensor1_2.into_data(), true);
380
381        let tensor2_1 = tensor2.clone().to_device(&device1);
382        tensor2.into_data().assert_eq(&tensor2_1.into_data(), true);
383    }
384}