Skip to main content

burn_dispatch/
device.rs

1use alloc::boxed::Box;
2
3use burn_backend::{DeviceId, DeviceOps};
4
5use crate::backends::*;
6
7/// Represents a device for the [`Dispatch`](crate::Dispatch).
8///
9/// Each variant corresponds to a backend that the [`Dispatch`](crate::Dispatch) can dispatch operations to.
10///
11/// # Example
12///
13/// ```ignore
14/// use burn::DispatchDevice;
15///
16/// #[cfg(feature = "cpu")]
17/// let cpu_device = DispatchDevice::Cpu(Default::default());
18///
19/// #[cfg(feature = "cuda")]
20/// let cuda_device = DispatchDevice::Cuda(Default::default());
21/// ```
22#[derive(Clone, Eq)]
23pub enum DispatchDevice {
24    /// The [CPU backend](Cpu) device.
25    #[cfg(feature = "cpu")]
26    Cpu(CpuDevice),
27
28    /// The [CUDA backend](Cuda) device.
29    #[cfg(feature = "cuda")]
30    Cuda(CudaDevice),
31
32    /// The [Metal backend](Metal) device (via WGPU runtime).
33    #[cfg(wgpu_metal)]
34    Metal(WgpuDevice),
35
36    /// The [ROCm backend](Rocm) device.
37    #[cfg(feature = "rocm")]
38    Rocm(RocmDevice),
39
40    /// The [Vulkan backend](Vulkan) device.
41    #[cfg(wgpu_vulkan)]
42    Vulkan(WgpuDevice),
43
44    /// The [WebGPU backend](Wgpu) device (via WGPU runtime).
45    #[cfg(wgpu_webgpu)]
46    Wgpu(WgpuDevice),
47
48    /// The [Flex backend](Flex) device (CPU-only).
49    #[cfg(feature = "flex")]
50    Flex(FlexDevice),
51
52    /// The [NdArray backend](NdArray) device (CPU-only).
53    #[cfg(feature = "ndarray")]
54    NdArray(NdArrayDevice),
55
56    /// The [LibTorch backend](LibTorch) device.
57    #[cfg(feature = "tch")]
58    LibTorch(LibTorchDevice),
59
60    /// The [autodiff enabled backend](Autodiff) device.
61    #[cfg(feature = "autodiff")]
62    Autodiff(AutodiffDevice),
63}
64
65#[cfg(feature = "autodiff")]
66// This tuple struct mainly restricts users from creating Autodiff(Autodiff) devices.
67/// A wrapper that enables automatic differentiation for a [`DispatchDevice`].
68///
69/// Use [`DispatchDevice::autodiff`] to construct this type.
70#[derive(Debug, Clone, PartialEq, Eq)]
71pub struct AutodiffDevice {
72    pub(crate) inner: Box<DispatchDevice>,
73    pub(crate) checkpointing: CheckpointingStrategy,
74}
75
76#[cfg(feature = "autodiff")]
77impl AutodiffDevice {
78    pub(crate) fn new(device: DispatchDevice, checkpointing: CheckpointingStrategy) -> Self {
79        Self {
80            inner: Box::new(device),
81            checkpointing,
82        }
83    }
84}
85
86#[cfg(feature = "autodiff")]
87// Useful for match in dispatch macros
88impl core::ops::Deref for AutodiffDevice {
89    type Target = DispatchDevice;
90
91    fn deref(&self) -> &Self::Target {
92        &self.inner
93    }
94}
95
96#[cfg(feature = "autodiff")]
97#[allow(missing_docs)]
98#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
99/// Checkpointing strategy for autodiff.
100#[repr(u8)]
101pub enum CheckpointingStrategy {
102    Balanced,
103    #[default]
104    None,
105}
106
107#[cfg(feature = "autodiff")]
108pub(crate) fn validate_checkpointing(
109    lhs: crate::CheckpointingStrategy,
110    rhs: crate::CheckpointingStrategy,
111) -> crate::CheckpointingStrategy {
112    assert_eq!(
113        lhs, rhs,
114        "Autodiff strategy mismatch: {lhs:?} vs {rhs:?}. Tensors in the same operation must share a strategy."
115    );
116    lhs
117}
118
119impl core::fmt::Debug for DispatchDevice {
120    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
121        match self {
122            #[cfg(feature = "cpu")]
123            Self::Cpu(device) => f.debug_tuple("Cpu").field(device).finish(),
124            #[cfg(feature = "cuda")]
125            Self::Cuda(device) => f.debug_tuple("Cuda").field(device).finish(),
126            #[cfg(wgpu_metal)]
127            Self::Metal(device) => f.debug_tuple("Metal").field(device).finish(),
128            #[cfg(feature = "rocm")]
129            Self::Rocm(device) => f.debug_tuple("Rocm").field(device).finish(),
130            #[cfg(wgpu_vulkan)]
131            Self::Vulkan(device) => f.debug_tuple("Vulkan").field(device).finish(),
132            #[cfg(wgpu_webgpu)]
133            Self::Wgpu(device) => f.debug_tuple("Wgpu").field(device).finish(),
134            #[cfg(feature = "flex")]
135            Self::Flex(device) => f.debug_tuple("Flex").field(device).finish(),
136            #[cfg(feature = "ndarray")]
137            Self::NdArray(device) => f.debug_tuple("NdArray").field(device).finish(),
138            #[cfg(feature = "tch")]
139            Self::LibTorch(device) => f.debug_tuple("LibTorch").field(device).finish(),
140            #[cfg(feature = "autodiff")]
141            // Format without `AutodiffDevice` wrapper
142            Self::Autodiff(device) => f.debug_tuple("Autodiff").field(&device.inner).finish(),
143        }
144    }
145}
146
147impl Default for DispatchDevice {
148    #[allow(unreachable_code)]
149    fn default() -> Self {
150        // TODO: which priority?
151
152        #[cfg(feature = "cuda")]
153        return Self::Cuda(CudaDevice::default());
154
155        #[cfg(wgpu_metal)]
156        return Self::Metal(burn_wgpu::WgpuDevice::default());
157
158        #[cfg(feature = "rocm")]
159        return Self::Rocm(RocmDevice::default());
160
161        #[cfg(wgpu_vulkan)]
162        return Self::Vulkan(burn_wgpu::WgpuDevice::default());
163
164        #[cfg(wgpu_webgpu)]
165        return Self::Wgpu(burn_wgpu::WgpuDevice::default());
166
167        #[cfg(feature = "cpu")]
168        return Self::Cpu(CpuDevice);
169
170        #[cfg(feature = "tch")]
171        return Self::LibTorch(LibTorchDevice::default());
172
173        // Prefer Flex over NdArray when both are enabled: Flex is the long-term
174        // CPU backend replacement and should win the default tie.
175        #[cfg(feature = "flex")]
176        return Self::Flex(FlexDevice);
177
178        #[cfg(feature = "ndarray")]
179        return Self::NdArray(NdArrayDevice::default());
180    }
181}
182
183impl PartialEq for DispatchDevice {
184    fn eq(&self, other: &Self) -> bool {
185        match (self, other) {
186            // If both are Autodiff, compare the inner devices
187            #[cfg(feature = "autodiff")]
188            (DispatchDevice::Autodiff(a), DispatchDevice::Autodiff(b)) => a == b,
189            // If one is Autodiff, compare it to the raw device
190            #[cfg(feature = "autodiff")]
191            (DispatchDevice::Autodiff(a), b) => a.inner.as_ref() == b,
192            #[cfg(feature = "autodiff")]
193            (a, DispatchDevice::Autodiff(b)) => a == b.inner.as_ref(),
194            #[cfg(feature = "cpu")]
195            (Self::Cpu(a), Self::Cpu(b)) => a == b,
196            #[cfg(feature = "cuda")]
197            (Self::Cuda(a), Self::Cuda(b)) => a == b,
198            #[cfg(wgpu_metal)]
199            (Self::Metal(a), Self::Metal(b)) => a == b,
200            #[cfg(feature = "rocm")]
201            (Self::Rocm(a), Self::Rocm(b)) => a == b,
202            #[cfg(wgpu_vulkan)]
203            (Self::Vulkan(a), Self::Vulkan(b)) => a == b,
204            #[cfg(wgpu_webgpu)]
205            (Self::Wgpu(a), Self::Wgpu(b)) => a == b,
206            #[cfg(feature = "flex")]
207            (Self::Flex(a), Self::Flex(b)) => a == b,
208            #[cfg(feature = "ndarray")]
209            (Self::NdArray(a), Self::NdArray(b)) => a == b,
210            #[cfg(feature = "tch")]
211            (Self::LibTorch(a), Self::LibTorch(b)) => a == b,
212            #[allow(unreachable_patterns)]
213            (_, _) => false,
214        }
215    }
216}
217
218/// Base multiplier to avoid type_id clashes between backends.
219/// Limits the number of device types per backend, but this is a sensible limit.
220const TYPE_ID_BASE: u16 = 10;
221
222impl DispatchDevice {
223    #[cfg(feature = "autodiff")]
224    /// Creates a new [`DispatchDevice`] with [automatic differentiation](Autodiff) enabled.
225    pub fn autodiff(device: impl Into<DispatchDevice>) -> DispatchDevice {
226        Self::autodiff_checkpointed(device, CheckpointingStrategy::None)
227    }
228    #[cfg(feature = "autodiff")]
229    /// Creates a new [`DispatchDevice`] with [automatic differentiation](Autodiff) enabled.
230    pub fn autodiff_checkpointed(
231        device: impl Into<DispatchDevice>,
232        checkpointing: CheckpointingStrategy,
233    ) -> DispatchDevice {
234        let device = device.into();
235        DispatchDevice::Autodiff(AutodiffDevice::new(device, checkpointing))
236    }
237
238    /// Returns a unique number per variant to encode into type_id.
239    fn backend_id(&self) -> BackendId {
240        match self {
241            #[cfg(feature = "cpu")]
242            Self::Cpu(_) => BackendId::Cpu,
243            #[cfg(feature = "cuda")]
244            Self::Cuda(_) => BackendId::Cuda,
245            #[cfg(wgpu_metal)]
246            Self::Metal(_) => BackendId::Metal,
247            #[cfg(feature = "rocm")]
248            Self::Rocm(_) => BackendId::Rocm,
249            #[cfg(wgpu_vulkan)]
250            Self::Vulkan(_) => BackendId::Vulkan,
251            #[cfg(wgpu_webgpu)]
252            Self::Wgpu(_) => BackendId::Wgpu,
253            #[cfg(feature = "flex")]
254            Self::Flex(_) => BackendId::Flex,
255            #[cfg(feature = "ndarray")]
256            Self::NdArray(_) => BackendId::NdArray,
257            #[cfg(feature = "tch")]
258            Self::LibTorch(_) => BackendId::LibTorch,
259            #[cfg(feature = "autodiff")]
260            Self::Autodiff(device) => device.inner.backend_id(),
261        }
262    }
263
264    /// Encode variant ID and backend type ID into a unique `type_id`.
265    fn encode_type_id(&self, backend_type_id: u16) -> u16 {
266        u16::from(self.backend_id()) * TYPE_ID_BASE + backend_type_id
267    }
268
269    /// Decode an encoded `type_id` into variant ID and backend type ID.
270    pub(crate) fn decode_type_id(type_id: u16) -> (BackendId, u16) {
271        let variant = type_id / TYPE_ID_BASE;
272        let backend_type_id = type_id % TYPE_ID_BASE;
273        (
274            BackendId::try_from(variant).expect("Unknown DispatchDevice variant"),
275            backend_type_id,
276        )
277    }
278}
279
280#[derive(Debug, Clone, Copy, PartialEq, Eq)]
281#[repr(u16)]
282pub(crate) enum BackendId {
283    #[cfg(feature = "cpu")]
284    Cpu = 0,
285    #[cfg(feature = "cuda")]
286    Cuda = 1,
287    #[cfg(wgpu_metal)]
288    Metal = 2,
289    #[cfg(feature = "rocm")]
290    Rocm = 3,
291    #[cfg(wgpu_vulkan)]
292    Vulkan = 4,
293    #[cfg(wgpu_webgpu)]
294    Wgpu = 5,
295    #[cfg(feature = "ndarray")]
296    NdArray = 6,
297    #[cfg(feature = "tch")]
298    LibTorch = 7,
299    #[cfg(feature = "flex")]
300    Flex = 8,
301}
302
303impl From<BackendId> for u16 {
304    fn from(variant: BackendId) -> Self {
305        variant as u16
306    }
307}
308
309impl TryFrom<u16> for BackendId {
310    type Error = ();
311
312    fn try_from(value: u16) -> Result<Self, Self::Error> {
313        match value {
314            #[cfg(feature = "cpu")]
315            0 => Ok(Self::Cpu),
316            #[cfg(feature = "cuda")]
317            1 => Ok(Self::Cuda),
318            #[cfg(wgpu_metal)]
319            2 => Ok(Self::Metal),
320            #[cfg(feature = "rocm")]
321            3 => Ok(Self::Rocm),
322            #[cfg(wgpu_vulkan)]
323            4 => Ok(Self::Vulkan),
324            #[cfg(wgpu_webgpu)]
325            5 => Ok(Self::Wgpu),
326            #[cfg(feature = "ndarray")]
327            6 => Ok(Self::NdArray),
328            #[cfg(feature = "tch")]
329            7 => Ok(Self::LibTorch),
330            #[cfg(feature = "flex")]
331            8 => Ok(Self::Flex),
332            _ => Err(()),
333        }
334    }
335}
336
337impl DeviceOps for DispatchDevice {
338    fn inner(&self) -> &Self {
339        match self {
340            #[cfg(feature = "autodiff")]
341            DispatchDevice::Autodiff(device) => &device.inner,
342            device => device,
343        }
344    }
345}
346
347impl burn_backend::Device for DispatchDevice {
348    fn from_id(mut device_id: DeviceId) -> Self {
349        let (dispatch_id, backend_type_id) = Self::decode_type_id(device_id.type_id);
350        device_id.type_id = backend_type_id;
351
352        match dispatch_id {
353            #[cfg(feature = "cpu")]
354            BackendId::Cpu => Self::Cpu(CpuDevice::from_id(device_id)),
355            #[cfg(feature = "cuda")]
356            BackendId::Cuda => Self::Cuda(CudaDevice::from_id(device_id)),
357            #[cfg(wgpu_metal)]
358            BackendId::Metal => Self::Metal(WgpuDevice::from_id(device_id)),
359            #[cfg(feature = "rocm")]
360            BackendId::Rocm => Self::Rocm(RocmDevice::from_id(device_id)),
361            #[cfg(wgpu_vulkan)]
362            BackendId::Vulkan => Self::Vulkan(WgpuDevice::from_id(device_id)),
363            #[cfg(wgpu_webgpu)]
364            BackendId::Wgpu => Self::Wgpu(WgpuDevice::from_id(device_id)),
365            #[cfg(feature = "flex")]
366            BackendId::Flex => Self::Flex(FlexDevice::from_id(device_id)),
367            #[cfg(feature = "ndarray")]
368            BackendId::NdArray => Self::NdArray(NdArrayDevice::from_id(device_id)),
369            #[cfg(feature = "tch")]
370            BackendId::LibTorch => Self::LibTorch(LibTorchDevice::from_id(device_id)),
371        }
372    }
373
374    fn to_id(&self) -> DeviceId {
375        let mut device_id = match self {
376            #[cfg(feature = "cpu")]
377            Self::Cpu(device) => device.to_id(),
378            #[cfg(feature = "cuda")]
379            Self::Cuda(device) => device.to_id(),
380            #[cfg(wgpu_metal)]
381            Self::Metal(device) => device.to_id(),
382            #[cfg(feature = "rocm")]
383            Self::Rocm(device) => device.to_id(),
384            #[cfg(wgpu_vulkan)]
385            Self::Vulkan(device) => device.to_id(),
386            #[cfg(wgpu_webgpu)]
387            Self::Wgpu(device) => device.to_id(),
388            #[cfg(feature = "flex")]
389            Self::Flex(device) => device.to_id(),
390            #[cfg(feature = "ndarray")]
391            Self::NdArray(device) => device.to_id(),
392            #[cfg(feature = "tch")]
393            Self::LibTorch(device) => device.to_id(),
394            #[cfg(feature = "autodiff")]
395            Self::Autodiff(device) => device.inner.to_id(),
396        };
397        device_id.type_id = self.encode_type_id(device_id.type_id);
398        device_id
399    }
400}
401
402#[cfg(feature = "cpu")]
403impl From<CpuDevice> for DispatchDevice {
404    fn from(device: CpuDevice) -> Self {
405        DispatchDevice::Cpu(device)
406    }
407}
408
409#[cfg(feature = "cuda")]
410impl From<CudaDevice> for DispatchDevice {
411    fn from(device: CudaDevice) -> Self {
412        DispatchDevice::Cuda(device)
413    }
414}
415
416#[cfg(wgpu_metal)]
417impl From<WgpuDevice> for DispatchDevice {
418    fn from(device: WgpuDevice) -> Self {
419        DispatchDevice::Metal(device)
420    }
421}
422
423#[cfg(feature = "rocm")]
424impl From<RocmDevice> for DispatchDevice {
425    fn from(device: RocmDevice) -> Self {
426        DispatchDevice::Rocm(device)
427    }
428}
429
430#[cfg(wgpu_vulkan)]
431impl From<WgpuDevice> for DispatchDevice {
432    fn from(device: WgpuDevice) -> Self {
433        DispatchDevice::Vulkan(device)
434    }
435}
436
437#[cfg(wgpu_webgpu)]
438impl From<WgpuDevice> for DispatchDevice {
439    fn from(device: WgpuDevice) -> Self {
440        DispatchDevice::Wgpu(device)
441    }
442}
443
444#[cfg(feature = "flex")]
445impl From<FlexDevice> for DispatchDevice {
446    fn from(device: FlexDevice) -> Self {
447        DispatchDevice::Flex(device)
448    }
449}
450
451#[cfg(feature = "ndarray")]
452impl From<NdArrayDevice> for DispatchDevice {
453    fn from(device: NdArrayDevice) -> Self {
454        DispatchDevice::NdArray(device)
455    }
456}
457
458#[cfg(feature = "tch")]
459impl From<LibTorchDevice> for DispatchDevice {
460    fn from(device: LibTorchDevice) -> Self {
461        DispatchDevice::LibTorch(device)
462    }
463}