1use alloc::boxed::Box;
2use alloc::format;
3use alloc::string::String;
4
5use burn_backend::quantization::QuantScheme;
6use burn_backend::{Backend, DType, ExecutionError, QTensorPrimitive};
7
8#[cfg(feature = "autodiff")]
9use burn_autodiff::grads::Gradients;
10#[cfg(feature = "autodiff")]
11use burn_backend::AutodiffBackend;
12
13#[allow(unused)]
14use crate::BackendId;
15use crate::DispatchTensorKind;
16use crate::backends::*;
17use crate::{DispatchDevice, DispatchTensor};
18
19#[derive(Debug, Default, Clone)]
44pub struct Dispatch;
45
46impl Backend for Dispatch {
47 type Device = DispatchDevice;
48
49 type FloatTensorPrimitive = DispatchTensor;
50
51 type FloatElem = f32;
53
54 type IntTensorPrimitive = DispatchTensor;
55
56 type IntElem = i32;
57
58 type BoolTensorPrimitive = DispatchTensor;
59
60 type BoolElem = u8;
61
62 type QuantizedTensorPrimitive = DispatchTensor;
63
64 fn name(device: &Self::Device) -> String {
65 let inner = dispatch_device!(device, |device| B::name(device));
66 format!("dispatch<{inner}>")
67 }
68
69 fn seed(device: &Self::Device, seed: u64) {
70 dispatch_device!(device, |device| B::seed(device, seed))
71 }
72
73 fn sync(device: &Self::Device) -> Result<(), ExecutionError> {
74 dispatch_device!(device, |device| B::sync(device))
75 }
76
77 fn dtype_usage(device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet {
78 dispatch_device!(device, |device| B::dtype_usage(device, dtype))
79 }
80
81 fn ad_enabled(device: &Self::Device) -> bool {
82 match device {
83 #[cfg(feature = "autodiff")]
84 DispatchDevice::Autodiff(_) => true,
85 _ => false,
86 }
87 }
88
89 fn device_count(type_id: u16) -> usize {
90 let (dispatch_id, backend_type_id) = DispatchDevice::decode_type_id(type_id);
91 match dispatch_id {
92 #[cfg(feature = "cpu")]
93 BackendId::Cpu => Cpu::<f32>::device_count(backend_type_id),
94 #[cfg(feature = "cuda")]
95 BackendId::Cuda => Cuda::<f32>::device_count(backend_type_id),
96 #[cfg(wgpu_metal)]
97 BackendId::Metal => Metal::<f32>::device_count(backend_type_id),
98 #[cfg(feature = "rocm")]
99 BackendId::Rocm => Rocm::<f32>::device_count(backend_type_id),
100 #[cfg(wgpu_vulkan)]
101 BackendId::Vulkan => Vulkan::<f32>::device_count(backend_type_id),
102 #[cfg(wgpu_webgpu)]
103 BackendId::Wgpu => Wgpu::<f32>::device_count(backend_type_id),
104 #[cfg(feature = "ndarray")]
105 BackendId::NdArray => NdArray::<f32>::device_count(backend_type_id),
106 #[cfg(feature = "tch")]
107 BackendId::LibTorch => LibTorch::<f32>::device_count(backend_type_id),
108 }
109 }
110}
111
112#[cfg(feature = "autodiff")]
113impl AutodiffBackend for Dispatch {
114 type InnerBackend = Dispatch;
115
116 type Gradients = Gradients;
117
118 fn backward(tensor: DispatchTensor) -> Self::Gradients {
119 let DispatchTensor { kind, .. } = tensor;
120 match kind {
121 #[cfg(feature = "autodiff")]
122 DispatchTensorKind::Autodiff(tensor) => match *tensor {
123 #[cfg(feature = "cpu")]
124 DispatchTensorKind::Cpu(tensor) => tensor.autodiff().backward(),
125 #[cfg(feature = "cuda")]
126 DispatchTensorKind::Cuda(tensor) => tensor.autodiff().backward(),
127 #[cfg(wgpu_metal)]
128 DispatchTensorKind::Metal(tensor) => tensor.autodiff().backward(),
129 #[cfg(feature = "rocm")]
130 DispatchTensorKind::Rocm(tensor) => tensor.autodiff().backward(),
131 #[cfg(wgpu_vulkan)]
132 DispatchTensorKind::Vulkan(tensor) => tensor.autodiff().backward(),
133 #[cfg(wgpu_webgpu)]
134 DispatchTensorKind::Wgpu(tensor) => tensor.autodiff().backward(),
135 #[cfg(feature = "ndarray")]
136 DispatchTensorKind::NdArray(tensor) => tensor.autodiff().backward(),
137 #[cfg(feature = "tch")]
138 DispatchTensorKind::LibTorch(tensor) => tensor.autodiff().backward(),
139 DispatchTensorKind::Autodiff(_) => {
140 panic!("Autodiff should not wrap an autodiff tensor.")
141 }
142 },
143 _ => panic!("Requires autodiff tensor."),
144 }
145 }
146
147 fn grad(tensor: &DispatchTensor, grads: &Self::Gradients) -> Option<DispatchTensor> {
148 let DispatchTensor {
149 kind,
150 checkpointing,
151 } = tensor;
152 let grad = match &kind {
153 #[cfg(feature = "autodiff")]
154 DispatchTensorKind::Autodiff(inner_kind) => match &**inner_kind {
155 #[cfg(feature = "cpu")]
156 DispatchTensorKind::Cpu(tensor) => tensor
157 .as_autodiff()
158 .grad(grads)
159 .map(|t| DispatchTensorKind::Cpu(crate::BackendTensor::Float(t))),
160 #[cfg(feature = "cuda")]
161 DispatchTensorKind::Cuda(tensor) => tensor
162 .as_autodiff()
163 .grad(grads)
164 .map(|t| DispatchTensorKind::Cuda(crate::BackendTensor::Float(t))),
165 #[cfg(wgpu_metal)]
166 DispatchTensorKind::Metal(tensor) => tensor
167 .as_autodiff()
168 .grad(grads)
169 .map(|t| DispatchTensorKind::Metal(crate::BackendTensor::Float(t))),
170 #[cfg(feature = "rocm")]
171 DispatchTensorKind::Rocm(tensor) => tensor
172 .as_autodiff()
173 .grad(grads)
174 .map(|t| DispatchTensorKind::Rocm(crate::BackendTensor::Float(t))),
175 #[cfg(wgpu_vulkan)]
176 DispatchTensorKind::Vulkan(tensor) => tensor
177 .as_autodiff()
178 .grad(grads)
179 .map(|t| DispatchTensorKind::Vulkan(crate::BackendTensor::Float(t))),
180 #[cfg(wgpu_webgpu)]
181 DispatchTensorKind::Wgpu(tensor) => tensor
182 .as_autodiff()
183 .grad(grads)
184 .map(|t| DispatchTensorKind::Wgpu(crate::BackendTensor::Float(t))),
185 #[cfg(feature = "ndarray")]
186 DispatchTensorKind::NdArray(tensor) => tensor
187 .as_autodiff()
188 .grad(grads)
189 .map(|t| DispatchTensorKind::NdArray(crate::BackendTensor::Float(t))),
190 #[cfg(feature = "tch")]
191 DispatchTensorKind::LibTorch(tensor) => tensor
192 .as_autodiff()
193 .grad(grads)
194 .map(|t| DispatchTensorKind::LibTorch(crate::BackendTensor::Float(t))),
195 DispatchTensorKind::Autodiff(_) => {
196 panic!("Autodiff should not wrap an autodiff tensor.")
197 }
198 },
199 _ => panic!("Requires autodiff tensor."),
200 };
201 grad.map(|kind| DispatchTensor {
202 kind,
203 checkpointing: *checkpointing,
204 })
205 }
206
207 fn grad_remove(tensor: &DispatchTensor, grads: &mut Self::Gradients) -> Option<DispatchTensor> {
208 let DispatchTensor {
209 kind,
210 checkpointing,
211 } = tensor;
212 let grad = match &kind {
213 #[cfg(feature = "autodiff")]
214 DispatchTensorKind::Autodiff(inner_kind) => match &**inner_kind {
215 #[cfg(feature = "cpu")]
216 DispatchTensorKind::Cpu(tensor) => tensor
217 .as_autodiff()
218 .grad_remove(grads)
219 .map(|t| DispatchTensorKind::Cpu(crate::BackendTensor::Float(t))),
220 #[cfg(feature = "cuda")]
221 DispatchTensorKind::Cuda(tensor) => tensor
222 .as_autodiff()
223 .grad_remove(grads)
224 .map(|t| DispatchTensorKind::Cuda(crate::BackendTensor::Float(t))),
225 #[cfg(wgpu_metal)]
226 DispatchTensorKind::Metal(tensor) => tensor
227 .as_autodiff()
228 .grad_remove(grads)
229 .map(|t| DispatchTensorKind::Metal(crate::BackendTensor::Float(t))),
230 #[cfg(feature = "rocm")]
231 DispatchTensorKind::Rocm(tensor) => tensor
232 .as_autodiff()
233 .grad_remove(grads)
234 .map(|t| DispatchTensorKind::Rocm(crate::BackendTensor::Float(t))),
235 #[cfg(wgpu_vulkan)]
236 DispatchTensorKind::Vulkan(tensor) => tensor
237 .as_autodiff()
238 .grad_remove(grads)
239 .map(|t| DispatchTensorKind::Vulkan(crate::BackendTensor::Float(t))),
240 #[cfg(wgpu_webgpu)]
241 DispatchTensorKind::Wgpu(tensor) => tensor
242 .as_autodiff()
243 .grad_remove(grads)
244 .map(|t| DispatchTensorKind::Wgpu(crate::BackendTensor::Float(t))),
245 #[cfg(feature = "ndarray")]
246 DispatchTensorKind::NdArray(tensor) => tensor
247 .as_autodiff()
248 .grad_remove(grads)
249 .map(|t| DispatchTensorKind::NdArray(crate::BackendTensor::Float(t))),
250 #[cfg(feature = "tch")]
251 DispatchTensorKind::LibTorch(tensor) => tensor
252 .as_autodiff()
253 .grad_remove(grads)
254 .map(|t| DispatchTensorKind::LibTorch(crate::BackendTensor::Float(t))),
255 DispatchTensorKind::Autodiff(_) => {
256 panic!("Autodiff should not wrap an autodiff tensor.")
257 }
258 },
259 _ => panic!("Requires autodiff tensor."),
260 };
261 grad.map(|kind| DispatchTensor {
262 kind,
263 checkpointing: *checkpointing,
264 })
265 }
266
267 fn grad_replace(tensor: &DispatchTensor, grads: &mut Self::Gradients, grad: DispatchTensor) {
268 let DispatchTensor {
269 kind,
270 checkpointing,
271 } = tensor;
272 let DispatchTensor {
273 kind: grad,
274 checkpointing: grad_ckp,
275 } = grad;
276 debug_assert_eq!(checkpointing, &grad_ckp);
277
278 match &kind {
279 #[cfg(feature = "autodiff")]
280 DispatchTensorKind::Autodiff(inner_kind) => match (&**inner_kind, grad) {
281 #[cfg(feature = "cpu")]
282 (DispatchTensorKind::Cpu(tensor), DispatchTensorKind::Cpu(grad)) => {
283 tensor.as_autodiff().grad_replace(grads, grad.float())
284 }
285 #[cfg(feature = "cuda")]
286 (DispatchTensorKind::Cuda(tensor), DispatchTensorKind::Cuda(grad)) => {
287 tensor.as_autodiff().grad_replace(grads, grad.float())
288 }
289 #[cfg(wgpu_metal)]
290 (DispatchTensorKind::Metal(tensor), DispatchTensorKind::Metal(grad)) => {
291 tensor.as_autodiff().grad_replace(grads, grad.float())
292 }
293 #[cfg(feature = "rocm")]
294 (DispatchTensorKind::Rocm(tensor), DispatchTensorKind::Rocm(grad)) => {
295 tensor.as_autodiff().grad_replace(grads, grad.float())
296 }
297 #[cfg(wgpu_vulkan)]
298 (DispatchTensorKind::Vulkan(tensor), DispatchTensorKind::Vulkan(grad)) => {
299 tensor.as_autodiff().grad_replace(grads, grad.float())
300 }
301 #[cfg(wgpu_webgpu)]
302 (DispatchTensorKind::Wgpu(tensor), DispatchTensorKind::Wgpu(grad)) => {
303 tensor.as_autodiff().grad_replace(grads, grad.float())
304 }
305 #[cfg(feature = "ndarray")]
306 (DispatchTensorKind::NdArray(tensor), DispatchTensorKind::NdArray(grad)) => {
307 tensor.as_autodiff().grad_replace(grads, grad.float())
308 }
309 (DispatchTensorKind::Autodiff(_), _) => {
310 panic!("Autodiff should not wrap an autodiff tensor.")
311 }
312 (t, g) => panic!(
313 "The provided tensors are not on the same backend. Got backends {t:?} and {g:?}."
314 ),
315 },
316 _ => panic!("Requires autodiff tensor."),
317 }
318 }
319
320 fn inner(tensor: DispatchTensor) -> DispatchTensor {
321 let DispatchTensor {
322 kind,
323 checkpointing,
324 } = tensor;
325
326 let kind = match kind {
327 #[cfg(feature = "autodiff")]
328 DispatchTensorKind::Autodiff(inner_kind) => match *inner_kind {
329 #[cfg(feature = "cpu")]
330 DispatchTensorKind::Cpu(tensor) => DispatchTensorKind::Cpu(
331 crate::BackendTensor::Float(tensor.autodiff().primitive),
332 ),
333 #[cfg(feature = "cuda")]
334 DispatchTensorKind::Cuda(tensor) => DispatchTensorKind::Cuda(
335 crate::BackendTensor::Float(tensor.autodiff().primitive),
336 ),
337 #[cfg(wgpu_metal)]
338 DispatchTensorKind::Metal(tensor) => DispatchTensorKind::Metal(
339 crate::BackendTensor::Float(tensor.autodiff().primitive),
340 ),
341 #[cfg(feature = "rocm")]
342 DispatchTensorKind::Rocm(tensor) => DispatchTensorKind::Rocm(
343 crate::BackendTensor::Float(tensor.autodiff().primitive),
344 ),
345 #[cfg(wgpu_vulkan)]
346 DispatchTensorKind::Vulkan(tensor) => DispatchTensorKind::Vulkan(
347 crate::BackendTensor::Float(tensor.autodiff().primitive),
348 ),
349 #[cfg(wgpu_webgpu)]
350 DispatchTensorKind::Wgpu(tensor) => DispatchTensorKind::Wgpu(
351 crate::BackendTensor::Float(tensor.autodiff().primitive),
352 ),
353 #[cfg(feature = "ndarray")]
354 DispatchTensorKind::NdArray(tensor) => DispatchTensorKind::NdArray(
355 crate::BackendTensor::Float(tensor.autodiff().primitive),
356 ),
357 #[cfg(feature = "tch")]
358 DispatchTensorKind::LibTorch(tensor) => DispatchTensorKind::LibTorch(
359 crate::BackendTensor::Float(tensor.autodiff().primitive),
360 ),
361 DispatchTensorKind::Autodiff(_) => {
362 panic!("Autodiff should not wrap an autodiff tensor.")
363 }
364 },
365 _ => panic!("Requires autodiff tensor."),
366 };
367 DispatchTensor {
368 kind,
369 checkpointing,
370 }
371 }
372
373 fn int_inner(tensor: DispatchTensor) -> DispatchTensor {
374 tensor
375 }
376
377 fn bool_inner(tensor: DispatchTensor) -> DispatchTensor {
378 tensor
379 }
380
381 fn q_inner(tensor: DispatchTensor) -> DispatchTensor {
382 tensor
383 }
384
385 fn from_inner(tensor: DispatchTensor) -> DispatchTensor {
386 let DispatchTensor {
387 kind,
388 checkpointing,
389 } = tensor;
390
391 let kind = match kind {
392 #[cfg(feature = "cpu")]
393 DispatchTensorKind::Cpu(tensor) => DispatchTensorKind::Autodiff(Box::new(
394 DispatchTensorKind::Cpu(crate::BackendTensor::Autodiff(
395 Autodiff::<Cpu<f32>>::from_inner(tensor.float()),
396 )),
397 )),
398 #[cfg(feature = "cuda")]
399 DispatchTensorKind::Cuda(tensor) => DispatchTensorKind::Autodiff(Box::new(
400 DispatchTensorKind::Cuda(crate::BackendTensor::Autodiff(
401 Autodiff::<Cuda<f32>>::from_inner(tensor.float()),
402 )),
403 )),
404 #[cfg(wgpu_metal)]
405 DispatchTensorKind::Metal(tensor) => DispatchTensorKind::Autodiff(Box::new(
406 DispatchTensorKind::Metal(crate::BackendTensor::Autodiff(
407 Autodiff::<Metal<f32>>::from_inner(tensor.float()),
408 )),
409 )),
410 #[cfg(feature = "rocm")]
411 DispatchTensorKind::Rocm(tensor) => DispatchTensorKind::Autodiff(Box::new(
412 DispatchTensorKind::Rocm(crate::BackendTensor::Autodiff(
413 Autodiff::<Rocm<f32>>::from_inner(tensor.float()),
414 )),
415 )),
416 #[cfg(wgpu_vulkan)]
417 DispatchTensorKind::Vulkan(tensor) => DispatchTensorKind::Autodiff(Box::new(
418 DispatchTensorKind::Vulkan(crate::BackendTensor::Autodiff(
419 Autodiff::<Vulkan<f32>>::from_inner(tensor.float()),
420 )),
421 )),
422 #[cfg(wgpu_webgpu)]
423 DispatchTensorKind::Wgpu(tensor) => DispatchTensorKind::Autodiff(Box::new(
424 DispatchTensorKind::Wgpu(crate::BackendTensor::Autodiff(
425 Autodiff::<Wgpu<f32>>::from_inner(tensor.float()),
426 )),
427 )),
428 #[cfg(feature = "ndarray")]
429 DispatchTensorKind::NdArray(tensor) => DispatchTensorKind::Autodiff(Box::new(
430 DispatchTensorKind::NdArray(crate::BackendTensor::Autodiff(
431 Autodiff::<NdArray<f32>>::from_inner(tensor.float()),
432 )),
433 )),
434 #[cfg(feature = "tch")]
435 DispatchTensorKind::LibTorch(tensor) => {
436 DispatchTensorKind::Autodiff(Box::new(DispatchTensorKind::LibTorch(
437 crate::BackendTensor::Autodiff(Autodiff::<LibTorch<f32>>::from_inner(
438 tensor.float(),
439 )),
440 )))
441 }
442 DispatchTensorKind::Autodiff(_) => {
443 panic!("Autodiff should not wrap an autodiff tensor.")
444 }
445 };
446 DispatchTensor {
447 kind,
448 checkpointing,
449 }
450 }
451
452 fn int_from_inner(tensor: DispatchTensor) -> DispatchTensor {
453 tensor
454 }
455
456 fn bool_from_inner(tensor: DispatchTensor) -> DispatchTensor {
457 tensor
458 }
459
460 fn q_from_inner(tensor: DispatchTensor) -> DispatchTensor {
461 tensor
462 }
463}
464
465impl DispatchTensorKind {
466 pub(crate) fn device(&self) -> DispatchDevice {
467 match self {
468 #[cfg(feature = "cpu")]
469 DispatchTensorKind::Cpu(tensor) => DispatchDevice::Cpu(tensor.device()),
470 #[cfg(feature = "cuda")]
471 DispatchTensorKind::Cuda(tensor) => DispatchDevice::Cuda(tensor.device()),
472 #[cfg(wgpu_metal)]
473 DispatchTensorKind::Metal(tensor) => DispatchDevice::Metal(tensor.device()),
474 #[cfg(feature = "rocm")]
475 DispatchTensorKind::Rocm(tensor) => DispatchDevice::Rocm(tensor.device()),
476 #[cfg(wgpu_vulkan)]
477 DispatchTensorKind::Vulkan(tensor) => DispatchDevice::Vulkan(tensor.device()),
478 #[cfg(wgpu_webgpu)]
479 DispatchTensorKind::Wgpu(tensor) => DispatchDevice::Wgpu(tensor.device()),
480 #[cfg(feature = "ndarray")]
481 DispatchTensorKind::NdArray(tensor) => DispatchDevice::NdArray(tensor.device()),
482 #[cfg(feature = "tch")]
483 DispatchTensorKind::LibTorch(tensor) => DispatchDevice::LibTorch(tensor.device()),
484 #[cfg(feature = "autodiff")]
485 DispatchTensorKind::Autodiff(tensor) => DispatchDevice::autodiff(tensor.device()),
486 }
487 }
488}
489
490impl DispatchTensor {
491 pub(crate) fn device(&self) -> DispatchDevice {
492 #[allow(unused_mut)]
493 let mut device = self.kind.device();
494
495 #[cfg(feature = "autodiff")]
496 if let DispatchDevice::Autodiff(device) = &mut device {
497 device.checkpointing = self.checkpointing;
498 }
499
500 device
501 }
502}
503
504impl Dispatch {
505 pub fn default_quant_scheme(device: &<Self as Backend>::Device) -> QuantScheme {
509 match device {
510 #[cfg(feature = "cpu")]
511 DispatchDevice::Cpu(_) => {
512 <<Cpu as Backend>::QuantizedTensorPrimitive as QTensorPrimitive>::default_scheme()
513 }
514 #[cfg(feature = "cuda")]
515 DispatchDevice::Cuda(_) => {
516 <<Cuda as Backend>::QuantizedTensorPrimitive as QTensorPrimitive>::default_scheme()
517 }
518 #[cfg(wgpu_metal)]
519 DispatchDevice::Metal(_) => {
520 <<Metal as Backend>::QuantizedTensorPrimitive as QTensorPrimitive>::default_scheme()
521 }
522 #[cfg(feature = "rocm")]
523 DispatchDevice::Rocm(_) => {
524 <<Rocm as Backend>::QuantizedTensorPrimitive as QTensorPrimitive>::default_scheme()
525 }
526 #[cfg(wgpu_vulkan)]
527 DispatchDevice::Vulkan(_) => {
528 <<Vulkan as Backend>::QuantizedTensorPrimitive as QTensorPrimitive>::default_scheme()
529 }
530 #[cfg(wgpu_webgpu)]
531 DispatchDevice::Wgpu(_) => {
532 <<Wgpu as Backend>::QuantizedTensorPrimitive as QTensorPrimitive>::default_scheme()
533 }
534 #[cfg(feature = "ndarray")]
535 DispatchDevice::NdArray(_) => {
536 <<NdArray as Backend>::QuantizedTensorPrimitive as QTensorPrimitive>::default_scheme()
537 }
538 #[cfg(feature = "tch")]
539 DispatchDevice::LibTorch(_) => {
540 <<LibTorch as Backend>::QuantizedTensorPrimitive as QTensorPrimitive>::default_scheme()
541 }
542 #[cfg(feature = "autodiff")]
543 DispatchDevice::Autodiff(ad_device) => Self::default_quant_scheme(&ad_device.inner),
544 }
545 }
546}