1use crate::{
2 AutoCompiler, AutoGraphicsApi, GraphicsApi, WgpuDevice, backend, compute::WgpuServer,
3 contiguous_strides,
4};
5use cubecl_common::{future, profile::TimingMethod};
6
7#[cfg(not(all(target_os = "macos", feature = "msl")))]
8use cubecl_core::{
9 AtomicFeature, Feature,
10 ir::{Elem, FloatKind},
11};
12use cubecl_core::{CubeCount, CubeDim, Runtime};
13pub use cubecl_runtime::memory_management::MemoryConfiguration;
14use cubecl_runtime::memory_management::MemoryDeviceProperties;
15use cubecl_runtime::{
16 ComputeRuntime,
17 channel::MutexComputeChannel,
18 client::ComputeClient,
19 id::DeviceId,
20 logging::{ProfileLevel, ServerLogger},
21};
22use cubecl_runtime::{DeviceProperties, memory_management::HardwareProperties};
23use wgpu::{InstanceFlags, RequestAdapterOptions};
24
25#[derive(Debug)]
29pub struct WgpuRuntime;
30
31type Server = WgpuServer;
32
33static RUNTIME: ComputeRuntime<WgpuDevice, Server, MutexComputeChannel<Server>> =
35 ComputeRuntime::new();
36
37impl Runtime for WgpuRuntime {
38 type Compiler = AutoCompiler;
39 type Server = WgpuServer;
40
41 type Channel = MutexComputeChannel<WgpuServer>;
42 type Device = WgpuDevice;
43
44 fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel> {
45 RUNTIME.client(device, move || {
46 let setup =
47 future::block_on(create_setup_for_device(device, AutoGraphicsApi::backend()));
48 create_client_on_setup(setup, RuntimeOptions::default())
49 })
50 }
51
52 fn name(client: &ComputeClient<Self::Server, Self::Channel>) -> &'static str {
53 match client.info() {
54 wgpu::Backend::Vulkan => {
55 #[cfg(feature = "spirv")]
56 return "wgpu<spirv>";
57
58 #[cfg(not(feature = "spirv"))]
59 return "wgpu<wgsl>";
60 }
61 wgpu::Backend::Metal => {
62 #[cfg(feature = "msl")]
63 return "wgpu<msl>";
64
65 #[cfg(not(feature = "msl"))]
66 return "wgpu<wgsl>";
67 }
68 _ => "wgpu<wgsl>",
69 }
70 }
71
72 fn supported_line_sizes() -> &'static [u8] {
73 #[cfg(feature = "msl")]
74 {
75 &[8, 4, 2, 1]
76 }
77 #[cfg(not(feature = "msl"))]
78 {
79 &[4, 2, 1]
80 }
81 }
82
83 fn max_cube_count() -> (u32, u32, u32) {
84 let max_dim = u16::MAX as u32;
85 (max_dim, max_dim, max_dim)
86 }
87
88 fn device_id(device: &Self::Device) -> DeviceId {
89 #[allow(deprecated)]
90 match device {
91 WgpuDevice::DiscreteGpu(index) => DeviceId::new(0, *index as u32),
92 WgpuDevice::IntegratedGpu(index) => DeviceId::new(1, *index as u32),
93 WgpuDevice::VirtualGpu(index) => DeviceId::new(2, *index as u32),
94 WgpuDevice::Cpu => DeviceId::new(3, 0),
95 WgpuDevice::BestAvailable | WgpuDevice::DefaultDevice => DeviceId::new(4, 0),
96 WgpuDevice::Existing(id) => DeviceId::new(5, *id),
97 }
98 }
99
100 fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool {
101 if shape.is_empty() {
102 return true;
103 }
104
105 for (expected, &stride) in contiguous_strides(shape).into_iter().zip(strides) {
106 if expected != stride {
107 return false;
108 }
109 }
110
111 true
112 }
113}
114
115pub struct RuntimeOptions {
117 pub tasks_max: usize,
119 pub memory_config: MemoryConfiguration,
121}
122
123impl Default for RuntimeOptions {
124 fn default() -> Self {
125 #[cfg(test)]
126 const DEFAULT_MAX_TASKS: usize = 1;
127 #[cfg(not(test))]
128 const DEFAULT_MAX_TASKS: usize = 32;
129
130 let tasks_max = match std::env::var("CUBECL_WGPU_MAX_TASKS") {
131 Ok(value) => value
132 .parse::<usize>()
133 .expect("CUBECL_WGPU_MAX_TASKS should be a positive integer."),
134 Err(_) => DEFAULT_MAX_TASKS,
135 };
136
137 Self {
138 tasks_max,
139 memory_config: MemoryConfiguration::default(),
140 }
141 }
142}
143
144#[derive(Clone, Debug)]
148pub struct WgpuSetup {
149 pub instance: wgpu::Instance,
151 pub adapter: wgpu::Adapter,
153 pub device: wgpu::Device,
155 pub queue: wgpu::Queue,
157 pub backend: wgpu::Backend,
159}
160
161pub fn init_device(setup: WgpuSetup, options: RuntimeOptions) -> WgpuDevice {
171 use core::sync::atomic::{AtomicU32, Ordering};
172
173 static COUNTER: AtomicU32 = AtomicU32::new(0);
174
175 let device_id = COUNTER.fetch_add(1, Ordering::Relaxed);
176 if device_id == u32::MAX {
177 core::panic!("Memory ID overflowed");
178 }
179
180 let device_id = WgpuDevice::Existing(device_id);
181 let client = create_client_on_setup(setup, options);
182 RUNTIME.register(&device_id, client);
183 device_id
184}
185
186pub fn init_setup<G: GraphicsApi>(device: &WgpuDevice, options: RuntimeOptions) -> WgpuSetup {
189 cfg_if::cfg_if! {
190 if #[cfg(target_family = "wasm")] {
191 let _ = (device, options);
192 panic!("Creating a wgpu setup synchronously is unsupported on wasm. Use init_async instead");
193 } else {
194 future::block_on(init_setup_async::<G>(device, options))
195 }
196 }
197}
198
199pub async fn init_setup_async<G: GraphicsApi>(
203 device: &WgpuDevice,
204 options: RuntimeOptions,
205) -> WgpuSetup {
206 let setup = create_setup_for_device(device, G::backend()).await;
207 let return_setup = setup.clone();
208 let client = create_client_on_setup(setup, options);
209 RUNTIME.register(device, client);
210 return_setup
211}
212
213pub(crate) fn create_client_on_setup(
214 setup: WgpuSetup,
215 options: RuntimeOptions,
216) -> ComputeClient<WgpuServer, MutexComputeChannel<WgpuServer>> {
217 let limits = setup.device.limits();
218 let adapter_limits = setup.adapter.limits();
219
220 let mem_props = MemoryDeviceProperties {
221 max_page_size: limits.max_storage_buffer_binding_size as u64,
222 alignment: limits.min_storage_buffer_offset_alignment as u64,
223 };
224 let max_count = adapter_limits.max_compute_workgroups_per_dimension;
225 let hardware_props = HardwareProperties {
226 #[cfg(apple_silicon)]
230 plane_size_min: 32,
231 #[cfg(not(apple_silicon))]
232 plane_size_min: adapter_limits.min_subgroup_size,
233 #[cfg(apple_silicon)]
234 plane_size_max: 32,
235 #[cfg(not(apple_silicon))]
236 plane_size_max: adapter_limits.max_subgroup_size,
237 max_bindings: limits.max_storage_buffers_per_shader_stage,
238 max_shared_memory_size: limits.max_compute_workgroup_storage_size as usize,
239 max_cube_count: CubeCount::new_3d(max_count, max_count, max_count),
240 max_units_per_cube: adapter_limits.max_compute_invocations_per_workgroup,
241 max_cube_dim: CubeDim::new_3d(
242 adapter_limits.max_compute_workgroup_size_x,
243 adapter_limits.max_compute_workgroup_size_y,
244 adapter_limits.max_compute_workgroup_size_z,
245 ),
246 num_streaming_multiprocessors: None,
247 num_tensor_cores: None,
248 min_tensor_cores_dim: None,
249 };
250
251 let mut compilation_options = Default::default();
252
253 let features = setup.adapter.features();
254
255 let time_measurement = if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
256 TimingMethod::Device
257 } else {
258 TimingMethod::System
259 };
260
261 let mut device_props =
262 DeviceProperties::new(&[], mem_props.clone(), hardware_props, time_measurement);
263
264 #[cfg(not(all(target_os = "macos", feature = "msl")))]
265 {
266 let fake_plane_info =
270 adapter_limits.min_subgroup_size == 0 && adapter_limits.max_subgroup_size == 0;
271
272 if features.contains(wgpu::Features::SUBGROUP)
273 && setup.adapter.get_info().device_type != wgpu::DeviceType::Cpu
274 && !fake_plane_info
275 {
276 device_props.register_feature(Feature::Plane);
277 }
278 }
279
280 backend::register_features(&setup.adapter, &mut device_props, &mut compilation_options);
281
282 let server = WgpuServer::new(
283 mem_props,
284 options.memory_config,
285 compilation_options,
286 setup.device.clone(),
287 setup.queue,
288 options.tasks_max,
289 setup.backend,
290 time_measurement,
291 );
292 let channel = MutexComputeChannel::new(server);
293
294 #[cfg(not(all(target_os = "macos", feature = "msl")))]
295 if features.contains(wgpu::Features::SHADER_FLOAT32_ATOMIC) {
296 device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::F32)));
297
298 device_props.register_feature(Feature::AtomicFloat(AtomicFeature::LoadStore));
299 device_props.register_feature(Feature::AtomicFloat(AtomicFeature::Add));
300 }
301
302 #[cfg(not(all(target_os = "macos", feature = "msl")))]
303 {
304 use cubecl_core::ir::{IntKind, UIntKind};
305
306 device_props.register_feature(Feature::Type(Elem::AtomicInt(IntKind::I32)));
307 device_props.register_feature(Feature::Type(Elem::AtomicUInt(UIntKind::U32)));
308 device_props.register_feature(Feature::AtomicInt(AtomicFeature::LoadStore));
309 device_props.register_feature(Feature::AtomicInt(AtomicFeature::Add));
310 device_props.register_feature(Feature::AtomicUInt(AtomicFeature::LoadStore));
311 device_props.register_feature(Feature::AtomicUInt(AtomicFeature::Add));
312 }
313
314 ComputeClient::new(channel, device_props, setup.backend)
315}
316
317pub(crate) async fn create_setup_for_device(
320 device: &WgpuDevice,
321 backend: wgpu::Backend,
322) -> WgpuSetup {
323 let (instance, adapter) = request_adapter(device, backend).await;
324 let (device, queue) = backend::request_device(&adapter).await;
325
326 log::info!(
327 "Created wgpu compute server on device {:?} => {:?}",
328 device,
329 adapter.get_info()
330 );
331
332 WgpuSetup {
333 instance,
334 adapter,
335 device,
336 queue,
337 backend,
338 }
339}
340
341async fn request_adapter(
342 device: &WgpuDevice,
343 backend: wgpu::Backend,
344) -> (wgpu::Instance, wgpu::Adapter) {
345 let debug = ServerLogger::default();
346 let instance_flags = match (debug.profile_level(), debug.compilation_activated()) {
347 (Some(ProfileLevel::Full), _) => InstanceFlags::advanced_debugging(),
348 (_, true) => InstanceFlags::debugging(),
349 (_, false) => InstanceFlags::default(),
350 };
351 log::debug!("{instance_flags:?}");
352 let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
353 backends: backend.into(),
354 flags: instance_flags,
355 ..Default::default()
356 });
357
358 #[allow(deprecated)]
359 let override_device = if matches!(
360 device,
361 WgpuDevice::DefaultDevice | WgpuDevice::BestAvailable
362 ) {
363 get_device_override()
364 } else {
365 None
366 };
367
368 let device = override_device.unwrap_or_else(|| device.clone());
369
370 let adapter = match device {
371 #[cfg(not(target_family = "wasm"))]
372 WgpuDevice::DiscreteGpu(num) => select_from_adapter_list(
373 num,
374 "No Discrete GPU device found",
375 &instance,
376 &device,
377 backend,
378 ),
379 #[cfg(not(target_family = "wasm"))]
380 WgpuDevice::IntegratedGpu(num) => select_from_adapter_list(
381 num,
382 "No Integrated GPU device found",
383 &instance,
384 &device,
385 backend,
386 ),
387 #[cfg(not(target_family = "wasm"))]
388 WgpuDevice::VirtualGpu(num) => select_from_adapter_list(
389 num,
390 "No Virtual GPU device found",
391 &instance,
392 &device,
393 backend,
394 ),
395 #[cfg(not(target_family = "wasm"))]
396 WgpuDevice::Cpu => {
397 select_from_adapter_list(0, "No CPU device found", &instance, &device, backend)
398 }
399 WgpuDevice::Existing(_) => {
400 unreachable!("Cannot select an adapter for an existing device.")
401 }
402 _ => instance
403 .request_adapter(&RequestAdapterOptions {
404 power_preference: wgpu::PowerPreference::HighPerformance,
405 force_fallback_adapter: false,
406 compatible_surface: None,
407 })
408 .await
409 .expect("No possible adapter available for backend. Falling back to first available."),
410 };
411
412 log::info!("Using adapter {:?}", adapter.get_info());
413
414 (instance, adapter)
415}
416
417#[cfg(not(target_family = "wasm"))]
418fn select_from_adapter_list(
419 num: usize,
420 error: &str,
421 instance: &wgpu::Instance,
422 device: &WgpuDevice,
423 backend: wgpu::Backend,
424) -> wgpu::Adapter {
425 let mut adapters_other = Vec::new();
426 let mut adapters = Vec::new();
427
428 instance
429 .enumerate_adapters(backend.into())
430 .into_iter()
431 .for_each(|adapter| {
432 let device_type = adapter.get_info().device_type;
433
434 if let wgpu::DeviceType::Other = device_type {
435 adapters_other.push(adapter);
436 return;
437 }
438
439 let is_same_type = match device {
440 WgpuDevice::DiscreteGpu(_) => device_type == wgpu::DeviceType::DiscreteGpu,
441 WgpuDevice::IntegratedGpu(_) => device_type == wgpu::DeviceType::IntegratedGpu,
442 WgpuDevice::VirtualGpu(_) => device_type == wgpu::DeviceType::VirtualGpu,
443 WgpuDevice::Cpu => device_type == wgpu::DeviceType::Cpu,
444 #[allow(deprecated)]
445 WgpuDevice::DefaultDevice | WgpuDevice::BestAvailable => true,
446 WgpuDevice::Existing(_) => {
447 unreachable!("Cannot select an adapter for an existing device.")
448 }
449 };
450
451 if is_same_type {
452 adapters.push(adapter);
453 }
454 });
455
456 if adapters.len() <= num {
457 if adapters_other.len() <= num {
458 panic!(
459 "{}, adapters {:?}, other adapters {:?}",
460 error,
461 adapters
462 .into_iter()
463 .map(|adapter| adapter.get_info())
464 .collect::<Vec<_>>(),
465 adapters_other
466 .into_iter()
467 .map(|adapter| adapter.get_info())
468 .collect::<Vec<_>>(),
469 );
470 }
471
472 return adapters_other.remove(num);
473 }
474
475 adapters.remove(num)
476}
477
478fn get_device_override() -> Option<WgpuDevice> {
479 std::env::var("CUBECL_WGPU_DEFAULT_DEVICE")
482 .ok()
483 .and_then(|var| {
484 let override_device = if let Some(inner) = var.strip_prefix("DiscreteGpu(") {
485 inner
486 .strip_suffix(")")
487 .and_then(|s| s.parse().ok())
488 .map(WgpuDevice::DiscreteGpu)
489 } else if let Some(inner) = var.strip_prefix("IntegratedGpu(") {
490 inner
491 .strip_suffix(")")
492 .and_then(|s| s.parse().ok())
493 .map(WgpuDevice::IntegratedGpu)
494 } else if let Some(inner) = var.strip_prefix("VirtualGpu(") {
495 inner
496 .strip_suffix(")")
497 .and_then(|s| s.parse().ok())
498 .map(WgpuDevice::VirtualGpu)
499 } else if var == "Cpu" {
500 Some(WgpuDevice::Cpu)
501 } else {
502 None
503 };
504
505 if override_device.is_none() {
506 log::warn!("Unknown CUBECL_WGPU_DEVICE override {var}");
507 }
508 override_device
509 })
510}