1use crate::{
2 AutoCompiler, AutoGraphicsApi, GraphicsApi, WgpuDevice, backend, compute::WgpuServer,
3 contiguous_strides,
4};
5use cubecl_common::device::{Device, DeviceState};
6use cubecl_common::{future, profile::TimingMethod};
7use cubecl_core::server::ServerUtilities;
8use cubecl_core::{CubeCount, CubeDim, Runtime, ir::TargetProperties};
9pub use cubecl_runtime::memory_management::MemoryConfiguration;
10use cubecl_runtime::memory_management::MemoryDeviceProperties;
11use cubecl_runtime::{DeviceProperties, memory_management::HardwareProperties};
12use cubecl_runtime::{
13 client::ComputeClient,
14 logging::{ProfileLevel, ServerLogger},
15};
16use wgpu::{InstanceFlags, RequestAdapterOptions};
17
18#[derive(Debug)]
22pub struct WgpuRuntime;
23
24impl DeviceState for WgpuServer {
25 fn init(device_id: cubecl_common::device::DeviceId) -> Self {
26 let device = WgpuDevice::from_id(device_id);
27 let setup = future::block_on(create_setup_for_device(&device, AutoGraphicsApi::backend()));
28 create_server(setup, RuntimeOptions::default())
29 }
30}
31
32impl Runtime for WgpuRuntime {
33 type Compiler = AutoCompiler;
34 type Server = WgpuServer;
35 type Device = WgpuDevice;
36
37 fn client(device: &Self::Device) -> ComputeClient<Self::Server> {
38 ComputeClient::load(device)
39 }
40
41 fn name(client: &ComputeClient<Self::Server>) -> &'static str {
42 match client.info() {
43 wgpu::Backend::Vulkan => {
44 #[cfg(feature = "spirv")]
45 return "wgpu<spirv>";
46
47 #[cfg(not(feature = "spirv"))]
48 return "wgpu<wgsl>";
49 }
50 wgpu::Backend::Metal => {
51 #[cfg(feature = "msl")]
52 return "wgpu<msl>";
53
54 #[cfg(not(feature = "msl"))]
55 return "wgpu<wgsl>";
56 }
57 _ => "wgpu<wgsl>",
58 }
59 }
60
61 fn supported_line_sizes() -> &'static [u8] {
62 #[cfg(feature = "msl")]
63 {
64 &[8, 4, 2, 1]
65 }
66 #[cfg(not(feature = "msl"))]
67 {
68 &[4, 2, 1]
69 }
70 }
71
72 fn max_global_line_size() -> u8 {
73 4
74 }
75
76 fn max_cube_count() -> (u32, u32, u32) {
77 let max_dim = u16::MAX as u32;
78 (max_dim, max_dim, max_dim)
79 }
80
81 fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool {
82 if shape.is_empty() {
83 return true;
84 }
85
86 for (expected, &stride) in contiguous_strides(shape).into_iter().zip(strides) {
87 if expected != stride {
88 return false;
89 }
90 }
91
92 true
93 }
94
95 fn target_properties() -> TargetProperties {
96 TargetProperties {
97 mma: Default::default(),
99 }
100 }
101}
102
103pub struct RuntimeOptions {
105 pub tasks_max: usize,
107 pub memory_config: MemoryConfiguration,
109}
110
111impl Default for RuntimeOptions {
112 fn default() -> Self {
113 #[cfg(test)]
114 const DEFAULT_MAX_TASKS: usize = 1;
115 #[cfg(not(test))]
116 const DEFAULT_MAX_TASKS: usize = 32;
117
118 let tasks_max = match std::env::var("CUBECL_WGPU_MAX_TASKS") {
119 Ok(value) => value
120 .parse::<usize>()
121 .expect("CUBECL_WGPU_MAX_TASKS should be a positive integer."),
122 Err(_) => DEFAULT_MAX_TASKS,
123 };
124
125 Self {
126 tasks_max,
127 memory_config: MemoryConfiguration::default(),
128 }
129 }
130}
131
132#[derive(Clone, Debug)]
136pub struct WgpuSetup {
137 pub instance: wgpu::Instance,
139 pub adapter: wgpu::Adapter,
141 pub device: wgpu::Device,
143 pub queue: wgpu::Queue,
145 pub backend: wgpu::Backend,
147}
148
149pub fn init_device(setup: WgpuSetup, options: RuntimeOptions) -> WgpuDevice {
159 use core::sync::atomic::{AtomicU32, Ordering};
160
161 static COUNTER: AtomicU32 = AtomicU32::new(0);
162
163 let device_id = COUNTER.fetch_add(1, Ordering::Relaxed);
164 if device_id == u32::MAX {
165 core::panic!("Memory ID overflowed");
166 }
167
168 let device_id = WgpuDevice::Existing(device_id);
169 let server = create_server(setup, options);
170 let _ = ComputeClient::init(&device_id, server);
171 device_id
172}
173
174pub fn init_setup<G: GraphicsApi>(device: &WgpuDevice, options: RuntimeOptions) -> WgpuSetup {
177 cfg_if::cfg_if! {
178 if #[cfg(target_family = "wasm")] {
179 let _ = (device, options);
180 panic!("Creating a wgpu setup synchronously is unsupported on wasm. Use init_async instead");
181 } else {
182 future::block_on(init_setup_async::<G>(device, options))
183 }
184 }
185}
186
187pub async fn init_setup_async<G: GraphicsApi>(
191 device: &WgpuDevice,
192 options: RuntimeOptions,
193) -> WgpuSetup {
194 let setup = create_setup_for_device(device, G::backend()).await;
195 let return_setup = setup.clone();
196 let server = create_server(setup, options);
197 let _ = ComputeClient::init(device, server);
198 return_setup
199}
200
201pub(crate) fn create_server(setup: WgpuSetup, options: RuntimeOptions) -> WgpuServer {
202 let limits = setup.device.limits();
203 let mut adapter_limits = setup.adapter.limits();
204
205 if adapter_limits.min_subgroup_size == 0 && adapter_limits.max_subgroup_size == 0 {
209 adapter_limits.min_subgroup_size = 8;
211 adapter_limits.max_subgroup_size = 128;
213 }
214
215 let mem_props = MemoryDeviceProperties {
216 max_page_size: limits.max_storage_buffer_binding_size as u64,
217 alignment: limits.min_storage_buffer_offset_alignment as u64,
218 };
219 let max_count = adapter_limits.max_compute_workgroups_per_dimension;
220 let hardware_props = HardwareProperties {
221 #[cfg(apple_silicon)]
225 plane_size_min: 32,
226 #[cfg(not(apple_silicon))]
227 plane_size_min: adapter_limits.min_subgroup_size,
228 #[cfg(apple_silicon)]
229 plane_size_max: 32,
230 #[cfg(not(apple_silicon))]
231 plane_size_max: adapter_limits.max_subgroup_size,
232 max_bindings: limits
236 .max_storage_buffers_per_shader_stage
237 .saturating_sub(1),
238 max_shared_memory_size: limits.max_compute_workgroup_storage_size as usize,
239 max_cube_count: CubeCount::new_3d(max_count, max_count, max_count),
240 max_units_per_cube: adapter_limits.max_compute_invocations_per_workgroup,
241 max_cube_dim: CubeDim::new_3d(
242 adapter_limits.max_compute_workgroup_size_x,
243 adapter_limits.max_compute_workgroup_size_y,
244 adapter_limits.max_compute_workgroup_size_z,
245 ),
246 num_streaming_multiprocessors: None,
247 num_tensor_cores: None,
248 min_tensor_cores_dim: None,
249 };
250
251 let mut compilation_options = Default::default();
252
253 let features = setup.adapter.features();
254
255 let time_measurement = if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
256 TimingMethod::Device
257 } else {
258 TimingMethod::System
259 };
260
261 let mut device_props = DeviceProperties::new(
262 Default::default(),
263 mem_props.clone(),
264 hardware_props,
265 time_measurement,
266 );
267
268 #[cfg(not(all(target_os = "macos", feature = "msl")))]
269 {
270 if features.contains(wgpu::Features::SUBGROUP)
271 && setup.adapter.get_info().device_type != wgpu::DeviceType::Cpu
272 {
273 use cubecl_runtime::Plane;
274
275 device_props.features.plane.insert(Plane::Ops);
276 }
277 }
278
279 backend::register_features(&setup.adapter, &mut device_props, &mut compilation_options);
280
281 let logger = alloc::sync::Arc::new(ServerLogger::default());
282
283 WgpuServer::new(
284 mem_props,
285 options.memory_config,
286 compilation_options,
287 setup.device.clone(),
288 setup.queue,
289 options.tasks_max,
290 setup.backend,
291 time_measurement,
292 ServerUtilities::new(device_props, logger, setup.backend),
293 )
294}
295
296pub(crate) async fn create_setup_for_device(
299 device: &WgpuDevice,
300 backend: wgpu::Backend,
301) -> WgpuSetup {
302 let (instance, adapter) = request_adapter(device, backend).await;
303 let (device, queue) = backend::request_device(&adapter).await;
304
305 log::info!(
306 "Created wgpu compute server on device {:?} => {:?}",
307 device,
308 adapter.get_info()
309 );
310
311 WgpuSetup {
312 instance,
313 adapter,
314 device,
315 queue,
316 backend,
317 }
318}
319
320async fn request_adapter(
321 device: &WgpuDevice,
322 backend: wgpu::Backend,
323) -> (wgpu::Instance, wgpu::Adapter) {
324 let debug = ServerLogger::default();
325 let instance_flags = match (debug.profile_level(), debug.compilation_activated()) {
326 (Some(ProfileLevel::Full), _) => InstanceFlags::advanced_debugging(),
327 (_, true) => InstanceFlags::debugging(),
328 (_, false) => InstanceFlags::default(),
329 };
330 log::debug!("{instance_flags:?}");
331 let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
332 backends: backend.into(),
333 flags: instance_flags,
334 ..Default::default()
335 });
336
337 #[allow(deprecated)]
338 let override_device = if matches!(
339 device,
340 WgpuDevice::DefaultDevice | WgpuDevice::BestAvailable
341 ) {
342 get_device_override()
343 } else {
344 None
345 };
346
347 let device = override_device.unwrap_or_else(|| device.clone());
348
349 let adapter = match device {
350 #[cfg(not(target_family = "wasm"))]
351 WgpuDevice::DiscreteGpu(num) => select_from_adapter_list(
352 num,
353 "No Discrete GPU device found",
354 &instance,
355 &device,
356 backend,
357 ),
358 #[cfg(not(target_family = "wasm"))]
359 WgpuDevice::IntegratedGpu(num) => select_from_adapter_list(
360 num,
361 "No Integrated GPU device found",
362 &instance,
363 &device,
364 backend,
365 ),
366 #[cfg(not(target_family = "wasm"))]
367 WgpuDevice::VirtualGpu(num) => select_from_adapter_list(
368 num,
369 "No Virtual GPU device found",
370 &instance,
371 &device,
372 backend,
373 ),
374 #[cfg(not(target_family = "wasm"))]
375 WgpuDevice::Cpu => {
376 select_from_adapter_list(0, "No CPU device found", &instance, &device, backend)
377 }
378 WgpuDevice::Existing(_) => {
379 unreachable!("Cannot select an adapter for an existing device.")
380 }
381 _ => instance
382 .request_adapter(&RequestAdapterOptions {
383 power_preference: wgpu::PowerPreference::HighPerformance,
384 force_fallback_adapter: false,
385 compatible_surface: None,
386 })
387 .await
388 .expect("No possible adapter available for backend. Falling back to first available."),
389 };
390
391 log::info!("Using adapter {:?}", adapter.get_info());
392
393 (instance, adapter)
394}
395
396#[cfg(not(target_family = "wasm"))]
397fn select_from_adapter_list(
398 num: usize,
399 error: &str,
400 instance: &wgpu::Instance,
401 device: &WgpuDevice,
402 backend: wgpu::Backend,
403) -> wgpu::Adapter {
404 let mut adapters_other = Vec::new();
405 let mut adapters = Vec::new();
406
407 instance
408 .enumerate_adapters(backend.into())
409 .into_iter()
410 .for_each(|adapter| {
411 let device_type = adapter.get_info().device_type;
412
413 if let wgpu::DeviceType::Other = device_type {
414 adapters_other.push(adapter);
415 return;
416 }
417
418 let is_same_type = match device {
419 WgpuDevice::DiscreteGpu(_) => device_type == wgpu::DeviceType::DiscreteGpu,
420 WgpuDevice::IntegratedGpu(_) => device_type == wgpu::DeviceType::IntegratedGpu,
421 WgpuDevice::VirtualGpu(_) => device_type == wgpu::DeviceType::VirtualGpu,
422 WgpuDevice::Cpu => device_type == wgpu::DeviceType::Cpu,
423 #[allow(deprecated)]
424 WgpuDevice::DefaultDevice | WgpuDevice::BestAvailable => true,
425 WgpuDevice::Existing(_) => {
426 unreachable!("Cannot select an adapter for an existing device.")
427 }
428 };
429
430 if is_same_type {
431 adapters.push(adapter);
432 }
433 });
434
435 if adapters.len() <= num {
436 if adapters_other.len() <= num {
437 panic!(
438 "{}, adapters {:?}, other adapters {:?}",
439 error,
440 adapters
441 .into_iter()
442 .map(|adapter| adapter.get_info())
443 .collect::<Vec<_>>(),
444 adapters_other
445 .into_iter()
446 .map(|adapter| adapter.get_info())
447 .collect::<Vec<_>>(),
448 );
449 }
450
451 return adapters_other.remove(num);
452 }
453
454 adapters.remove(num)
455}
456
457fn get_device_override() -> Option<WgpuDevice> {
458 std::env::var("CUBECL_WGPU_DEFAULT_DEVICE")
461 .ok()
462 .and_then(|var| {
463 let override_device = if let Some(inner) = var.strip_prefix("DiscreteGpu(") {
464 inner
465 .strip_suffix(")")
466 .and_then(|s| s.parse().ok())
467 .map(WgpuDevice::DiscreteGpu)
468 } else if let Some(inner) = var.strip_prefix("IntegratedGpu(") {
469 inner
470 .strip_suffix(")")
471 .and_then(|s| s.parse().ok())
472 .map(WgpuDevice::IntegratedGpu)
473 } else if let Some(inner) = var.strip_prefix("VirtualGpu(") {
474 inner
475 .strip_suffix(")")
476 .and_then(|s| s.parse().ok())
477 .map(WgpuDevice::VirtualGpu)
478 } else if var == "Cpu" {
479 Some(WgpuDevice::Cpu)
480 } else {
481 None
482 };
483
484 if override_device.is_none() {
485 log::warn!("Unknown CUBECL_WGPU_DEVICE override {var}");
486 }
487 override_device
488 })
489}