Skip to main content

burn_dispatch/
backend.rs

1use alloc::boxed::Box;
2use alloc::format;
3use alloc::string::String;
4
5use burn_backend::quantization::QuantScheme;
6use burn_backend::{Backend, DType, ExecutionError, QTensorPrimitive};
7
8#[cfg(feature = "autodiff")]
9use burn_autodiff::grads::Gradients;
10#[cfg(feature = "autodiff")]
11use burn_backend::AutodiffBackend;
12
13#[allow(unused)]
14use crate::BackendId;
15use crate::DispatchTensorKind;
16use crate::backends::*;
17use crate::{DispatchDevice, DispatchTensor};
18
19/// The main execution backend in Burn.
20///
21/// [`Dispatch`] acts as a global backend that can manage multiple underlying
22/// backends (e.g., `Cpu`, `Cuda`, `Wgpu`, `Metal`, etc.).  
23/// It is responsible for:
24/// - Dispatching tensor operations to the appropriate backend.
25/// - Managing cross-backend tensor transfers.
26///
27/// Essentially, [`Dispatch`] is the single entry point for executing tensor operations
28/// in a backend-agnostic way. It allows Burn to provide a unified, global backend
29/// for users while still leveraging multiple specialized backends under the hood.
30///
31/// # Example
32///
33/// ```ignore
34/// use burn::Dispatch;
35/// use burn::DispatchDevice;
36///
37/// // Select the device to execute operations on
38/// let device = DispatchDevice::Cuda(Default::default());
39///
40/// // Create a tensor using the global backend
41/// let t = Tensor::<Dispatch, 2>::zeros([128, 128], &device);
42/// ```
43#[derive(Debug, Default, Clone)]
44pub struct Dispatch;
45
46impl Backend for Dispatch {
47    type Device = DispatchDevice;
48
49    type FloatTensorPrimitive = DispatchTensor;
50
51    // TODO: either allow default dtype generic or remove associated types entirely?
52    type FloatElem = f32;
53
54    type IntTensorPrimitive = DispatchTensor;
55
56    type IntElem = i32;
57
58    type BoolTensorPrimitive = DispatchTensor;
59
60    type BoolElem = u8;
61
62    type QuantizedTensorPrimitive = DispatchTensor;
63
64    fn name(device: &Self::Device) -> String {
65        let inner = dispatch_device!(device, |device| B::name(device));
66        format!("dispatch<{inner}>")
67    }
68
69    fn seed(device: &Self::Device, seed: u64) {
70        dispatch_device!(device, |device| B::seed(device, seed))
71    }
72
73    fn sync(device: &Self::Device) -> Result<(), ExecutionError> {
74        dispatch_device!(device, |device| B::sync(device))
75    }
76
77    fn dtype_usage(device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet {
78        dispatch_device!(device, |device| B::dtype_usage(device, dtype))
79    }
80
81    fn ad_enabled(device: &Self::Device) -> bool {
82        match device {
83            #[cfg(feature = "autodiff")]
84            DispatchDevice::Autodiff(_) => true,
85            _ => false,
86        }
87    }
88
89    fn device_count(type_id: u16) -> usize {
90        let (dispatch_id, backend_type_id) = DispatchDevice::decode_type_id(type_id);
91        match dispatch_id {
92            #[cfg(feature = "cpu")]
93            BackendId::Cpu => Cpu::<f32>::device_count(backend_type_id),
94            #[cfg(feature = "cuda")]
95            BackendId::Cuda => Cuda::<f32>::device_count(backend_type_id),
96            #[cfg(wgpu_metal)]
97            BackendId::Metal => Metal::<f32>::device_count(backend_type_id),
98            #[cfg(feature = "rocm")]
99            BackendId::Rocm => Rocm::<f32>::device_count(backend_type_id),
100            #[cfg(wgpu_vulkan)]
101            BackendId::Vulkan => Vulkan::<f32>::device_count(backend_type_id),
102            #[cfg(wgpu_webgpu)]
103            BackendId::Wgpu => Wgpu::<f32>::device_count(backend_type_id),
104            #[cfg(feature = "ndarray")]
105            BackendId::NdArray => NdArray::<f32>::device_count(backend_type_id),
106            #[cfg(feature = "tch")]
107            BackendId::LibTorch => LibTorch::<f32>::device_count(backend_type_id),
108        }
109    }
110}
111
112#[cfg(feature = "autodiff")]
113impl AutodiffBackend for Dispatch {
114    type InnerBackend = Dispatch;
115
116    type Gradients = Gradients;
117
118    fn backward(tensor: DispatchTensor) -> Self::Gradients {
119        let DispatchTensor { kind, .. } = tensor;
120        match kind {
121            #[cfg(feature = "autodiff")]
122            DispatchTensorKind::Autodiff(tensor) => match *tensor {
123                #[cfg(feature = "cpu")]
124                DispatchTensorKind::Cpu(tensor) => tensor.autodiff().backward(),
125                #[cfg(feature = "cuda")]
126                DispatchTensorKind::Cuda(tensor) => tensor.autodiff().backward(),
127                #[cfg(wgpu_metal)]
128                DispatchTensorKind::Metal(tensor) => tensor.autodiff().backward(),
129                #[cfg(feature = "rocm")]
130                DispatchTensorKind::Rocm(tensor) => tensor.autodiff().backward(),
131                #[cfg(wgpu_vulkan)]
132                DispatchTensorKind::Vulkan(tensor) => tensor.autodiff().backward(),
133                #[cfg(wgpu_webgpu)]
134                DispatchTensorKind::Wgpu(tensor) => tensor.autodiff().backward(),
135                #[cfg(feature = "ndarray")]
136                DispatchTensorKind::NdArray(tensor) => tensor.autodiff().backward(),
137                #[cfg(feature = "tch")]
138                DispatchTensorKind::LibTorch(tensor) => tensor.autodiff().backward(),
139                DispatchTensorKind::Autodiff(_) => {
140                    panic!("Autodiff should not wrap an autodiff tensor.")
141                }
142            },
143            _ => panic!("Requires autodiff tensor."),
144        }
145    }
146
147    fn grad(tensor: &DispatchTensor, grads: &Self::Gradients) -> Option<DispatchTensor> {
148        let DispatchTensor {
149            kind,
150            checkpointing,
151        } = tensor;
152        let grad = match &kind {
153            #[cfg(feature = "autodiff")]
154            DispatchTensorKind::Autodiff(inner_kind) => match &**inner_kind {
155                #[cfg(feature = "cpu")]
156                DispatchTensorKind::Cpu(tensor) => tensor
157                    .as_autodiff()
158                    .grad(grads)
159                    .map(|t| DispatchTensorKind::Cpu(crate::BackendTensor::Float(t))),
160                #[cfg(feature = "cuda")]
161                DispatchTensorKind::Cuda(tensor) => tensor
162                    .as_autodiff()
163                    .grad(grads)
164                    .map(|t| DispatchTensorKind::Cuda(crate::BackendTensor::Float(t))),
165                #[cfg(wgpu_metal)]
166                DispatchTensorKind::Metal(tensor) => tensor
167                    .as_autodiff()
168                    .grad(grads)
169                    .map(|t| DispatchTensorKind::Metal(crate::BackendTensor::Float(t))),
170                #[cfg(feature = "rocm")]
171                DispatchTensorKind::Rocm(tensor) => tensor
172                    .as_autodiff()
173                    .grad(grads)
174                    .map(|t| DispatchTensorKind::Rocm(crate::BackendTensor::Float(t))),
175                #[cfg(wgpu_vulkan)]
176                DispatchTensorKind::Vulkan(tensor) => tensor
177                    .as_autodiff()
178                    .grad(grads)
179                    .map(|t| DispatchTensorKind::Vulkan(crate::BackendTensor::Float(t))),
180                #[cfg(wgpu_webgpu)]
181                DispatchTensorKind::Wgpu(tensor) => tensor
182                    .as_autodiff()
183                    .grad(grads)
184                    .map(|t| DispatchTensorKind::Wgpu(crate::BackendTensor::Float(t))),
185                #[cfg(feature = "ndarray")]
186                DispatchTensorKind::NdArray(tensor) => tensor
187                    .as_autodiff()
188                    .grad(grads)
189                    .map(|t| DispatchTensorKind::NdArray(crate::BackendTensor::Float(t))),
190                #[cfg(feature = "tch")]
191                DispatchTensorKind::LibTorch(tensor) => tensor
192                    .as_autodiff()
193                    .grad(grads)
194                    .map(|t| DispatchTensorKind::LibTorch(crate::BackendTensor::Float(t))),
195                DispatchTensorKind::Autodiff(_) => {
196                    panic!("Autodiff should not wrap an autodiff tensor.")
197                }
198            },
199            _ => panic!("Requires autodiff tensor."),
200        };
201        grad.map(|kind| DispatchTensor {
202            kind,
203            checkpointing: *checkpointing,
204        })
205    }
206
207    fn grad_remove(tensor: &DispatchTensor, grads: &mut Self::Gradients) -> Option<DispatchTensor> {
208        let DispatchTensor {
209            kind,
210            checkpointing,
211        } = tensor;
212        let grad = match &kind {
213            #[cfg(feature = "autodiff")]
214            DispatchTensorKind::Autodiff(inner_kind) => match &**inner_kind {
215                #[cfg(feature = "cpu")]
216                DispatchTensorKind::Cpu(tensor) => tensor
217                    .as_autodiff()
218                    .grad_remove(grads)
219                    .map(|t| DispatchTensorKind::Cpu(crate::BackendTensor::Float(t))),
220                #[cfg(feature = "cuda")]
221                DispatchTensorKind::Cuda(tensor) => tensor
222                    .as_autodiff()
223                    .grad_remove(grads)
224                    .map(|t| DispatchTensorKind::Cuda(crate::BackendTensor::Float(t))),
225                #[cfg(wgpu_metal)]
226                DispatchTensorKind::Metal(tensor) => tensor
227                    .as_autodiff()
228                    .grad_remove(grads)
229                    .map(|t| DispatchTensorKind::Metal(crate::BackendTensor::Float(t))),
230                #[cfg(feature = "rocm")]
231                DispatchTensorKind::Rocm(tensor) => tensor
232                    .as_autodiff()
233                    .grad_remove(grads)
234                    .map(|t| DispatchTensorKind::Rocm(crate::BackendTensor::Float(t))),
235                #[cfg(wgpu_vulkan)]
236                DispatchTensorKind::Vulkan(tensor) => tensor
237                    .as_autodiff()
238                    .grad_remove(grads)
239                    .map(|t| DispatchTensorKind::Vulkan(crate::BackendTensor::Float(t))),
240                #[cfg(wgpu_webgpu)]
241                DispatchTensorKind::Wgpu(tensor) => tensor
242                    .as_autodiff()
243                    .grad_remove(grads)
244                    .map(|t| DispatchTensorKind::Wgpu(crate::BackendTensor::Float(t))),
245                #[cfg(feature = "ndarray")]
246                DispatchTensorKind::NdArray(tensor) => tensor
247                    .as_autodiff()
248                    .grad_remove(grads)
249                    .map(|t| DispatchTensorKind::NdArray(crate::BackendTensor::Float(t))),
250                #[cfg(feature = "tch")]
251                DispatchTensorKind::LibTorch(tensor) => tensor
252                    .as_autodiff()
253                    .grad_remove(grads)
254                    .map(|t| DispatchTensorKind::LibTorch(crate::BackendTensor::Float(t))),
255                DispatchTensorKind::Autodiff(_) => {
256                    panic!("Autodiff should not wrap an autodiff tensor.")
257                }
258            },
259            _ => panic!("Requires autodiff tensor."),
260        };
261        grad.map(|kind| DispatchTensor {
262            kind,
263            checkpointing: *checkpointing,
264        })
265    }
266
267    fn grad_replace(tensor: &DispatchTensor, grads: &mut Self::Gradients, grad: DispatchTensor) {
268        let DispatchTensor {
269            kind,
270            checkpointing,
271        } = tensor;
272        let DispatchTensor {
273            kind: grad,
274            checkpointing: grad_ckp,
275        } = grad;
276        debug_assert_eq!(checkpointing, &grad_ckp);
277
278        match &kind {
279            #[cfg(feature = "autodiff")]
280            DispatchTensorKind::Autodiff(inner_kind) => match (&**inner_kind, grad) {
281                #[cfg(feature = "cpu")]
282                (DispatchTensorKind::Cpu(tensor), DispatchTensorKind::Cpu(grad)) => {
283                    tensor.as_autodiff().grad_replace(grads, grad.float())
284                }
285                #[cfg(feature = "cuda")]
286                (DispatchTensorKind::Cuda(tensor), DispatchTensorKind::Cuda(grad)) => {
287                    tensor.as_autodiff().grad_replace(grads, grad.float())
288                }
289                #[cfg(wgpu_metal)]
290                (DispatchTensorKind::Metal(tensor), DispatchTensorKind::Metal(grad)) => {
291                    tensor.as_autodiff().grad_replace(grads, grad.float())
292                }
293                #[cfg(feature = "rocm")]
294                (DispatchTensorKind::Rocm(tensor), DispatchTensorKind::Rocm(grad)) => {
295                    tensor.as_autodiff().grad_replace(grads, grad.float())
296                }
297                #[cfg(wgpu_vulkan)]
298                (DispatchTensorKind::Vulkan(tensor), DispatchTensorKind::Vulkan(grad)) => {
299                    tensor.as_autodiff().grad_replace(grads, grad.float())
300                }
301                #[cfg(wgpu_webgpu)]
302                (DispatchTensorKind::Wgpu(tensor), DispatchTensorKind::Wgpu(grad)) => {
303                    tensor.as_autodiff().grad_replace(grads, grad.float())
304                }
305                #[cfg(feature = "ndarray")]
306                (DispatchTensorKind::NdArray(tensor), DispatchTensorKind::NdArray(grad)) => {
307                    tensor.as_autodiff().grad_replace(grads, grad.float())
308                }
309                (DispatchTensorKind::Autodiff(_), _) => {
310                    panic!("Autodiff should not wrap an autodiff tensor.")
311                }
312                (t, g) => panic!(
313                    "The provided tensors are not on the same backend. Got backends {t:?} and {g:?}."
314                ),
315            },
316            _ => panic!("Requires autodiff tensor."),
317        }
318    }
319
320    fn inner(tensor: DispatchTensor) -> DispatchTensor {
321        let DispatchTensor {
322            kind,
323            checkpointing,
324        } = tensor;
325
326        let kind = match kind {
327            #[cfg(feature = "autodiff")]
328            DispatchTensorKind::Autodiff(inner_kind) => match *inner_kind {
329                #[cfg(feature = "cpu")]
330                DispatchTensorKind::Cpu(tensor) => DispatchTensorKind::Cpu(
331                    crate::BackendTensor::Float(tensor.autodiff().primitive),
332                ),
333                #[cfg(feature = "cuda")]
334                DispatchTensorKind::Cuda(tensor) => DispatchTensorKind::Cuda(
335                    crate::BackendTensor::Float(tensor.autodiff().primitive),
336                ),
337                #[cfg(wgpu_metal)]
338                DispatchTensorKind::Metal(tensor) => DispatchTensorKind::Metal(
339                    crate::BackendTensor::Float(tensor.autodiff().primitive),
340                ),
341                #[cfg(feature = "rocm")]
342                DispatchTensorKind::Rocm(tensor) => DispatchTensorKind::Rocm(
343                    crate::BackendTensor::Float(tensor.autodiff().primitive),
344                ),
345                #[cfg(wgpu_vulkan)]
346                DispatchTensorKind::Vulkan(tensor) => DispatchTensorKind::Vulkan(
347                    crate::BackendTensor::Float(tensor.autodiff().primitive),
348                ),
349                #[cfg(wgpu_webgpu)]
350                DispatchTensorKind::Wgpu(tensor) => DispatchTensorKind::Wgpu(
351                    crate::BackendTensor::Float(tensor.autodiff().primitive),
352                ),
353                #[cfg(feature = "ndarray")]
354                DispatchTensorKind::NdArray(tensor) => DispatchTensorKind::NdArray(
355                    crate::BackendTensor::Float(tensor.autodiff().primitive),
356                ),
357                #[cfg(feature = "tch")]
358                DispatchTensorKind::LibTorch(tensor) => DispatchTensorKind::LibTorch(
359                    crate::BackendTensor::Float(tensor.autodiff().primitive),
360                ),
361                DispatchTensorKind::Autodiff(_) => {
362                    panic!("Autodiff should not wrap an autodiff tensor.")
363                }
364            },
365            _ => panic!("Requires autodiff tensor."),
366        };
367        DispatchTensor {
368            kind,
369            checkpointing,
370        }
371    }
372
373    fn int_inner(tensor: DispatchTensor) -> DispatchTensor {
374        tensor
375    }
376
377    fn bool_inner(tensor: DispatchTensor) -> DispatchTensor {
378        tensor
379    }
380
381    fn q_inner(tensor: DispatchTensor) -> DispatchTensor {
382        tensor
383    }
384
385    fn from_inner(tensor: DispatchTensor) -> DispatchTensor {
386        let DispatchTensor {
387            kind,
388            checkpointing,
389        } = tensor;
390
391        let kind = match kind {
392            #[cfg(feature = "cpu")]
393            DispatchTensorKind::Cpu(tensor) => DispatchTensorKind::Autodiff(Box::new(
394                DispatchTensorKind::Cpu(crate::BackendTensor::Autodiff(
395                    Autodiff::<Cpu<f32>>::from_inner(tensor.float()),
396                )),
397            )),
398            #[cfg(feature = "cuda")]
399            DispatchTensorKind::Cuda(tensor) => DispatchTensorKind::Autodiff(Box::new(
400                DispatchTensorKind::Cuda(crate::BackendTensor::Autodiff(
401                    Autodiff::<Cuda<f32>>::from_inner(tensor.float()),
402                )),
403            )),
404            #[cfg(wgpu_metal)]
405            DispatchTensorKind::Metal(tensor) => DispatchTensorKind::Autodiff(Box::new(
406                DispatchTensorKind::Metal(crate::BackendTensor::Autodiff(
407                    Autodiff::<Metal<f32>>::from_inner(tensor.float()),
408                )),
409            )),
410            #[cfg(feature = "rocm")]
411            DispatchTensorKind::Rocm(tensor) => DispatchTensorKind::Autodiff(Box::new(
412                DispatchTensorKind::Rocm(crate::BackendTensor::Autodiff(
413                    Autodiff::<Rocm<f32>>::from_inner(tensor.float()),
414                )),
415            )),
416            #[cfg(wgpu_vulkan)]
417            DispatchTensorKind::Vulkan(tensor) => DispatchTensorKind::Autodiff(Box::new(
418                DispatchTensorKind::Vulkan(crate::BackendTensor::Autodiff(
419                    Autodiff::<Vulkan<f32>>::from_inner(tensor.float()),
420                )),
421            )),
422            #[cfg(wgpu_webgpu)]
423            DispatchTensorKind::Wgpu(tensor) => DispatchTensorKind::Autodiff(Box::new(
424                DispatchTensorKind::Wgpu(crate::BackendTensor::Autodiff(
425                    Autodiff::<Wgpu<f32>>::from_inner(tensor.float()),
426                )),
427            )),
428            #[cfg(feature = "ndarray")]
429            DispatchTensorKind::NdArray(tensor) => DispatchTensorKind::Autodiff(Box::new(
430                DispatchTensorKind::NdArray(crate::BackendTensor::Autodiff(
431                    Autodiff::<NdArray<f32>>::from_inner(tensor.float()),
432                )),
433            )),
434            #[cfg(feature = "tch")]
435            DispatchTensorKind::LibTorch(tensor) => {
436                DispatchTensorKind::Autodiff(Box::new(DispatchTensorKind::LibTorch(
437                    crate::BackendTensor::Autodiff(Autodiff::<LibTorch<f32>>::from_inner(
438                        tensor.float(),
439                    )),
440                )))
441            }
442            DispatchTensorKind::Autodiff(_) => {
443                panic!("Autodiff should not wrap an autodiff tensor.")
444            }
445        };
446        DispatchTensor {
447            kind,
448            checkpointing,
449        }
450    }
451
452    fn int_from_inner(tensor: DispatchTensor) -> DispatchTensor {
453        tensor
454    }
455
456    fn bool_from_inner(tensor: DispatchTensor) -> DispatchTensor {
457        tensor
458    }
459
460    fn q_from_inner(tensor: DispatchTensor) -> DispatchTensor {
461        tensor
462    }
463}
464
465impl DispatchTensorKind {
466    pub(crate) fn device(&self) -> DispatchDevice {
467        match self {
468            #[cfg(feature = "cpu")]
469            DispatchTensorKind::Cpu(tensor) => DispatchDevice::Cpu(tensor.device()),
470            #[cfg(feature = "cuda")]
471            DispatchTensorKind::Cuda(tensor) => DispatchDevice::Cuda(tensor.device()),
472            #[cfg(wgpu_metal)]
473            DispatchTensorKind::Metal(tensor) => DispatchDevice::Metal(tensor.device()),
474            #[cfg(feature = "rocm")]
475            DispatchTensorKind::Rocm(tensor) => DispatchDevice::Rocm(tensor.device()),
476            #[cfg(wgpu_vulkan)]
477            DispatchTensorKind::Vulkan(tensor) => DispatchDevice::Vulkan(tensor.device()),
478            #[cfg(wgpu_webgpu)]
479            DispatchTensorKind::Wgpu(tensor) => DispatchDevice::Wgpu(tensor.device()),
480            #[cfg(feature = "ndarray")]
481            DispatchTensorKind::NdArray(tensor) => DispatchDevice::NdArray(tensor.device()),
482            #[cfg(feature = "tch")]
483            DispatchTensorKind::LibTorch(tensor) => DispatchDevice::LibTorch(tensor.device()),
484            #[cfg(feature = "autodiff")]
485            DispatchTensorKind::Autodiff(tensor) => DispatchDevice::autodiff(tensor.device()),
486        }
487    }
488}
489
490impl DispatchTensor {
491    pub(crate) fn device(&self) -> DispatchDevice {
492        #[allow(unused_mut)]
493        let mut device = self.kind.device();
494
495        #[cfg(feature = "autodiff")]
496        if let DispatchDevice::Autodiff(device) = &mut device {
497            device.checkpointing = self.checkpointing;
498        }
499
500        device
501    }
502}
503
504impl Dispatch {
505    /// Returns the default tensor quantization scheme for the device.
506    // TODO: replace this + QTensorPrimitive trait method with better API.
507    // This is temporary, for test purposes.
508    pub fn default_quant_scheme(device: &<Self as Backend>::Device) -> QuantScheme {
509        match device {
510            #[cfg(feature = "cpu")]
511            DispatchDevice::Cpu(_) => {
512                <<Cpu as Backend>::QuantizedTensorPrimitive as QTensorPrimitive>::default_scheme()
513            }
514            #[cfg(feature = "cuda")]
515            DispatchDevice::Cuda(_) => {
516                <<Cuda as Backend>::QuantizedTensorPrimitive as QTensorPrimitive>::default_scheme()
517            }
518            #[cfg(wgpu_metal)]
519            DispatchDevice::Metal(_) => {
520                <<Metal as Backend>::QuantizedTensorPrimitive as QTensorPrimitive>::default_scheme()
521            }
522            #[cfg(feature = "rocm")]
523            DispatchDevice::Rocm(_) => {
524                <<Rocm as Backend>::QuantizedTensorPrimitive as QTensorPrimitive>::default_scheme()
525            }
526            #[cfg(wgpu_vulkan)]
527            DispatchDevice::Vulkan(_) => {
528                <<Vulkan as Backend>::QuantizedTensorPrimitive as QTensorPrimitive>::default_scheme()
529            }
530            #[cfg(wgpu_webgpu)]
531            DispatchDevice::Wgpu(_) => {
532                <<Wgpu as Backend>::QuantizedTensorPrimitive as QTensorPrimitive>::default_scheme()
533            }
534            #[cfg(feature = "ndarray")]
535            DispatchDevice::NdArray(_) => {
536                <<NdArray as Backend>::QuantizedTensorPrimitive as QTensorPrimitive>::default_scheme()
537            }
538            #[cfg(feature = "tch")]
539            DispatchDevice::LibTorch(_) => {
540                <<LibTorch as Backend>::QuantizedTensorPrimitive as QTensorPrimitive>::default_scheme()
541            }
542            #[cfg(feature = "autodiff")]
543            DispatchDevice::Autodiff(ad_device) => Self::default_quant_scheme(&ad_device.inner),
544        }
545    }
546}