1use alloc::boxed::Box;
2
3use burn_backend::{DeviceId, DeviceOps};
4
5use crate::backends::*;
6
7#[derive(Clone, Eq)]
23pub enum DispatchDevice {
24 #[cfg(feature = "cpu")]
26 Cpu(CpuDevice),
27
28 #[cfg(feature = "cuda")]
30 Cuda(CudaDevice),
31
32 #[cfg(wgpu_metal)]
34 Metal(WgpuDevice),
35
36 #[cfg(feature = "rocm")]
38 Rocm(RocmDevice),
39
40 #[cfg(wgpu_vulkan)]
42 Vulkan(WgpuDevice),
43
44 #[cfg(wgpu_webgpu)]
46 Wgpu(WgpuDevice),
47
48 #[cfg(feature = "flex")]
50 Flex(FlexDevice),
51
52 #[cfg(feature = "ndarray")]
54 NdArray(NdArrayDevice),
55
56 #[cfg(feature = "tch")]
58 LibTorch(LibTorchDevice),
59
60 #[cfg(feature = "autodiff")]
62 Autodiff(AutodiffDevice),
63}
64
65#[cfg(feature = "autodiff")]
66#[derive(Debug, Clone, PartialEq, Eq)]
71pub struct AutodiffDevice {
72 pub(crate) inner: Box<DispatchDevice>,
73 pub(crate) checkpointing: CheckpointingStrategy,
74}
75
76#[cfg(feature = "autodiff")]
77impl AutodiffDevice {
78 pub(crate) fn new(device: DispatchDevice, checkpointing: CheckpointingStrategy) -> Self {
79 Self {
80 inner: Box::new(device),
81 checkpointing,
82 }
83 }
84}
85
86#[cfg(feature = "autodiff")]
87impl core::ops::Deref for AutodiffDevice {
89 type Target = DispatchDevice;
90
91 fn deref(&self) -> &Self::Target {
92 &self.inner
93 }
94}
95
96#[cfg(feature = "autodiff")]
97#[allow(missing_docs)]
98#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
99#[repr(u8)]
101pub enum CheckpointingStrategy {
102 Balanced,
103 #[default]
104 None,
105}
106
107#[cfg(feature = "autodiff")]
108pub(crate) fn validate_checkpointing(
109 lhs: crate::CheckpointingStrategy,
110 rhs: crate::CheckpointingStrategy,
111) -> crate::CheckpointingStrategy {
112 assert_eq!(
113 lhs, rhs,
114 "Autodiff strategy mismatch: {lhs:?} vs {rhs:?}. Tensors in the same operation must share a strategy."
115 );
116 lhs
117}
118
119impl core::fmt::Debug for DispatchDevice {
120 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
121 match self {
122 #[cfg(feature = "cpu")]
123 Self::Cpu(device) => f.debug_tuple("Cpu").field(device).finish(),
124 #[cfg(feature = "cuda")]
125 Self::Cuda(device) => f.debug_tuple("Cuda").field(device).finish(),
126 #[cfg(wgpu_metal)]
127 Self::Metal(device) => f.debug_tuple("Metal").field(device).finish(),
128 #[cfg(feature = "rocm")]
129 Self::Rocm(device) => f.debug_tuple("Rocm").field(device).finish(),
130 #[cfg(wgpu_vulkan)]
131 Self::Vulkan(device) => f.debug_tuple("Vulkan").field(device).finish(),
132 #[cfg(wgpu_webgpu)]
133 Self::Wgpu(device) => f.debug_tuple("Wgpu").field(device).finish(),
134 #[cfg(feature = "flex")]
135 Self::Flex(device) => f.debug_tuple("Flex").field(device).finish(),
136 #[cfg(feature = "ndarray")]
137 Self::NdArray(device) => f.debug_tuple("NdArray").field(device).finish(),
138 #[cfg(feature = "tch")]
139 Self::LibTorch(device) => f.debug_tuple("LibTorch").field(device).finish(),
140 #[cfg(feature = "autodiff")]
141 Self::Autodiff(device) => f.debug_tuple("Autodiff").field(&device.inner).finish(),
143 }
144 }
145}
146
147impl Default for DispatchDevice {
148 #[allow(unreachable_code)]
149 fn default() -> Self {
150 #[cfg(feature = "cuda")]
153 return Self::Cuda(CudaDevice::default());
154
155 #[cfg(wgpu_metal)]
156 return Self::Metal(burn_wgpu::WgpuDevice::default());
157
158 #[cfg(feature = "rocm")]
159 return Self::Rocm(RocmDevice::default());
160
161 #[cfg(wgpu_vulkan)]
162 return Self::Vulkan(burn_wgpu::WgpuDevice::default());
163
164 #[cfg(wgpu_webgpu)]
165 return Self::Wgpu(burn_wgpu::WgpuDevice::default());
166
167 #[cfg(feature = "cpu")]
168 return Self::Cpu(CpuDevice);
169
170 #[cfg(feature = "tch")]
171 return Self::LibTorch(LibTorchDevice::default());
172
173 #[cfg(feature = "flex")]
176 return Self::Flex(FlexDevice);
177
178 #[cfg(feature = "ndarray")]
179 return Self::NdArray(NdArrayDevice::default());
180 }
181}
182
183impl PartialEq for DispatchDevice {
184 fn eq(&self, other: &Self) -> bool {
185 match (self, other) {
186 #[cfg(feature = "autodiff")]
188 (DispatchDevice::Autodiff(a), DispatchDevice::Autodiff(b)) => a == b,
189 #[cfg(feature = "autodiff")]
191 (DispatchDevice::Autodiff(a), b) => a.inner.as_ref() == b,
192 #[cfg(feature = "autodiff")]
193 (a, DispatchDevice::Autodiff(b)) => a == b.inner.as_ref(),
194 #[cfg(feature = "cpu")]
195 (Self::Cpu(a), Self::Cpu(b)) => a == b,
196 #[cfg(feature = "cuda")]
197 (Self::Cuda(a), Self::Cuda(b)) => a == b,
198 #[cfg(wgpu_metal)]
199 (Self::Metal(a), Self::Metal(b)) => a == b,
200 #[cfg(feature = "rocm")]
201 (Self::Rocm(a), Self::Rocm(b)) => a == b,
202 #[cfg(wgpu_vulkan)]
203 (Self::Vulkan(a), Self::Vulkan(b)) => a == b,
204 #[cfg(wgpu_webgpu)]
205 (Self::Wgpu(a), Self::Wgpu(b)) => a == b,
206 #[cfg(feature = "flex")]
207 (Self::Flex(a), Self::Flex(b)) => a == b,
208 #[cfg(feature = "ndarray")]
209 (Self::NdArray(a), Self::NdArray(b)) => a == b,
210 #[cfg(feature = "tch")]
211 (Self::LibTorch(a), Self::LibTorch(b)) => a == b,
212 #[allow(unreachable_patterns)]
213 (_, _) => false,
214 }
215 }
216}
217
218const TYPE_ID_BASE: u16 = 10;
221
222impl DispatchDevice {
223 #[cfg(feature = "autodiff")]
224 pub fn autodiff(device: impl Into<DispatchDevice>) -> DispatchDevice {
226 Self::autodiff_checkpointed(device, CheckpointingStrategy::None)
227 }
228 #[cfg(feature = "autodiff")]
229 pub fn autodiff_checkpointed(
231 device: impl Into<DispatchDevice>,
232 checkpointing: CheckpointingStrategy,
233 ) -> DispatchDevice {
234 let device = device.into();
235 DispatchDevice::Autodiff(AutodiffDevice::new(device, checkpointing))
236 }
237
238 fn backend_id(&self) -> BackendId {
240 match self {
241 #[cfg(feature = "cpu")]
242 Self::Cpu(_) => BackendId::Cpu,
243 #[cfg(feature = "cuda")]
244 Self::Cuda(_) => BackendId::Cuda,
245 #[cfg(wgpu_metal)]
246 Self::Metal(_) => BackendId::Metal,
247 #[cfg(feature = "rocm")]
248 Self::Rocm(_) => BackendId::Rocm,
249 #[cfg(wgpu_vulkan)]
250 Self::Vulkan(_) => BackendId::Vulkan,
251 #[cfg(wgpu_webgpu)]
252 Self::Wgpu(_) => BackendId::Wgpu,
253 #[cfg(feature = "flex")]
254 Self::Flex(_) => BackendId::Flex,
255 #[cfg(feature = "ndarray")]
256 Self::NdArray(_) => BackendId::NdArray,
257 #[cfg(feature = "tch")]
258 Self::LibTorch(_) => BackendId::LibTorch,
259 #[cfg(feature = "autodiff")]
260 Self::Autodiff(device) => device.inner.backend_id(),
261 }
262 }
263
264 fn encode_type_id(&self, backend_type_id: u16) -> u16 {
266 u16::from(self.backend_id()) * TYPE_ID_BASE + backend_type_id
267 }
268
269 pub(crate) fn decode_type_id(type_id: u16) -> (BackendId, u16) {
271 let variant = type_id / TYPE_ID_BASE;
272 let backend_type_id = type_id % TYPE_ID_BASE;
273 (
274 BackendId::try_from(variant).expect("Unknown DispatchDevice variant"),
275 backend_type_id,
276 )
277 }
278}
279
280#[derive(Debug, Clone, Copy, PartialEq, Eq)]
281#[repr(u16)]
282pub(crate) enum BackendId {
283 #[cfg(feature = "cpu")]
284 Cpu = 0,
285 #[cfg(feature = "cuda")]
286 Cuda = 1,
287 #[cfg(wgpu_metal)]
288 Metal = 2,
289 #[cfg(feature = "rocm")]
290 Rocm = 3,
291 #[cfg(wgpu_vulkan)]
292 Vulkan = 4,
293 #[cfg(wgpu_webgpu)]
294 Wgpu = 5,
295 #[cfg(feature = "ndarray")]
296 NdArray = 6,
297 #[cfg(feature = "tch")]
298 LibTorch = 7,
299 #[cfg(feature = "flex")]
300 Flex = 8,
301}
302
303impl From<BackendId> for u16 {
304 fn from(variant: BackendId) -> Self {
305 variant as u16
306 }
307}
308
309impl TryFrom<u16> for BackendId {
310 type Error = ();
311
312 fn try_from(value: u16) -> Result<Self, Self::Error> {
313 match value {
314 #[cfg(feature = "cpu")]
315 0 => Ok(Self::Cpu),
316 #[cfg(feature = "cuda")]
317 1 => Ok(Self::Cuda),
318 #[cfg(wgpu_metal)]
319 2 => Ok(Self::Metal),
320 #[cfg(feature = "rocm")]
321 3 => Ok(Self::Rocm),
322 #[cfg(wgpu_vulkan)]
323 4 => Ok(Self::Vulkan),
324 #[cfg(wgpu_webgpu)]
325 5 => Ok(Self::Wgpu),
326 #[cfg(feature = "ndarray")]
327 6 => Ok(Self::NdArray),
328 #[cfg(feature = "tch")]
329 7 => Ok(Self::LibTorch),
330 #[cfg(feature = "flex")]
331 8 => Ok(Self::Flex),
332 _ => Err(()),
333 }
334 }
335}
336
337impl DeviceOps for DispatchDevice {
338 fn inner(&self) -> &Self {
339 match self {
340 #[cfg(feature = "autodiff")]
341 DispatchDevice::Autodiff(device) => &device.inner,
342 device => device,
343 }
344 }
345}
346
347impl burn_backend::Device for DispatchDevice {
348 fn from_id(mut device_id: DeviceId) -> Self {
349 let (dispatch_id, backend_type_id) = Self::decode_type_id(device_id.type_id);
350 device_id.type_id = backend_type_id;
351
352 match dispatch_id {
353 #[cfg(feature = "cpu")]
354 BackendId::Cpu => Self::Cpu(CpuDevice::from_id(device_id)),
355 #[cfg(feature = "cuda")]
356 BackendId::Cuda => Self::Cuda(CudaDevice::from_id(device_id)),
357 #[cfg(wgpu_metal)]
358 BackendId::Metal => Self::Metal(WgpuDevice::from_id(device_id)),
359 #[cfg(feature = "rocm")]
360 BackendId::Rocm => Self::Rocm(RocmDevice::from_id(device_id)),
361 #[cfg(wgpu_vulkan)]
362 BackendId::Vulkan => Self::Vulkan(WgpuDevice::from_id(device_id)),
363 #[cfg(wgpu_webgpu)]
364 BackendId::Wgpu => Self::Wgpu(WgpuDevice::from_id(device_id)),
365 #[cfg(feature = "flex")]
366 BackendId::Flex => Self::Flex(FlexDevice::from_id(device_id)),
367 #[cfg(feature = "ndarray")]
368 BackendId::NdArray => Self::NdArray(NdArrayDevice::from_id(device_id)),
369 #[cfg(feature = "tch")]
370 BackendId::LibTorch => Self::LibTorch(LibTorchDevice::from_id(device_id)),
371 }
372 }
373
374 fn to_id(&self) -> DeviceId {
375 let mut device_id = match self {
376 #[cfg(feature = "cpu")]
377 Self::Cpu(device) => device.to_id(),
378 #[cfg(feature = "cuda")]
379 Self::Cuda(device) => device.to_id(),
380 #[cfg(wgpu_metal)]
381 Self::Metal(device) => device.to_id(),
382 #[cfg(feature = "rocm")]
383 Self::Rocm(device) => device.to_id(),
384 #[cfg(wgpu_vulkan)]
385 Self::Vulkan(device) => device.to_id(),
386 #[cfg(wgpu_webgpu)]
387 Self::Wgpu(device) => device.to_id(),
388 #[cfg(feature = "flex")]
389 Self::Flex(device) => device.to_id(),
390 #[cfg(feature = "ndarray")]
391 Self::NdArray(device) => device.to_id(),
392 #[cfg(feature = "tch")]
393 Self::LibTorch(device) => device.to_id(),
394 #[cfg(feature = "autodiff")]
395 Self::Autodiff(device) => device.inner.to_id(),
396 };
397 device_id.type_id = self.encode_type_id(device_id.type_id);
398 device_id
399 }
400}
401
402#[cfg(feature = "cpu")]
403impl From<CpuDevice> for DispatchDevice {
404 fn from(device: CpuDevice) -> Self {
405 DispatchDevice::Cpu(device)
406 }
407}
408
409#[cfg(feature = "cuda")]
410impl From<CudaDevice> for DispatchDevice {
411 fn from(device: CudaDevice) -> Self {
412 DispatchDevice::Cuda(device)
413 }
414}
415
416#[cfg(wgpu_metal)]
417impl From<WgpuDevice> for DispatchDevice {
418 fn from(device: WgpuDevice) -> Self {
419 DispatchDevice::Metal(device)
420 }
421}
422
423#[cfg(feature = "rocm")]
424impl From<RocmDevice> for DispatchDevice {
425 fn from(device: RocmDevice) -> Self {
426 DispatchDevice::Rocm(device)
427 }
428}
429
430#[cfg(wgpu_vulkan)]
431impl From<WgpuDevice> for DispatchDevice {
432 fn from(device: WgpuDevice) -> Self {
433 DispatchDevice::Vulkan(device)
434 }
435}
436
437#[cfg(wgpu_webgpu)]
438impl From<WgpuDevice> for DispatchDevice {
439 fn from(device: WgpuDevice) -> Self {
440 DispatchDevice::Wgpu(device)
441 }
442}
443
444#[cfg(feature = "flex")]
445impl From<FlexDevice> for DispatchDevice {
446 fn from(device: FlexDevice) -> Self {
447 DispatchDevice::Flex(device)
448 }
449}
450
451#[cfg(feature = "ndarray")]
452impl From<NdArrayDevice> for DispatchDevice {
453 fn from(device: NdArrayDevice) -> Self {
454 DispatchDevice::NdArray(device)
455 }
456}
457
458#[cfg(feature = "tch")]
459impl From<LibTorchDevice> for DispatchDevice {
460 fn from(device: LibTorchDevice) -> Self {
461 DispatchDevice::LibTorch(device)
462 }
463}