1use crate::{
2 AutoCompiler, AutoGraphicsApi, GraphicsApi, WgpuDevice, backend, compute::WgpuServer,
3 contiguous_strides,
4};
5use cubecl_common::device::{Device, DeviceState};
6use cubecl_common::{future, profile::TimingMethod};
7use cubecl_core::server::ServerUtilities;
8use cubecl_core::{CubeCount, CubeDim, Runtime, ir::TargetProperties};
9pub use cubecl_runtime::memory_management::MemoryConfiguration;
10use cubecl_runtime::memory_management::MemoryDeviceProperties;
11use cubecl_runtime::{DeviceProperties, memory_management::HardwareProperties};
12use cubecl_runtime::{
13 client::ComputeClient,
14 logging::{ProfileLevel, ServerLogger},
15};
16use wgpu::{InstanceFlags, RequestAdapterOptions};
17
18#[derive(Debug)]
22pub struct WgpuRuntime;
23
24impl DeviceState for WgpuServer {
25 fn init(device_id: cubecl_common::device::DeviceId) -> Self {
26 let device = WgpuDevice::from_id(device_id);
27 let setup = future::block_on(create_setup_for_device(&device, AutoGraphicsApi::backend()));
28 create_server(setup, RuntimeOptions::default())
29 }
30}
31
32impl Runtime for WgpuRuntime {
33 type Compiler = AutoCompiler;
34 type Server = WgpuServer;
35 type Device = WgpuDevice;
36
37 fn client(device: &Self::Device) -> ComputeClient<Self> {
38 ComputeClient::load(device)
39 }
40
41 fn name(client: &ComputeClient<Self>) -> &'static str {
42 match client.info() {
43 wgpu::Backend::Vulkan => {
44 #[cfg(feature = "spirv")]
45 return "wgpu<spirv>";
46
47 #[cfg(not(feature = "spirv"))]
48 return "wgpu<wgsl>";
49 }
50 wgpu::Backend::Metal => {
51 #[cfg(feature = "msl")]
52 return "wgpu<msl>";
53
54 #[cfg(not(feature = "msl"))]
55 return "wgpu<wgsl>";
56 }
57 _ => "wgpu<wgsl>",
58 }
59 }
60
61 fn supported_line_sizes() -> &'static [u8] {
62 #[cfg(feature = "msl")]
63 {
64 &[8, 4, 2, 1]
65 }
66 #[cfg(not(feature = "msl"))]
67 {
68 &[4, 2, 1]
69 }
70 }
71
72 fn max_global_line_size() -> u8 {
73 4
74 }
75
76 fn max_cube_count() -> (u32, u32, u32) {
77 let max_dim = u16::MAX as u32;
78 (max_dim, max_dim, max_dim)
79 }
80
81 fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool {
82 if shape.is_empty() {
83 return true;
84 }
85
86 for (expected, &stride) in contiguous_strides(shape).into_iter().zip(strides) {
87 if expected != stride {
88 return false;
89 }
90 }
91
92 true
93 }
94
95 fn target_properties() -> TargetProperties {
96 TargetProperties {
97 mma: Default::default(),
99 }
100 }
101}
102
103pub struct RuntimeOptions {
105 pub tasks_max: usize,
107 pub memory_config: MemoryConfiguration,
109}
110
111impl Default for RuntimeOptions {
112 fn default() -> Self {
113 #[cfg(test)]
114 const DEFAULT_MAX_TASKS: usize = 1;
115 #[cfg(not(test))]
116 const DEFAULT_MAX_TASKS: usize = 32;
117
118 let tasks_max = match std::env::var("CUBECL_WGPU_MAX_TASKS") {
119 Ok(value) => value
120 .parse::<usize>()
121 .expect("CUBECL_WGPU_MAX_TASKS should be a positive integer."),
122 Err(_) => DEFAULT_MAX_TASKS,
123 };
124
125 Self {
126 tasks_max,
127 memory_config: MemoryConfiguration::default(),
128 }
129 }
130}
131
132#[derive(Clone, Debug)]
136pub struct WgpuSetup {
137 pub instance: wgpu::Instance,
139 pub adapter: wgpu::Adapter,
141 pub device: wgpu::Device,
143 pub queue: wgpu::Queue,
145 pub backend: wgpu::Backend,
147}
148
149pub fn init_device(setup: WgpuSetup, options: RuntimeOptions) -> WgpuDevice {
159 use core::sync::atomic::{AtomicU32, Ordering};
160
161 static COUNTER: AtomicU32 = AtomicU32::new(0);
162
163 let device_id = COUNTER.fetch_add(1, Ordering::Relaxed);
164 if device_id == u32::MAX {
165 core::panic!("Memory ID overflowed");
166 }
167
168 let device_id = WgpuDevice::Existing(device_id);
169 let server = create_server(setup, options);
170 let _ = ComputeClient::<WgpuRuntime>::init(&device_id, server);
171 device_id
172}
173
174pub fn init_setup<G: GraphicsApi>(device: &WgpuDevice, options: RuntimeOptions) -> WgpuSetup {
177 cfg_if::cfg_if! {
178 if #[cfg(target_family = "wasm")] {
179 let _ = (device, options);
180 panic!("Creating a wgpu setup synchronously is unsupported on wasm. Use init_async instead");
181 } else {
182 future::block_on(init_setup_async::<G>(device, options))
183 }
184 }
185}
186
187pub async fn init_setup_async<G: GraphicsApi>(
191 device: &WgpuDevice,
192 options: RuntimeOptions,
193) -> WgpuSetup {
194 let setup = create_setup_for_device(device, G::backend()).await;
195 let return_setup = setup.clone();
196 let server = create_server(setup, options);
197 let _ = ComputeClient::<WgpuRuntime>::init(device, server);
198 return_setup
199}
200
201pub(crate) fn create_server(setup: WgpuSetup, options: RuntimeOptions) -> WgpuServer {
202 let limits = setup.device.limits();
203 let mut adapter_limits = setup.adapter.limits();
204
205 if adapter_limits.min_subgroup_size == 0 && adapter_limits.max_subgroup_size == 0 {
209 adapter_limits.min_subgroup_size = 8;
211 adapter_limits.max_subgroup_size = 128;
213 }
214
215 let mem_props = MemoryDeviceProperties {
216 max_page_size: limits.max_storage_buffer_binding_size as u64,
217 alignment: limits.min_storage_buffer_offset_alignment as u64,
218 };
219 let max_count = adapter_limits.max_compute_workgroups_per_dimension;
220 let hardware_props = HardwareProperties {
221 load_width: 128,
222 #[cfg(apple_silicon)]
226 plane_size_min: 32,
227 #[cfg(not(apple_silicon))]
228 plane_size_min: adapter_limits.min_subgroup_size,
229 #[cfg(apple_silicon)]
230 plane_size_max: 32,
231 #[cfg(not(apple_silicon))]
232 plane_size_max: adapter_limits.max_subgroup_size,
233 max_bindings: limits
237 .max_storage_buffers_per_shader_stage
238 .saturating_sub(1),
239 max_shared_memory_size: limits.max_compute_workgroup_storage_size as usize,
240 max_cube_count: CubeCount::new_3d(max_count, max_count, max_count),
241 max_units_per_cube: adapter_limits.max_compute_invocations_per_workgroup,
242 max_cube_dim: CubeDim::new_3d(
243 adapter_limits.max_compute_workgroup_size_x,
244 adapter_limits.max_compute_workgroup_size_y,
245 adapter_limits.max_compute_workgroup_size_z,
246 ),
247 num_streaming_multiprocessors: None,
248 num_tensor_cores: None,
249 min_tensor_cores_dim: None,
250 };
251
252 let mut compilation_options = Default::default();
253
254 let features = setup.adapter.features();
255
256 let time_measurement = if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
257 TimingMethod::Device
258 } else {
259 TimingMethod::System
260 };
261
262 let mut device_props = DeviceProperties::new(
263 Default::default(),
264 mem_props.clone(),
265 hardware_props,
266 time_measurement,
267 );
268
269 #[cfg(not(all(target_os = "macos", feature = "msl")))]
270 {
271 if features.contains(wgpu::Features::SUBGROUP)
272 && setup.adapter.get_info().device_type != wgpu::DeviceType::Cpu
273 {
274 use cubecl_runtime::Plane;
275
276 device_props.features.plane.insert(Plane::Ops);
277 }
278 }
279
280 backend::register_features(&setup.adapter, &mut device_props, &mut compilation_options);
281
282 let logger = alloc::sync::Arc::new(ServerLogger::default());
283
284 WgpuServer::new(
285 mem_props,
286 options.memory_config,
287 compilation_options,
288 setup.device.clone(),
289 setup.queue,
290 options.tasks_max,
291 setup.backend,
292 time_measurement,
293 ServerUtilities::new(device_props, logger, setup.backend),
294 )
295}
296
297pub(crate) async fn create_setup_for_device(
300 device: &WgpuDevice,
301 backend: wgpu::Backend,
302) -> WgpuSetup {
303 let (instance, adapter) = request_adapter(device, backend).await;
304 let (device, queue) = backend::request_device(&adapter).await;
305
306 log::info!(
307 "Created wgpu compute server on device {:?} => {:?}",
308 device,
309 adapter.get_info()
310 );
311
312 WgpuSetup {
313 instance,
314 adapter,
315 device,
316 queue,
317 backend,
318 }
319}
320
321async fn request_adapter(
322 device: &WgpuDevice,
323 backend: wgpu::Backend,
324) -> (wgpu::Instance, wgpu::Adapter) {
325 let debug = ServerLogger::default();
326 let instance_flags = match (debug.profile_level(), debug.compilation_activated()) {
327 (Some(ProfileLevel::Full), _) => InstanceFlags::advanced_debugging(),
328 (_, true) => InstanceFlags::debugging(),
329 (_, false) => InstanceFlags::default(),
330 };
331 log::debug!("{instance_flags:?}");
332 let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
333 backends: backend.into(),
334 flags: instance_flags,
335 ..Default::default()
336 });
337
338 #[allow(deprecated)]
339 let override_device = if matches!(
340 device,
341 WgpuDevice::DefaultDevice | WgpuDevice::BestAvailable
342 ) {
343 get_device_override()
344 } else {
345 None
346 };
347
348 let device = override_device.unwrap_or_else(|| device.clone());
349
350 let adapter = match device {
351 #[cfg(not(target_family = "wasm"))]
352 WgpuDevice::DiscreteGpu(num) => select_from_adapter_list(
353 num,
354 "No Discrete GPU device found",
355 &instance,
356 &device,
357 backend,
358 ),
359 #[cfg(not(target_family = "wasm"))]
360 WgpuDevice::IntegratedGpu(num) => select_from_adapter_list(
361 num,
362 "No Integrated GPU device found",
363 &instance,
364 &device,
365 backend,
366 ),
367 #[cfg(not(target_family = "wasm"))]
368 WgpuDevice::VirtualGpu(num) => select_from_adapter_list(
369 num,
370 "No Virtual GPU device found",
371 &instance,
372 &device,
373 backend,
374 ),
375 #[cfg(not(target_family = "wasm"))]
376 WgpuDevice::Cpu => {
377 select_from_adapter_list(0, "No CPU device found", &instance, &device, backend)
378 }
379 WgpuDevice::Existing(_) => {
380 unreachable!("Cannot select an adapter for an existing device.")
381 }
382 _ => instance
383 .request_adapter(&RequestAdapterOptions {
384 power_preference: wgpu::PowerPreference::HighPerformance,
385 force_fallback_adapter: false,
386 compatible_surface: None,
387 })
388 .await
389 .expect("No possible adapter available for backend. Falling back to first available."),
390 };
391
392 log::info!("Using adapter {:?}", adapter.get_info());
393
394 (instance, adapter)
395}
396
397#[cfg(not(target_family = "wasm"))]
398fn select_from_adapter_list(
399 num: usize,
400 error: &str,
401 instance: &wgpu::Instance,
402 device: &WgpuDevice,
403 backend: wgpu::Backend,
404) -> wgpu::Adapter {
405 let mut adapters_other = Vec::new();
406 let mut adapters = Vec::new();
407
408 instance
409 .enumerate_adapters(backend.into())
410 .into_iter()
411 .for_each(|adapter| {
412 let device_type = adapter.get_info().device_type;
413
414 if let wgpu::DeviceType::Other = device_type {
415 adapters_other.push(adapter);
416 return;
417 }
418
419 let is_same_type = match device {
420 WgpuDevice::DiscreteGpu(_) => device_type == wgpu::DeviceType::DiscreteGpu,
421 WgpuDevice::IntegratedGpu(_) => device_type == wgpu::DeviceType::IntegratedGpu,
422 WgpuDevice::VirtualGpu(_) => device_type == wgpu::DeviceType::VirtualGpu,
423 WgpuDevice::Cpu => device_type == wgpu::DeviceType::Cpu,
424 #[allow(deprecated)]
425 WgpuDevice::DefaultDevice | WgpuDevice::BestAvailable => true,
426 WgpuDevice::Existing(_) => {
427 unreachable!("Cannot select an adapter for an existing device.")
428 }
429 };
430
431 if is_same_type {
432 adapters.push(adapter);
433 }
434 });
435
436 if adapters.len() <= num {
437 if adapters_other.len() <= num {
438 panic!(
439 "{}, adapters {:?}, other adapters {:?}",
440 error,
441 adapters
442 .into_iter()
443 .map(|adapter| adapter.get_info())
444 .collect::<Vec<_>>(),
445 adapters_other
446 .into_iter()
447 .map(|adapter| adapter.get_info())
448 .collect::<Vec<_>>(),
449 );
450 }
451
452 return adapters_other.remove(num);
453 }
454
455 adapters.remove(num)
456}
457
458fn get_device_override() -> Option<WgpuDevice> {
459 std::env::var("CUBECL_WGPU_DEFAULT_DEVICE")
462 .ok()
463 .and_then(|var| {
464 let override_device = if let Some(inner) = var.strip_prefix("DiscreteGpu(") {
465 inner
466 .strip_suffix(")")
467 .and_then(|s| s.parse().ok())
468 .map(WgpuDevice::DiscreteGpu)
469 } else if let Some(inner) = var.strip_prefix("IntegratedGpu(") {
470 inner
471 .strip_suffix(")")
472 .and_then(|s| s.parse().ok())
473 .map(WgpuDevice::IntegratedGpu)
474 } else if let Some(inner) = var.strip_prefix("VirtualGpu(") {
475 inner
476 .strip_suffix(")")
477 .and_then(|s| s.parse().ok())
478 .map(WgpuDevice::VirtualGpu)
479 } else if var == "Cpu" {
480 Some(WgpuDevice::Cpu)
481 } else {
482 None
483 };
484
485 if override_device.is_none() {
486 log::warn!("Unknown CUBECL_WGPU_DEVICE override {var}");
487 }
488 override_device
489 })
490}