use crate::{
AutoCompiler, AutoGraphicsApi, GraphicsApi, WgpuDevice, backend, compute::WgpuServer,
contiguous_strides,
};
use cubecl_common::device::{Device, DeviceState};
use cubecl_common::{future, profile::TimingMethod};
use cubecl_core::{Runtime, ir::TargetProperties};
use cubecl_core::{ir::LineSize, server::ServerUtilities};
use cubecl_ir::{DeviceProperties, HardwareProperties, MemoryDeviceProperties};
pub use cubecl_runtime::memory_management::MemoryConfiguration;
use cubecl_runtime::{
client::ComputeClient,
logging::{ProfileLevel, ServerLogger},
};
use wgpu::{InstanceFlags, RequestAdapterOptions};
#[derive(Debug)]
pub struct WgpuRuntime;
impl DeviceState for WgpuServer {
fn init(device_id: cubecl_common::device::DeviceId) -> Self {
let device = WgpuDevice::from_id(device_id);
let setup = future::block_on(create_setup_for_device(&device, AutoGraphicsApi::backend()));
create_server(setup, RuntimeOptions::default())
}
}
impl Runtime for WgpuRuntime {
type Compiler = AutoCompiler;
type Server = WgpuServer;
type Device = WgpuDevice;
fn client(device: &Self::Device) -> ComputeClient<Self> {
ComputeClient::load(device)
}
fn name(client: &ComputeClient<Self>) -> &'static str {
match client.info() {
wgpu::Backend::Vulkan => {
#[cfg(feature = "spirv")]
return "wgpu<spirv>";
#[cfg(not(feature = "spirv"))]
return "wgpu<wgsl>";
}
wgpu::Backend::Metal => {
#[cfg(feature = "msl")]
return "wgpu<msl>";
#[cfg(not(feature = "msl"))]
return "wgpu<wgsl>";
}
_ => "wgpu<wgsl>",
}
}
fn supported_line_sizes() -> &'static [LineSize] {
#[cfg(feature = "msl")]
{
&[8, 4, 2, 1]
}
#[cfg(not(feature = "msl"))]
{
&[4, 2, 1]
}
}
fn max_global_line_size() -> LineSize {
4
}
fn max_cube_count() -> (u32, u32, u32) {
let max_dim = u16::MAX as u32;
(max_dim, max_dim, max_dim)
}
fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool {
if shape.is_empty() {
return true;
}
for (expected, &stride) in contiguous_strides(shape).into_iter().zip(strides) {
if expected != stride {
return false;
}
}
true
}
fn target_properties() -> TargetProperties {
TargetProperties {
mma: Default::default(),
}
}
}
pub struct RuntimeOptions {
pub tasks_max: usize,
pub memory_config: MemoryConfiguration,
}
impl Default for RuntimeOptions {
fn default() -> Self {
#[cfg(test)]
const DEFAULT_MAX_TASKS: usize = 1;
#[cfg(not(test))]
const DEFAULT_MAX_TASKS: usize = 32;
let tasks_max = match std::env::var("CUBECL_WGPU_MAX_TASKS") {
Ok(value) => value
.parse::<usize>()
.expect("CUBECL_WGPU_MAX_TASKS should be a positive integer."),
Err(_) => DEFAULT_MAX_TASKS,
};
Self {
tasks_max,
memory_config: MemoryConfiguration::default(),
}
}
}
#[derive(Clone, Debug)]
pub struct WgpuSetup {
pub instance: wgpu::Instance,
pub adapter: wgpu::Adapter,
pub device: wgpu::Device,
pub queue: wgpu::Queue,
pub backend: wgpu::Backend,
}
pub fn init_device(setup: WgpuSetup, options: RuntimeOptions) -> WgpuDevice {
use core::sync::atomic::{AtomicU32, Ordering};
static COUNTER: AtomicU32 = AtomicU32::new(0);
let device_id = COUNTER.fetch_add(1, Ordering::Relaxed);
if device_id == u32::MAX {
core::panic!("Memory ID overflowed");
}
let device_id = WgpuDevice::Existing(device_id);
let server = create_server(setup, options);
let _ = ComputeClient::<WgpuRuntime>::init(&device_id, server);
device_id
}
pub fn init_setup<G: GraphicsApi>(device: &WgpuDevice, options: RuntimeOptions) -> WgpuSetup {
cfg_if::cfg_if! {
if #[cfg(target_family = "wasm")] {
let _ = (device, options);
panic!("Creating a wgpu setup synchronously is unsupported on wasm. Use init_async instead");
} else {
future::block_on(init_setup_async::<G>(device, options))
}
}
}
pub async fn init_setup_async<G: GraphicsApi>(
device: &WgpuDevice,
options: RuntimeOptions,
) -> WgpuSetup {
let setup = create_setup_for_device(device, G::backend()).await;
let return_setup = setup.clone();
let server = create_server(setup, options);
let _ = ComputeClient::<WgpuRuntime>::init(device, server);
return_setup
}
pub(crate) fn create_server(setup: WgpuSetup, options: RuntimeOptions) -> WgpuServer {
let limits = setup.device.limits();
let mut adapter_limits = setup.adapter.limits();
if adapter_limits.min_subgroup_size == 0 && adapter_limits.max_subgroup_size == 0 {
adapter_limits.min_subgroup_size = 8;
adapter_limits.max_subgroup_size = 128;
}
let mem_props = MemoryDeviceProperties {
max_page_size: limits.max_storage_buffer_binding_size as u64,
alignment: limits.min_storage_buffer_offset_alignment as u64,
};
let max_count = adapter_limits.max_compute_workgroups_per_dimension;
let hardware_props = HardwareProperties {
load_width: 128,
#[cfg(apple_silicon)]
plane_size_min: 32,
#[cfg(not(apple_silicon))]
plane_size_min: adapter_limits.min_subgroup_size,
#[cfg(apple_silicon)]
plane_size_max: 32,
#[cfg(not(apple_silicon))]
plane_size_max: adapter_limits.max_subgroup_size,
max_bindings: limits
.max_storage_buffers_per_shader_stage
.saturating_sub(1),
max_shared_memory_size: limits.max_compute_workgroup_storage_size as usize,
max_cube_count: (max_count, max_count, max_count),
max_units_per_cube: adapter_limits.max_compute_invocations_per_workgroup,
max_cube_dim: (
adapter_limits.max_compute_workgroup_size_x,
adapter_limits.max_compute_workgroup_size_y,
adapter_limits.max_compute_workgroup_size_z,
),
num_streaming_multiprocessors: None,
num_tensor_cores: None,
min_tensor_cores_dim: None,
num_cpu_cores: None, };
let mut compilation_options = Default::default();
let features = setup.adapter.features();
let time_measurement = if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
TimingMethod::Device
} else {
TimingMethod::System
};
let mut device_props = DeviceProperties::new(
Default::default(),
mem_props.clone(),
hardware_props,
time_measurement,
);
#[cfg(not(all(target_os = "macos", feature = "msl")))]
{
if features.contains(wgpu::Features::SUBGROUP)
&& setup.adapter.get_info().device_type != wgpu::DeviceType::Cpu
{
use cubecl_ir::features::Plane;
device_props.features.plane.insert(Plane::Ops);
}
}
#[cfg(any(feature = "spirv", feature = "msl"))]
device_props
.features
.plane
.insert(cubecl_ir::features::Plane::NonUniformControlFlow);
backend::register_features(&setup.adapter, &mut device_props, &mut compilation_options);
let logger = alloc::sync::Arc::new(ServerLogger::default());
WgpuServer::new(
mem_props,
options.memory_config,
compilation_options,
setup.device.clone(),
setup.queue,
options.tasks_max,
setup.backend,
time_measurement,
ServerUtilities::new(device_props, logger, setup.backend),
)
}
pub(crate) async fn create_setup_for_device(
device: &WgpuDevice,
backend: wgpu::Backend,
) -> WgpuSetup {
let (instance, adapter) = request_adapter(device, backend).await;
let (device, queue) = backend::request_device(&adapter).await;
log::info!(
"Created wgpu compute server on device {:?} => {:?}",
device,
adapter.get_info()
);
WgpuSetup {
instance,
adapter,
device,
queue,
backend,
}
}
async fn request_adapter(
device: &WgpuDevice,
backend: wgpu::Backend,
) -> (wgpu::Instance, wgpu::Adapter) {
let debug = ServerLogger::default();
let instance_flags = match (debug.profile_level(), debug.compilation_activated()) {
(Some(ProfileLevel::Full), _) => InstanceFlags::advanced_debugging(),
(_, true) => InstanceFlags::debugging(),
(_, false) => InstanceFlags::default(),
};
log::debug!("{instance_flags:?}");
let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
backends: backend.into(),
flags: instance_flags,
..Default::default()
});
#[allow(deprecated)]
let override_device = if matches!(
device,
WgpuDevice::DefaultDevice | WgpuDevice::BestAvailable
) {
get_device_override()
} else {
None
};
let device = override_device.unwrap_or_else(|| device.clone());
let adapter = match device {
#[cfg(not(target_family = "wasm"))]
WgpuDevice::DiscreteGpu(num) => select_from_adapter_list(
num,
"No Discrete GPU device found",
&instance,
&device,
backend,
),
#[cfg(not(target_family = "wasm"))]
WgpuDevice::IntegratedGpu(num) => select_from_adapter_list(
num,
"No Integrated GPU device found",
&instance,
&device,
backend,
),
#[cfg(not(target_family = "wasm"))]
WgpuDevice::VirtualGpu(num) => select_from_adapter_list(
num,
"No Virtual GPU device found",
&instance,
&device,
backend,
),
#[cfg(not(target_family = "wasm"))]
WgpuDevice::Cpu => {
select_from_adapter_list(0, "No CPU device found", &instance, &device, backend)
}
WgpuDevice::Existing(_) => {
unreachable!("Cannot select an adapter for an existing device.")
}
_ => instance
.request_adapter(&RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
force_fallback_adapter: false,
compatible_surface: None,
})
.await
.expect("No possible adapter available for backend. Falling back to first available."),
};
log::info!("Using adapter {:?}", adapter.get_info());
(instance, adapter)
}
#[cfg(not(target_family = "wasm"))]
fn select_from_adapter_list(
num: usize,
error: &str,
instance: &wgpu::Instance,
device: &WgpuDevice,
backend: wgpu::Backend,
) -> wgpu::Adapter {
let mut adapters_other = Vec::new();
let mut adapters = Vec::new();
instance
.enumerate_adapters(backend.into())
.into_iter()
.for_each(|adapter| {
let device_type = adapter.get_info().device_type;
if let wgpu::DeviceType::Other = device_type {
adapters_other.push(adapter);
return;
}
let is_same_type = match device {
WgpuDevice::DiscreteGpu(_) => device_type == wgpu::DeviceType::DiscreteGpu,
WgpuDevice::IntegratedGpu(_) => device_type == wgpu::DeviceType::IntegratedGpu,
WgpuDevice::VirtualGpu(_) => device_type == wgpu::DeviceType::VirtualGpu,
WgpuDevice::Cpu => device_type == wgpu::DeviceType::Cpu,
#[allow(deprecated)]
WgpuDevice::DefaultDevice | WgpuDevice::BestAvailable => true,
WgpuDevice::Existing(_) => {
unreachable!("Cannot select an adapter for an existing device.")
}
};
if is_same_type {
adapters.push(adapter);
}
});
if adapters.len() <= num {
if adapters_other.len() <= num {
panic!(
"{}, adapters {:?}, other adapters {:?}",
error,
adapters
.into_iter()
.map(|adapter| adapter.get_info())
.collect::<Vec<_>>(),
adapters_other
.into_iter()
.map(|adapter| adapter.get_info())
.collect::<Vec<_>>(),
);
}
return adapters_other.remove(num);
}
adapters.remove(num)
}
fn get_device_override() -> Option<WgpuDevice> {
std::env::var("CUBECL_WGPU_DEFAULT_DEVICE")
.ok()
.and_then(|var| {
let override_device = if let Some(inner) = var.strip_prefix("DiscreteGpu(") {
inner
.strip_suffix(")")
.and_then(|s| s.parse().ok())
.map(WgpuDevice::DiscreteGpu)
} else if let Some(inner) = var.strip_prefix("IntegratedGpu(") {
inner
.strip_suffix(")")
.and_then(|s| s.parse().ok())
.map(WgpuDevice::IntegratedGpu)
} else if let Some(inner) = var.strip_prefix("VirtualGpu(") {
inner
.strip_suffix(")")
.and_then(|s| s.parse().ok())
.map(WgpuDevice::VirtualGpu)
} else if var == "Cpu" {
Some(WgpuDevice::Cpu)
} else {
None
};
if override_device.is_none() {
log::warn!("Unknown CUBECL_WGPU_DEVICE override {var}");
}
override_device
})
}