Skip to main content

burn_dispatch/
device.rs

1use burn_backend::{DeviceId, DeviceOps};
2
3use crate::backends::*;
4
5/// Represents a device for the [`Dispatch`](crate::Dispatch).
6///
7/// Each variant corresponds to a backend that the [`Dispatch`](crate::Dispatch) can dispatch operations to.
8///
9/// # Example
10///
11/// ```ignore
12/// use burn::DispatchDevice;
13///
14/// #[cfg(feature = "cpu")]
15/// let cpu_device = DispatchDevice::Cpu(Default::default());
16///
17/// #[cfg(feature = "cuda")]
18/// let cuda_device = DispatchDevice::Cuda(Default::default());
19/// ```
20#[derive(Clone, Eq)]
21pub enum DispatchDevice {
22    /// The [CPU backend](Cpu) device.
23    #[cfg(feature = "cpu")]
24    Cpu(CpuDevice),
25
26    /// The [CUDA backend](Cuda) device.
27    #[cfg(feature = "cuda")]
28    Cuda(CudaDevice),
29
30    /// The [Metal backend](Metal) device (via WGPU runtime).
31    #[cfg(wgpu_metal)]
32    Metal(WgpuDevice),
33
34    /// The [ROCm backend](Rocm) device.
35    #[cfg(feature = "rocm")]
36    Rocm(RocmDevice),
37
38    /// The [Vulkan backend](Vulkan) device.
39    #[cfg(wgpu_vulkan)]
40    Vulkan(WgpuDevice),
41
42    /// The [WebGPU backend](WebGpu) device (via WGPU runtime).
43    #[cfg(wgpu_webgpu)]
44    WebGpu(WgpuDevice),
45
46    /// The [NdArray backend](NdArray) device (CPU-only).
47    #[cfg(feature = "ndarray")]
48    NdArray(NdArrayDevice),
49
50    /// The [LibTorch backend](LibTorch) device.
51    #[cfg(feature = "tch")]
52    LibTorch(LibTorchDevice),
53
54    /// The [autodiff enabled backend](Autodiff) device.
55    #[cfg(feature = "autodiff")]
56    Autodiff(AutodiffDevice),
57}
58
59// This tuple struct mainly restricts users from creating Autodiff(Autodiff) devices.
60/// A wrapper that enables automatic differentiation for a [`DispatchDevice`].
61///
62/// Use [`DispatchDevice::autodiff`] to construct this type.
63#[derive(Debug, Clone, PartialEq, Eq)]
64pub struct AutodiffDevice(pub(crate) Box<DispatchDevice>);
65
66// Useful for match in dispatch macros
67impl core::ops::Deref for AutodiffDevice {
68    type Target = DispatchDevice;
69
70    fn deref(&self) -> &Self::Target {
71        &self.0
72    }
73}
74
75impl core::fmt::Debug for DispatchDevice {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        match self {
78            #[cfg(feature = "cpu")]
79            Self::Cpu(device) => f.debug_tuple("Cpu").field(device).finish(),
80            #[cfg(feature = "cuda")]
81            Self::Cuda(device) => f.debug_tuple("Cuda").field(device).finish(),
82            #[cfg(wgpu_metal)]
83            Self::Metal(device) => f.debug_tuple("Metal").field(device).finish(),
84            #[cfg(feature = "rocm")]
85            Self::Rocm(device) => f.debug_tuple("Rocm").field(device).finish(),
86            #[cfg(wgpu_vulkan)]
87            Self::Vulkan(device) => f.debug_tuple("Vulkan").field(device).finish(),
88            #[cfg(wgpu_webgpu)]
89            Self::WebGpu(device) => f.debug_tuple("WebGpu").field(device).finish(),
90            #[cfg(feature = "ndarray")]
91            Self::NdArray(device) => f.debug_tuple("NdArray").field(device).finish(),
92            #[cfg(feature = "tch")]
93            Self::LibTorch(device) => f.debug_tuple("LibTorch").field(device).finish(),
94            #[cfg(feature = "autodiff")]
95            // Format without `AutodiffDevice` wrapper
96            Self::Autodiff(device) => f.debug_tuple("Autodiff").field(&device.0).finish(),
97        }
98    }
99}
100
101impl Default for DispatchDevice {
102    #[allow(unreachable_code)]
103    fn default() -> Self {
104        // TODO: which priority?
105
106        #[cfg(feature = "cpu")]
107        return Self::Cpu(CpuDevice);
108
109        #[cfg(feature = "cuda")]
110        return Self::Cuda(CudaDevice::default());
111
112        #[cfg(wgpu_metal)]
113        return Self::Metal(burn_wgpu::WgpuDevice::default());
114
115        #[cfg(feature = "rocm")]
116        return Self::Rocm(RocmDevice::default());
117
118        #[cfg(wgpu_vulkan)]
119        return Self::Vulkan(burn_wgpu::WgpuDevice::default());
120
121        #[cfg(wgpu_webgpu)]
122        return Self::WebGpu(burn_wgpu::WgpuDevice::default());
123
124        #[cfg(feature = "ndarray")]
125        return Self::NdArray(NdArrayDevice::default());
126
127        #[cfg(feature = "tch")]
128        return Self::LibTorch(LibTorchDevice::default());
129    }
130}
131
132impl PartialEq for DispatchDevice {
133    fn eq(&self, other: &Self) -> bool {
134        match (self, other) {
135            // If both are Autodiff, compare the inner devices
136            #[cfg(feature = "autodiff")]
137            (DispatchDevice::Autodiff(a), DispatchDevice::Autodiff(b)) => a == b,
138            // If one is Autodiff, compare it to the raw device
139            #[cfg(feature = "autodiff")]
140            (DispatchDevice::Autodiff(a), b) => a.0.as_ref() == b,
141            #[cfg(feature = "autodiff")]
142            (a, DispatchDevice::Autodiff(b)) => a == b.0.as_ref(),
143            #[cfg(feature = "cpu")]
144            (Self::Cpu(a), Self::Cpu(b)) => a == b,
145            #[cfg(feature = "cuda")]
146            (Self::Cuda(a), Self::Cuda(b)) => a == b,
147            #[cfg(wgpu_metal)]
148            (Self::Metal(a), Self::Metal(b)) => a == b,
149            #[cfg(feature = "rocm")]
150            (Self::Rocm(a), Self::Rocm(b)) => a == b,
151            #[cfg(wgpu_vulkan)]
152            (Self::Vulkan(a), Self::Vulkan(b)) => a == b,
153            #[cfg(wgpu_webgpu)]
154            (Self::WebGpu(a), Self::WebGpu(b)) => a == b,
155            #[cfg(feature = "ndarray")]
156            (Self::NdArray(a), Self::NdArray(b)) => a == b,
157            #[cfg(feature = "tch")]
158            (Self::LibTorch(a), Self::LibTorch(b)) => a == b,
159            #[allow(unreachable_patterns)]
160            (_, _) => false,
161        }
162    }
163}
164
165/// Base multiplier to avoid type_id clashes between backends.
166/// Limits the number of device types per backend, but this is a sensible limit.
167const TYPE_ID_BASE: u16 = 10;
168
169impl DispatchDevice {
170    #[cfg(feature = "autodiff")]
171    /// Creates a new [`DispatchDevice`] with [automatic differentiation](Autodiff) enabled.
172    pub fn autodiff(device: impl Into<DispatchDevice>) -> DispatchDevice {
173        let device = device.into();
174        DispatchDevice::Autodiff(AutodiffDevice(Box::new(device)))
175    }
176
177    /// Returns a unique number per variant to encode into type_id.
178    fn backend_id(&self) -> BackendId {
179        match self {
180            #[cfg(feature = "cpu")]
181            Self::Cpu(_) => BackendId::Cpu,
182            #[cfg(feature = "cuda")]
183            Self::Cuda(_) => BackendId::Cuda,
184            #[cfg(wgpu_metal)]
185            Self::Metal(_) => BackendId::Metal,
186            #[cfg(feature = "rocm")]
187            Self::Rocm(_) => BackendId::Rocm,
188            #[cfg(wgpu_vulkan)]
189            Self::Vulkan(_) => BackendId::Vulkan,
190            #[cfg(wgpu_webgpu)]
191            Self::WebGpu(_) => BackendId::WebGpu,
192            #[cfg(feature = "ndarray")]
193            Self::NdArray(_) => BackendId::NdArray,
194            #[cfg(feature = "tch")]
195            Self::LibTorch(_) => BackendId::LibTorch,
196            #[cfg(feature = "autodiff")]
197            Self::Autodiff(device) => device.0.backend_id(),
198        }
199    }
200
201    /// Encode variant ID and backend type ID into a unique `type_id`.
202    fn encode_type_id(&self, backend_type_id: u16) -> u16 {
203        u16::from(self.backend_id()) * TYPE_ID_BASE + backend_type_id
204    }
205
206    /// Decode an encoded `type_id` into variant ID and backend type ID.
207    fn decode_type_id(type_id: u16) -> (BackendId, u16) {
208        let variant = type_id / TYPE_ID_BASE;
209        let backend_type_id = type_id % TYPE_ID_BASE;
210        (
211            BackendId::try_from(variant).expect("Unknown DispatchDevice variant"),
212            backend_type_id,
213        )
214    }
215}
216
217#[derive(Debug, Clone, Copy, PartialEq, Eq)]
218#[repr(u16)]
219enum BackendId {
220    #[cfg(feature = "cpu")]
221    Cpu = 0,
222    #[cfg(feature = "cuda")]
223    Cuda = 1,
224    #[cfg(wgpu_metal)]
225    Metal = 2,
226    #[cfg(feature = "rocm")]
227    Rocm = 3,
228    #[cfg(wgpu_vulkan)]
229    Vulkan = 4,
230    #[cfg(wgpu_webgpu)]
231    WebGpu = 5,
232    #[cfg(feature = "ndarray")]
233    NdArray = 6,
234    #[cfg(feature = "tch")]
235    LibTorch = 7,
236}
237
238impl From<BackendId> for u16 {
239    fn from(variant: BackendId) -> Self {
240        variant as u16
241    }
242}
243
244impl TryFrom<u16> for BackendId {
245    type Error = ();
246
247    fn try_from(value: u16) -> Result<Self, Self::Error> {
248        match value {
249            #[cfg(feature = "cpu")]
250            0 => Ok(Self::Cpu),
251            #[cfg(feature = "cuda")]
252            1 => Ok(Self::Cuda),
253            #[cfg(wgpu_metal)]
254            2 => Ok(Self::Metal),
255            #[cfg(feature = "rocm")]
256            3 => Ok(Self::Rocm),
257            #[cfg(wgpu_vulkan)]
258            4 => Ok(Self::Vulkan),
259            #[cfg(wgpu_webgpu)]
260            5 => Ok(Self::WebGpu),
261            #[cfg(feature = "ndarray")]
262            6 => Ok(Self::NdArray),
263            #[cfg(feature = "tch")]
264            7 => Ok(Self::LibTorch),
265            _ => Err(()),
266        }
267    }
268}
269
270impl DeviceOps for DispatchDevice {
271    fn inner(&self) -> &Self {
272        match self {
273            #[cfg(feature = "autodiff")]
274            DispatchDevice::Autodiff(device) => &device.0,
275            device => device,
276        }
277    }
278}
279
280impl burn_std::device::Device for DispatchDevice {
281    fn from_id(mut device_id: DeviceId) -> Self {
282        let (dispatch_id, backend_type_id) = Self::decode_type_id(device_id.type_id);
283        device_id.type_id = backend_type_id;
284
285        match dispatch_id {
286            #[cfg(feature = "cpu")]
287            BackendId::Cpu => Self::Cpu(CpuDevice::from_id(device_id)),
288            #[cfg(feature = "cuda")]
289            BackendId::Cuda => Self::Cuda(CudaDevice::from_id(device_id)),
290            #[cfg(wgpu_metal)]
291            BackendId::Metal => Self::Metal(WgpuDevice::from_id(device_id)),
292            #[cfg(feature = "rocm")]
293            BackendId::Rocm => Self::Rocm(RocmDevice::from_id(device_id)),
294            #[cfg(wgpu_vulkan)]
295            BackendId::Vulkan => Self::Vulkan(WgpuDevice::from_id(device_id)),
296            #[cfg(wgpu_webgpu)]
297            BackendId::WebGpu => Self::WebGpu(WgpuDevice::from_id(device_id)),
298            #[cfg(feature = "ndarray")]
299            BackendId::NdArray => Self::NdArray(NdArrayDevice::from_id(device_id)),
300            #[cfg(feature = "tch")]
301            BackendId::LibTorch => Self::LibTorch(LibTorchDevice::from_id(device_id)),
302        }
303    }
304
305    fn to_id(&self) -> DeviceId {
306        let mut device_id = match self {
307            #[cfg(feature = "cpu")]
308            Self::Cpu(device) => device.to_id(),
309            #[cfg(feature = "cuda")]
310            Self::Cuda(device) => device.to_id(),
311            #[cfg(wgpu_metal)]
312            Self::Metal(device) => device.to_id(),
313            #[cfg(feature = "rocm")]
314            Self::Rocm(device) => device.to_id(),
315            #[cfg(wgpu_vulkan)]
316            Self::Vulkan(device) => device.to_id(),
317            #[cfg(wgpu_webgpu)]
318            Self::WebGpu(device) => device.to_id(),
319            #[cfg(feature = "ndarray")]
320            Self::NdArray(device) => device.to_id(),
321            #[cfg(feature = "tch")]
322            Self::LibTorch(device) => device.to_id(),
323            #[cfg(feature = "autodiff")]
324            Self::Autodiff(device) => device.0.to_id(),
325        };
326        device_id.type_id = self.encode_type_id(device_id.type_id);
327        device_id
328    }
329
330    fn device_count(type_id: u16) -> usize {
331        let (dispatch_id, backend_type_id) = Self::decode_type_id(type_id);
332        match dispatch_id {
333            #[cfg(feature = "cpu")]
334            BackendId::Cpu => CpuDevice::device_count(backend_type_id),
335            #[cfg(feature = "cuda")]
336            BackendId::Cuda => CudaDevice::device_count(backend_type_id),
337            #[cfg(wgpu_metal)]
338            BackendId::Metal => WgpuDevice::device_count(backend_type_id),
339            #[cfg(feature = "rocm")]
340            BackendId::Rocm => RocmDevice::device_count(backend_type_id),
341            #[cfg(wgpu_vulkan)]
342            BackendId::Vulkan => WgpuDevice::device_count(backend_type_id),
343            #[cfg(wgpu_webgpu)]
344            BackendId::WebGpu => WgpuDevice::device_count(backend_type_id),
345            #[cfg(feature = "ndarray")]
346            BackendId::NdArray => NdArrayDevice::device_count(backend_type_id),
347            #[cfg(feature = "tch")]
348            BackendId::LibTorch => LibTorchDevice::device_count(backend_type_id),
349        }
350    }
351}
352
353#[cfg(feature = "cpu")]
354impl From<CpuDevice> for DispatchDevice {
355    fn from(device: CpuDevice) -> Self {
356        DispatchDevice::Cpu(device)
357    }
358}
359
360#[cfg(feature = "cuda")]
361impl From<CudaDevice> for DispatchDevice {
362    fn from(device: CudaDevice) -> Self {
363        DispatchDevice::Cuda(device)
364    }
365}
366
367#[cfg(wgpu_metal)]
368impl From<WgpuDevice> for DispatchDevice {
369    fn from(device: WgpuDevice) -> Self {
370        DispatchDevice::Metal(device)
371    }
372}
373
374#[cfg(feature = "rocm")]
375impl From<RocmDevice> for DispatchDevice {
376    fn from(device: RocmDevice) -> Self {
377        DispatchDevice::Rocm(device)
378    }
379}
380
381#[cfg(wgpu_vulkan)]
382impl From<WgpuDevice> for DispatchDevice {
383    fn from(device: WgpuDevice) -> Self {
384        DispatchDevice::Vulkan(device)
385    }
386}
387
388#[cfg(wgpu_webgpu)]
389impl From<WgpuDevice> for DispatchDevice {
390    fn from(device: WgpuDevice) -> Self {
391        DispatchDevice::WebGpu(device)
392    }
393}
394
395#[cfg(feature = "ndarray")]
396impl From<NdArrayDevice> for DispatchDevice {
397    fn from(device: NdArrayDevice) -> Self {
398        DispatchDevice::NdArray(device)
399    }
400}
401
402#[cfg(feature = "tch")]
403impl From<LibTorchDevice> for DispatchDevice {
404    fn from(device: LibTorchDevice) -> Self {
405        DispatchDevice::LibTorch(device)
406    }
407}
408
409#[cfg(feature = "tch")]
410impl From<LibTorchDevice> for DispatchDevice {
411    fn from(device: LibTorchDevice) -> Self {
412        DispatchDevice::LibTorch(device)
413    }
414}