1use alloc::boxed::Box;
2
3use burn_backend::{DeviceId, DeviceOps};
4
5use crate::backends::*;
6
7#[derive(Clone, Eq)]
23pub enum DispatchDevice {
24 #[cfg(feature = "cpu")]
26 Cpu(CpuDevice),
27
28 #[cfg(feature = "cuda")]
30 Cuda(CudaDevice),
31
32 #[cfg(wgpu_metal)]
34 Metal(WgpuDevice),
35
36 #[cfg(feature = "rocm")]
38 Rocm(RocmDevice),
39
40 #[cfg(wgpu_vulkan)]
42 Vulkan(WgpuDevice),
43
44 #[cfg(wgpu_webgpu)]
46 Wgpu(WgpuDevice),
47
48 #[cfg(feature = "ndarray")]
50 NdArray(NdArrayDevice),
51
52 #[cfg(feature = "tch")]
54 LibTorch(LibTorchDevice),
55
56 #[cfg(feature = "autodiff")]
58 Autodiff(AutodiffDevice),
59}
60
61#[cfg(feature = "autodiff")]
62#[derive(Debug, Clone, PartialEq, Eq)]
67pub struct AutodiffDevice {
68 pub(crate) inner: Box<DispatchDevice>,
69 pub(crate) checkpointing: CheckpointingStrategy,
70}
71
72#[cfg(feature = "autodiff")]
73impl AutodiffDevice {
74 pub(crate) fn new(device: DispatchDevice, checkpointing: CheckpointingStrategy) -> Self {
75 Self {
76 inner: Box::new(device),
77 checkpointing,
78 }
79 }
80}
81
82#[cfg(feature = "autodiff")]
83impl core::ops::Deref for AutodiffDevice {
85 type Target = DispatchDevice;
86
87 fn deref(&self) -> &Self::Target {
88 &self.inner
89 }
90}
91
92#[cfg(feature = "autodiff")]
93#[allow(missing_docs)]
94#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
95#[repr(u8)]
97pub enum CheckpointingStrategy {
98 Balanced,
99 #[default]
100 None,
101}
102
103#[cfg(feature = "autodiff")]
104pub(crate) fn validate_checkpointing(
105 lhs: crate::CheckpointingStrategy,
106 rhs: crate::CheckpointingStrategy,
107) -> crate::CheckpointingStrategy {
108 assert_eq!(
109 lhs, rhs,
110 "Autodiff strategy mismatch: {lhs:?} vs {rhs:?}. Tensors in the same operation must share a strategy."
111 );
112 lhs
113}
114
115impl core::fmt::Debug for DispatchDevice {
116 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
117 match self {
118 #[cfg(feature = "cpu")]
119 Self::Cpu(device) => f.debug_tuple("Cpu").field(device).finish(),
120 #[cfg(feature = "cuda")]
121 Self::Cuda(device) => f.debug_tuple("Cuda").field(device).finish(),
122 #[cfg(wgpu_metal)]
123 Self::Metal(device) => f.debug_tuple("Metal").field(device).finish(),
124 #[cfg(feature = "rocm")]
125 Self::Rocm(device) => f.debug_tuple("Rocm").field(device).finish(),
126 #[cfg(wgpu_vulkan)]
127 Self::Vulkan(device) => f.debug_tuple("Vulkan").field(device).finish(),
128 #[cfg(wgpu_webgpu)]
129 Self::Wgpu(device) => f.debug_tuple("Wgpu").field(device).finish(),
130 #[cfg(feature = "ndarray")]
131 Self::NdArray(device) => f.debug_tuple("NdArray").field(device).finish(),
132 #[cfg(feature = "tch")]
133 Self::LibTorch(device) => f.debug_tuple("LibTorch").field(device).finish(),
134 #[cfg(feature = "autodiff")]
135 Self::Autodiff(device) => f.debug_tuple("Autodiff").field(&device.inner).finish(),
137 }
138 }
139}
140
141impl Default for DispatchDevice {
142 #[allow(unreachable_code)]
143 fn default() -> Self {
144 #[cfg(feature = "cuda")]
147 return Self::Cuda(CudaDevice::default());
148
149 #[cfg(wgpu_metal)]
150 return Self::Metal(burn_wgpu::WgpuDevice::default());
151
152 #[cfg(feature = "rocm")]
153 return Self::Rocm(RocmDevice::default());
154
155 #[cfg(wgpu_vulkan)]
156 return Self::Vulkan(burn_wgpu::WgpuDevice::default());
157
158 #[cfg(wgpu_webgpu)]
159 return Self::Wgpu(burn_wgpu::WgpuDevice::default());
160
161 #[cfg(feature = "cpu")]
162 return Self::Cpu(CpuDevice);
163
164 #[cfg(feature = "tch")]
165 return Self::LibTorch(LibTorchDevice::default());
166
167 #[cfg(feature = "ndarray")]
168 return Self::NdArray(NdArrayDevice::default());
169 }
170}
171
172impl PartialEq for DispatchDevice {
173 fn eq(&self, other: &Self) -> bool {
174 match (self, other) {
175 #[cfg(feature = "autodiff")]
177 (DispatchDevice::Autodiff(a), DispatchDevice::Autodiff(b)) => a == b,
178 #[cfg(feature = "autodiff")]
180 (DispatchDevice::Autodiff(a), b) => a.inner.as_ref() == b,
181 #[cfg(feature = "autodiff")]
182 (a, DispatchDevice::Autodiff(b)) => a == b.inner.as_ref(),
183 #[cfg(feature = "cpu")]
184 (Self::Cpu(a), Self::Cpu(b)) => a == b,
185 #[cfg(feature = "cuda")]
186 (Self::Cuda(a), Self::Cuda(b)) => a == b,
187 #[cfg(wgpu_metal)]
188 (Self::Metal(a), Self::Metal(b)) => a == b,
189 #[cfg(feature = "rocm")]
190 (Self::Rocm(a), Self::Rocm(b)) => a == b,
191 #[cfg(wgpu_vulkan)]
192 (Self::Vulkan(a), Self::Vulkan(b)) => a == b,
193 #[cfg(wgpu_webgpu)]
194 (Self::Wgpu(a), Self::Wgpu(b)) => a == b,
195 #[cfg(feature = "ndarray")]
196 (Self::NdArray(a), Self::NdArray(b)) => a == b,
197 #[cfg(feature = "tch")]
198 (Self::LibTorch(a), Self::LibTorch(b)) => a == b,
199 #[allow(unreachable_patterns)]
200 (_, _) => false,
201 }
202 }
203}
204
205const TYPE_ID_BASE: u16 = 10;
208
209impl DispatchDevice {
210 #[cfg(feature = "autodiff")]
211 pub fn autodiff(device: impl Into<DispatchDevice>) -> DispatchDevice {
213 Self::autodiff_checkpointed(device, CheckpointingStrategy::None)
214 }
215 #[cfg(feature = "autodiff")]
216 pub fn autodiff_checkpointed(
218 device: impl Into<DispatchDevice>,
219 checkpointing: CheckpointingStrategy,
220 ) -> DispatchDevice {
221 let device = device.into();
222 DispatchDevice::Autodiff(AutodiffDevice::new(device, checkpointing))
223 }
224
225 fn backend_id(&self) -> BackendId {
227 match self {
228 #[cfg(feature = "cpu")]
229 Self::Cpu(_) => BackendId::Cpu,
230 #[cfg(feature = "cuda")]
231 Self::Cuda(_) => BackendId::Cuda,
232 #[cfg(wgpu_metal)]
233 Self::Metal(_) => BackendId::Metal,
234 #[cfg(feature = "rocm")]
235 Self::Rocm(_) => BackendId::Rocm,
236 #[cfg(wgpu_vulkan)]
237 Self::Vulkan(_) => BackendId::Vulkan,
238 #[cfg(wgpu_webgpu)]
239 Self::Wgpu(_) => BackendId::Wgpu,
240 #[cfg(feature = "ndarray")]
241 Self::NdArray(_) => BackendId::NdArray,
242 #[cfg(feature = "tch")]
243 Self::LibTorch(_) => BackendId::LibTorch,
244 #[cfg(feature = "autodiff")]
245 Self::Autodiff(device) => device.inner.backend_id(),
246 }
247 }
248
249 fn encode_type_id(&self, backend_type_id: u16) -> u16 {
251 u16::from(self.backend_id()) * TYPE_ID_BASE + backend_type_id
252 }
253
254 pub(crate) fn decode_type_id(type_id: u16) -> (BackendId, u16) {
256 let variant = type_id / TYPE_ID_BASE;
257 let backend_type_id = type_id % TYPE_ID_BASE;
258 (
259 BackendId::try_from(variant).expect("Unknown DispatchDevice variant"),
260 backend_type_id,
261 )
262 }
263}
264
265#[derive(Debug, Clone, Copy, PartialEq, Eq)]
266#[repr(u16)]
267pub(crate) enum BackendId {
268 #[cfg(feature = "cpu")]
269 Cpu = 0,
270 #[cfg(feature = "cuda")]
271 Cuda = 1,
272 #[cfg(wgpu_metal)]
273 Metal = 2,
274 #[cfg(feature = "rocm")]
275 Rocm = 3,
276 #[cfg(wgpu_vulkan)]
277 Vulkan = 4,
278 #[cfg(wgpu_webgpu)]
279 Wgpu = 5,
280 #[cfg(feature = "ndarray")]
281 NdArray = 6,
282 #[cfg(feature = "tch")]
283 LibTorch = 7,
284}
285
286impl From<BackendId> for u16 {
287 fn from(variant: BackendId) -> Self {
288 variant as u16
289 }
290}
291
292impl TryFrom<u16> for BackendId {
293 type Error = ();
294
295 fn try_from(value: u16) -> Result<Self, Self::Error> {
296 match value {
297 #[cfg(feature = "cpu")]
298 0 => Ok(Self::Cpu),
299 #[cfg(feature = "cuda")]
300 1 => Ok(Self::Cuda),
301 #[cfg(wgpu_metal)]
302 2 => Ok(Self::Metal),
303 #[cfg(feature = "rocm")]
304 3 => Ok(Self::Rocm),
305 #[cfg(wgpu_vulkan)]
306 4 => Ok(Self::Vulkan),
307 #[cfg(wgpu_webgpu)]
308 5 => Ok(Self::Wgpu),
309 #[cfg(feature = "ndarray")]
310 6 => Ok(Self::NdArray),
311 #[cfg(feature = "tch")]
312 7 => Ok(Self::LibTorch),
313 _ => Err(()),
314 }
315 }
316}
317
318impl DeviceOps for DispatchDevice {
319 fn inner(&self) -> &Self {
320 match self {
321 #[cfg(feature = "autodiff")]
322 DispatchDevice::Autodiff(device) => &device.inner,
323 device => device,
324 }
325 }
326}
327
328impl burn_backend::Device for DispatchDevice {
329 fn from_id(mut device_id: DeviceId) -> Self {
330 let (dispatch_id, backend_type_id) = Self::decode_type_id(device_id.type_id);
331 device_id.type_id = backend_type_id;
332
333 match dispatch_id {
334 #[cfg(feature = "cpu")]
335 BackendId::Cpu => Self::Cpu(CpuDevice::from_id(device_id)),
336 #[cfg(feature = "cuda")]
337 BackendId::Cuda => Self::Cuda(CudaDevice::from_id(device_id)),
338 #[cfg(wgpu_metal)]
339 BackendId::Metal => Self::Metal(WgpuDevice::from_id(device_id)),
340 #[cfg(feature = "rocm")]
341 BackendId::Rocm => Self::Rocm(RocmDevice::from_id(device_id)),
342 #[cfg(wgpu_vulkan)]
343 BackendId::Vulkan => Self::Vulkan(WgpuDevice::from_id(device_id)),
344 #[cfg(wgpu_webgpu)]
345 BackendId::Wgpu => Self::Wgpu(WgpuDevice::from_id(device_id)),
346 #[cfg(feature = "ndarray")]
347 BackendId::NdArray => Self::NdArray(NdArrayDevice::from_id(device_id)),
348 #[cfg(feature = "tch")]
349 BackendId::LibTorch => Self::LibTorch(LibTorchDevice::from_id(device_id)),
350 }
351 }
352
353 fn to_id(&self) -> DeviceId {
354 let mut device_id = match self {
355 #[cfg(feature = "cpu")]
356 Self::Cpu(device) => device.to_id(),
357 #[cfg(feature = "cuda")]
358 Self::Cuda(device) => device.to_id(),
359 #[cfg(wgpu_metal)]
360 Self::Metal(device) => device.to_id(),
361 #[cfg(feature = "rocm")]
362 Self::Rocm(device) => device.to_id(),
363 #[cfg(wgpu_vulkan)]
364 Self::Vulkan(device) => device.to_id(),
365 #[cfg(wgpu_webgpu)]
366 Self::Wgpu(device) => device.to_id(),
367 #[cfg(feature = "ndarray")]
368 Self::NdArray(device) => device.to_id(),
369 #[cfg(feature = "tch")]
370 Self::LibTorch(device) => device.to_id(),
371 #[cfg(feature = "autodiff")]
372 Self::Autodiff(device) => device.inner.to_id(),
373 };
374 device_id.type_id = self.encode_type_id(device_id.type_id);
375 device_id
376 }
377}
378
379#[cfg(feature = "cpu")]
380impl From<CpuDevice> for DispatchDevice {
381 fn from(device: CpuDevice) -> Self {
382 DispatchDevice::Cpu(device)
383 }
384}
385
386#[cfg(feature = "cuda")]
387impl From<CudaDevice> for DispatchDevice {
388 fn from(device: CudaDevice) -> Self {
389 DispatchDevice::Cuda(device)
390 }
391}
392
393#[cfg(wgpu_metal)]
394impl From<WgpuDevice> for DispatchDevice {
395 fn from(device: WgpuDevice) -> Self {
396 DispatchDevice::Metal(device)
397 }
398}
399
400#[cfg(feature = "rocm")]
401impl From<RocmDevice> for DispatchDevice {
402 fn from(device: RocmDevice) -> Self {
403 DispatchDevice::Rocm(device)
404 }
405}
406
407#[cfg(wgpu_vulkan)]
408impl From<WgpuDevice> for DispatchDevice {
409 fn from(device: WgpuDevice) -> Self {
410 DispatchDevice::Vulkan(device)
411 }
412}
413
414#[cfg(wgpu_webgpu)]
415impl From<WgpuDevice> for DispatchDevice {
416 fn from(device: WgpuDevice) -> Self {
417 DispatchDevice::Wgpu(device)
418 }
419}
420
421#[cfg(feature = "ndarray")]
422impl From<NdArrayDevice> for DispatchDevice {
423 fn from(device: NdArrayDevice) -> Self {
424 DispatchDevice::NdArray(device)
425 }
426}
427
428#[cfg(feature = "tch")]
429impl From<LibTorchDevice> for DispatchDevice {
430 fn from(device: LibTorchDevice) -> Self {
431 DispatchDevice::LibTorch(device)
432 }
433}