1use crate::{
2 AutoCompiler, AutoGraphicsApi, GraphicsApi, WgpuDevice, backend, compute::WgpuServer,
3 contiguous_strides,
4};
5use cubecl_common::device::{Device, DeviceState};
6use cubecl_common::{future, profile::TimingMethod};
7use cubecl_core::{Runtime, ir::TargetProperties};
8use cubecl_core::{ir::LineSize, server::ServerUtilities};
9use cubecl_ir::{DeviceProperties, HardwareProperties, MemoryDeviceProperties};
10pub use cubecl_runtime::memory_management::MemoryConfiguration;
11use cubecl_runtime::{
12 client::ComputeClient,
13 logging::{ProfileLevel, ServerLogger},
14};
15use wgpu::{InstanceFlags, RequestAdapterOptions};
16
17#[derive(Debug)]
21pub struct WgpuRuntime;
22
23impl DeviceState for WgpuServer {
24 fn init(device_id: cubecl_common::device::DeviceId) -> Self {
25 let device = WgpuDevice::from_id(device_id);
26 let setup = future::block_on(create_setup_for_device(&device, AutoGraphicsApi::backend()));
27 create_server(setup, RuntimeOptions::default())
28 }
29}
30
31impl Runtime for WgpuRuntime {
32 type Compiler = AutoCompiler;
33 type Server = WgpuServer;
34 type Device = WgpuDevice;
35
36 fn client(device: &Self::Device) -> ComputeClient<Self> {
37 ComputeClient::load(device)
38 }
39
40 fn name(client: &ComputeClient<Self>) -> &'static str {
41 match client.info() {
42 wgpu::Backend::Vulkan => {
43 #[cfg(feature = "spirv")]
44 return "wgpu<spirv>";
45
46 #[cfg(not(feature = "spirv"))]
47 return "wgpu<wgsl>";
48 }
49 wgpu::Backend::Metal => {
50 #[cfg(feature = "msl")]
51 return "wgpu<msl>";
52
53 #[cfg(not(feature = "msl"))]
54 return "wgpu<wgsl>";
55 }
56 _ => "wgpu<wgsl>",
57 }
58 }
59
60 fn supported_line_sizes() -> &'static [LineSize] {
61 #[cfg(feature = "msl")]
62 {
63 &[8, 4, 2, 1]
64 }
65 #[cfg(not(feature = "msl"))]
66 {
67 &[4, 2, 1]
68 }
69 }
70
71 fn max_global_line_size() -> LineSize {
72 4
73 }
74
75 fn max_cube_count() -> (u32, u32, u32) {
76 let max_dim = u16::MAX as u32;
77 (max_dim, max_dim, max_dim)
78 }
79
80 fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool {
81 if shape.is_empty() {
82 return true;
83 }
84
85 for (expected, &stride) in contiguous_strides(shape).into_iter().zip(strides) {
86 if expected != stride {
87 return false;
88 }
89 }
90
91 true
92 }
93
94 fn target_properties() -> TargetProperties {
95 TargetProperties {
96 mma: Default::default(),
98 }
99 }
100}
101
102pub struct RuntimeOptions {
104 pub tasks_max: usize,
106 pub memory_config: MemoryConfiguration,
108}
109
110impl Default for RuntimeOptions {
111 fn default() -> Self {
112 #[cfg(test)]
113 const DEFAULT_MAX_TASKS: usize = 1;
114 #[cfg(not(test))]
115 const DEFAULT_MAX_TASKS: usize = 32;
116
117 let tasks_max = match std::env::var("CUBECL_WGPU_MAX_TASKS") {
118 Ok(value) => value
119 .parse::<usize>()
120 .expect("CUBECL_WGPU_MAX_TASKS should be a positive integer."),
121 Err(_) => DEFAULT_MAX_TASKS,
122 };
123
124 Self {
125 tasks_max,
126 memory_config: MemoryConfiguration::default(),
127 }
128 }
129}
130
131#[derive(Clone, Debug)]
135pub struct WgpuSetup {
136 pub instance: wgpu::Instance,
138 pub adapter: wgpu::Adapter,
140 pub device: wgpu::Device,
142 pub queue: wgpu::Queue,
144 pub backend: wgpu::Backend,
146}
147
148pub fn init_device(setup: WgpuSetup, options: RuntimeOptions) -> WgpuDevice {
158 use core::sync::atomic::{AtomicU32, Ordering};
159
160 static COUNTER: AtomicU32 = AtomicU32::new(0);
161
162 let device_id = COUNTER.fetch_add(1, Ordering::Relaxed);
163 if device_id == u32::MAX {
164 core::panic!("Memory ID overflowed");
165 }
166
167 let device_id = WgpuDevice::Existing(device_id);
168 let server = create_server(setup, options);
169 let _ = ComputeClient::<WgpuRuntime>::init(&device_id, server);
170 device_id
171}
172
173pub fn init_setup<G: GraphicsApi>(device: &WgpuDevice, options: RuntimeOptions) -> WgpuSetup {
176 cfg_if::cfg_if! {
177 if #[cfg(target_family = "wasm")] {
178 let _ = (device, options);
179 panic!("Creating a wgpu setup synchronously is unsupported on wasm. Use init_async instead");
180 } else {
181 future::block_on(init_setup_async::<G>(device, options))
182 }
183 }
184}
185
186pub async fn init_setup_async<G: GraphicsApi>(
190 device: &WgpuDevice,
191 options: RuntimeOptions,
192) -> WgpuSetup {
193 let setup = create_setup_for_device(device, G::backend()).await;
194 let return_setup = setup.clone();
195 let server = create_server(setup, options);
196 let _ = ComputeClient::<WgpuRuntime>::init(device, server);
197 return_setup
198}
199
200pub(crate) fn create_server(setup: WgpuSetup, options: RuntimeOptions) -> WgpuServer {
201 let limits = setup.device.limits();
202 let adapter_limits = setup.adapter.limits();
203 let mut adapter_info = setup.adapter.get_info();
204
205 if adapter_info.subgroup_min_size == 0 && adapter_info.subgroup_max_size == 0 {
209 adapter_info.subgroup_min_size = 8;
211 adapter_info.subgroup_max_size = 128;
213 }
214
215 let mem_props = MemoryDeviceProperties {
216 max_page_size: limits.max_storage_buffer_binding_size as u64,
217 alignment: limits.min_storage_buffer_offset_alignment as u64,
218 };
219 let max_count = adapter_limits.max_compute_workgroups_per_dimension;
220 let hardware_props = HardwareProperties {
221 load_width: 128,
222 #[cfg(apple_silicon)]
226 plane_size_min: 32,
227 #[cfg(not(apple_silicon))]
228 plane_size_min: adapter_info.subgroup_min_size,
229 #[cfg(apple_silicon)]
230 plane_size_max: 32,
231 #[cfg(not(apple_silicon))]
232 plane_size_max: adapter_info.subgroup_max_size,
233 max_bindings: limits
237 .max_storage_buffers_per_shader_stage
238 .saturating_sub(1),
239 max_shared_memory_size: limits.max_compute_workgroup_storage_size as usize,
240 max_cube_count: (max_count, max_count, max_count),
241 max_units_per_cube: adapter_limits.max_compute_invocations_per_workgroup,
242 max_cube_dim: (
243 adapter_limits.max_compute_workgroup_size_x,
244 adapter_limits.max_compute_workgroup_size_y,
245 adapter_limits.max_compute_workgroup_size_z,
246 ),
247 num_streaming_multiprocessors: None,
248 num_tensor_cores: None,
249 min_tensor_cores_dim: None,
250 num_cpu_cores: None, };
252
253 let mut compilation_options = Default::default();
254
255 let features = setup.adapter.features();
256
257 let time_measurement = if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
258 TimingMethod::Device
259 } else {
260 TimingMethod::System
261 };
262
263 let mut device_props = DeviceProperties::new(
264 Default::default(),
265 mem_props.clone(),
266 hardware_props,
267 time_measurement,
268 );
269
270 #[cfg(not(all(target_os = "macos", feature = "msl")))]
271 {
272 if features.contains(wgpu::Features::SUBGROUP)
273 && setup.adapter.get_info().device_type != wgpu::DeviceType::Cpu
274 {
275 use cubecl_ir::features::Plane;
276
277 device_props.features.plane.insert(Plane::Ops);
278 }
279 }
280
281 #[cfg(any(feature = "spirv", feature = "msl"))]
282 device_props
283 .features
284 .plane
285 .insert(cubecl_ir::features::Plane::NonUniformControlFlow);
286
287 backend::register_features(&setup.adapter, &mut device_props, &mut compilation_options);
288
289 let logger = alloc::sync::Arc::new(ServerLogger::default());
290
291 WgpuServer::new(
292 mem_props,
293 options.memory_config,
294 compilation_options,
295 setup.device.clone(),
296 setup.queue,
297 options.tasks_max,
298 setup.backend,
299 time_measurement,
300 ServerUtilities::new(device_props, logger, setup.backend),
301 )
302}
303
304pub(crate) async fn create_setup_for_device(
307 device: &WgpuDevice,
308 backend: wgpu::Backend,
309) -> WgpuSetup {
310 let (instance, adapter) = request_adapter(device, backend).await;
311 let (device, queue) = backend::request_device(&adapter).await;
312
313 log::info!(
314 "Created wgpu compute server on device {:?} => {:?}",
315 device,
316 adapter.get_info()
317 );
318
319 WgpuSetup {
320 instance,
321 adapter,
322 device,
323 queue,
324 backend,
325 }
326}
327
328async fn request_adapter(
329 device: &WgpuDevice,
330 backend: wgpu::Backend,
331) -> (wgpu::Instance, wgpu::Adapter) {
332 let debug = ServerLogger::default();
333 let instance_flags = match (debug.profile_level(), debug.compilation_activated()) {
334 (Some(ProfileLevel::Full), _) => InstanceFlags::advanced_debugging(),
335 (_, true) => InstanceFlags::debugging(),
336 (_, false) => InstanceFlags::default(),
337 };
338 log::debug!("{instance_flags:?}");
339 let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
340 backends: backend.into(),
341 flags: instance_flags,
342 ..Default::default()
343 });
344
345 #[allow(deprecated)]
346 let override_device = if matches!(
347 device,
348 WgpuDevice::DefaultDevice | WgpuDevice::BestAvailable
349 ) {
350 get_device_override()
351 } else {
352 None
353 };
354
355 let device = override_device.unwrap_or_else(|| device.clone());
356
357 let adapter = match device {
358 #[cfg(not(target_family = "wasm"))]
359 WgpuDevice::DiscreteGpu(num) => {
360 select_from_adapter_list(
361 num,
362 "No Discrete GPU device found",
363 &instance,
364 &device,
365 backend,
366 )
367 .await
368 }
369 #[cfg(not(target_family = "wasm"))]
370 WgpuDevice::IntegratedGpu(num) => {
371 select_from_adapter_list(
372 num,
373 "No Integrated GPU device found",
374 &instance,
375 &device,
376 backend,
377 )
378 .await
379 }
380 #[cfg(not(target_family = "wasm"))]
381 WgpuDevice::VirtualGpu(num) => {
382 select_from_adapter_list(
383 num,
384 "No Virtual GPU device found",
385 &instance,
386 &device,
387 backend,
388 )
389 .await
390 }
391 #[cfg(not(target_family = "wasm"))]
392 WgpuDevice::Cpu => {
393 select_from_adapter_list(0, "No CPU device found", &instance, &device, backend).await
394 }
395 #[cfg(target_family = "wasm")]
396 WgpuDevice::IntegratedGpu(_) => {
397 request_adapter_with_preference(&instance, wgpu::PowerPreference::LowPower).await
398 }
399 WgpuDevice::Existing(_) => {
400 unreachable!("Cannot select an adapter for an existing device.")
401 }
402 _ => {
403 request_adapter_with_preference(&instance, wgpu::PowerPreference::HighPerformance).await
404 }
405 };
406
407 log::info!("Using adapter {:?}", adapter.get_info());
408
409 (instance, adapter)
410}
411
412async fn request_adapter_with_preference(
413 instance: &wgpu::Instance,
414 power_preference: wgpu::PowerPreference,
415) -> wgpu::Adapter {
416 instance
417 .request_adapter(&RequestAdapterOptions {
418 power_preference,
419 force_fallback_adapter: false,
420 compatible_surface: None,
421 })
422 .await
423 .expect("No possible adapter available for backend. Falling back to first available.")
424}
425
426#[cfg(not(target_family = "wasm"))]
427async fn select_from_adapter_list(
428 num: usize,
429 error: &str,
430 instance: &wgpu::Instance,
431 device: &WgpuDevice,
432 backend: wgpu::Backend,
433) -> wgpu::Adapter {
434 let mut adapters_other = Vec::new();
435 let mut adapters = Vec::new();
436
437 instance
438 .enumerate_adapters(backend.into())
439 .await
440 .into_iter()
441 .for_each(|adapter| {
442 let device_type = adapter.get_info().device_type;
443
444 if let wgpu::DeviceType::Other = device_type {
445 adapters_other.push(adapter);
446 return;
447 }
448
449 let is_same_type = match device {
450 WgpuDevice::DiscreteGpu(_) => device_type == wgpu::DeviceType::DiscreteGpu,
451 WgpuDevice::IntegratedGpu(_) => device_type == wgpu::DeviceType::IntegratedGpu,
452 WgpuDevice::VirtualGpu(_) => device_type == wgpu::DeviceType::VirtualGpu,
453 WgpuDevice::Cpu => device_type == wgpu::DeviceType::Cpu,
454 #[allow(deprecated)]
455 WgpuDevice::DefaultDevice | WgpuDevice::BestAvailable => true,
456 WgpuDevice::Existing(_) => {
457 unreachable!("Cannot select an adapter for an existing device.")
458 }
459 };
460
461 if is_same_type {
462 adapters.push(adapter);
463 }
464 });
465
466 if adapters.len() <= num {
467 if adapters_other.len() <= num {
468 panic!(
469 "{}, adapters {:?}, other adapters {:?}",
470 error,
471 adapters
472 .into_iter()
473 .map(|adapter| adapter.get_info())
474 .collect::<Vec<_>>(),
475 adapters_other
476 .into_iter()
477 .map(|adapter| adapter.get_info())
478 .collect::<Vec<_>>(),
479 );
480 }
481
482 return adapters_other.remove(num);
483 }
484
485 adapters.remove(num)
486}
487
488fn get_device_override() -> Option<WgpuDevice> {
489 std::env::var("CUBECL_WGPU_DEFAULT_DEVICE")
492 .ok()
493 .and_then(|var| {
494 let override_device = if let Some(inner) = var.strip_prefix("DiscreteGpu(") {
495 inner
496 .strip_suffix(")")
497 .and_then(|s| s.parse().ok())
498 .map(WgpuDevice::DiscreteGpu)
499 } else if let Some(inner) = var.strip_prefix("IntegratedGpu(") {
500 inner
501 .strip_suffix(")")
502 .and_then(|s| s.parse().ok())
503 .map(WgpuDevice::IntegratedGpu)
504 } else if let Some(inner) = var.strip_prefix("VirtualGpu(") {
505 inner
506 .strip_suffix(")")
507 .and_then(|s| s.parse().ok())
508 .map(WgpuDevice::VirtualGpu)
509 } else if var == "Cpu" {
510 Some(WgpuDevice::Cpu)
511 } else {
512 None
513 };
514
515 if override_device.is_none() {
516 log::warn!("Unknown CUBECL_WGPU_DEVICE override {var}");
517 }
518 override_device
519 })
520}