1use burn_backend::{DeviceId, DeviceOps};
2
3use crate::backends::*;
4
5#[derive(Clone, Eq)]
21pub enum DispatchDevice {
22 #[cfg(feature = "cpu")]
24 Cpu(CpuDevice),
25
26 #[cfg(feature = "cuda")]
28 Cuda(CudaDevice),
29
30 #[cfg(wgpu_metal)]
32 Metal(WgpuDevice),
33
34 #[cfg(feature = "rocm")]
36 Rocm(RocmDevice),
37
38 #[cfg(wgpu_vulkan)]
40 Vulkan(WgpuDevice),
41
42 #[cfg(wgpu_webgpu)]
44 WebGpu(WgpuDevice),
45
46 #[cfg(feature = "ndarray")]
48 NdArray(NdArrayDevice),
49
50 #[cfg(feature = "tch")]
52 LibTorch(LibTorchDevice),
53
54 #[cfg(feature = "autodiff")]
56 Autodiff(AutodiffDevice),
57}
58
59#[derive(Debug, Clone, PartialEq, Eq)]
64pub struct AutodiffDevice(pub(crate) Box<DispatchDevice>);
65
66impl core::ops::Deref for AutodiffDevice {
68 type Target = DispatchDevice;
69
70 fn deref(&self) -> &Self::Target {
71 &self.0
72 }
73}
74
75impl core::fmt::Debug for DispatchDevice {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 match self {
78 #[cfg(feature = "cpu")]
79 Self::Cpu(device) => f.debug_tuple("Cpu").field(device).finish(),
80 #[cfg(feature = "cuda")]
81 Self::Cuda(device) => f.debug_tuple("Cuda").field(device).finish(),
82 #[cfg(wgpu_metal)]
83 Self::Metal(device) => f.debug_tuple("Metal").field(device).finish(),
84 #[cfg(feature = "rocm")]
85 Self::Rocm(device) => f.debug_tuple("Rocm").field(device).finish(),
86 #[cfg(wgpu_vulkan)]
87 Self::Vulkan(device) => f.debug_tuple("Vulkan").field(device).finish(),
88 #[cfg(wgpu_webgpu)]
89 Self::WebGpu(device) => f.debug_tuple("WebGpu").field(device).finish(),
90 #[cfg(feature = "ndarray")]
91 Self::NdArray(device) => f.debug_tuple("NdArray").field(device).finish(),
92 #[cfg(feature = "tch")]
93 Self::LibTorch(device) => f.debug_tuple("LibTorch").field(device).finish(),
94 #[cfg(feature = "autodiff")]
95 Self::Autodiff(device) => f.debug_tuple("Autodiff").field(&device.0).finish(),
97 }
98 }
99}
100
101impl Default for DispatchDevice {
102 #[allow(unreachable_code)]
103 fn default() -> Self {
104 #[cfg(feature = "cpu")]
107 return Self::Cpu(CpuDevice);
108
109 #[cfg(feature = "cuda")]
110 return Self::Cuda(CudaDevice::default());
111
112 #[cfg(wgpu_metal)]
113 return Self::Metal(burn_wgpu::WgpuDevice::default());
114
115 #[cfg(feature = "rocm")]
116 return Self::Rocm(RocmDevice::default());
117
118 #[cfg(wgpu_vulkan)]
119 return Self::Vulkan(burn_wgpu::WgpuDevice::default());
120
121 #[cfg(wgpu_webgpu)]
122 return Self::WebGpu(burn_wgpu::WgpuDevice::default());
123
124 #[cfg(feature = "ndarray")]
125 return Self::NdArray(NdArrayDevice::default());
126
127 #[cfg(feature = "tch")]
128 return Self::LibTorch(LibTorchDevice::default());
129 }
130}
131
132impl PartialEq for DispatchDevice {
133 fn eq(&self, other: &Self) -> bool {
134 match (self, other) {
135 #[cfg(feature = "autodiff")]
137 (DispatchDevice::Autodiff(a), DispatchDevice::Autodiff(b)) => a == b,
138 #[cfg(feature = "autodiff")]
140 (DispatchDevice::Autodiff(a), b) => a.0.as_ref() == b,
141 #[cfg(feature = "autodiff")]
142 (a, DispatchDevice::Autodiff(b)) => a == b.0.as_ref(),
143 #[cfg(feature = "cpu")]
144 (Self::Cpu(a), Self::Cpu(b)) => a == b,
145 #[cfg(feature = "cuda")]
146 (Self::Cuda(a), Self::Cuda(b)) => a == b,
147 #[cfg(wgpu_metal)]
148 (Self::Metal(a), Self::Metal(b)) => a == b,
149 #[cfg(feature = "rocm")]
150 (Self::Rocm(a), Self::Rocm(b)) => a == b,
151 #[cfg(wgpu_vulkan)]
152 (Self::Vulkan(a), Self::Vulkan(b)) => a == b,
153 #[cfg(wgpu_webgpu)]
154 (Self::WebGpu(a), Self::WebGpu(b)) => a == b,
155 #[cfg(feature = "ndarray")]
156 (Self::NdArray(a), Self::NdArray(b)) => a == b,
157 #[cfg(feature = "tch")]
158 (Self::LibTorch(a), Self::LibTorch(b)) => a == b,
159 #[allow(unreachable_patterns)]
160 (_, _) => false,
161 }
162 }
163}
164
165const TYPE_ID_BASE: u16 = 10;
168
169impl DispatchDevice {
170 #[cfg(feature = "autodiff")]
171 pub fn autodiff(device: impl Into<DispatchDevice>) -> DispatchDevice {
173 let device = device.into();
174 DispatchDevice::Autodiff(AutodiffDevice(Box::new(device)))
175 }
176
177 fn backend_id(&self) -> BackendId {
179 match self {
180 #[cfg(feature = "cpu")]
181 Self::Cpu(_) => BackendId::Cpu,
182 #[cfg(feature = "cuda")]
183 Self::Cuda(_) => BackendId::Cuda,
184 #[cfg(wgpu_metal)]
185 Self::Metal(_) => BackendId::Metal,
186 #[cfg(feature = "rocm")]
187 Self::Rocm(_) => BackendId::Rocm,
188 #[cfg(wgpu_vulkan)]
189 Self::Vulkan(_) => BackendId::Vulkan,
190 #[cfg(wgpu_webgpu)]
191 Self::WebGpu(_) => BackendId::WebGpu,
192 #[cfg(feature = "ndarray")]
193 Self::NdArray(_) => BackendId::NdArray,
194 #[cfg(feature = "tch")]
195 Self::LibTorch(_) => BackendId::LibTorch,
196 #[cfg(feature = "autodiff")]
197 Self::Autodiff(device) => device.0.backend_id(),
198 }
199 }
200
201 fn encode_type_id(&self, backend_type_id: u16) -> u16 {
203 u16::from(self.backend_id()) * TYPE_ID_BASE + backend_type_id
204 }
205
206 fn decode_type_id(type_id: u16) -> (BackendId, u16) {
208 let variant = type_id / TYPE_ID_BASE;
209 let backend_type_id = type_id % TYPE_ID_BASE;
210 (
211 BackendId::try_from(variant).expect("Unknown DispatchDevice variant"),
212 backend_type_id,
213 )
214 }
215}
216
217#[derive(Debug, Clone, Copy, PartialEq, Eq)]
218#[repr(u16)]
219enum BackendId {
220 #[cfg(feature = "cpu")]
221 Cpu = 0,
222 #[cfg(feature = "cuda")]
223 Cuda = 1,
224 #[cfg(wgpu_metal)]
225 Metal = 2,
226 #[cfg(feature = "rocm")]
227 Rocm = 3,
228 #[cfg(wgpu_vulkan)]
229 Vulkan = 4,
230 #[cfg(wgpu_webgpu)]
231 WebGpu = 5,
232 #[cfg(feature = "ndarray")]
233 NdArray = 6,
234 #[cfg(feature = "tch")]
235 LibTorch = 7,
236}
237
238impl From<BackendId> for u16 {
239 fn from(variant: BackendId) -> Self {
240 variant as u16
241 }
242}
243
244impl TryFrom<u16> for BackendId {
245 type Error = ();
246
247 fn try_from(value: u16) -> Result<Self, Self::Error> {
248 match value {
249 #[cfg(feature = "cpu")]
250 0 => Ok(Self::Cpu),
251 #[cfg(feature = "cuda")]
252 1 => Ok(Self::Cuda),
253 #[cfg(wgpu_metal)]
254 2 => Ok(Self::Metal),
255 #[cfg(feature = "rocm")]
256 3 => Ok(Self::Rocm),
257 #[cfg(wgpu_vulkan)]
258 4 => Ok(Self::Vulkan),
259 #[cfg(wgpu_webgpu)]
260 5 => Ok(Self::WebGpu),
261 #[cfg(feature = "ndarray")]
262 6 => Ok(Self::NdArray),
263 #[cfg(feature = "tch")]
264 7 => Ok(Self::LibTorch),
265 _ => Err(()),
266 }
267 }
268}
269
270impl DeviceOps for DispatchDevice {
271 fn inner(&self) -> &Self {
272 match self {
273 #[cfg(feature = "autodiff")]
274 DispatchDevice::Autodiff(device) => &device.0,
275 device => device,
276 }
277 }
278}
279
280impl burn_std::device::Device for DispatchDevice {
281 fn from_id(mut device_id: DeviceId) -> Self {
282 let (dispatch_id, backend_type_id) = Self::decode_type_id(device_id.type_id);
283 device_id.type_id = backend_type_id;
284
285 match dispatch_id {
286 #[cfg(feature = "cpu")]
287 BackendId::Cpu => Self::Cpu(CpuDevice::from_id(device_id)),
288 #[cfg(feature = "cuda")]
289 BackendId::Cuda => Self::Cuda(CudaDevice::from_id(device_id)),
290 #[cfg(wgpu_metal)]
291 BackendId::Metal => Self::Metal(WgpuDevice::from_id(device_id)),
292 #[cfg(feature = "rocm")]
293 BackendId::Rocm => Self::Rocm(RocmDevice::from_id(device_id)),
294 #[cfg(wgpu_vulkan)]
295 BackendId::Vulkan => Self::Vulkan(WgpuDevice::from_id(device_id)),
296 #[cfg(wgpu_webgpu)]
297 BackendId::WebGpu => Self::WebGpu(WgpuDevice::from_id(device_id)),
298 #[cfg(feature = "ndarray")]
299 BackendId::NdArray => Self::NdArray(NdArrayDevice::from_id(device_id)),
300 #[cfg(feature = "tch")]
301 BackendId::LibTorch => Self::LibTorch(LibTorchDevice::from_id(device_id)),
302 }
303 }
304
305 fn to_id(&self) -> DeviceId {
306 let mut device_id = match self {
307 #[cfg(feature = "cpu")]
308 Self::Cpu(device) => device.to_id(),
309 #[cfg(feature = "cuda")]
310 Self::Cuda(device) => device.to_id(),
311 #[cfg(wgpu_metal)]
312 Self::Metal(device) => device.to_id(),
313 #[cfg(feature = "rocm")]
314 Self::Rocm(device) => device.to_id(),
315 #[cfg(wgpu_vulkan)]
316 Self::Vulkan(device) => device.to_id(),
317 #[cfg(wgpu_webgpu)]
318 Self::WebGpu(device) => device.to_id(),
319 #[cfg(feature = "ndarray")]
320 Self::NdArray(device) => device.to_id(),
321 #[cfg(feature = "tch")]
322 Self::LibTorch(device) => device.to_id(),
323 #[cfg(feature = "autodiff")]
324 Self::Autodiff(device) => device.0.to_id(),
325 };
326 device_id.type_id = self.encode_type_id(device_id.type_id);
327 device_id
328 }
329
330 fn device_count(type_id: u16) -> usize {
331 let (dispatch_id, backend_type_id) = Self::decode_type_id(type_id);
332 match dispatch_id {
333 #[cfg(feature = "cpu")]
334 BackendId::Cpu => CpuDevice::device_count(backend_type_id),
335 #[cfg(feature = "cuda")]
336 BackendId::Cuda => CudaDevice::device_count(backend_type_id),
337 #[cfg(wgpu_metal)]
338 BackendId::Metal => WgpuDevice::device_count(backend_type_id),
339 #[cfg(feature = "rocm")]
340 BackendId::Rocm => RocmDevice::device_count(backend_type_id),
341 #[cfg(wgpu_vulkan)]
342 BackendId::Vulkan => WgpuDevice::device_count(backend_type_id),
343 #[cfg(wgpu_webgpu)]
344 BackendId::WebGpu => WgpuDevice::device_count(backend_type_id),
345 #[cfg(feature = "ndarray")]
346 BackendId::NdArray => NdArrayDevice::device_count(backend_type_id),
347 #[cfg(feature = "tch")]
348 BackendId::LibTorch => LibTorchDevice::device_count(backend_type_id),
349 }
350 }
351}
352
353#[cfg(feature = "cpu")]
354impl From<CpuDevice> for DispatchDevice {
355 fn from(device: CpuDevice) -> Self {
356 DispatchDevice::Cpu(device)
357 }
358}
359
360#[cfg(feature = "cuda")]
361impl From<CudaDevice> for DispatchDevice {
362 fn from(device: CudaDevice) -> Self {
363 DispatchDevice::Cuda(device)
364 }
365}
366
367#[cfg(wgpu_metal)]
368impl From<WgpuDevice> for DispatchDevice {
369 fn from(device: WgpuDevice) -> Self {
370 DispatchDevice::Metal(device)
371 }
372}
373
374#[cfg(feature = "rocm")]
375impl From<RocmDevice> for DispatchDevice {
376 fn from(device: RocmDevice) -> Self {
377 DispatchDevice::Rocm(device)
378 }
379}
380
381#[cfg(wgpu_vulkan)]
382impl From<WgpuDevice> for DispatchDevice {
383 fn from(device: WgpuDevice) -> Self {
384 DispatchDevice::Vulkan(device)
385 }
386}
387
388#[cfg(wgpu_webgpu)]
389impl From<WgpuDevice> for DispatchDevice {
390 fn from(device: WgpuDevice) -> Self {
391 DispatchDevice::WebGpu(device)
392 }
393}
394
395#[cfg(feature = "ndarray")]
396impl From<NdArrayDevice> for DispatchDevice {
397 fn from(device: NdArrayDevice) -> Self {
398 DispatchDevice::NdArray(device)
399 }
400}
401
402#[cfg(feature = "tch")]
403impl From<LibTorchDevice> for DispatchDevice {
404 fn from(device: LibTorchDevice) -> Self {
405 DispatchDevice::LibTorch(device)
406 }
407}
408
409#[cfg(feature = "tch")]
410impl From<LibTorchDevice> for DispatchDevice {
411 fn from(device: LibTorchDevice) -> Self {
412 DispatchDevice::LibTorch(device)
413 }
414}