cubecl_wgpu/device.rs
1use cubecl_common::device::{Device, DeviceId};
2
3/// The device struct when using the `wgpu` backend.
4///
5/// Note that you need to provide the device index when using a GPU backend.
6///
7/// # Example
8///
9/// ```ignore
10/// use cubecl_wgpu::WgpuDevice;
11///
12/// let device_gpu_1 = WgpuDevice::DiscreteGpu(0); // First discrete GPU found.
13/// let device_gpu_2 = WgpuDevice::DiscreteGpu(1); // Second discrete GPU found.
14/// ```
15#[derive(Clone, Debug, Hash, PartialEq, Eq, Default)]
16pub enum WgpuDevice {
17 /// Discrete GPU with the given index. The index is the index of the discrete GPU in the list
18 /// of all discrete GPUs found on the system.
19 DiscreteGpu(usize),
20
21 /// Integrated GPU with the given index. The index is the index of the integrated GPU in the
22 /// list of all integrated GPUs found on the system.
23 IntegratedGpu(usize),
24
25 /// Virtual GPU with the given index. The index is the index of the virtual GPU in the list of
26 /// all virtual GPUs found on the system.
27 VirtualGpu(usize),
28
29 /// CPU.
30 Cpu,
31
32 /// The best available device found with the current [graphics API](crate::GraphicsApi).
33 ///
34 /// This will prioritize GPUs wgpu recognizes as "high power". Additionally, you can override this using
35 /// the `CUBECL_WGPU_DEFAULT_DEVICE` environment variable. This variable is spelled as if i was a WgpuDevice,
36 /// so for example CUBECL_WGPU_DEFAULT_DEVICE=IntegratedGpu(1) or CUBECL_WGPU_DEFAULT_DEVICE=Cpu
37 #[default]
38 DefaultDevice,
39
40 /// Deprecated, use [`DefaultDevice`](WgpuDevice::DefaultDevice).
41 #[deprecated]
42 BestAvailable,
43
44 /// Use an externally created, existing, wgpu setup. This is helpful when using CubeCL in conjunction
45 /// with some existing wgpu setup (eg. egui or bevy), as resources can be transferred in & out of CubeCL.
46 ///
47 /// # Notes
48 ///
49 /// This can be initialized with [`init_device`](crate::runtime::init_device).
50 Existing(u32),
51}
52
53impl Device for WgpuDevice {
54 fn from_id(device_id: DeviceId) -> Self {
55 match device_id.type_id {
56 0 => Self::DiscreteGpu(device_id.index_id as usize),
57 1 => Self::IntegratedGpu(device_id.index_id as usize),
58 2 => Self::VirtualGpu(device_id.index_id as usize),
59 3 => Self::Cpu,
60 4 => Self::DefaultDevice,
61 5 => Self::Existing(device_id.index_id),
62 _ => Self::DefaultDevice,
63 }
64 }
65
66 fn to_id(&self) -> DeviceId {
67 #[allow(deprecated)]
68 match self {
69 Self::DiscreteGpu(index) => DeviceId::new(0, *index as u32),
70 Self::IntegratedGpu(index) => DeviceId::new(1, *index as u32),
71 Self::VirtualGpu(index) => DeviceId::new(2, *index as u32),
72 Self::Cpu => DeviceId::new(3, 0),
73 Self::BestAvailable | WgpuDevice::DefaultDevice => DeviceId::new(4, 0),
74 Self::Existing(id) => DeviceId::new(5, *id),
75 }
76 }
77
78 fn device_count(type_id: u16) -> usize {
79 #[cfg(target_family = "wasm")]
80 {
81 // WebGPU only supports a single device currently.
82 1
83 }
84
85 #[cfg(not(target_family = "wasm"))]
86 {
87 let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
88 backends: wgpu::Backends::all(),
89 ..Default::default()
90 });
91 let adapters: Vec<_> = instance
92 .enumerate_adapters(wgpu::Backends::all())
93 .into_iter()
94 .filter(|adapter| {
95 // Default doesn't filter device types.
96 if type_id == 4 {
97 return true;
98 }
99
100 let device_type = adapter.get_info().device_type;
101
102 let adapter_type_id = match device_type {
103 wgpu::DeviceType::Other => 4,
104 wgpu::DeviceType::IntegratedGpu => 1,
105 wgpu::DeviceType::DiscreteGpu => 0,
106 wgpu::DeviceType::VirtualGpu => 2,
107 wgpu::DeviceType::Cpu => 3,
108 };
109
110 adapter_type_id == type_id
111 })
112 .collect();
113 adapters.len()
114 }
115 }
116
117 fn device_count_total() -> usize {
118 #[cfg(target_family = "wasm")]
119 {
120 // WebGPU only supports a single device currently.
121 1
122 }
123
124 #[cfg(not(target_family = "wasm"))]
125 {
126 let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
127 backends: wgpu::Backends::all(),
128 ..Default::default()
129 });
130 let adapters: Vec<_> = instance
131 .enumerate_adapters(wgpu::Backends::all())
132 .into_iter()
133 .collect();
134 adapters.len()
135 }
136 }
137}