1use crate::{
2 AutoCompiler, AutoGraphicsApi, GraphicsApi, WgpuDevice, backend, compute::WgpuServer,
3 contiguous_strides,
4};
5use cubecl_common::device::{Device, DeviceService};
6use cubecl_common::{future, profile::TimingMethod};
7use cubecl_core::device::{DeviceId, ServerUtilitiesHandle};
8use cubecl_core::server::ServerUtilities;
9use cubecl_core::zspace::{Shape, Strides};
10use cubecl_core::{Runtime, ir::TargetProperties};
11use cubecl_ir::{DeviceProperties, HardwareProperties, MemoryDeviceProperties};
12use cubecl_runtime::allocator::ContiguousMemoryLayoutPolicy;
13#[cfg(not(feature = "vulkan-validate"))]
14use cubecl_runtime::logging::ProfileLevel;
15pub use cubecl_runtime::memory_management::MemoryConfiguration;
16use cubecl_runtime::{client::ComputeClient, logging::ServerLogger};
17use wgpu::{InstanceFlags, RequestAdapterOptions};
18
19#[derive(Debug, Clone)]
23pub struct WgpuRuntime;
24
25impl DeviceService for WgpuServer {
26 fn init(device_id: cubecl_common::device::DeviceId) -> Self {
27 let device = WgpuDevice::from_id(device_id);
28 let setup = future::block_on(create_setup_for_device(&device, AutoGraphicsApi::backend()));
29 create_server(setup, RuntimeOptions::default())
30 }
31
32 fn utilities(&self) -> ServerUtilitiesHandle {
33 self.utilities.clone() as ServerUtilitiesHandle
34 }
35}
36
37impl Runtime for WgpuRuntime {
38 type Compiler = AutoCompiler;
39 type Server = WgpuServer;
40 type Device = WgpuDevice;
41
42 fn client(device: &Self::Device) -> ComputeClient<Self> {
43 ComputeClient::load(device)
44 }
45
46 fn name(client: &ComputeClient<Self>) -> &'static str {
47 match client.info() {
48 wgpu::Backend::Vulkan => {
49 #[cfg(feature = "spirv")]
50 return "wgpu<spirv>";
51
52 #[cfg(not(feature = "spirv"))]
53 return "wgpu<wgsl>";
54 }
55 wgpu::Backend::Metal => {
56 #[cfg(feature = "msl")]
57 return "wgpu<msl>";
58
59 #[cfg(not(feature = "msl"))]
60 return "wgpu<wgsl>";
61 }
62 _ => "wgpu<wgsl>",
63 }
64 }
65
66 fn max_cube_count() -> (u32, u32, u32) {
67 let max_dim = u16::MAX as u32;
68 (max_dim, max_dim, max_dim)
69 }
70
71 fn can_read_tensor(shape: &Shape, strides: &Strides) -> bool {
72 if shape.is_empty() {
73 return true;
74 }
75
76 for (&expected, &stride) in contiguous_strides(shape).iter().zip(strides.iter()) {
77 if expected != stride {
78 return false;
79 }
80 }
81
82 true
83 }
84
85 fn target_properties() -> TargetProperties {
86 TargetProperties {
87 mma: Default::default(),
89 }
90 }
91
92 fn enumerate_devices(type_id: u16, info: &wgpu::Backend) -> Vec<DeviceId> {
93 #[cfg(target_family = "wasm")]
94 {
95 let _ = type_id;
96 let _ = info;
97 vec![DeviceId::new(0, 0)]
99 }
100
101 #[cfg(not(target_family = "wasm"))]
102 {
103 let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
104 backends: wgpu::Backends::all(),
105 ..wgpu::InstanceDescriptor::new_without_display_handle()
106 });
107
108 let adapters = enumerate_all_adapters(instance, *info);
109 adapters
110 .into_iter()
111 .filter(|adapter| {
112 if type_id == 4 {
114 return true;
115 }
116
117 let device_type = adapter.get_info().device_type;
118
119 let adapter_type_id = match device_type {
120 wgpu::DeviceType::Other => 4,
121 wgpu::DeviceType::IntegratedGpu => 1,
122 wgpu::DeviceType::DiscreteGpu => 0,
123 wgpu::DeviceType::VirtualGpu => 2,
124 wgpu::DeviceType::Cpu => 3,
125 };
126
127 adapter_type_id == type_id
128 })
129 .enumerate()
130 .map(|(index, adapter)| match adapter.get_info().device_type {
131 wgpu::DeviceType::DiscreteGpu => DeviceId::new(0, index as u16),
132 wgpu::DeviceType::IntegratedGpu => DeviceId::new(1, index as u16),
133 wgpu::DeviceType::VirtualGpu => DeviceId::new(2, index as u16),
134 wgpu::DeviceType::Cpu => DeviceId::new(3, 0),
135 wgpu::DeviceType::Other => DeviceId::new(4, 0),
136 })
137 .collect()
138 }
139 }
140
141 fn enumerate_all_devices(info: &wgpu::Backend) -> Vec<DeviceId> {
142 #[cfg(target_family = "wasm")]
143 {
144 let _ = info;
145 vec![DeviceId::new(0, 0)]
147 }
148
149 #[cfg(not(target_family = "wasm"))]
150 {
151 let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
152 backends: wgpu::Backends::all(),
153 ..wgpu::InstanceDescriptor::new_without_display_handle()
154 });
155 let adapters = enumerate_all_adapters(instance, *info);
156 adapters
157 .into_iter()
158 .enumerate()
159 .map(|(index, adapter)| match adapter.get_info().device_type {
160 wgpu::DeviceType::DiscreteGpu => DeviceId::new(0, index as u16),
161 wgpu::DeviceType::IntegratedGpu => DeviceId::new(1, index as u16),
162 wgpu::DeviceType::VirtualGpu => DeviceId::new(2, index as u16),
163 wgpu::DeviceType::Cpu => DeviceId::new(3, 0),
164 wgpu::DeviceType::Other => DeviceId::new(4, 0),
165 })
166 .collect()
167 }
168 }
169}
170
171#[cfg(not(target_family = "wasm"))]
172fn enumerate_all_adapters(instance: wgpu::Instance, backend: wgpu::Backend) -> Vec<wgpu::Adapter> {
173 cubecl_common::future::block_on(instance.enumerate_adapters(backend.into()))
175}
176
177pub struct RuntimeOptions {
179 pub tasks_max: usize,
181 pub memory_config: MemoryConfiguration,
183}
184
185impl Default for RuntimeOptions {
186 fn default() -> Self {
187 #[cfg(test)]
188 const DEFAULT_MAX_TASKS: usize = 32;
189 #[cfg(not(test))]
190 const DEFAULT_MAX_TASKS: usize = 32;
191
192 let tasks_max = match std::env::var("CUBECL_WGPU_MAX_TASKS") {
193 Ok(value) => value
194 .parse::<usize>()
195 .expect("CUBECL_WGPU_MAX_TASKS should be a positive integer."),
196 Err(_) => DEFAULT_MAX_TASKS,
197 };
198
199 Self {
200 tasks_max,
201 memory_config: MemoryConfiguration::default(),
202 }
203 }
204}
205
206#[derive(Clone, Debug)]
210pub struct WgpuSetup {
211 pub instance: wgpu::Instance,
213 pub adapter: wgpu::Adapter,
215 pub device: wgpu::Device,
217 pub queue: wgpu::Queue,
219 pub backend: wgpu::Backend,
221}
222
223pub fn init_device(setup: WgpuSetup, options: RuntimeOptions) -> WgpuDevice {
233 use core::sync::atomic::{AtomicU32, Ordering};
234
235 static COUNTER: AtomicU32 = AtomicU32::new(0);
236
237 let device_id = COUNTER.fetch_add(1, Ordering::Relaxed);
238 if device_id == u32::MAX {
239 core::panic!("Memory ID overflowed");
240 }
241
242 let device_id = WgpuDevice::Existing(device_id);
243 let server = create_server(setup, options);
244 let _ = ComputeClient::<WgpuRuntime>::init(&device_id, server);
245 device_id
246}
247
248pub fn init_setup<G: GraphicsApi>(device: &WgpuDevice, options: RuntimeOptions) -> WgpuSetup {
251 cfg_if::cfg_if! {
252 if #[cfg(target_family = "wasm")] {
253 let _ = (device, options);
254 panic!("Creating a wgpu setup synchronously is unsupported on wasm. Use init_async instead");
255 } else {
256 future::block_on(init_setup_async::<G>(device, options))
257 }
258 }
259}
260
261pub async fn init_setup_async<G: GraphicsApi>(
265 device: &WgpuDevice,
266 options: RuntimeOptions,
267) -> WgpuSetup {
268 let setup = create_setup_for_device(device, G::backend()).await;
269 let return_setup = setup.clone();
270 let server = create_server(setup, options);
271 let _ = ComputeClient::<WgpuRuntime>::init(device, server);
272 return_setup
273}
274
275pub(crate) fn create_server(setup: WgpuSetup, options: RuntimeOptions) -> WgpuServer {
276 let limits = setup.device.limits();
277 let adapter_limits = setup.adapter.limits();
278 let mut adapter_info = setup.adapter.get_info();
279
280 if adapter_info.subgroup_min_size == 0 && adapter_info.subgroup_max_size == 0 {
284 adapter_info.subgroup_min_size = 8;
286 adapter_info.subgroup_max_size = 128;
288 }
289
290 let mem_props = MemoryDeviceProperties {
291 max_page_size: limits.max_storage_buffer_binding_size,
292 alignment: limits.min_uniform_buffer_offset_alignment as u64,
293 };
294 let max_count = adapter_limits.max_compute_workgroups_per_dimension;
295 let hardware_props = HardwareProperties {
296 load_width: 128,
297 #[cfg(apple_silicon)]
301 plane_size_min: 32,
302 #[cfg(not(apple_silicon))]
303 plane_size_min: adapter_info.subgroup_min_size,
304 #[cfg(apple_silicon)]
305 plane_size_max: 32,
306 #[cfg(not(apple_silicon))]
307 plane_size_max: adapter_info.subgroup_max_size,
308 max_bindings: limits
312 .max_storage_buffers_per_shader_stage
313 .saturating_sub(1),
314 max_shared_memory_size: limits.max_compute_workgroup_storage_size as usize,
315 max_cube_count: (max_count, max_count, max_count),
316 max_units_per_cube: adapter_limits.max_compute_invocations_per_workgroup,
317 max_cube_dim: (
318 adapter_limits.max_compute_workgroup_size_x,
319 adapter_limits.max_compute_workgroup_size_y,
320 adapter_limits.max_compute_workgroup_size_z,
321 ),
322 num_streaming_multiprocessors: None,
323 num_tensor_cores: None,
324 min_tensor_cores_dim: None,
325 num_cpu_cores: None, max_vector_size: 4,
327 };
328
329 let mut compilation_options = Default::default();
330
331 let features = setup.adapter.features();
332
333 let time_measurement = if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
334 TimingMethod::Device
335 } else {
336 TimingMethod::System
337 };
338
339 let mut device_props = DeviceProperties::new(
340 Default::default(),
341 mem_props,
342 hardware_props,
343 time_measurement,
344 );
345
346 #[cfg(not(all(target_os = "macos", feature = "msl")))]
347 {
348 if features.contains(wgpu::Features::SUBGROUP)
349 && setup.adapter.get_info().device_type != wgpu::DeviceType::Cpu
350 {
351 use cubecl_ir::features::Plane;
352
353 device_props.features.plane.insert(Plane::Ops);
354 }
355 }
356
357 #[cfg(any(feature = "spirv", feature = "msl"))]
358 device_props
359 .features
360 .plane
361 .insert(cubecl_ir::features::Plane::NonUniformControlFlow);
362
363 backend::register_features(
364 &setup.adapter,
365 &mut device_props,
366 &mut compilation_options,
367 &options.memory_config,
368 );
369
370 let logger = alloc::sync::Arc::new(ServerLogger::default());
371
372 let allocator = ContiguousMemoryLayoutPolicy::new(device_props.memory.alignment as usize);
373 WgpuServer::new(
374 device_props.memory.clone(),
375 options.memory_config,
376 compilation_options,
377 setup.device.clone(),
378 setup.queue,
379 options.tasks_max,
380 setup.backend,
381 time_measurement,
382 ServerUtilities::new(device_props, logger, setup.backend, allocator),
383 )
384}
385
386pub(crate) async fn create_setup_for_device(
389 device: &WgpuDevice,
390 backend: wgpu::Backend,
391) -> WgpuSetup {
392 let (instance, adapter) = request_adapter(device, backend).await;
393 let (device, queue) = backend::request_device(&adapter).await;
394
395 log::info!(
396 "Created wgpu compute server on device {:?} => {:?}",
397 device,
398 adapter.get_info()
399 );
400
401 WgpuSetup {
402 instance,
403 adapter,
404 device,
405 queue,
406 backend,
407 }
408}
409
410async fn request_adapter(
411 device: &WgpuDevice,
412 backend: wgpu::Backend,
413) -> (wgpu::Instance, wgpu::Adapter) {
414 #[cfg(not(feature = "vulkan-validate"))]
415 let instance_flags = {
416 let debug = ServerLogger::default();
417 match (debug.profile_level(), debug.compilation_activated()) {
418 (Some(ProfileLevel::Full), _) => InstanceFlags::advanced_debugging(),
419 (_, true) => InstanceFlags::debugging(),
420 (_, false) => InstanceFlags::default(),
421 }
422 };
423 #[cfg(feature = "vulkan-validate")]
424 let instance_flags = InstanceFlags::advanced_debugging();
425 log::debug!("{instance_flags:?}");
426 let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
427 backends: backend.into(),
428 flags: instance_flags,
429 ..wgpu::InstanceDescriptor::new_without_display_handle()
430 });
431
432 #[allow(deprecated)]
433 let override_device = if matches!(
434 device,
435 WgpuDevice::DefaultDevice | WgpuDevice::BestAvailable
436 ) {
437 get_device_override()
438 } else {
439 None
440 };
441
442 let device = override_device.unwrap_or_else(|| device.clone());
443
444 let adapter = match device {
445 #[cfg(not(target_family = "wasm"))]
446 WgpuDevice::DiscreteGpu(num) => {
447 select_from_adapter_list(
448 num,
449 "No Discrete GPU device found",
450 &instance,
451 &device,
452 backend,
453 )
454 .await
455 }
456 #[cfg(not(target_family = "wasm"))]
457 WgpuDevice::IntegratedGpu(num) => {
458 select_from_adapter_list(
459 num,
460 "No Integrated GPU device found",
461 &instance,
462 &device,
463 backend,
464 )
465 .await
466 }
467 #[cfg(not(target_family = "wasm"))]
468 WgpuDevice::VirtualGpu(num) => {
469 select_from_adapter_list(
470 num,
471 "No Virtual GPU device found",
472 &instance,
473 &device,
474 backend,
475 )
476 .await
477 }
478 #[cfg(not(target_family = "wasm"))]
479 WgpuDevice::Cpu => {
480 select_from_adapter_list(0, "No CPU device found", &instance, &device, backend).await
481 }
482 #[cfg(target_family = "wasm")]
483 WgpuDevice::IntegratedGpu(_) => {
484 request_adapter_with_preference(&instance, wgpu::PowerPreference::LowPower).await
485 }
486 WgpuDevice::Existing(_) => {
487 unreachable!("Cannot select an adapter for an existing device.")
488 }
489 _ => {
490 request_adapter_with_preference(&instance, wgpu::PowerPreference::HighPerformance).await
491 }
492 };
493
494 log::info!("Using adapter {:?}", adapter.get_info());
495
496 (instance, adapter)
497}
498
499async fn request_adapter_with_preference(
500 instance: &wgpu::Instance,
501 power_preference: wgpu::PowerPreference,
502) -> wgpu::Adapter {
503 instance
504 .request_adapter(&RequestAdapterOptions {
505 power_preference,
506 force_fallback_adapter: false,
507 compatible_surface: None,
508 })
509 .await
510 .expect("No possible adapter available for backend. Falling back to first available.")
511}
512
513#[cfg(not(target_family = "wasm"))]
514async fn select_from_adapter_list(
515 num: usize,
516 error: &str,
517 instance: &wgpu::Instance,
518 device: &WgpuDevice,
519 backend: wgpu::Backend,
520) -> wgpu::Adapter {
521 let mut adapters_other = Vec::new();
522 let mut adapters = Vec::new();
523
524 instance
525 .enumerate_adapters(backend.into())
526 .await
527 .into_iter()
528 .for_each(|adapter| {
529 let device_type = adapter.get_info().device_type;
530
531 if let wgpu::DeviceType::Other = device_type {
532 adapters_other.push(adapter);
533 return;
534 }
535
536 let is_same_type = match device {
537 WgpuDevice::DiscreteGpu(_) => device_type == wgpu::DeviceType::DiscreteGpu,
538 WgpuDevice::IntegratedGpu(_) => device_type == wgpu::DeviceType::IntegratedGpu,
539 WgpuDevice::VirtualGpu(_) => device_type == wgpu::DeviceType::VirtualGpu,
540 WgpuDevice::Cpu => device_type == wgpu::DeviceType::Cpu,
541 #[allow(deprecated)]
542 WgpuDevice::DefaultDevice | WgpuDevice::BestAvailable => true,
543 WgpuDevice::Existing(_) => {
544 unreachable!("Cannot select an adapter for an existing device.")
545 }
546 };
547
548 if is_same_type {
549 adapters.push(adapter);
550 }
551 });
552
553 if adapters.len() <= num {
554 if adapters_other.len() <= num {
555 panic!(
556 "{}, adapters {:?}, other adapters {:?}",
557 error,
558 adapters
559 .into_iter()
560 .map(|adapter| adapter.get_info())
561 .collect::<Vec<_>>(),
562 adapters_other
563 .into_iter()
564 .map(|adapter| adapter.get_info())
565 .collect::<Vec<_>>(),
566 );
567 }
568
569 return adapters_other.remove(num);
570 }
571
572 adapters.remove(num)
573}
574
575fn get_device_override() -> Option<WgpuDevice> {
576 std::env::var("CUBECL_WGPU_DEFAULT_DEVICE")
579 .ok()
580 .and_then(|var| {
581 let override_device = if let Some(inner) = var.strip_prefix("DiscreteGpu(") {
582 inner
583 .strip_suffix(")")
584 .and_then(|s| s.parse().ok())
585 .map(WgpuDevice::DiscreteGpu)
586 } else if let Some(inner) = var.strip_prefix("IntegratedGpu(") {
587 inner
588 .strip_suffix(")")
589 .and_then(|s| s.parse().ok())
590 .map(WgpuDevice::IntegratedGpu)
591 } else if let Some(inner) = var.strip_prefix("VirtualGpu(") {
592 inner
593 .strip_suffix(")")
594 .and_then(|s| s.parse().ok())
595 .map(WgpuDevice::VirtualGpu)
596 } else if var == "Cpu" {
597 Some(WgpuDevice::Cpu)
598 } else {
599 None
600 };
601
602 if override_device.is_none() {
603 log::warn!("Unknown CUBECL_WGPU_DEVICE override {var}");
604 }
605 override_device
606 })
607}