1use std::marker::PhantomData;
2
3use crate::{
4 compiler::{base::WgpuCompiler, wgsl::WgslCompiler},
5 compute::{WgpuServer, WgpuStorage},
6 AutoGraphicsApi, GraphicsApi, WgpuDevice,
7};
8use alloc::sync::Arc;
9use cubecl_common::future;
10use cubecl_core::{Feature, Runtime};
11pub use cubecl_runtime::memory_management::MemoryConfiguration;
12use cubecl_runtime::{
13 channel::MutexComputeChannel,
14 client::ComputeClient,
15 debug::{DebugLogger, ProfileLevel},
16 ComputeRuntime,
17};
18use cubecl_runtime::{memory_management::HardwareProperties, DeviceProperties};
19use cubecl_runtime::{
20 memory_management::{MemoryDeviceProperties, MemoryManagement},
21 storage::ComputeStorage,
22};
23use wgpu::{InstanceFlags, RequestAdapterOptions};
24
25#[derive(Debug)]
29pub struct WgpuRuntime<C: WgpuCompiler = WgslCompiler>(PhantomData<C>);
30
31type Server = WgpuServer<WgslCompiler>;
32
33static RUNTIME: ComputeRuntime<WgpuDevice, Server, MutexComputeChannel<Server>> =
35 ComputeRuntime::new();
36
37impl Runtime for WgpuRuntime<WgslCompiler> {
38 type Compiler = WgslCompiler;
39 type Server = WgpuServer<WgslCompiler>;
40
41 type Channel = MutexComputeChannel<WgpuServer<WgslCompiler>>;
42 type Device = WgpuDevice;
43
44 fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel> {
45 RUNTIME.client(device, move || {
46 let setup = future::block_on(create_setup_for_device::<AutoGraphicsApi, WgslCompiler>(
47 device,
48 ));
49 create_client_on_setup(setup, RuntimeOptions::default())
50 })
51 }
52
53 fn name() -> &'static str {
54 "wgpu<wgsl>"
55 }
56
57 fn supported_line_sizes() -> &'static [u8] {
58 &[4, 2, 1]
59 }
60
61 fn max_cube_count() -> (u32, u32, u32) {
62 let max_dim = u16::MAX as u32;
63 (max_dim, max_dim, max_dim)
64 }
65
66 fn extension() -> &'static str {
67 "wgsl"
68 }
69}
70
71pub struct RuntimeOptions {
73 pub tasks_max: usize,
75 pub memory_config: MemoryConfiguration,
77}
78
79impl Default for RuntimeOptions {
80 fn default() -> Self {
81 #[cfg(test)]
82 const DEFAULT_MAX_TASKS: usize = 1;
83 #[cfg(not(test))]
84 const DEFAULT_MAX_TASKS: usize = 32;
85
86 let tasks_max = match std::env::var("CUBECL_WGPU_MAX_TASKS") {
87 Ok(value) => value
88 .parse::<usize>()
89 .expect("CUBECL_WGPU_MAX_TASKS should be a positive integer."),
90 Err(_) => DEFAULT_MAX_TASKS,
91 };
92
93 Self {
94 tasks_max,
95 memory_config: MemoryConfiguration::default(),
96 }
97 }
98}
99
100#[derive(Clone, Debug)]
104pub struct WgpuSetup {
105 pub instance: Arc<wgpu::Instance>,
107 pub adapter: Arc<wgpu::Adapter>,
109 pub device: Arc<wgpu::Device>,
111 pub queue: Arc<wgpu::Queue>,
113}
114
115pub fn init_device(setup: WgpuSetup, options: RuntimeOptions) -> WgpuDevice {
125 use core::sync::atomic::{AtomicU32, Ordering};
126
127 static COUNTER: AtomicU32 = AtomicU32::new(0);
128
129 let device_id = COUNTER.fetch_add(1, Ordering::Relaxed);
130 if device_id == u32::MAX {
131 core::panic!("Memory ID overflowed");
132 }
133
134 let device_id = WgpuDevice::Existing(device_id);
135 let client = create_client_on_setup(setup, options);
136 RUNTIME.register(&device_id, client);
137 device_id
138}
139
140pub fn init_setup<G: GraphicsApi>(device: &WgpuDevice, options: RuntimeOptions) -> WgpuSetup {
143 cfg_if::cfg_if! {
144 if #[cfg(target_family = "wasm")] {
145 let _ = (device, options);
146 panic!("Creating a wgpu setup synchronously is unsupported on wasm. Use init_async instead");
147 } else {
148 future::block_on(init_setup_async::<G>(device, options))
149 }
150 }
151}
152
153pub async fn init_setup_async<G: GraphicsApi>(
157 device: &WgpuDevice,
158 options: RuntimeOptions,
159) -> WgpuSetup {
160 let setup = create_setup_for_device::<G, WgslCompiler>(device).await;
161 let return_setup = setup.clone();
162 let client = create_client_on_setup(setup, options);
163 RUNTIME.register(device, client);
164 return_setup
165}
166
167pub(crate) fn create_client_on_setup<C: WgpuCompiler>(
168 setup: WgpuSetup,
169 options: RuntimeOptions,
170) -> ComputeClient<WgpuServer<C>, MutexComputeChannel<WgpuServer<C>>> {
171 let limits = setup.device.limits();
172 let adapter_limits = setup.adapter.limits();
173
174 let mem_props = MemoryDeviceProperties {
175 max_page_size: limits.max_storage_buffer_binding_size as u64,
176 alignment: WgpuStorage::ALIGNMENT.max(limits.min_storage_buffer_offset_alignment as u64),
177 };
178 let hardware_props = HardwareProperties {
179 plane_size_min: adapter_limits.min_subgroup_size,
180 plane_size_max: adapter_limits.max_subgroup_size,
181 max_bindings: limits.max_storage_buffers_per_shader_stage,
182 };
183
184 let memory_management = {
185 let device = setup.device.clone();
186 let mem_props = mem_props.clone();
187 let config = options.memory_config;
188 let storage = WgpuStorage::new(device.clone());
189 MemoryManagement::from_configuration(storage, mem_props, config)
190 };
191 let compilation_options = Default::default();
192 let server = WgpuServer::new(
193 memory_management,
194 compilation_options,
195 setup.device.clone(),
196 setup.queue,
197 options.tasks_max,
198 );
199 let channel = MutexComputeChannel::new(server);
200
201 let features = setup.adapter.features();
202 let mut device_props = DeviceProperties::new(&[], mem_props, hardware_props);
203
204 let fake_plane_info =
208 adapter_limits.min_subgroup_size == 0 && adapter_limits.max_subgroup_size == 0;
209
210 if features.contains(wgpu::Features::SUBGROUP)
211 && setup.adapter.get_info().device_type != wgpu::DeviceType::Cpu
212 && !fake_plane_info
213 {
214 device_props.register_feature(Feature::Plane);
215 }
216 C::register_features(&setup.adapter, &setup.device, &mut device_props);
217 ComputeClient::new(channel, device_props)
218}
219
220pub(crate) async fn create_setup_for_device<G: GraphicsApi, C: WgpuCompiler>(
222 device: &WgpuDevice,
223) -> WgpuSetup {
224 let (instance, adapter) = request_adapter::<G>(device).await;
225 let (device, queue) = C::request_device(&adapter).await;
226
227 log::info!(
228 "Created wgpu compute server on device {:?} => {:?}",
229 device,
230 adapter.get_info()
231 );
232
233 WgpuSetup {
234 instance: Arc::new(instance),
235 adapter: Arc::new(adapter),
236 device: Arc::new(device),
237 queue: Arc::new(queue),
238 }
239}
240
241async fn request_adapter<G: GraphicsApi>(device: &WgpuDevice) -> (wgpu::Instance, wgpu::Adapter) {
242 let debug = DebugLogger::default();
243 let instance_flags = match (debug.profile_level(), debug.is_activated()) {
244 (Some(ProfileLevel::Full), _) => InstanceFlags::advanced_debugging(),
245 (_, true) => InstanceFlags::debugging(),
246 (_, false) => InstanceFlags::default(),
247 };
248 log::debug!("{instance_flags:?}");
249 let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
250 backends: G::backend().into(),
251 flags: instance_flags,
252 ..Default::default()
253 });
254
255 #[allow(deprecated)]
256 let override_device = if matches!(
257 device,
258 WgpuDevice::DefaultDevice | WgpuDevice::BestAvailable
259 ) {
260 get_device_override()
261 } else {
262 None
263 };
264
265 let device = override_device.unwrap_or_else(|| device.clone());
266
267 let adapter = match device {
268 #[cfg(not(target_family = "wasm"))]
269 WgpuDevice::DiscreteGpu(num) => {
270 select_from_adapter_list::<G>(num, "No Discrete GPU device found", &instance, &device)
271 }
272 #[cfg(not(target_family = "wasm"))]
273 WgpuDevice::IntegratedGpu(num) => {
274 select_from_adapter_list::<G>(num, "No Integrated GPU device found", &instance, &device)
275 }
276 #[cfg(not(target_family = "wasm"))]
277 WgpuDevice::VirtualGpu(num) => {
278 select_from_adapter_list::<G>(num, "No Virtual GPU device found", &instance, &device)
279 }
280 #[cfg(not(target_family = "wasm"))]
281 WgpuDevice::Cpu => {
282 select_from_adapter_list::<G>(0, "No CPU device found", &instance, &device)
283 }
284 WgpuDevice::Existing(_) => {
285 unreachable!("Cannot select an adapter for an existing device.")
286 }
287 _ => instance
288 .request_adapter(&RequestAdapterOptions {
289 power_preference: wgpu::PowerPreference::HighPerformance,
290 force_fallback_adapter: false,
291 compatible_surface: None,
292 })
293 .await
294 .expect("No possible adapter available for backend. Falling back to first available."),
295 };
296
297 log::info!("Using adapter {:?}", adapter.get_info());
298
299 (instance, adapter)
300}
301
302#[cfg(not(target_family = "wasm"))]
303fn select_from_adapter_list<G: GraphicsApi>(
304 num: usize,
305 error: &str,
306 instance: &wgpu::Instance,
307 device: &WgpuDevice,
308) -> wgpu::Adapter {
309 let mut adapters_other = Vec::new();
310 let mut adapters = Vec::new();
311
312 instance
313 .enumerate_adapters(G::backend().into())
314 .into_iter()
315 .for_each(|adapter| {
316 let device_type = adapter.get_info().device_type;
317
318 if let wgpu::DeviceType::Other = device_type {
319 adapters_other.push(adapter);
320 return;
321 }
322
323 let is_same_type = match device {
324 WgpuDevice::DiscreteGpu(_) => device_type == wgpu::DeviceType::DiscreteGpu,
325 WgpuDevice::IntegratedGpu(_) => device_type == wgpu::DeviceType::IntegratedGpu,
326 WgpuDevice::VirtualGpu(_) => device_type == wgpu::DeviceType::VirtualGpu,
327 WgpuDevice::Cpu => device_type == wgpu::DeviceType::Cpu,
328 #[allow(deprecated)]
329 WgpuDevice::DefaultDevice | WgpuDevice::BestAvailable => true,
330 WgpuDevice::Existing(_) => {
331 unreachable!("Cannot select an adapter for an existing device.")
332 }
333 };
334
335 if is_same_type {
336 adapters.push(adapter);
337 }
338 });
339
340 if adapters.len() <= num {
341 if adapters_other.len() <= num {
342 panic!(
343 "{}, adapters {:?}, other adapters {:?}",
344 error,
345 adapters
346 .into_iter()
347 .map(|adapter| adapter.get_info())
348 .collect::<Vec<_>>(),
349 adapters_other
350 .into_iter()
351 .map(|adapter| adapter.get_info())
352 .collect::<Vec<_>>(),
353 );
354 }
355
356 return adapters_other.remove(num);
357 }
358
359 adapters.remove(num)
360}
361
362fn get_device_override() -> Option<WgpuDevice> {
363 std::env::var("CUBECL_WGPU_DEFAULT_DEVICE")
366 .ok()
367 .and_then(|var| {
368 let override_device = if let Some(inner) = var.strip_prefix("DiscreteGpu(") {
369 inner
370 .strip_suffix(")")
371 .and_then(|s| s.parse().ok())
372 .map(WgpuDevice::DiscreteGpu)
373 } else if let Some(inner) = var.strip_prefix("IntegratedGpu(") {
374 inner
375 .strip_suffix(")")
376 .and_then(|s| s.parse().ok())
377 .map(WgpuDevice::IntegratedGpu)
378 } else if let Some(inner) = var.strip_prefix("VirtualGpu(") {
379 inner
380 .strip_suffix(")")
381 .and_then(|s| s.parse().ok())
382 .map(WgpuDevice::VirtualGpu)
383 } else if var == "Cpu" {
384 Some(WgpuDevice::Cpu)
385 } else {
386 None
387 };
388
389 if override_device.is_none() {
390 log::warn!("Unknown CUBECL_WGPU_DEVICE override {var}");
391 }
392 override_device
393 })
394}