1use crate::{
2 AutoCompiler, AutoGraphicsApi, GraphicsApi, WgpuDevice, backend, compute::WgpuServer,
3 contiguous_strides,
4};
5use cubecl_common::device::{Device, DeviceState};
6use cubecl_common::{future, profile::TimingMethod};
7use cubecl_core::server::ServerUtilities;
8use cubecl_core::{Runtime, ir::TargetProperties};
9use cubecl_ir::{DeviceProperties, HardwareProperties, MemoryDeviceProperties};
10pub use cubecl_runtime::memory_management::MemoryConfiguration;
11use cubecl_runtime::{
12 client::ComputeClient,
13 logging::{ProfileLevel, ServerLogger},
14};
15use wgpu::{InstanceFlags, RequestAdapterOptions};
16
17#[derive(Debug)]
21pub struct WgpuRuntime;
22
23impl DeviceState for WgpuServer {
24 fn init(device_id: cubecl_common::device::DeviceId) -> Self {
25 let device = WgpuDevice::from_id(device_id);
26 let setup = future::block_on(create_setup_for_device(&device, AutoGraphicsApi::backend()));
27 create_server(setup, RuntimeOptions::default())
28 }
29}
30
31impl Runtime for WgpuRuntime {
32 type Compiler = AutoCompiler;
33 type Server = WgpuServer;
34 type Device = WgpuDevice;
35
36 fn client(device: &Self::Device) -> ComputeClient<Self> {
37 ComputeClient::load(device)
38 }
39
40 fn name(client: &ComputeClient<Self>) -> &'static str {
41 match client.info() {
42 wgpu::Backend::Vulkan => {
43 #[cfg(feature = "spirv")]
44 return "wgpu<spirv>";
45
46 #[cfg(not(feature = "spirv"))]
47 return "wgpu<wgsl>";
48 }
49 wgpu::Backend::Metal => {
50 #[cfg(feature = "msl")]
51 return "wgpu<msl>";
52
53 #[cfg(not(feature = "msl"))]
54 return "wgpu<wgsl>";
55 }
56 _ => "wgpu<wgsl>",
57 }
58 }
59
60 fn max_cube_count() -> (u32, u32, u32) {
61 let max_dim = u16::MAX as u32;
62 (max_dim, max_dim, max_dim)
63 }
64
65 fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool {
66 if shape.is_empty() {
67 return true;
68 }
69
70 for (&expected, &stride) in contiguous_strides(shape).iter().zip(strides) {
71 if expected != stride {
72 return false;
73 }
74 }
75
76 true
77 }
78
79 fn target_properties() -> TargetProperties {
80 TargetProperties {
81 mma: Default::default(),
83 }
84 }
85}
86
87pub struct RuntimeOptions {
89 pub tasks_max: usize,
91 pub memory_config: MemoryConfiguration,
93}
94
95impl Default for RuntimeOptions {
96 fn default() -> Self {
97 #[cfg(test)]
98 const DEFAULT_MAX_TASKS: usize = 1;
99 #[cfg(not(test))]
100 const DEFAULT_MAX_TASKS: usize = 32;
101
102 let tasks_max = match std::env::var("CUBECL_WGPU_MAX_TASKS") {
103 Ok(value) => value
104 .parse::<usize>()
105 .expect("CUBECL_WGPU_MAX_TASKS should be a positive integer."),
106 Err(_) => DEFAULT_MAX_TASKS,
107 };
108
109 Self {
110 tasks_max,
111 memory_config: MemoryConfiguration::default(),
112 }
113 }
114}
115
116#[derive(Clone, Debug)]
120pub struct WgpuSetup {
121 pub instance: wgpu::Instance,
123 pub adapter: wgpu::Adapter,
125 pub device: wgpu::Device,
127 pub queue: wgpu::Queue,
129 pub backend: wgpu::Backend,
131}
132
133pub fn init_device(setup: WgpuSetup, options: RuntimeOptions) -> WgpuDevice {
143 use core::sync::atomic::{AtomicU32, Ordering};
144
145 static COUNTER: AtomicU32 = AtomicU32::new(0);
146
147 let device_id = COUNTER.fetch_add(1, Ordering::Relaxed);
148 if device_id == u32::MAX {
149 core::panic!("Memory ID overflowed");
150 }
151
152 let device_id = WgpuDevice::Existing(device_id);
153 let server = create_server(setup, options);
154 let _ = ComputeClient::<WgpuRuntime>::init(&device_id, server);
155 device_id
156}
157
158pub fn init_setup<G: GraphicsApi>(device: &WgpuDevice, options: RuntimeOptions) -> WgpuSetup {
161 cfg_if::cfg_if! {
162 if #[cfg(target_family = "wasm")] {
163 let _ = (device, options);
164 panic!("Creating a wgpu setup synchronously is unsupported on wasm. Use init_async instead");
165 } else {
166 future::block_on(init_setup_async::<G>(device, options))
167 }
168 }
169}
170
171pub async fn init_setup_async<G: GraphicsApi>(
175 device: &WgpuDevice,
176 options: RuntimeOptions,
177) -> WgpuSetup {
178 let setup = create_setup_for_device(device, G::backend()).await;
179 let return_setup = setup.clone();
180 let server = create_server(setup, options);
181 let _ = ComputeClient::<WgpuRuntime>::init(device, server);
182 return_setup
183}
184
185pub(crate) fn create_server(setup: WgpuSetup, options: RuntimeOptions) -> WgpuServer {
186 let limits = setup.device.limits();
187 let adapter_limits = setup.adapter.limits();
188 let mut adapter_info = setup.adapter.get_info();
189
190 if adapter_info.subgroup_min_size == 0 && adapter_info.subgroup_max_size == 0 {
194 adapter_info.subgroup_min_size = 8;
196 adapter_info.subgroup_max_size = 128;
198 }
199
200 let mem_props = MemoryDeviceProperties {
201 max_page_size: limits.max_storage_buffer_binding_size as u64,
202 alignment: limits.min_storage_buffer_offset_alignment as u64,
203 };
204 let max_count = adapter_limits.max_compute_workgroups_per_dimension;
205 let hardware_props = HardwareProperties {
206 load_width: 128,
207 #[cfg(apple_silicon)]
211 plane_size_min: 32,
212 #[cfg(not(apple_silicon))]
213 plane_size_min: adapter_info.subgroup_min_size,
214 #[cfg(apple_silicon)]
215 plane_size_max: 32,
216 #[cfg(not(apple_silicon))]
217 plane_size_max: adapter_info.subgroup_max_size,
218 max_bindings: limits
222 .max_storage_buffers_per_shader_stage
223 .saturating_sub(1),
224 max_shared_memory_size: limits.max_compute_workgroup_storage_size as usize,
225 max_cube_count: (max_count, max_count, max_count),
226 max_units_per_cube: adapter_limits.max_compute_invocations_per_workgroup,
227 max_cube_dim: (
228 adapter_limits.max_compute_workgroup_size_x,
229 adapter_limits.max_compute_workgroup_size_y,
230 adapter_limits.max_compute_workgroup_size_z,
231 ),
232 num_streaming_multiprocessors: None,
233 num_tensor_cores: None,
234 min_tensor_cores_dim: None,
235 num_cpu_cores: None, max_line_size: 4,
237 };
238
239 let mut compilation_options = Default::default();
240
241 let features = setup.adapter.features();
242
243 let time_measurement = if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
244 TimingMethod::Device
245 } else {
246 TimingMethod::System
247 };
248
249 let mut device_props = DeviceProperties::new(
250 Default::default(),
251 mem_props.clone(),
252 hardware_props,
253 time_measurement,
254 );
255
256 #[cfg(not(all(target_os = "macos", feature = "msl")))]
257 {
258 if features.contains(wgpu::Features::SUBGROUP)
259 && setup.adapter.get_info().device_type != wgpu::DeviceType::Cpu
260 {
261 use cubecl_ir::features::Plane;
262
263 device_props.features.plane.insert(Plane::Ops);
264 }
265 }
266
267 #[cfg(any(feature = "spirv", feature = "msl"))]
268 device_props
269 .features
270 .plane
271 .insert(cubecl_ir::features::Plane::NonUniformControlFlow);
272
273 backend::register_features(&setup.adapter, &mut device_props, &mut compilation_options);
274
275 let logger = alloc::sync::Arc::new(ServerLogger::default());
276
277 WgpuServer::new(
278 mem_props,
279 options.memory_config,
280 compilation_options,
281 setup.device.clone(),
282 setup.queue,
283 options.tasks_max,
284 setup.backend,
285 time_measurement,
286 ServerUtilities::new(device_props, logger, setup.backend),
287 )
288}
289
290pub(crate) async fn create_setup_for_device(
293 device: &WgpuDevice,
294 backend: wgpu::Backend,
295) -> WgpuSetup {
296 let (instance, adapter) = request_adapter(device, backend).await;
297 let (device, queue) = backend::request_device(&adapter).await;
298
299 log::info!(
300 "Created wgpu compute server on device {:?} => {:?}",
301 device,
302 adapter.get_info()
303 );
304
305 WgpuSetup {
306 instance,
307 adapter,
308 device,
309 queue,
310 backend,
311 }
312}
313
314async fn request_adapter(
315 device: &WgpuDevice,
316 backend: wgpu::Backend,
317) -> (wgpu::Instance, wgpu::Adapter) {
318 let debug = ServerLogger::default();
319 let instance_flags = match (debug.profile_level(), debug.compilation_activated()) {
320 (Some(ProfileLevel::Full), _) => InstanceFlags::advanced_debugging(),
321 (_, true) => InstanceFlags::debugging(),
322 (_, false) => InstanceFlags::default(),
323 };
324 log::debug!("{instance_flags:?}");
325 let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
326 backends: backend.into(),
327 flags: instance_flags,
328 ..Default::default()
329 });
330
331 #[allow(deprecated)]
332 let override_device = if matches!(
333 device,
334 WgpuDevice::DefaultDevice | WgpuDevice::BestAvailable
335 ) {
336 get_device_override()
337 } else {
338 None
339 };
340
341 let device = override_device.unwrap_or_else(|| device.clone());
342
343 let adapter = match device {
344 #[cfg(not(target_family = "wasm"))]
345 WgpuDevice::DiscreteGpu(num) => {
346 select_from_adapter_list(
347 num,
348 "No Discrete GPU device found",
349 &instance,
350 &device,
351 backend,
352 )
353 .await
354 }
355 #[cfg(not(target_family = "wasm"))]
356 WgpuDevice::IntegratedGpu(num) => {
357 select_from_adapter_list(
358 num,
359 "No Integrated GPU device found",
360 &instance,
361 &device,
362 backend,
363 )
364 .await
365 }
366 #[cfg(not(target_family = "wasm"))]
367 WgpuDevice::VirtualGpu(num) => {
368 select_from_adapter_list(
369 num,
370 "No Virtual GPU device found",
371 &instance,
372 &device,
373 backend,
374 )
375 .await
376 }
377 #[cfg(not(target_family = "wasm"))]
378 WgpuDevice::Cpu => {
379 select_from_adapter_list(0, "No CPU device found", &instance, &device, backend).await
380 }
381 #[cfg(target_family = "wasm")]
382 WgpuDevice::IntegratedGpu(_) => {
383 request_adapter_with_preference(&instance, wgpu::PowerPreference::LowPower).await
384 }
385 WgpuDevice::Existing(_) => {
386 unreachable!("Cannot select an adapter for an existing device.")
387 }
388 _ => {
389 request_adapter_with_preference(&instance, wgpu::PowerPreference::HighPerformance).await
390 }
391 };
392
393 log::info!("Using adapter {:?}", adapter.get_info());
394
395 (instance, adapter)
396}
397
398async fn request_adapter_with_preference(
399 instance: &wgpu::Instance,
400 power_preference: wgpu::PowerPreference,
401) -> wgpu::Adapter {
402 instance
403 .request_adapter(&RequestAdapterOptions {
404 power_preference,
405 force_fallback_adapter: false,
406 compatible_surface: None,
407 })
408 .await
409 .expect("No possible adapter available for backend. Falling back to first available.")
410}
411
412#[cfg(not(target_family = "wasm"))]
413async fn select_from_adapter_list(
414 num: usize,
415 error: &str,
416 instance: &wgpu::Instance,
417 device: &WgpuDevice,
418 backend: wgpu::Backend,
419) -> wgpu::Adapter {
420 let mut adapters_other = Vec::new();
421 let mut adapters = Vec::new();
422
423 instance
424 .enumerate_adapters(backend.into())
425 .await
426 .into_iter()
427 .for_each(|adapter| {
428 let device_type = adapter.get_info().device_type;
429
430 if let wgpu::DeviceType::Other = device_type {
431 adapters_other.push(adapter);
432 return;
433 }
434
435 let is_same_type = match device {
436 WgpuDevice::DiscreteGpu(_) => device_type == wgpu::DeviceType::DiscreteGpu,
437 WgpuDevice::IntegratedGpu(_) => device_type == wgpu::DeviceType::IntegratedGpu,
438 WgpuDevice::VirtualGpu(_) => device_type == wgpu::DeviceType::VirtualGpu,
439 WgpuDevice::Cpu => device_type == wgpu::DeviceType::Cpu,
440 #[allow(deprecated)]
441 WgpuDevice::DefaultDevice | WgpuDevice::BestAvailable => true,
442 WgpuDevice::Existing(_) => {
443 unreachable!("Cannot select an adapter for an existing device.")
444 }
445 };
446
447 if is_same_type {
448 adapters.push(adapter);
449 }
450 });
451
452 if adapters.len() <= num {
453 if adapters_other.len() <= num {
454 panic!(
455 "{}, adapters {:?}, other adapters {:?}",
456 error,
457 adapters
458 .into_iter()
459 .map(|adapter| adapter.get_info())
460 .collect::<Vec<_>>(),
461 adapters_other
462 .into_iter()
463 .map(|adapter| adapter.get_info())
464 .collect::<Vec<_>>(),
465 );
466 }
467
468 return adapters_other.remove(num);
469 }
470
471 adapters.remove(num)
472}
473
474fn get_device_override() -> Option<WgpuDevice> {
475 std::env::var("CUBECL_WGPU_DEFAULT_DEVICE")
478 .ok()
479 .and_then(|var| {
480 let override_device = if let Some(inner) = var.strip_prefix("DiscreteGpu(") {
481 inner
482 .strip_suffix(")")
483 .and_then(|s| s.parse().ok())
484 .map(WgpuDevice::DiscreteGpu)
485 } else if let Some(inner) = var.strip_prefix("IntegratedGpu(") {
486 inner
487 .strip_suffix(")")
488 .and_then(|s| s.parse().ok())
489 .map(WgpuDevice::IntegratedGpu)
490 } else if let Some(inner) = var.strip_prefix("VirtualGpu(") {
491 inner
492 .strip_suffix(")")
493 .and_then(|s| s.parse().ok())
494 .map(WgpuDevice::VirtualGpu)
495 } else if var == "Cpu" {
496 Some(WgpuDevice::Cpu)
497 } else {
498 None
499 };
500
501 if override_device.is_none() {
502 log::warn!("Unknown CUBECL_WGPU_DEVICE override {var}");
503 }
504 override_device
505 })
506}