Skip to main content

cubecl_wgpu/
device.rs

1use cubecl_common::device::{Device, DeviceId};
2
3/// The device struct when using the `wgpu` backend.
4///
5/// Note that you need to provide the device index when using a GPU backend.
6///
7/// # Example
8///
9/// ```ignore
10/// use cubecl_wgpu::WgpuDevice;
11///
12/// let device_gpu_1 = WgpuDevice::DiscreteGpu(0); // First discrete GPU found.
13/// let device_gpu_2 = WgpuDevice::DiscreteGpu(1);  // Second discrete GPU found.
14/// ```
15#[derive(Clone, Debug, Hash, PartialEq, Eq, Default)]
16pub enum WgpuDevice {
17    /// Discrete GPU with the given index. The index is the index of the discrete GPU in the list
18    /// of all discrete GPUs found on the system.
19    DiscreteGpu(usize),
20
21    /// Integrated GPU with the given index. The index is the index of the integrated GPU in the
22    /// list of all integrated GPUs found on the system.
23    IntegratedGpu(usize),
24
25    /// Virtual GPU with the given index. The index is the index of the virtual GPU in the list of
26    /// all virtual GPUs found on the system.
27    VirtualGpu(usize),
28
29    /// CPU.
30    Cpu,
31
32    /// The best available device found with the current [graphics API](crate::GraphicsApi).
33    ///
34    /// This will prioritize GPUs wgpu recognizes as "high power". Additionally, you can override this using
35    /// the `CUBECL_WGPU_DEFAULT_DEVICE` environment variable. This variable is spelled as if i was a `WgpuDevice`,
36    /// so for example `CUBECL_WGPU_DEFAULT_DEVICE=IntegratedGpu(1)` or `CUBECL_WGPU_DEFAULT_DEVICE=Cpu`
37    #[default]
38    DefaultDevice,
39
40    /// Deprecated, use [`DefaultDevice`](WgpuDevice::DefaultDevice).
41    #[deprecated]
42    BestAvailable,
43
44    /// Use an externally created, existing, wgpu setup. This is helpful when using `CubeCL` in conjunction
45    /// with some existing wgpu setup (eg. egui or bevy), as resources can be transferred in & out of `CubeCL`.
46    ///
47    /// # Notes
48    ///
49    /// This can be initialized with [`init_device`](crate::runtime::init_device).
50    Existing(u32),
51}
52
53impl Device for WgpuDevice {
54    fn from_id(device_id: DeviceId) -> Self {
55        match device_id.type_id {
56            0 => Self::DiscreteGpu(device_id.index_id as usize),
57            1 => Self::IntegratedGpu(device_id.index_id as usize),
58            2 => Self::VirtualGpu(device_id.index_id as usize),
59            3 => Self::Cpu,
60            4 => Self::DefaultDevice,
61            5 => Self::Existing(device_id.index_id),
62            _ => Self::DefaultDevice,
63        }
64    }
65
66    fn to_id(&self) -> DeviceId {
67        #[allow(deprecated)]
68        match self {
69            Self::DiscreteGpu(index) => DeviceId::new(0, *index as u32),
70            Self::IntegratedGpu(index) => DeviceId::new(1, *index as u32),
71            Self::VirtualGpu(index) => DeviceId::new(2, *index as u32),
72            Self::Cpu => DeviceId::new(3, 0),
73            Self::BestAvailable | WgpuDevice::DefaultDevice => DeviceId::new(4, 0),
74            Self::Existing(id) => DeviceId::new(5, *id),
75        }
76    }
77
78    fn device_count(type_id: u16) -> usize {
79        #[cfg(target_family = "wasm")]
80        {
81            // WebGPU only supports a single device currently.
82            1
83        }
84
85        #[cfg(not(target_family = "wasm"))]
86        {
87            let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
88                backends: wgpu::Backends::all(),
89                ..Default::default()
90            });
91            let adapters: Vec<_> = enumerate_all_adapters(instance)
92                .into_iter()
93                .filter(|adapter| {
94                    // Default doesn't filter device types.
95                    if type_id == 4 {
96                        return true;
97                    }
98
99                    let device_type = adapter.get_info().device_type;
100
101                    let adapter_type_id = match device_type {
102                        wgpu::DeviceType::Other => 4,
103                        wgpu::DeviceType::IntegratedGpu => 1,
104                        wgpu::DeviceType::DiscreteGpu => 0,
105                        wgpu::DeviceType::VirtualGpu => 2,
106                        wgpu::DeviceType::Cpu => 3,
107                    };
108
109                    adapter_type_id == type_id
110                })
111                .collect();
112            adapters.len()
113        }
114    }
115
116    fn device_count_total() -> usize {
117        #[cfg(target_family = "wasm")]
118        {
119            // WebGPU only supports a single device currently.
120            1
121        }
122
123        #[cfg(not(target_family = "wasm"))]
124        {
125            let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
126                backends: wgpu::Backends::all(),
127                ..Default::default()
128            });
129            let adapters = enumerate_all_adapters(instance);
130            adapters.len()
131        }
132    }
133}
134
135#[cfg(not(target_family = "wasm"))]
136fn enumerate_all_adapters(instance: wgpu::Instance) -> Vec<wgpu::Adapter> {
137    // `enumerate_adapters` is now async & available on WebGPU
138    cubecl_common::future::block_on(instance.enumerate_adapters(wgpu::Backends::all()))
139}