use cubecl_common::device::{Device, DeviceId};
#[derive(Clone, Debug, Hash, PartialEq, Eq, Default)]
pub enum WgpuDevice {
DiscreteGpu(usize),
IntegratedGpu(usize),
VirtualGpu(usize),
Cpu,
#[default]
DefaultDevice,
#[deprecated]
BestAvailable,
Existing(u32),
}
impl Device for WgpuDevice {
fn from_id(device_id: DeviceId) -> Self {
match device_id.type_id {
0 => Self::DiscreteGpu(device_id.index_id as usize),
1 => Self::IntegratedGpu(device_id.index_id as usize),
2 => Self::VirtualGpu(device_id.index_id as usize),
3 => Self::Cpu,
4 => Self::DefaultDevice,
5 => Self::Existing(device_id.index_id),
_ => Self::DefaultDevice,
}
}
fn to_id(&self) -> DeviceId {
#[allow(deprecated)]
match self {
Self::DiscreteGpu(index) => DeviceId::new(0, *index as u32),
Self::IntegratedGpu(index) => DeviceId::new(1, *index as u32),
Self::VirtualGpu(index) => DeviceId::new(2, *index as u32),
Self::Cpu => DeviceId::new(3, 0),
Self::BestAvailable | WgpuDevice::DefaultDevice => DeviceId::new(4, 0),
Self::Existing(id) => DeviceId::new(5, *id),
}
}
fn device_count(type_id: u16) -> usize {
#[cfg(target_family = "wasm")]
{
1
}
#[cfg(not(target_family = "wasm"))]
{
let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
backends: wgpu::Backends::all(),
..Default::default()
});
let adapters: Vec<_> = instance
.enumerate_adapters(wgpu::Backends::all())
.into_iter()
.filter(|adapter| {
if type_id == 4 {
return true;
}
let device_type = adapter.get_info().device_type;
let adapter_type_id = match device_type {
wgpu::DeviceType::Other => 4,
wgpu::DeviceType::IntegratedGpu => 1,
wgpu::DeviceType::DiscreteGpu => 0,
wgpu::DeviceType::VirtualGpu => 2,
wgpu::DeviceType::Cpu => 3,
};
adapter_type_id == type_id
})
.collect();
adapters.len()
}
}
fn device_count_total() -> usize {
#[cfg(target_family = "wasm")]
{
1
}
#[cfg(not(target_family = "wasm"))]
{
let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
backends: wgpu::Backends::all(),
..Default::default()
});
let adapters: Vec<_> = instance
.enumerate_adapters(wgpu::Backends::all())
.into_iter()
.collect();
adapters.len()
}
}
}