Skip to main content

burn_dispatch/
backend.rs

1use alloc::format;
2use alloc::string::String;
3
4use burn_backend::Backend;
5use burn_backend::ExecutionError;
6use burn_std::DType;
7
8#[cfg(feature = "autodiff")]
9use burn_autodiff::grads::Gradients;
10#[cfg(feature = "autodiff")]
11use burn_backend::AutodiffBackend;
12
13use crate::backends::*;
14use crate::{DispatchDevice, DispatchTensor};
15
16/// The main execution backend in Burn.
17///
18/// [`Dispatch`] acts as a global backend that can manage multiple underlying
19/// backends (e.g., `Cpu`, `Cuda`, `Wgpu`, `Metal`, etc.).  
20/// It is responsible for:
21/// - Dispatching tensor operations to the appropriate backend.
22/// - Managing cross-backend tensor transfers.
23///
24/// Essentially, [`Dispatch`] is the single entry point for executing tensor operations
25/// in a backend-agnostic way. It allows Burn to provide a unified, global backend
26/// for users while still leveraging multiple specialized backends under the hood.
27///
28/// # Example
29///
30/// ```ignore
31/// use burn::Dispatch;
32/// use burn::DispatchDevice;
33///
34/// // Select the device to execute operations on
35/// let device = DispatchDevice::Cuda(Default::default());
36///
37/// // Create a tensor using the global backend
38/// let t = Tensor::<Dispatch, 2>::zeros([128, 128], &device);
39/// ```
40#[derive(Debug, Default, Clone)]
41pub struct Dispatch;
42
43impl Backend for Dispatch {
44    type Device = DispatchDevice;
45
46    type FloatTensorPrimitive = DispatchTensor;
47
48    // TODO: either allow default dtype generic or remove associated types entirely?
49    type FloatElem = f32;
50
51    type IntTensorPrimitive = DispatchTensor;
52
53    type IntElem = i32;
54
55    type BoolTensorPrimitive = DispatchTensor;
56
57    type BoolElem = u8;
58
59    type QuantizedTensorPrimitive = DispatchTensor;
60
61    fn name(device: &Self::Device) -> String {
62        let inner = dispatch_device!(device, |device| B::name(device));
63        format!("dispatch<{inner}>")
64    }
65
66    fn seed(device: &Self::Device, seed: u64) {
67        dispatch_device!(device, |device| B::seed(device, seed))
68    }
69
70    fn sync(device: &Self::Device) -> Result<(), ExecutionError> {
71        dispatch_device!(device, |device| B::sync(device))
72    }
73
74    fn dtype_usage(device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet {
75        dispatch_device!(device, |device| B::dtype_usage(device, dtype))
76    }
77
78    fn ad_enabled(device: &Self::Device) -> bool {
79        match device {
80            #[cfg(feature = "autodiff")]
81            DispatchDevice::Autodiff(_) => true,
82            _ => false,
83        }
84    }
85}
86
87#[cfg(feature = "autodiff")]
88impl AutodiffBackend for Dispatch {
89    type InnerBackend = Dispatch;
90
91    type Gradients = Gradients;
92
93    fn backward(tensor: DispatchTensor) -> Self::Gradients {
94        match tensor {
95            #[cfg(feature = "autodiff")]
96            DispatchTensor::Autodiff(tensor) => match *tensor {
97                #[cfg(feature = "cpu")]
98                DispatchTensor::Cpu(tensor) => tensor.autodiff().backward(),
99                #[cfg(feature = "cuda")]
100                DispatchTensor::Cuda(tensor) => tensor.autodiff().backward(),
101                #[cfg(wgpu_metal)]
102                DispatchTensor::Metal(tensor) => tensor.autodiff().backward(),
103                #[cfg(feature = "rocm")]
104                DispatchTensor::Rocm(tensor) => tensor.autodiff().backward(),
105                #[cfg(wgpu_vulkan)]
106                DispatchTensor::Vulkan(tensor) => tensor.autodiff().backward(),
107                #[cfg(wgpu_webgpu)]
108                DispatchTensor::WebGpu(tensor) => tensor.autodiff().backward(),
109                #[cfg(feature = "ndarray")]
110                DispatchTensor::NdArray(tensor) => tensor.autodiff().backward(),
111                DispatchTensor::Autodiff(_) => {
112                    panic!("Autodiff should not wrap an autodiff tensor.")
113                }
114            },
115            _ => panic!("Requires autodiff tensor."),
116        }
117    }
118
119    fn grad(tensor: &DispatchTensor, grads: &Self::Gradients) -> Option<DispatchTensor> {
120        match &tensor {
121            #[cfg(feature = "autodiff")]
122            DispatchTensor::Autodiff(tensor) => match &**tensor {
123                #[cfg(feature = "cpu")]
124                DispatchTensor::Cpu(tensor) => tensor
125                    .as_autodiff()
126                    .grad(grads)
127                    .map(|t| DispatchTensor::Cpu(crate::BackendTensor::Float(t))),
128                #[cfg(feature = "cuda")]
129                DispatchTensor::Cuda(tensor) => tensor
130                    .as_autodiff()
131                    .grad(grads)
132                    .map(|t| DispatchTensor::Cuda(crate::BackendTensor::Float(t))),
133                #[cfg(wgpu_metal)]
134                DispatchTensor::Metal(tensor) => tensor
135                    .as_autodiff()
136                    .grad(grads)
137                    .map(|t| DispatchTensor::Metal(crate::BackendTensor::Float(t))),
138                #[cfg(feature = "rocm")]
139                DispatchTensor::Rocm(tensor) => tensor
140                    .as_autodiff()
141                    .grad(grads)
142                    .map(|t| DispatchTensor::Rocm(crate::BackendTensor::Float(t))),
143                #[cfg(wgpu_vulkan)]
144                DispatchTensor::Vulkan(tensor) => tensor
145                    .as_autodiff()
146                    .grad(grads)
147                    .map(|t| DispatchTensor::Vulkan(crate::BackendTensor::Float(t))),
148                #[cfg(wgpu_webgpu)]
149                DispatchTensor::WebGpu(tensor) => tensor
150                    .as_autodiff()
151                    .grad(grads)
152                    .map(|t| DispatchTensor::WebGpu(crate::BackendTensor::Float(t))),
153                #[cfg(feature = "ndarray")]
154                DispatchTensor::NdArray(tensor) => tensor
155                    .as_autodiff()
156                    .grad(grads)
157                    .map(|t| DispatchTensor::NdArray(crate::BackendTensor::Float(t))),
158                DispatchTensor::Autodiff(_) => {
159                    panic!("Autodiff should not wrap an autodiff tensor.")
160                }
161            },
162            _ => panic!("Requires autodiff tensor."),
163        }
164    }
165
166    fn grad_remove(tensor: &DispatchTensor, grads: &mut Self::Gradients) -> Option<DispatchTensor> {
167        match &tensor {
168            #[cfg(feature = "autodiff")]
169            DispatchTensor::Autodiff(tensor) => match &**tensor {
170                #[cfg(feature = "cpu")]
171                DispatchTensor::Cpu(tensor) => tensor
172                    .as_autodiff()
173                    .grad_remove(grads)
174                    .map(|t| DispatchTensor::Cpu(crate::BackendTensor::Float(t))),
175                #[cfg(feature = "cuda")]
176                DispatchTensor::Cuda(tensor) => tensor
177                    .as_autodiff()
178                    .grad_remove(grads)
179                    .map(|t| DispatchTensor::Cuda(crate::BackendTensor::Float(t))),
180                #[cfg(wgpu_metal)]
181                DispatchTensor::Metal(tensor) => tensor
182                    .as_autodiff()
183                    .grad_remove(grads)
184                    .map(|t| DispatchTensor::Metal(crate::BackendTensor::Float(t))),
185                #[cfg(feature = "rocm")]
186                DispatchTensor::Rocm(tensor) => tensor
187                    .as_autodiff()
188                    .grad_remove(grads)
189                    .map(|t| DispatchTensor::Rocm(crate::BackendTensor::Float(t))),
190                #[cfg(wgpu_vulkan)]
191                DispatchTensor::Vulkan(tensor) => tensor
192                    .as_autodiff()
193                    .grad_remove(grads)
194                    .map(|t| DispatchTensor::Vulkan(crate::BackendTensor::Float(t))),
195                #[cfg(wgpu_webgpu)]
196                DispatchTensor::WebGpu(tensor) => tensor
197                    .as_autodiff()
198                    .grad_remove(grads)
199                    .map(|t| DispatchTensor::WebGpu(crate::BackendTensor::Float(t))),
200                #[cfg(feature = "ndarray")]
201                DispatchTensor::NdArray(tensor) => tensor
202                    .as_autodiff()
203                    .grad_remove(grads)
204                    .map(|t| DispatchTensor::NdArray(crate::BackendTensor::Float(t))),
205                DispatchTensor::Autodiff(_) => {
206                    panic!("Autodiff should not wrap an autodiff tensor.")
207                }
208            },
209            _ => panic!("Requires autodiff tensor."),
210        }
211    }
212
213    fn grad_replace(tensor: &DispatchTensor, grads: &mut Self::Gradients, grad: DispatchTensor) {
214        match &tensor {
215            #[cfg(feature = "autodiff")]
216            DispatchTensor::Autodiff(tensor) => match (&**tensor, grad) {
217                #[cfg(feature = "cpu")]
218                (DispatchTensor::Cpu(tensor), DispatchTensor::Cpu(grad)) => {
219                    tensor.as_autodiff().grad_replace(grads, grad.float())
220                }
221                #[cfg(feature = "cuda")]
222                (DispatchTensor::Cuda(tensor), DispatchTensor::Cuda(grad)) => {
223                    tensor.as_autodiff().grad_replace(grads, grad.float())
224                }
225                #[cfg(wgpu_metal)]
226                (DispatchTensor::Metal(tensor), DispatchTensor::Metal(grad)) => {
227                    tensor.as_autodiff().grad_replace(grads, grad.float())
228                }
229                #[cfg(feature = "rocm")]
230                (DispatchTensor::Rocm(tensor), DispatchTensor::Rocm(grad)) => {
231                    tensor.as_autodiff().grad_replace(grads, grad.float())
232                }
233                #[cfg(wgpu_vulkan)]
234                (DispatchTensor::Vulkan(tensor), DispatchTensor::Vulkan(grad)) => {
235                    tensor.as_autodiff().grad_replace(grads, grad.float())
236                }
237                #[cfg(wgpu_webgpu)]
238                (DispatchTensor::WebGpu(tensor), DispatchTensor::WebGpu(grad)) => {
239                    tensor.as_autodiff().grad_replace(grads, grad.float())
240                }
241                #[cfg(feature = "ndarray")]
242                (DispatchTensor::NdArray(tensor), DispatchTensor::NdArray(grad)) => {
243                    tensor.as_autodiff().grad_replace(grads, grad.float())
244                }
245                (DispatchTensor::Autodiff(_), _) => {
246                    panic!("Autodiff should not wrap an autodiff tensor.")
247                }
248                (t, g) => panic!(
249                    "The provided tensors are not on the same backend. Got backends {t:?} and {g:?}."
250                ),
251            },
252            _ => panic!("Requires autodiff tensor."),
253        }
254    }
255
256    fn inner(tensor: DispatchTensor) -> DispatchTensor {
257        match tensor {
258            #[cfg(feature = "autodiff")]
259            DispatchTensor::Autodiff(tensor) => match *tensor {
260                #[cfg(feature = "cpu")]
261                DispatchTensor::Cpu(tensor) => {
262                    DispatchTensor::Cpu(crate::BackendTensor::Float(tensor.autodiff().primitive))
263                }
264                #[cfg(feature = "cuda")]
265                DispatchTensor::Cuda(tensor) => {
266                    DispatchTensor::Cuda(crate::BackendTensor::Float(tensor.autodiff().primitive))
267                }
268                #[cfg(wgpu_metal)]
269                DispatchTensor::Metal(tensor) => {
270                    DispatchTensor::Metal(crate::BackendTensor::Float(tensor.autodiff().primitive))
271                }
272                #[cfg(feature = "rocm")]
273                DispatchTensor::Rocm(tensor) => {
274                    DispatchTensor::Rocm(crate::BackendTensor::Float(tensor.autodiff().primitive))
275                }
276                #[cfg(wgpu_vulkan)]
277                DispatchTensor::Vulkan(tensor) => {
278                    DispatchTensor::Vulkan(crate::BackendTensor::Float(tensor.autodiff().primitive))
279                }
280                #[cfg(wgpu_webgpu)]
281                DispatchTensor::WebGpu(tensor) => {
282                    DispatchTensor::WebGpu(crate::BackendTensor::Float(tensor.autodiff().primitive))
283                }
284                #[cfg(feature = "ndarray")]
285                DispatchTensor::NdArray(tensor) => DispatchTensor::NdArray(
286                    crate::BackendTensor::Float(tensor.autodiff().primitive),
287                ),
288                DispatchTensor::Autodiff(_) => {
289                    panic!("Autodiff should not wrap an autodiff tensor.")
290                }
291            },
292            _ => panic!("Requires autodiff tensor."),
293        }
294    }
295
296    fn int_inner(tensor: DispatchTensor) -> DispatchTensor {
297        tensor
298    }
299
300    fn bool_inner(tensor: DispatchTensor) -> DispatchTensor {
301        tensor
302    }
303
304    fn q_inner(tensor: DispatchTensor) -> DispatchTensor {
305        tensor
306    }
307
308    fn from_inner(tensor: DispatchTensor) -> DispatchTensor {
309        match tensor {
310            #[cfg(feature = "cpu")]
311            DispatchTensor::Cpu(tensor) => DispatchTensor::Autodiff(Box::new(DispatchTensor::Cpu(
312                crate::BackendTensor::Autodiff(Autodiff::<Cpu<f32>>::from_inner(tensor.float())),
313            ))),
314            #[cfg(feature = "cuda")]
315            DispatchTensor::Cuda(tensor) => DispatchTensor::Autodiff(Box::new(
316                DispatchTensor::Cuda(crate::BackendTensor::Autodiff(
317                    Autodiff::<Cuda<f32>>::from_inner(tensor.float()),
318                )),
319            )),
320            #[cfg(wgpu_metal)]
321            DispatchTensor::Metal(tensor) => DispatchTensor::Autodiff(Box::new(
322                DispatchTensor::Metal(crate::BackendTensor::Autodiff(
323                    Autodiff::<Metal<f32>>::from_inner(tensor.float()),
324                )),
325            )),
326            #[cfg(feature = "rocm")]
327            DispatchTensor::Rocm(tensor) => DispatchTensor::Autodiff(Box::new(
328                DispatchTensor::Rocm(crate::BackendTensor::Autodiff(
329                    Autodiff::<Rocm<f32>>::from_inner(tensor.float()),
330                )),
331            )),
332            #[cfg(wgpu_vulkan)]
333            DispatchTensor::Vulkan(tensor) => DispatchTensor::Autodiff(Box::new(
334                DispatchTensor::Vulkan(crate::BackendTensor::Autodiff(
335                    Autodiff::<Vulkan<f32>>::from_inner(tensor.float()),
336                )),
337            )),
338            #[cfg(wgpu_webgpu)]
339            DispatchTensor::WebGpu(tensor) => DispatchTensor::Autodiff(Box::new(
340                DispatchTensor::WebGpu(crate::BackendTensor::Autodiff(
341                    Autodiff::<WebGpu<f32>>::from_inner(tensor.float()),
342                )),
343            )),
344            #[cfg(feature = "ndarray")]
345            DispatchTensor::NdArray(tensor) => DispatchTensor::Autodiff(Box::new(
346                DispatchTensor::NdArray(crate::BackendTensor::Autodiff(
347                    Autodiff::<NdArray<f32>>::from_inner(tensor.float()),
348                )),
349            )),
350            DispatchTensor::Autodiff(_) => {
351                panic!("Autodiff should not wrap an autodiff tensor.")
352            }
353        }
354    }
355
356    fn int_from_inner(tensor: DispatchTensor) -> DispatchTensor {
357        tensor
358    }
359
360    fn bool_from_inner(tensor: DispatchTensor) -> DispatchTensor {
361        tensor
362    }
363
364    fn q_from_inner(tensor: DispatchTensor) -> DispatchTensor {
365        tensor
366    }
367}
368
369impl DispatchTensor {
370    pub(crate) fn device(&self) -> DispatchDevice {
371        match self {
372            #[cfg(feature = "cpu")]
373            DispatchTensor::Cpu(tensor) => DispatchDevice::Cpu(tensor.device()),
374            #[cfg(feature = "cuda")]
375            DispatchTensor::Cuda(tensor) => DispatchDevice::Cuda(tensor.device()),
376            #[cfg(wgpu_metal)]
377            DispatchTensor::Metal(tensor) => DispatchDevice::Metal(tensor.device()),
378            #[cfg(feature = "rocm")]
379            DispatchTensor::Rocm(tensor) => DispatchDevice::Rocm(tensor.device()),
380            #[cfg(wgpu_vulkan)]
381            DispatchTensor::Vulkan(tensor) => DispatchDevice::Vulkan(tensor.device()),
382            #[cfg(wgpu_webgpu)]
383            DispatchTensor::WebGpu(tensor) => DispatchDevice::WebGpu(tensor.device()),
384            #[cfg(feature = "ndarray")]
385            DispatchTensor::NdArray(tensor) => DispatchDevice::NdArray(tensor.device()),
386            #[cfg(feature = "tch")]
387            DispatchTensor::LibTorch(tensor) => DispatchDevice::LibTorch(tensor.device()),
388            #[cfg(feature = "autodiff")]
389            DispatchTensor::Autodiff(tensor) => DispatchDevice::autodiff(tensor.device()),
390        }
391    }
392}