1use alloc::boxed::Box;
2use alloc::format;
3use alloc::string::String;
4
5use burn_backend::quantization::QuantScheme;
6use burn_backend::tensor::{Device, QuantizedTensor};
7use burn_backend::{Backend, BackendTypes, DType, ExecutionError, QTensorPrimitive};
8
9#[cfg(feature = "autodiff")]
10use burn_autodiff::grads::Gradients;
11#[cfg(feature = "autodiff")]
12use burn_backend::AutodiffBackend;
13
14#[allow(unused)]
15use crate::BackendId;
16use crate::DispatchTensorKind;
17use crate::backends::*;
18use crate::{DispatchDevice, DispatchTensor};
19
20#[derive(Debug, Default, Clone)]
45pub struct Dispatch;
46
47impl BackendTypes for Dispatch {
48 type Device = DispatchDevice;
49
50 type FloatTensorPrimitive = DispatchTensor;
51
52 type FloatElem = f32;
54
55 type IntTensorPrimitive = DispatchTensor;
56
57 type IntElem = i32;
58
59 type BoolTensorPrimitive = DispatchTensor;
60
61 type BoolElem = u8;
62
63 type QuantizedTensorPrimitive = DispatchTensor;
64}
65
66impl Backend for Dispatch {
67 fn name(device: &Self::Device) -> String {
68 let inner = dispatch_device!(device, |device| B::name(device));
69 format!("dispatch<{inner}>")
70 }
71
72 fn seed(device: &Self::Device, seed: u64) {
73 dispatch_device!(device, |device| B::seed(device, seed))
74 }
75
76 fn sync(device: &Self::Device) -> Result<(), ExecutionError> {
77 dispatch_device!(device, |device| B::sync(device))
78 }
79
80 fn dtype_usage(device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet {
81 dispatch_device!(device, |device| B::dtype_usage(device, dtype))
82 }
83
84 fn ad_enabled(device: &Self::Device) -> bool {
85 match device {
86 #[cfg(feature = "autodiff")]
87 DispatchDevice::Autodiff(_) => true,
88 _ => false,
89 }
90 }
91
92 fn device_count(type_id: u16) -> usize {
93 let (dispatch_id, backend_type_id) = DispatchDevice::decode_type_id(type_id);
94 match dispatch_id {
95 #[cfg(feature = "cpu")]
96 BackendId::Cpu => Cpu::<f32>::device_count(backend_type_id),
97 #[cfg(feature = "cuda")]
98 BackendId::Cuda => Cuda::<f32>::device_count(backend_type_id),
99 #[cfg(wgpu_metal)]
100 BackendId::Metal => Metal::<f32>::device_count(backend_type_id),
101 #[cfg(feature = "rocm")]
102 BackendId::Rocm => Rocm::<f32>::device_count(backend_type_id),
103 #[cfg(wgpu_vulkan)]
104 BackendId::Vulkan => Vulkan::<f32>::device_count(backend_type_id),
105 #[cfg(wgpu_webgpu)]
106 BackendId::Wgpu => Wgpu::<f32>::device_count(backend_type_id),
107 #[cfg(feature = "flex")]
108 BackendId::Flex => Flex::device_count(backend_type_id),
109 #[cfg(feature = "ndarray")]
110 BackendId::NdArray => NdArray::<f32>::device_count(backend_type_id),
111 #[cfg(feature = "tch")]
112 BackendId::LibTorch => LibTorch::<f32>::device_count(backend_type_id),
113 }
114 }
115}
116
117#[cfg(feature = "autodiff")]
118impl AutodiffBackend for Dispatch {
119 type InnerBackend = Dispatch;
120
121 type Gradients = Gradients;
122
123 fn backward(tensor: DispatchTensor) -> Self::Gradients {
124 let DispatchTensor { kind, .. } = tensor;
125 match kind {
126 #[cfg(feature = "autodiff")]
127 DispatchTensorKind::Autodiff(tensor) => match *tensor {
128 #[cfg(feature = "cpu")]
129 DispatchTensorKind::Cpu(tensor) => tensor.autodiff().backward(),
130 #[cfg(feature = "cuda")]
131 DispatchTensorKind::Cuda(tensor) => tensor.autodiff().backward(),
132 #[cfg(wgpu_metal)]
133 DispatchTensorKind::Metal(tensor) => tensor.autodiff().backward(),
134 #[cfg(feature = "rocm")]
135 DispatchTensorKind::Rocm(tensor) => tensor.autodiff().backward(),
136 #[cfg(wgpu_vulkan)]
137 DispatchTensorKind::Vulkan(tensor) => tensor.autodiff().backward(),
138 #[cfg(wgpu_webgpu)]
139 DispatchTensorKind::Wgpu(tensor) => tensor.autodiff().backward(),
140 #[cfg(feature = "flex")]
141 DispatchTensorKind::Flex(tensor) => tensor.autodiff().backward(),
142 #[cfg(feature = "ndarray")]
143 DispatchTensorKind::NdArray(tensor) => tensor.autodiff().backward(),
144 #[cfg(feature = "tch")]
145 DispatchTensorKind::LibTorch(tensor) => tensor.autodiff().backward(),
146 DispatchTensorKind::Autodiff(_) => {
147 panic!("Autodiff should not wrap an autodiff tensor.")
148 }
149 },
150 _ => panic!("Requires autodiff tensor."),
151 }
152 }
153
154 fn grad(tensor: &DispatchTensor, grads: &Self::Gradients) -> Option<DispatchTensor> {
155 let DispatchTensor {
156 kind,
157 checkpointing,
158 } = tensor;
159 let grad = match &kind {
160 #[cfg(feature = "autodiff")]
161 DispatchTensorKind::Autodiff(inner_kind) => match &**inner_kind {
162 #[cfg(feature = "cpu")]
163 DispatchTensorKind::Cpu(tensor) => tensor
164 .as_autodiff()
165 .grad(grads)
166 .map(|t| DispatchTensorKind::Cpu(crate::BackendTensor::Float(t))),
167 #[cfg(feature = "cuda")]
168 DispatchTensorKind::Cuda(tensor) => tensor
169 .as_autodiff()
170 .grad(grads)
171 .map(|t| DispatchTensorKind::Cuda(crate::BackendTensor::Float(t))),
172 #[cfg(wgpu_metal)]
173 DispatchTensorKind::Metal(tensor) => tensor
174 .as_autodiff()
175 .grad(grads)
176 .map(|t| DispatchTensorKind::Metal(crate::BackendTensor::Float(t))),
177 #[cfg(feature = "rocm")]
178 DispatchTensorKind::Rocm(tensor) => tensor
179 .as_autodiff()
180 .grad(grads)
181 .map(|t| DispatchTensorKind::Rocm(crate::BackendTensor::Float(t))),
182 #[cfg(wgpu_vulkan)]
183 DispatchTensorKind::Vulkan(tensor) => tensor
184 .as_autodiff()
185 .grad(grads)
186 .map(|t| DispatchTensorKind::Vulkan(crate::BackendTensor::Float(t))),
187 #[cfg(wgpu_webgpu)]
188 DispatchTensorKind::Wgpu(tensor) => tensor
189 .as_autodiff()
190 .grad(grads)
191 .map(|t| DispatchTensorKind::Wgpu(crate::BackendTensor::Float(t))),
192 #[cfg(feature = "flex")]
193 DispatchTensorKind::Flex(tensor) => tensor
194 .as_autodiff()
195 .grad(grads)
196 .map(|t| DispatchTensorKind::Flex(crate::BackendTensor::Float(t))),
197 #[cfg(feature = "ndarray")]
198 DispatchTensorKind::NdArray(tensor) => tensor
199 .as_autodiff()
200 .grad(grads)
201 .map(|t| DispatchTensorKind::NdArray(crate::BackendTensor::Float(t))),
202 #[cfg(feature = "tch")]
203 DispatchTensorKind::LibTorch(tensor) => tensor
204 .as_autodiff()
205 .grad(grads)
206 .map(|t| DispatchTensorKind::LibTorch(crate::BackendTensor::Float(t))),
207 DispatchTensorKind::Autodiff(_) => {
208 panic!("Autodiff should not wrap an autodiff tensor.")
209 }
210 },
211 _ => panic!("Requires autodiff tensor."),
212 };
213 grad.map(|kind| DispatchTensor {
214 kind,
215 checkpointing: *checkpointing,
216 })
217 }
218
219 fn grad_remove(tensor: &DispatchTensor, grads: &mut Self::Gradients) -> Option<DispatchTensor> {
220 let DispatchTensor {
221 kind,
222 checkpointing,
223 } = tensor;
224 let grad = match &kind {
225 #[cfg(feature = "autodiff")]
226 DispatchTensorKind::Autodiff(inner_kind) => match &**inner_kind {
227 #[cfg(feature = "cpu")]
228 DispatchTensorKind::Cpu(tensor) => tensor
229 .as_autodiff()
230 .grad_remove(grads)
231 .map(|t| DispatchTensorKind::Cpu(crate::BackendTensor::Float(t))),
232 #[cfg(feature = "cuda")]
233 DispatchTensorKind::Cuda(tensor) => tensor
234 .as_autodiff()
235 .grad_remove(grads)
236 .map(|t| DispatchTensorKind::Cuda(crate::BackendTensor::Float(t))),
237 #[cfg(wgpu_metal)]
238 DispatchTensorKind::Metal(tensor) => tensor
239 .as_autodiff()
240 .grad_remove(grads)
241 .map(|t| DispatchTensorKind::Metal(crate::BackendTensor::Float(t))),
242 #[cfg(feature = "rocm")]
243 DispatchTensorKind::Rocm(tensor) => tensor
244 .as_autodiff()
245 .grad_remove(grads)
246 .map(|t| DispatchTensorKind::Rocm(crate::BackendTensor::Float(t))),
247 #[cfg(wgpu_vulkan)]
248 DispatchTensorKind::Vulkan(tensor) => tensor
249 .as_autodiff()
250 .grad_remove(grads)
251 .map(|t| DispatchTensorKind::Vulkan(crate::BackendTensor::Float(t))),
252 #[cfg(wgpu_webgpu)]
253 DispatchTensorKind::Wgpu(tensor) => tensor
254 .as_autodiff()
255 .grad_remove(grads)
256 .map(|t| DispatchTensorKind::Wgpu(crate::BackendTensor::Float(t))),
257 #[cfg(feature = "flex")]
258 DispatchTensorKind::Flex(tensor) => tensor
259 .as_autodiff()
260 .grad_remove(grads)
261 .map(|t| DispatchTensorKind::Flex(crate::BackendTensor::Float(t))),
262 #[cfg(feature = "ndarray")]
263 DispatchTensorKind::NdArray(tensor) => tensor
264 .as_autodiff()
265 .grad_remove(grads)
266 .map(|t| DispatchTensorKind::NdArray(crate::BackendTensor::Float(t))),
267 #[cfg(feature = "tch")]
268 DispatchTensorKind::LibTorch(tensor) => tensor
269 .as_autodiff()
270 .grad_remove(grads)
271 .map(|t| DispatchTensorKind::LibTorch(crate::BackendTensor::Float(t))),
272 DispatchTensorKind::Autodiff(_) => {
273 panic!("Autodiff should not wrap an autodiff tensor.")
274 }
275 },
276 _ => panic!("Requires autodiff tensor."),
277 };
278 grad.map(|kind| DispatchTensor {
279 kind,
280 checkpointing: *checkpointing,
281 })
282 }
283
284 fn grad_replace(tensor: &DispatchTensor, grads: &mut Self::Gradients, grad: DispatchTensor) {
285 let DispatchTensor {
286 kind,
287 checkpointing,
288 } = tensor;
289 let DispatchTensor {
290 kind: grad,
291 checkpointing: grad_ckp,
292 } = grad;
293 debug_assert_eq!(checkpointing, &grad_ckp);
294
295 match &kind {
296 #[cfg(feature = "autodiff")]
297 DispatchTensorKind::Autodiff(inner_kind) => match (&**inner_kind, grad) {
298 #[cfg(feature = "cpu")]
299 (DispatchTensorKind::Cpu(tensor), DispatchTensorKind::Cpu(grad)) => {
300 tensor.as_autodiff().grad_replace(grads, grad.float())
301 }
302 #[cfg(feature = "cuda")]
303 (DispatchTensorKind::Cuda(tensor), DispatchTensorKind::Cuda(grad)) => {
304 tensor.as_autodiff().grad_replace(grads, grad.float())
305 }
306 #[cfg(wgpu_metal)]
307 (DispatchTensorKind::Metal(tensor), DispatchTensorKind::Metal(grad)) => {
308 tensor.as_autodiff().grad_replace(grads, grad.float())
309 }
310 #[cfg(feature = "rocm")]
311 (DispatchTensorKind::Rocm(tensor), DispatchTensorKind::Rocm(grad)) => {
312 tensor.as_autodiff().grad_replace(grads, grad.float())
313 }
314 #[cfg(wgpu_vulkan)]
315 (DispatchTensorKind::Vulkan(tensor), DispatchTensorKind::Vulkan(grad)) => {
316 tensor.as_autodiff().grad_replace(grads, grad.float())
317 }
318 #[cfg(wgpu_webgpu)]
319 (DispatchTensorKind::Wgpu(tensor), DispatchTensorKind::Wgpu(grad)) => {
320 tensor.as_autodiff().grad_replace(grads, grad.float())
321 }
322 #[cfg(feature = "flex")]
323 (DispatchTensorKind::Flex(tensor), DispatchTensorKind::Flex(grad)) => {
324 tensor.as_autodiff().grad_replace(grads, grad.float())
325 }
326 #[cfg(feature = "ndarray")]
327 (DispatchTensorKind::NdArray(tensor), DispatchTensorKind::NdArray(grad)) => {
328 tensor.as_autodiff().grad_replace(grads, grad.float())
329 }
330 (DispatchTensorKind::Autodiff(_), _) => {
331 panic!("Autodiff should not wrap an autodiff tensor.")
332 }
333 (t, g) => panic!(
334 "The provided tensors are not on the same backend. Got backends {t:?} and {g:?}."
335 ),
336 },
337 _ => panic!("Requires autodiff tensor."),
338 }
339 }
340
341 fn inner(tensor: DispatchTensor) -> DispatchTensor {
342 let DispatchTensor {
343 kind,
344 checkpointing,
345 } = tensor;
346
347 let kind = match kind {
348 #[cfg(feature = "autodiff")]
349 DispatchTensorKind::Autodiff(inner_kind) => match *inner_kind {
350 #[cfg(feature = "cpu")]
351 DispatchTensorKind::Cpu(tensor) => DispatchTensorKind::Cpu(
352 crate::BackendTensor::Float(tensor.autodiff().primitive),
353 ),
354 #[cfg(feature = "cuda")]
355 DispatchTensorKind::Cuda(tensor) => DispatchTensorKind::Cuda(
356 crate::BackendTensor::Float(tensor.autodiff().primitive),
357 ),
358 #[cfg(wgpu_metal)]
359 DispatchTensorKind::Metal(tensor) => DispatchTensorKind::Metal(
360 crate::BackendTensor::Float(tensor.autodiff().primitive),
361 ),
362 #[cfg(feature = "rocm")]
363 DispatchTensorKind::Rocm(tensor) => DispatchTensorKind::Rocm(
364 crate::BackendTensor::Float(tensor.autodiff().primitive),
365 ),
366 #[cfg(wgpu_vulkan)]
367 DispatchTensorKind::Vulkan(tensor) => DispatchTensorKind::Vulkan(
368 crate::BackendTensor::Float(tensor.autodiff().primitive),
369 ),
370 #[cfg(wgpu_webgpu)]
371 DispatchTensorKind::Wgpu(tensor) => DispatchTensorKind::Wgpu(
372 crate::BackendTensor::Float(tensor.autodiff().primitive),
373 ),
374 #[cfg(feature = "flex")]
375 DispatchTensorKind::Flex(tensor) => DispatchTensorKind::Flex(
376 crate::BackendTensor::Float(tensor.autodiff().primitive),
377 ),
378 #[cfg(feature = "ndarray")]
379 DispatchTensorKind::NdArray(tensor) => DispatchTensorKind::NdArray(
380 crate::BackendTensor::Float(tensor.autodiff().primitive),
381 ),
382 #[cfg(feature = "tch")]
383 DispatchTensorKind::LibTorch(tensor) => DispatchTensorKind::LibTorch(
384 crate::BackendTensor::Float(tensor.autodiff().primitive),
385 ),
386 DispatchTensorKind::Autodiff(_) => {
387 panic!("Autodiff should not wrap an autodiff tensor.")
388 }
389 },
390 _ => panic!("Requires autodiff tensor."),
391 };
392 DispatchTensor {
393 kind,
394 checkpointing,
395 }
396 }
397
398 fn int_inner(tensor: DispatchTensor) -> DispatchTensor {
399 tensor
400 }
401
402 fn bool_inner(tensor: DispatchTensor) -> DispatchTensor {
403 tensor
404 }
405
406 fn q_inner(tensor: DispatchTensor) -> DispatchTensor {
407 tensor
408 }
409
410 fn from_inner(tensor: DispatchTensor) -> DispatchTensor {
411 let DispatchTensor {
412 kind,
413 checkpointing,
414 } = tensor;
415
416 let kind = match kind {
417 #[cfg(feature = "cpu")]
418 DispatchTensorKind::Cpu(tensor) => DispatchTensorKind::Autodiff(Box::new(
419 DispatchTensorKind::Cpu(crate::BackendTensor::Autodiff(
420 Autodiff::<Cpu<f32>>::from_inner(tensor.float()),
421 )),
422 )),
423 #[cfg(feature = "cuda")]
424 DispatchTensorKind::Cuda(tensor) => DispatchTensorKind::Autodiff(Box::new(
425 DispatchTensorKind::Cuda(crate::BackendTensor::Autodiff(
426 Autodiff::<Cuda<f32>>::from_inner(tensor.float()),
427 )),
428 )),
429 #[cfg(wgpu_metal)]
430 DispatchTensorKind::Metal(tensor) => DispatchTensorKind::Autodiff(Box::new(
431 DispatchTensorKind::Metal(crate::BackendTensor::Autodiff(
432 Autodiff::<Metal<f32>>::from_inner(tensor.float()),
433 )),
434 )),
435 #[cfg(feature = "rocm")]
436 DispatchTensorKind::Rocm(tensor) => DispatchTensorKind::Autodiff(Box::new(
437 DispatchTensorKind::Rocm(crate::BackendTensor::Autodiff(
438 Autodiff::<Rocm<f32>>::from_inner(tensor.float()),
439 )),
440 )),
441 #[cfg(wgpu_vulkan)]
442 DispatchTensorKind::Vulkan(tensor) => DispatchTensorKind::Autodiff(Box::new(
443 DispatchTensorKind::Vulkan(crate::BackendTensor::Autodiff(
444 Autodiff::<Vulkan<f32>>::from_inner(tensor.float()),
445 )),
446 )),
447 #[cfg(wgpu_webgpu)]
448 DispatchTensorKind::Wgpu(tensor) => DispatchTensorKind::Autodiff(Box::new(
449 DispatchTensorKind::Wgpu(crate::BackendTensor::Autodiff(
450 Autodiff::<Wgpu<f32>>::from_inner(tensor.float()),
451 )),
452 )),
453 #[cfg(feature = "flex")]
454 DispatchTensorKind::Flex(tensor) => {
455 DispatchTensorKind::Autodiff(Box::new(DispatchTensorKind::Flex(
456 crate::BackendTensor::Autodiff(Autodiff::<Flex>::from_inner(tensor.float())),
457 )))
458 }
459 #[cfg(feature = "ndarray")]
460 DispatchTensorKind::NdArray(tensor) => DispatchTensorKind::Autodiff(Box::new(
461 DispatchTensorKind::NdArray(crate::BackendTensor::Autodiff(
462 Autodiff::<NdArray<f32>>::from_inner(tensor.float()),
463 )),
464 )),
465 #[cfg(feature = "tch")]
466 DispatchTensorKind::LibTorch(tensor) => {
467 DispatchTensorKind::Autodiff(Box::new(DispatchTensorKind::LibTorch(
468 crate::BackendTensor::Autodiff(Autodiff::<LibTorch<f32>>::from_inner(
469 tensor.float(),
470 )),
471 )))
472 }
473 DispatchTensorKind::Autodiff(_) => {
474 panic!("Autodiff should not wrap an autodiff tensor.")
475 }
476 };
477 DispatchTensor {
478 kind,
479 checkpointing,
480 }
481 }
482
483 fn int_from_inner(tensor: DispatchTensor) -> DispatchTensor {
484 tensor
485 }
486
487 fn bool_from_inner(tensor: DispatchTensor) -> DispatchTensor {
488 tensor
489 }
490
491 fn q_from_inner(tensor: DispatchTensor) -> DispatchTensor {
492 tensor
493 }
494}
495
496impl DispatchTensorKind {
497 pub(crate) fn device(&self) -> DispatchDevice {
498 match self {
499 #[cfg(feature = "cpu")]
500 DispatchTensorKind::Cpu(tensor) => DispatchDevice::Cpu(tensor.device()),
501 #[cfg(feature = "cuda")]
502 DispatchTensorKind::Cuda(tensor) => DispatchDevice::Cuda(tensor.device()),
503 #[cfg(wgpu_metal)]
504 DispatchTensorKind::Metal(tensor) => DispatchDevice::Metal(tensor.device()),
505 #[cfg(feature = "rocm")]
506 DispatchTensorKind::Rocm(tensor) => DispatchDevice::Rocm(tensor.device()),
507 #[cfg(wgpu_vulkan)]
508 DispatchTensorKind::Vulkan(tensor) => DispatchDevice::Vulkan(tensor.device()),
509 #[cfg(wgpu_webgpu)]
510 DispatchTensorKind::Wgpu(tensor) => DispatchDevice::Wgpu(tensor.device()),
511 #[cfg(feature = "flex")]
512 DispatchTensorKind::Flex(tensor) => DispatchDevice::Flex(tensor.device()),
513 #[cfg(feature = "ndarray")]
514 DispatchTensorKind::NdArray(tensor) => DispatchDevice::NdArray(tensor.device()),
515 #[cfg(feature = "tch")]
516 DispatchTensorKind::LibTorch(tensor) => DispatchDevice::LibTorch(tensor.device()),
517 #[cfg(feature = "autodiff")]
518 DispatchTensorKind::Autodiff(tensor) => DispatchDevice::autodiff(tensor.device()),
519 }
520 }
521}
522
523impl DispatchTensor {
524 pub(crate) fn device(&self) -> DispatchDevice {
525 #[allow(unused_mut)]
526 let mut device = self.kind.device();
527
528 #[cfg(feature = "autodiff")]
529 if let DispatchDevice::Autodiff(device) = &mut device {
530 device.checkpointing = self.checkpointing;
531 }
532
533 device
534 }
535}
536
537impl Dispatch {
538 pub fn default_quant_scheme(device: &Device<Self>) -> QuantScheme {
542 match device {
543 #[cfg(feature = "cpu")]
544 DispatchDevice::Cpu(_) => <QuantizedTensor<Cpu> as QTensorPrimitive>::default_scheme(),
545 #[cfg(feature = "cuda")]
546 DispatchDevice::Cuda(_) => {
547 <QuantizedTensor<Cuda> as QTensorPrimitive>::default_scheme()
548 }
549 #[cfg(wgpu_metal)]
550 DispatchDevice::Metal(_) => {
551 <QuantizedTensor<Metal> as QTensorPrimitive>::default_scheme()
552 }
553 #[cfg(feature = "rocm")]
554 DispatchDevice::Rocm(_) => {
555 <QuantizedTensor<Rocm> as QTensorPrimitive>::default_scheme()
556 }
557 #[cfg(wgpu_vulkan)]
558 DispatchDevice::Vulkan(_) => {
559 <QuantizedTensor<Vulkan> as QTensorPrimitive>::default_scheme()
560 }
561 #[cfg(wgpu_webgpu)]
562 DispatchDevice::Wgpu(_) => {
563 <QuantizedTensor<Wgpu> as QTensorPrimitive>::default_scheme()
564 }
565 #[cfg(feature = "flex")]
566 DispatchDevice::Flex(_) => {
567 <QuantizedTensor<Flex> as QTensorPrimitive>::default_scheme()
568 }
569 #[cfg(feature = "ndarray")]
570 DispatchDevice::NdArray(_) => {
571 <QuantizedTensor<NdArray> as QTensorPrimitive>::default_scheme()
572 }
573 #[cfg(feature = "tch")]
574 DispatchDevice::LibTorch(_) => {
575 <QuantizedTensor<LibTorch> as QTensorPrimitive>::default_scheme()
576 }
577 #[cfg(feature = "autodiff")]
578 DispatchDevice::Autodiff(ad_device) => Self::default_quant_scheme(&ad_device.inner),
579 }
580 }
581}