cubecl_wgpu/
device.rs

1use cubecl_common::device::{Device, DeviceId};
2
3/// The device struct when using the `wgpu` backend.
4///
5/// Note that you need to provide the device index when using a GPU backend.
6///
7/// # Example
8///
9/// ```ignore
10/// use cubecl_wgpu::WgpuDevice;
11///
12/// let device_gpu_1 = WgpuDevice::DiscreteGpu(0); // First discrete GPU found.
13/// let device_gpu_2 = WgpuDevice::DiscreteGpu(1);  // Second discrete GPU found.
14/// ```
15#[derive(Clone, Debug, Hash, PartialEq, Eq, Default)]
16pub enum WgpuDevice {
17    /// Discrete GPU with the given index. The index is the index of the discrete GPU in the list
18    /// of all discrete GPUs found on the system.
19    DiscreteGpu(usize),
20
21    /// Integrated GPU with the given index. The index is the index of the integrated GPU in the
22    /// list of all integrated GPUs found on the system.
23    IntegratedGpu(usize),
24
25    /// Virtual GPU with the given index. The index is the index of the virtual GPU in the list of
26    /// all virtual GPUs found on the system.
27    VirtualGpu(usize),
28
29    /// CPU.
30    Cpu,
31
32    /// The best available device found with the current [graphics API](crate::GraphicsApi).
33    ///
34    /// This will prioritize GPUs wgpu recognizes as "high power". Additionally, you can override this using
35    /// the `CUBECL_WGPU_DEFAULT_DEVICE` environment variable. This variable is spelled as if i was a WgpuDevice,
36    /// so for example CUBECL_WGPU_DEFAULT_DEVICE=IntegratedGpu(1) or CUBECL_WGPU_DEFAULT_DEVICE=Cpu
37    #[default]
38    DefaultDevice,
39
40    /// Deprecated, use [`DefaultDevice`](WgpuDevice::DefaultDevice).
41    #[deprecated]
42    BestAvailable,
43
44    /// Use an externally created, existing, wgpu setup. This is helpful when using CubeCL in conjunction
45    /// with some existing wgpu setup (eg. egui or bevy), as resources can be transferred in & out of CubeCL.
46    ///
47    /// # Notes
48    ///
49    /// This can be initialized with [`init_device`](crate::runtime::init_device).
50    Existing(u32),
51}
52
53impl Device for WgpuDevice {
54    fn from_id(device_id: DeviceId) -> Self {
55        match device_id.type_id {
56            0 => Self::DiscreteGpu(device_id.index_id as usize),
57            1 => Self::IntegratedGpu(device_id.index_id as usize),
58            2 => Self::VirtualGpu(device_id.index_id as usize),
59            3 => Self::Cpu,
60            4 => Self::DefaultDevice,
61            5 => Self::Existing(device_id.index_id),
62            _ => Self::DefaultDevice,
63        }
64    }
65
66    fn to_id(&self) -> DeviceId {
67        #[allow(deprecated)]
68        match self {
69            Self::DiscreteGpu(index) => DeviceId::new(0, *index as u32),
70            Self::IntegratedGpu(index) => DeviceId::new(1, *index as u32),
71            Self::VirtualGpu(index) => DeviceId::new(2, *index as u32),
72            Self::Cpu => DeviceId::new(3, 0),
73            Self::BestAvailable | WgpuDevice::DefaultDevice => DeviceId::new(4, 0),
74            Self::Existing(id) => DeviceId::new(5, *id),
75        }
76    }
77
78    fn device_count(type_id: u16) -> usize {
79        #[cfg(target_family = "wasm")]
80        {
81            // WebGPU only supports a single device currently.
82            1
83        }
84
85        #[cfg(not(target_family = "wasm"))]
86        {
87            let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
88                backends: wgpu::Backends::all(),
89                ..Default::default()
90            });
91            let adapters: Vec<_> = instance
92                .enumerate_adapters(wgpu::Backends::all())
93                .into_iter()
94                .filter(|adapter| {
95                    // Default doesn't filter device types.
96                    if type_id == 4 {
97                        return true;
98                    }
99
100                    let device_type = adapter.get_info().device_type;
101
102                    let adapter_type_id = match device_type {
103                        wgpu::DeviceType::Other => 4,
104                        wgpu::DeviceType::IntegratedGpu => 1,
105                        wgpu::DeviceType::DiscreteGpu => 0,
106                        wgpu::DeviceType::VirtualGpu => 2,
107                        wgpu::DeviceType::Cpu => 3,
108                    };
109
110                    adapter_type_id == type_id
111                })
112                .collect();
113            adapters.len()
114        }
115    }
116
117    fn device_count_total() -> usize {
118        #[cfg(target_family = "wasm")]
119        {
120            // WebGPU only supports a single device currently.
121            1
122        }
123
124        #[cfg(not(target_family = "wasm"))]
125        {
126            let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
127                backends: wgpu::Backends::all(),
128                ..Default::default()
129            });
130            let adapters: Vec<_> = instance
131                .enumerate_adapters(wgpu::Backends::all())
132                .into_iter()
133                .collect();
134            adapters.len()
135        }
136    }
137}