Skip to main content

burn_dispatch/
backend.rs

1use alloc::boxed::Box;
2use alloc::format;
3use alloc::string::String;
4
5use burn_backend::quantization::QuantScheme;
6use burn_backend::tensor::{Device, QuantizedTensor};
7use burn_backend::{Backend, BackendTypes, DType, ExecutionError, QTensorPrimitive};
8
9#[cfg(feature = "autodiff")]
10use burn_autodiff::grads::Gradients;
11#[cfg(feature = "autodiff")]
12use burn_backend::AutodiffBackend;
13
14#[allow(unused)]
15use crate::BackendId;
16use crate::DispatchTensorKind;
17use crate::backends::*;
18use crate::{DispatchDevice, DispatchTensor};
19
20/// The main execution backend in Burn.
21///
22/// [`Dispatch`] acts as a global backend that can manage multiple underlying
23/// backends (e.g., `Cpu`, `Cuda`, `Wgpu`, `Metal`, etc.).  
24/// It is responsible for:
25/// - Dispatching tensor operations to the appropriate backend.
26/// - Managing cross-backend tensor transfers.
27///
28/// Essentially, [`Dispatch`] is the single entry point for executing tensor operations
29/// in a backend-agnostic way. It allows Burn to provide a unified, global backend
30/// for users while still leveraging multiple specialized backends under the hood.
31///
32/// # Example
33///
34/// ```ignore
35/// use burn::Dispatch;
36/// use burn::DispatchDevice;
37///
38/// // Select the device to execute operations on
39/// let device = DispatchDevice::Cuda(Default::default());
40///
41/// // Create a tensor using the global backend
42/// let t = Tensor::<Dispatch, 2>::zeros([128, 128], &device);
43/// ```
44#[derive(Debug, Default, Clone)]
45pub struct Dispatch;
46
47impl BackendTypes for Dispatch {
48    type Device = DispatchDevice;
49
50    type FloatTensorPrimitive = DispatchTensor;
51
52    // TODO: either allow default dtype generic or remove associated types entirely?
53    type FloatElem = f32;
54
55    type IntTensorPrimitive = DispatchTensor;
56
57    type IntElem = i32;
58
59    type BoolTensorPrimitive = DispatchTensor;
60
61    type BoolElem = u8;
62
63    type QuantizedTensorPrimitive = DispatchTensor;
64}
65
66impl Backend for Dispatch {
67    fn name(device: &Self::Device) -> String {
68        let inner = dispatch_device!(device, |device| B::name(device));
69        format!("dispatch<{inner}>")
70    }
71
72    fn seed(device: &Self::Device, seed: u64) {
73        dispatch_device!(device, |device| B::seed(device, seed))
74    }
75
76    fn sync(device: &Self::Device) -> Result<(), ExecutionError> {
77        dispatch_device!(device, |device| B::sync(device))
78    }
79
80    fn dtype_usage(device: &Self::Device, dtype: DType) -> burn_backend::DTypeUsageSet {
81        dispatch_device!(device, |device| B::dtype_usage(device, dtype))
82    }
83
84    fn ad_enabled(device: &Self::Device) -> bool {
85        match device {
86            #[cfg(feature = "autodiff")]
87            DispatchDevice::Autodiff(_) => true,
88            _ => false,
89        }
90    }
91
92    fn device_count(type_id: u16) -> usize {
93        let (dispatch_id, backend_type_id) = DispatchDevice::decode_type_id(type_id);
94        match dispatch_id {
95            #[cfg(feature = "cpu")]
96            BackendId::Cpu => Cpu::<f32>::device_count(backend_type_id),
97            #[cfg(feature = "cuda")]
98            BackendId::Cuda => Cuda::<f32>::device_count(backend_type_id),
99            #[cfg(wgpu_metal)]
100            BackendId::Metal => Metal::<f32>::device_count(backend_type_id),
101            #[cfg(feature = "rocm")]
102            BackendId::Rocm => Rocm::<f32>::device_count(backend_type_id),
103            #[cfg(wgpu_vulkan)]
104            BackendId::Vulkan => Vulkan::<f32>::device_count(backend_type_id),
105            #[cfg(wgpu_webgpu)]
106            BackendId::Wgpu => Wgpu::<f32>::device_count(backend_type_id),
107            #[cfg(feature = "flex")]
108            BackendId::Flex => Flex::device_count(backend_type_id),
109            #[cfg(feature = "ndarray")]
110            BackendId::NdArray => NdArray::<f32>::device_count(backend_type_id),
111            #[cfg(feature = "tch")]
112            BackendId::LibTorch => LibTorch::<f32>::device_count(backend_type_id),
113        }
114    }
115}
116
117#[cfg(feature = "autodiff")]
118impl AutodiffBackend for Dispatch {
119    type InnerBackend = Dispatch;
120
121    type Gradients = Gradients;
122
123    fn backward(tensor: DispatchTensor) -> Self::Gradients {
124        let DispatchTensor { kind, .. } = tensor;
125        match kind {
126            #[cfg(feature = "autodiff")]
127            DispatchTensorKind::Autodiff(tensor) => match *tensor {
128                #[cfg(feature = "cpu")]
129                DispatchTensorKind::Cpu(tensor) => tensor.autodiff().backward(),
130                #[cfg(feature = "cuda")]
131                DispatchTensorKind::Cuda(tensor) => tensor.autodiff().backward(),
132                #[cfg(wgpu_metal)]
133                DispatchTensorKind::Metal(tensor) => tensor.autodiff().backward(),
134                #[cfg(feature = "rocm")]
135                DispatchTensorKind::Rocm(tensor) => tensor.autodiff().backward(),
136                #[cfg(wgpu_vulkan)]
137                DispatchTensorKind::Vulkan(tensor) => tensor.autodiff().backward(),
138                #[cfg(wgpu_webgpu)]
139                DispatchTensorKind::Wgpu(tensor) => tensor.autodiff().backward(),
140                #[cfg(feature = "flex")]
141                DispatchTensorKind::Flex(tensor) => tensor.autodiff().backward(),
142                #[cfg(feature = "ndarray")]
143                DispatchTensorKind::NdArray(tensor) => tensor.autodiff().backward(),
144                #[cfg(feature = "tch")]
145                DispatchTensorKind::LibTorch(tensor) => tensor.autodiff().backward(),
146                DispatchTensorKind::Autodiff(_) => {
147                    panic!("Autodiff should not wrap an autodiff tensor.")
148                }
149            },
150            _ => panic!("Requires autodiff tensor."),
151        }
152    }
153
154    fn grad(tensor: &DispatchTensor, grads: &Self::Gradients) -> Option<DispatchTensor> {
155        let DispatchTensor {
156            kind,
157            checkpointing,
158        } = tensor;
159        let grad = match &kind {
160            #[cfg(feature = "autodiff")]
161            DispatchTensorKind::Autodiff(inner_kind) => match &**inner_kind {
162                #[cfg(feature = "cpu")]
163                DispatchTensorKind::Cpu(tensor) => tensor
164                    .as_autodiff()
165                    .grad(grads)
166                    .map(|t| DispatchTensorKind::Cpu(crate::BackendTensor::Float(t))),
167                #[cfg(feature = "cuda")]
168                DispatchTensorKind::Cuda(tensor) => tensor
169                    .as_autodiff()
170                    .grad(grads)
171                    .map(|t| DispatchTensorKind::Cuda(crate::BackendTensor::Float(t))),
172                #[cfg(wgpu_metal)]
173                DispatchTensorKind::Metal(tensor) => tensor
174                    .as_autodiff()
175                    .grad(grads)
176                    .map(|t| DispatchTensorKind::Metal(crate::BackendTensor::Float(t))),
177                #[cfg(feature = "rocm")]
178                DispatchTensorKind::Rocm(tensor) => tensor
179                    .as_autodiff()
180                    .grad(grads)
181                    .map(|t| DispatchTensorKind::Rocm(crate::BackendTensor::Float(t))),
182                #[cfg(wgpu_vulkan)]
183                DispatchTensorKind::Vulkan(tensor) => tensor
184                    .as_autodiff()
185                    .grad(grads)
186                    .map(|t| DispatchTensorKind::Vulkan(crate::BackendTensor::Float(t))),
187                #[cfg(wgpu_webgpu)]
188                DispatchTensorKind::Wgpu(tensor) => tensor
189                    .as_autodiff()
190                    .grad(grads)
191                    .map(|t| DispatchTensorKind::Wgpu(crate::BackendTensor::Float(t))),
192                #[cfg(feature = "flex")]
193                DispatchTensorKind::Flex(tensor) => tensor
194                    .as_autodiff()
195                    .grad(grads)
196                    .map(|t| DispatchTensorKind::Flex(crate::BackendTensor::Float(t))),
197                #[cfg(feature = "ndarray")]
198                DispatchTensorKind::NdArray(tensor) => tensor
199                    .as_autodiff()
200                    .grad(grads)
201                    .map(|t| DispatchTensorKind::NdArray(crate::BackendTensor::Float(t))),
202                #[cfg(feature = "tch")]
203                DispatchTensorKind::LibTorch(tensor) => tensor
204                    .as_autodiff()
205                    .grad(grads)
206                    .map(|t| DispatchTensorKind::LibTorch(crate::BackendTensor::Float(t))),
207                DispatchTensorKind::Autodiff(_) => {
208                    panic!("Autodiff should not wrap an autodiff tensor.")
209                }
210            },
211            _ => panic!("Requires autodiff tensor."),
212        };
213        grad.map(|kind| DispatchTensor {
214            kind,
215            checkpointing: *checkpointing,
216        })
217    }
218
219    fn grad_remove(tensor: &DispatchTensor, grads: &mut Self::Gradients) -> Option<DispatchTensor> {
220        let DispatchTensor {
221            kind,
222            checkpointing,
223        } = tensor;
224        let grad = match &kind {
225            #[cfg(feature = "autodiff")]
226            DispatchTensorKind::Autodiff(inner_kind) => match &**inner_kind {
227                #[cfg(feature = "cpu")]
228                DispatchTensorKind::Cpu(tensor) => tensor
229                    .as_autodiff()
230                    .grad_remove(grads)
231                    .map(|t| DispatchTensorKind::Cpu(crate::BackendTensor::Float(t))),
232                #[cfg(feature = "cuda")]
233                DispatchTensorKind::Cuda(tensor) => tensor
234                    .as_autodiff()
235                    .grad_remove(grads)
236                    .map(|t| DispatchTensorKind::Cuda(crate::BackendTensor::Float(t))),
237                #[cfg(wgpu_metal)]
238                DispatchTensorKind::Metal(tensor) => tensor
239                    .as_autodiff()
240                    .grad_remove(grads)
241                    .map(|t| DispatchTensorKind::Metal(crate::BackendTensor::Float(t))),
242                #[cfg(feature = "rocm")]
243                DispatchTensorKind::Rocm(tensor) => tensor
244                    .as_autodiff()
245                    .grad_remove(grads)
246                    .map(|t| DispatchTensorKind::Rocm(crate::BackendTensor::Float(t))),
247                #[cfg(wgpu_vulkan)]
248                DispatchTensorKind::Vulkan(tensor) => tensor
249                    .as_autodiff()
250                    .grad_remove(grads)
251                    .map(|t| DispatchTensorKind::Vulkan(crate::BackendTensor::Float(t))),
252                #[cfg(wgpu_webgpu)]
253                DispatchTensorKind::Wgpu(tensor) => tensor
254                    .as_autodiff()
255                    .grad_remove(grads)
256                    .map(|t| DispatchTensorKind::Wgpu(crate::BackendTensor::Float(t))),
257                #[cfg(feature = "flex")]
258                DispatchTensorKind::Flex(tensor) => tensor
259                    .as_autodiff()
260                    .grad_remove(grads)
261                    .map(|t| DispatchTensorKind::Flex(crate::BackendTensor::Float(t))),
262                #[cfg(feature = "ndarray")]
263                DispatchTensorKind::NdArray(tensor) => tensor
264                    .as_autodiff()
265                    .grad_remove(grads)
266                    .map(|t| DispatchTensorKind::NdArray(crate::BackendTensor::Float(t))),
267                #[cfg(feature = "tch")]
268                DispatchTensorKind::LibTorch(tensor) => tensor
269                    .as_autodiff()
270                    .grad_remove(grads)
271                    .map(|t| DispatchTensorKind::LibTorch(crate::BackendTensor::Float(t))),
272                DispatchTensorKind::Autodiff(_) => {
273                    panic!("Autodiff should not wrap an autodiff tensor.")
274                }
275            },
276            _ => panic!("Requires autodiff tensor."),
277        };
278        grad.map(|kind| DispatchTensor {
279            kind,
280            checkpointing: *checkpointing,
281        })
282    }
283
284    fn grad_replace(tensor: &DispatchTensor, grads: &mut Self::Gradients, grad: DispatchTensor) {
285        let DispatchTensor {
286            kind,
287            checkpointing,
288        } = tensor;
289        let DispatchTensor {
290            kind: grad,
291            checkpointing: grad_ckp,
292        } = grad;
293        debug_assert_eq!(checkpointing, &grad_ckp);
294
295        match &kind {
296            #[cfg(feature = "autodiff")]
297            DispatchTensorKind::Autodiff(inner_kind) => match (&**inner_kind, grad) {
298                #[cfg(feature = "cpu")]
299                (DispatchTensorKind::Cpu(tensor), DispatchTensorKind::Cpu(grad)) => {
300                    tensor.as_autodiff().grad_replace(grads, grad.float())
301                }
302                #[cfg(feature = "cuda")]
303                (DispatchTensorKind::Cuda(tensor), DispatchTensorKind::Cuda(grad)) => {
304                    tensor.as_autodiff().grad_replace(grads, grad.float())
305                }
306                #[cfg(wgpu_metal)]
307                (DispatchTensorKind::Metal(tensor), DispatchTensorKind::Metal(grad)) => {
308                    tensor.as_autodiff().grad_replace(grads, grad.float())
309                }
310                #[cfg(feature = "rocm")]
311                (DispatchTensorKind::Rocm(tensor), DispatchTensorKind::Rocm(grad)) => {
312                    tensor.as_autodiff().grad_replace(grads, grad.float())
313                }
314                #[cfg(wgpu_vulkan)]
315                (DispatchTensorKind::Vulkan(tensor), DispatchTensorKind::Vulkan(grad)) => {
316                    tensor.as_autodiff().grad_replace(grads, grad.float())
317                }
318                #[cfg(wgpu_webgpu)]
319                (DispatchTensorKind::Wgpu(tensor), DispatchTensorKind::Wgpu(grad)) => {
320                    tensor.as_autodiff().grad_replace(grads, grad.float())
321                }
322                #[cfg(feature = "flex")]
323                (DispatchTensorKind::Flex(tensor), DispatchTensorKind::Flex(grad)) => {
324                    tensor.as_autodiff().grad_replace(grads, grad.float())
325                }
326                #[cfg(feature = "ndarray")]
327                (DispatchTensorKind::NdArray(tensor), DispatchTensorKind::NdArray(grad)) => {
328                    tensor.as_autodiff().grad_replace(grads, grad.float())
329                }
330                (DispatchTensorKind::Autodiff(_), _) => {
331                    panic!("Autodiff should not wrap an autodiff tensor.")
332                }
333                (t, g) => panic!(
334                    "The provided tensors are not on the same backend. Got backends {t:?} and {g:?}."
335                ),
336            },
337            _ => panic!("Requires autodiff tensor."),
338        }
339    }
340
341    fn inner(tensor: DispatchTensor) -> DispatchTensor {
342        let DispatchTensor {
343            kind,
344            checkpointing,
345        } = tensor;
346
347        let kind = match kind {
348            #[cfg(feature = "autodiff")]
349            DispatchTensorKind::Autodiff(inner_kind) => match *inner_kind {
350                #[cfg(feature = "cpu")]
351                DispatchTensorKind::Cpu(tensor) => DispatchTensorKind::Cpu(
352                    crate::BackendTensor::Float(tensor.autodiff().primitive),
353                ),
354                #[cfg(feature = "cuda")]
355                DispatchTensorKind::Cuda(tensor) => DispatchTensorKind::Cuda(
356                    crate::BackendTensor::Float(tensor.autodiff().primitive),
357                ),
358                #[cfg(wgpu_metal)]
359                DispatchTensorKind::Metal(tensor) => DispatchTensorKind::Metal(
360                    crate::BackendTensor::Float(tensor.autodiff().primitive),
361                ),
362                #[cfg(feature = "rocm")]
363                DispatchTensorKind::Rocm(tensor) => DispatchTensorKind::Rocm(
364                    crate::BackendTensor::Float(tensor.autodiff().primitive),
365                ),
366                #[cfg(wgpu_vulkan)]
367                DispatchTensorKind::Vulkan(tensor) => DispatchTensorKind::Vulkan(
368                    crate::BackendTensor::Float(tensor.autodiff().primitive),
369                ),
370                #[cfg(wgpu_webgpu)]
371                DispatchTensorKind::Wgpu(tensor) => DispatchTensorKind::Wgpu(
372                    crate::BackendTensor::Float(tensor.autodiff().primitive),
373                ),
374                #[cfg(feature = "flex")]
375                DispatchTensorKind::Flex(tensor) => DispatchTensorKind::Flex(
376                    crate::BackendTensor::Float(tensor.autodiff().primitive),
377                ),
378                #[cfg(feature = "ndarray")]
379                DispatchTensorKind::NdArray(tensor) => DispatchTensorKind::NdArray(
380                    crate::BackendTensor::Float(tensor.autodiff().primitive),
381                ),
382                #[cfg(feature = "tch")]
383                DispatchTensorKind::LibTorch(tensor) => DispatchTensorKind::LibTorch(
384                    crate::BackendTensor::Float(tensor.autodiff().primitive),
385                ),
386                DispatchTensorKind::Autodiff(_) => {
387                    panic!("Autodiff should not wrap an autodiff tensor.")
388                }
389            },
390            _ => panic!("Requires autodiff tensor."),
391        };
392        DispatchTensor {
393            kind,
394            checkpointing,
395        }
396    }
397
398    fn int_inner(tensor: DispatchTensor) -> DispatchTensor {
399        tensor
400    }
401
402    fn bool_inner(tensor: DispatchTensor) -> DispatchTensor {
403        tensor
404    }
405
406    fn q_inner(tensor: DispatchTensor) -> DispatchTensor {
407        tensor
408    }
409
410    fn from_inner(tensor: DispatchTensor) -> DispatchTensor {
411        let DispatchTensor {
412            kind,
413            checkpointing,
414        } = tensor;
415
416        let kind = match kind {
417            #[cfg(feature = "cpu")]
418            DispatchTensorKind::Cpu(tensor) => DispatchTensorKind::Autodiff(Box::new(
419                DispatchTensorKind::Cpu(crate::BackendTensor::Autodiff(
420                    Autodiff::<Cpu<f32>>::from_inner(tensor.float()),
421                )),
422            )),
423            #[cfg(feature = "cuda")]
424            DispatchTensorKind::Cuda(tensor) => DispatchTensorKind::Autodiff(Box::new(
425                DispatchTensorKind::Cuda(crate::BackendTensor::Autodiff(
426                    Autodiff::<Cuda<f32>>::from_inner(tensor.float()),
427                )),
428            )),
429            #[cfg(wgpu_metal)]
430            DispatchTensorKind::Metal(tensor) => DispatchTensorKind::Autodiff(Box::new(
431                DispatchTensorKind::Metal(crate::BackendTensor::Autodiff(
432                    Autodiff::<Metal<f32>>::from_inner(tensor.float()),
433                )),
434            )),
435            #[cfg(feature = "rocm")]
436            DispatchTensorKind::Rocm(tensor) => DispatchTensorKind::Autodiff(Box::new(
437                DispatchTensorKind::Rocm(crate::BackendTensor::Autodiff(
438                    Autodiff::<Rocm<f32>>::from_inner(tensor.float()),
439                )),
440            )),
441            #[cfg(wgpu_vulkan)]
442            DispatchTensorKind::Vulkan(tensor) => DispatchTensorKind::Autodiff(Box::new(
443                DispatchTensorKind::Vulkan(crate::BackendTensor::Autodiff(
444                    Autodiff::<Vulkan<f32>>::from_inner(tensor.float()),
445                )),
446            )),
447            #[cfg(wgpu_webgpu)]
448            DispatchTensorKind::Wgpu(tensor) => DispatchTensorKind::Autodiff(Box::new(
449                DispatchTensorKind::Wgpu(crate::BackendTensor::Autodiff(
450                    Autodiff::<Wgpu<f32>>::from_inner(tensor.float()),
451                )),
452            )),
453            #[cfg(feature = "flex")]
454            DispatchTensorKind::Flex(tensor) => {
455                DispatchTensorKind::Autodiff(Box::new(DispatchTensorKind::Flex(
456                    crate::BackendTensor::Autodiff(Autodiff::<Flex>::from_inner(tensor.float())),
457                )))
458            }
459            #[cfg(feature = "ndarray")]
460            DispatchTensorKind::NdArray(tensor) => DispatchTensorKind::Autodiff(Box::new(
461                DispatchTensorKind::NdArray(crate::BackendTensor::Autodiff(
462                    Autodiff::<NdArray<f32>>::from_inner(tensor.float()),
463                )),
464            )),
465            #[cfg(feature = "tch")]
466            DispatchTensorKind::LibTorch(tensor) => {
467                DispatchTensorKind::Autodiff(Box::new(DispatchTensorKind::LibTorch(
468                    crate::BackendTensor::Autodiff(Autodiff::<LibTorch<f32>>::from_inner(
469                        tensor.float(),
470                    )),
471                )))
472            }
473            DispatchTensorKind::Autodiff(_) => {
474                panic!("Autodiff should not wrap an autodiff tensor.")
475            }
476        };
477        DispatchTensor {
478            kind,
479            checkpointing,
480        }
481    }
482
483    fn int_from_inner(tensor: DispatchTensor) -> DispatchTensor {
484        tensor
485    }
486
487    fn bool_from_inner(tensor: DispatchTensor) -> DispatchTensor {
488        tensor
489    }
490
491    fn q_from_inner(tensor: DispatchTensor) -> DispatchTensor {
492        tensor
493    }
494}
495
496impl DispatchTensorKind {
497    pub(crate) fn device(&self) -> DispatchDevice {
498        match self {
499            #[cfg(feature = "cpu")]
500            DispatchTensorKind::Cpu(tensor) => DispatchDevice::Cpu(tensor.device()),
501            #[cfg(feature = "cuda")]
502            DispatchTensorKind::Cuda(tensor) => DispatchDevice::Cuda(tensor.device()),
503            #[cfg(wgpu_metal)]
504            DispatchTensorKind::Metal(tensor) => DispatchDevice::Metal(tensor.device()),
505            #[cfg(feature = "rocm")]
506            DispatchTensorKind::Rocm(tensor) => DispatchDevice::Rocm(tensor.device()),
507            #[cfg(wgpu_vulkan)]
508            DispatchTensorKind::Vulkan(tensor) => DispatchDevice::Vulkan(tensor.device()),
509            #[cfg(wgpu_webgpu)]
510            DispatchTensorKind::Wgpu(tensor) => DispatchDevice::Wgpu(tensor.device()),
511            #[cfg(feature = "flex")]
512            DispatchTensorKind::Flex(tensor) => DispatchDevice::Flex(tensor.device()),
513            #[cfg(feature = "ndarray")]
514            DispatchTensorKind::NdArray(tensor) => DispatchDevice::NdArray(tensor.device()),
515            #[cfg(feature = "tch")]
516            DispatchTensorKind::LibTorch(tensor) => DispatchDevice::LibTorch(tensor.device()),
517            #[cfg(feature = "autodiff")]
518            DispatchTensorKind::Autodiff(tensor) => DispatchDevice::autodiff(tensor.device()),
519        }
520    }
521}
522
523impl DispatchTensor {
524    pub(crate) fn device(&self) -> DispatchDevice {
525        #[allow(unused_mut)]
526        let mut device = self.kind.device();
527
528        #[cfg(feature = "autodiff")]
529        if let DispatchDevice::Autodiff(device) = &mut device {
530            device.checkpointing = self.checkpointing;
531        }
532
533        device
534    }
535}
536
537impl Dispatch {
538    /// Returns the default tensor quantization scheme for the device.
539    // TODO: replace this + QTensorPrimitive trait method with better API.
540    // This is temporary, for test purposes.
541    pub fn default_quant_scheme(device: &Device<Self>) -> QuantScheme {
542        match device {
543            #[cfg(feature = "cpu")]
544            DispatchDevice::Cpu(_) => <QuantizedTensor<Cpu> as QTensorPrimitive>::default_scheme(),
545            #[cfg(feature = "cuda")]
546            DispatchDevice::Cuda(_) => {
547                <QuantizedTensor<Cuda> as QTensorPrimitive>::default_scheme()
548            }
549            #[cfg(wgpu_metal)]
550            DispatchDevice::Metal(_) => {
551                <QuantizedTensor<Metal> as QTensorPrimitive>::default_scheme()
552            }
553            #[cfg(feature = "rocm")]
554            DispatchDevice::Rocm(_) => {
555                <QuantizedTensor<Rocm> as QTensorPrimitive>::default_scheme()
556            }
557            #[cfg(wgpu_vulkan)]
558            DispatchDevice::Vulkan(_) => {
559                <QuantizedTensor<Vulkan> as QTensorPrimitive>::default_scheme()
560            }
561            #[cfg(wgpu_webgpu)]
562            DispatchDevice::Wgpu(_) => {
563                <QuantizedTensor<Wgpu> as QTensorPrimitive>::default_scheme()
564            }
565            #[cfg(feature = "flex")]
566            DispatchDevice::Flex(_) => {
567                <QuantizedTensor<Flex> as QTensorPrimitive>::default_scheme()
568            }
569            #[cfg(feature = "ndarray")]
570            DispatchDevice::NdArray(_) => {
571                <QuantizedTensor<NdArray> as QTensorPrimitive>::default_scheme()
572            }
573            #[cfg(feature = "tch")]
574            DispatchDevice::LibTorch(_) => {
575                <QuantizedTensor<LibTorch> as QTensorPrimitive>::default_scheme()
576            }
577            #[cfg(feature = "autodiff")]
578            DispatchDevice::Autodiff(ad_device) => Self::default_quant_scheme(&ad_device.inner),
579        }
580    }
581}