Skip to main content

burn_dispatch/
device.rs

1use alloc::boxed::Box;
2
3use burn_backend::{DeviceId, DeviceOps};
4
5use crate::backends::*;
6
7/// Represents a device for the [`Dispatch`](crate::Dispatch).
8///
9/// Each variant corresponds to a backend that the [`Dispatch`](crate::Dispatch) can dispatch operations to.
10///
11/// # Example
12///
13/// ```ignore
14/// use burn::DispatchDevice;
15///
16/// #[cfg(feature = "cpu")]
17/// let cpu_device = DispatchDevice::Cpu(Default::default());
18///
19/// #[cfg(feature = "cuda")]
20/// let cuda_device = DispatchDevice::Cuda(Default::default());
21/// ```
22#[derive(Clone, Eq)]
23pub enum DispatchDevice {
24    /// The [CPU backend](Cpu) device.
25    #[cfg(feature = "cpu")]
26    Cpu(CpuDevice),
27
28    /// The [CUDA backend](Cuda) device.
29    #[cfg(feature = "cuda")]
30    Cuda(CudaDevice),
31
32    /// The [Metal backend](Metal) device (via WGPU runtime).
33    #[cfg(wgpu_metal)]
34    Metal(WgpuDevice),
35
36    /// The [ROCm backend](Rocm) device.
37    #[cfg(feature = "rocm")]
38    Rocm(RocmDevice),
39
40    /// The [Vulkan backend](Vulkan) device.
41    #[cfg(wgpu_vulkan)]
42    Vulkan(WgpuDevice),
43
44    /// The [WebGPU backend](Wgpu) device (via WGPU runtime).
45    #[cfg(wgpu_webgpu)]
46    Wgpu(WgpuDevice),
47
48    /// The [NdArray backend](NdArray) device (CPU-only).
49    #[cfg(feature = "ndarray")]
50    NdArray(NdArrayDevice),
51
52    /// The [LibTorch backend](LibTorch) device.
53    #[cfg(feature = "tch")]
54    LibTorch(LibTorchDevice),
55
56    /// The [autodiff enabled backend](Autodiff) device.
57    #[cfg(feature = "autodiff")]
58    Autodiff(AutodiffDevice),
59}
60
61#[cfg(feature = "autodiff")]
62// This tuple struct mainly restricts users from creating Autodiff(Autodiff) devices.
63/// A wrapper that enables automatic differentiation for a [`DispatchDevice`].
64///
65/// Use [`DispatchDevice::autodiff`] to construct this type.
66#[derive(Debug, Clone, PartialEq, Eq)]
67pub struct AutodiffDevice {
68    pub(crate) inner: Box<DispatchDevice>,
69    pub(crate) checkpointing: CheckpointingStrategy,
70}
71
72#[cfg(feature = "autodiff")]
73impl AutodiffDevice {
74    pub(crate) fn new(device: DispatchDevice, checkpointing: CheckpointingStrategy) -> Self {
75        Self {
76            inner: Box::new(device),
77            checkpointing,
78        }
79    }
80}
81
82#[cfg(feature = "autodiff")]
83// Useful for match in dispatch macros
84impl core::ops::Deref for AutodiffDevice {
85    type Target = DispatchDevice;
86
87    fn deref(&self) -> &Self::Target {
88        &self.inner
89    }
90}
91
92#[cfg(feature = "autodiff")]
93#[allow(missing_docs)]
94#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
95/// Checkpointing strategy for autodiff.
96#[repr(u8)]
97pub enum CheckpointingStrategy {
98    Balanced,
99    #[default]
100    None,
101}
102
103#[cfg(feature = "autodiff")]
104pub(crate) fn validate_checkpointing(
105    lhs: crate::CheckpointingStrategy,
106    rhs: crate::CheckpointingStrategy,
107) -> crate::CheckpointingStrategy {
108    assert_eq!(
109        lhs, rhs,
110        "Autodiff strategy mismatch: {lhs:?} vs {rhs:?}. Tensors in the same operation must share a strategy."
111    );
112    lhs
113}
114
115impl core::fmt::Debug for DispatchDevice {
116    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
117        match self {
118            #[cfg(feature = "cpu")]
119            Self::Cpu(device) => f.debug_tuple("Cpu").field(device).finish(),
120            #[cfg(feature = "cuda")]
121            Self::Cuda(device) => f.debug_tuple("Cuda").field(device).finish(),
122            #[cfg(wgpu_metal)]
123            Self::Metal(device) => f.debug_tuple("Metal").field(device).finish(),
124            #[cfg(feature = "rocm")]
125            Self::Rocm(device) => f.debug_tuple("Rocm").field(device).finish(),
126            #[cfg(wgpu_vulkan)]
127            Self::Vulkan(device) => f.debug_tuple("Vulkan").field(device).finish(),
128            #[cfg(wgpu_webgpu)]
129            Self::Wgpu(device) => f.debug_tuple("Wgpu").field(device).finish(),
130            #[cfg(feature = "ndarray")]
131            Self::NdArray(device) => f.debug_tuple("NdArray").field(device).finish(),
132            #[cfg(feature = "tch")]
133            Self::LibTorch(device) => f.debug_tuple("LibTorch").field(device).finish(),
134            #[cfg(feature = "autodiff")]
135            // Format without `AutodiffDevice` wrapper
136            Self::Autodiff(device) => f.debug_tuple("Autodiff").field(&device.inner).finish(),
137        }
138    }
139}
140
141impl Default for DispatchDevice {
142    #[allow(unreachable_code)]
143    fn default() -> Self {
144        // TODO: which priority?
145
146        #[cfg(feature = "cuda")]
147        return Self::Cuda(CudaDevice::default());
148
149        #[cfg(wgpu_metal)]
150        return Self::Metal(burn_wgpu::WgpuDevice::default());
151
152        #[cfg(feature = "rocm")]
153        return Self::Rocm(RocmDevice::default());
154
155        #[cfg(wgpu_vulkan)]
156        return Self::Vulkan(burn_wgpu::WgpuDevice::default());
157
158        #[cfg(wgpu_webgpu)]
159        return Self::Wgpu(burn_wgpu::WgpuDevice::default());
160
161        #[cfg(feature = "cpu")]
162        return Self::Cpu(CpuDevice);
163
164        #[cfg(feature = "tch")]
165        return Self::LibTorch(LibTorchDevice::default());
166
167        #[cfg(feature = "ndarray")]
168        return Self::NdArray(NdArrayDevice::default());
169    }
170}
171
172impl PartialEq for DispatchDevice {
173    fn eq(&self, other: &Self) -> bool {
174        match (self, other) {
175            // If both are Autodiff, compare the inner devices
176            #[cfg(feature = "autodiff")]
177            (DispatchDevice::Autodiff(a), DispatchDevice::Autodiff(b)) => a == b,
178            // If one is Autodiff, compare it to the raw device
179            #[cfg(feature = "autodiff")]
180            (DispatchDevice::Autodiff(a), b) => a.inner.as_ref() == b,
181            #[cfg(feature = "autodiff")]
182            (a, DispatchDevice::Autodiff(b)) => a == b.inner.as_ref(),
183            #[cfg(feature = "cpu")]
184            (Self::Cpu(a), Self::Cpu(b)) => a == b,
185            #[cfg(feature = "cuda")]
186            (Self::Cuda(a), Self::Cuda(b)) => a == b,
187            #[cfg(wgpu_metal)]
188            (Self::Metal(a), Self::Metal(b)) => a == b,
189            #[cfg(feature = "rocm")]
190            (Self::Rocm(a), Self::Rocm(b)) => a == b,
191            #[cfg(wgpu_vulkan)]
192            (Self::Vulkan(a), Self::Vulkan(b)) => a == b,
193            #[cfg(wgpu_webgpu)]
194            (Self::Wgpu(a), Self::Wgpu(b)) => a == b,
195            #[cfg(feature = "ndarray")]
196            (Self::NdArray(a), Self::NdArray(b)) => a == b,
197            #[cfg(feature = "tch")]
198            (Self::LibTorch(a), Self::LibTorch(b)) => a == b,
199            #[allow(unreachable_patterns)]
200            (_, _) => false,
201        }
202    }
203}
204
205/// Base multiplier to avoid type_id clashes between backends.
206/// Limits the number of device types per backend, but this is a sensible limit.
207const TYPE_ID_BASE: u16 = 10;
208
209impl DispatchDevice {
210    #[cfg(feature = "autodiff")]
211    /// Creates a new [`DispatchDevice`] with [automatic differentiation](Autodiff) enabled.
212    pub fn autodiff(device: impl Into<DispatchDevice>) -> DispatchDevice {
213        Self::autodiff_checkpointed(device, CheckpointingStrategy::None)
214    }
215    #[cfg(feature = "autodiff")]
216    /// Creates a new [`DispatchDevice`] with [automatic differentiation](Autodiff) enabled.
217    pub fn autodiff_checkpointed(
218        device: impl Into<DispatchDevice>,
219        checkpointing: CheckpointingStrategy,
220    ) -> DispatchDevice {
221        let device = device.into();
222        DispatchDevice::Autodiff(AutodiffDevice::new(device, checkpointing))
223    }
224
225    /// Returns a unique number per variant to encode into type_id.
226    fn backend_id(&self) -> BackendId {
227        match self {
228            #[cfg(feature = "cpu")]
229            Self::Cpu(_) => BackendId::Cpu,
230            #[cfg(feature = "cuda")]
231            Self::Cuda(_) => BackendId::Cuda,
232            #[cfg(wgpu_metal)]
233            Self::Metal(_) => BackendId::Metal,
234            #[cfg(feature = "rocm")]
235            Self::Rocm(_) => BackendId::Rocm,
236            #[cfg(wgpu_vulkan)]
237            Self::Vulkan(_) => BackendId::Vulkan,
238            #[cfg(wgpu_webgpu)]
239            Self::Wgpu(_) => BackendId::Wgpu,
240            #[cfg(feature = "ndarray")]
241            Self::NdArray(_) => BackendId::NdArray,
242            #[cfg(feature = "tch")]
243            Self::LibTorch(_) => BackendId::LibTorch,
244            #[cfg(feature = "autodiff")]
245            Self::Autodiff(device) => device.inner.backend_id(),
246        }
247    }
248
249    /// Encode variant ID and backend type ID into a unique `type_id`.
250    fn encode_type_id(&self, backend_type_id: u16) -> u16 {
251        u16::from(self.backend_id()) * TYPE_ID_BASE + backend_type_id
252    }
253
254    /// Decode an encoded `type_id` into variant ID and backend type ID.
255    pub(crate) fn decode_type_id(type_id: u16) -> (BackendId, u16) {
256        let variant = type_id / TYPE_ID_BASE;
257        let backend_type_id = type_id % TYPE_ID_BASE;
258        (
259            BackendId::try_from(variant).expect("Unknown DispatchDevice variant"),
260            backend_type_id,
261        )
262    }
263}
264
265#[derive(Debug, Clone, Copy, PartialEq, Eq)]
266#[repr(u16)]
267pub(crate) enum BackendId {
268    #[cfg(feature = "cpu")]
269    Cpu = 0,
270    #[cfg(feature = "cuda")]
271    Cuda = 1,
272    #[cfg(wgpu_metal)]
273    Metal = 2,
274    #[cfg(feature = "rocm")]
275    Rocm = 3,
276    #[cfg(wgpu_vulkan)]
277    Vulkan = 4,
278    #[cfg(wgpu_webgpu)]
279    Wgpu = 5,
280    #[cfg(feature = "ndarray")]
281    NdArray = 6,
282    #[cfg(feature = "tch")]
283    LibTorch = 7,
284}
285
286impl From<BackendId> for u16 {
287    fn from(variant: BackendId) -> Self {
288        variant as u16
289    }
290}
291
292impl TryFrom<u16> for BackendId {
293    type Error = ();
294
295    fn try_from(value: u16) -> Result<Self, Self::Error> {
296        match value {
297            #[cfg(feature = "cpu")]
298            0 => Ok(Self::Cpu),
299            #[cfg(feature = "cuda")]
300            1 => Ok(Self::Cuda),
301            #[cfg(wgpu_metal)]
302            2 => Ok(Self::Metal),
303            #[cfg(feature = "rocm")]
304            3 => Ok(Self::Rocm),
305            #[cfg(wgpu_vulkan)]
306            4 => Ok(Self::Vulkan),
307            #[cfg(wgpu_webgpu)]
308            5 => Ok(Self::Wgpu),
309            #[cfg(feature = "ndarray")]
310            6 => Ok(Self::NdArray),
311            #[cfg(feature = "tch")]
312            7 => Ok(Self::LibTorch),
313            _ => Err(()),
314        }
315    }
316}
317
318impl DeviceOps for DispatchDevice {
319    fn inner(&self) -> &Self {
320        match self {
321            #[cfg(feature = "autodiff")]
322            DispatchDevice::Autodiff(device) => &device.inner,
323            device => device,
324        }
325    }
326}
327
328impl burn_backend::Device for DispatchDevice {
329    fn from_id(mut device_id: DeviceId) -> Self {
330        let (dispatch_id, backend_type_id) = Self::decode_type_id(device_id.type_id);
331        device_id.type_id = backend_type_id;
332
333        match dispatch_id {
334            #[cfg(feature = "cpu")]
335            BackendId::Cpu => Self::Cpu(CpuDevice::from_id(device_id)),
336            #[cfg(feature = "cuda")]
337            BackendId::Cuda => Self::Cuda(CudaDevice::from_id(device_id)),
338            #[cfg(wgpu_metal)]
339            BackendId::Metal => Self::Metal(WgpuDevice::from_id(device_id)),
340            #[cfg(feature = "rocm")]
341            BackendId::Rocm => Self::Rocm(RocmDevice::from_id(device_id)),
342            #[cfg(wgpu_vulkan)]
343            BackendId::Vulkan => Self::Vulkan(WgpuDevice::from_id(device_id)),
344            #[cfg(wgpu_webgpu)]
345            BackendId::Wgpu => Self::Wgpu(WgpuDevice::from_id(device_id)),
346            #[cfg(feature = "ndarray")]
347            BackendId::NdArray => Self::NdArray(NdArrayDevice::from_id(device_id)),
348            #[cfg(feature = "tch")]
349            BackendId::LibTorch => Self::LibTorch(LibTorchDevice::from_id(device_id)),
350        }
351    }
352
353    fn to_id(&self) -> DeviceId {
354        let mut device_id = match self {
355            #[cfg(feature = "cpu")]
356            Self::Cpu(device) => device.to_id(),
357            #[cfg(feature = "cuda")]
358            Self::Cuda(device) => device.to_id(),
359            #[cfg(wgpu_metal)]
360            Self::Metal(device) => device.to_id(),
361            #[cfg(feature = "rocm")]
362            Self::Rocm(device) => device.to_id(),
363            #[cfg(wgpu_vulkan)]
364            Self::Vulkan(device) => device.to_id(),
365            #[cfg(wgpu_webgpu)]
366            Self::Wgpu(device) => device.to_id(),
367            #[cfg(feature = "ndarray")]
368            Self::NdArray(device) => device.to_id(),
369            #[cfg(feature = "tch")]
370            Self::LibTorch(device) => device.to_id(),
371            #[cfg(feature = "autodiff")]
372            Self::Autodiff(device) => device.inner.to_id(),
373        };
374        device_id.type_id = self.encode_type_id(device_id.type_id);
375        device_id
376    }
377}
378
379#[cfg(feature = "cpu")]
380impl From<CpuDevice> for DispatchDevice {
381    fn from(device: CpuDevice) -> Self {
382        DispatchDevice::Cpu(device)
383    }
384}
385
386#[cfg(feature = "cuda")]
387impl From<CudaDevice> for DispatchDevice {
388    fn from(device: CudaDevice) -> Self {
389        DispatchDevice::Cuda(device)
390    }
391}
392
393#[cfg(wgpu_metal)]
394impl From<WgpuDevice> for DispatchDevice {
395    fn from(device: WgpuDevice) -> Self {
396        DispatchDevice::Metal(device)
397    }
398}
399
400#[cfg(feature = "rocm")]
401impl From<RocmDevice> for DispatchDevice {
402    fn from(device: RocmDevice) -> Self {
403        DispatchDevice::Rocm(device)
404    }
405}
406
407#[cfg(wgpu_vulkan)]
408impl From<WgpuDevice> for DispatchDevice {
409    fn from(device: WgpuDevice) -> Self {
410        DispatchDevice::Vulkan(device)
411    }
412}
413
414#[cfg(wgpu_webgpu)]
415impl From<WgpuDevice> for DispatchDevice {
416    fn from(device: WgpuDevice) -> Self {
417        DispatchDevice::Wgpu(device)
418    }
419}
420
421#[cfg(feature = "ndarray")]
422impl From<NdArrayDevice> for DispatchDevice {
423    fn from(device: NdArrayDevice) -> Self {
424        DispatchDevice::NdArray(device)
425    }
426}
427
428#[cfg(feature = "tch")]
429impl From<LibTorchDevice> for DispatchDevice {
430    fn from(device: LibTorchDevice) -> Self {
431        DispatchDevice::LibTorch(device)
432    }
433}