cubecl_wgpu/
runtime.rs

1use std::marker::PhantomData;
2
3use crate::{
4    compiler::{base::WgpuCompiler, wgsl::WgslCompiler},
5    compute::{WgpuServer, WgpuStorage},
6    AutoGraphicsApi, GraphicsApi, WgpuDevice,
7};
8use alloc::sync::Arc;
9use cubecl_common::future;
10use cubecl_core::{Feature, Runtime};
11pub use cubecl_runtime::memory_management::MemoryConfiguration;
12use cubecl_runtime::{
13    channel::MutexComputeChannel,
14    client::ComputeClient,
15    debug::{DebugLogger, ProfileLevel},
16    ComputeRuntime,
17};
18use cubecl_runtime::{memory_management::HardwareProperties, DeviceProperties};
19use cubecl_runtime::{
20    memory_management::{MemoryDeviceProperties, MemoryManagement},
21    storage::ComputeStorage,
22};
23use wgpu::{InstanceFlags, RequestAdapterOptions};
24
25/// Runtime that uses the [wgpu] crate with the wgsl compiler. This is used in the Wgpu backend.
26/// For advanced configuration, use [`init_sync`] to pass in runtime options or to select a
27/// specific graphics API.
28#[derive(Debug)]
29pub struct WgpuRuntime<C: WgpuCompiler = WgslCompiler>(PhantomData<C>);
30
31type Server = WgpuServer<WgslCompiler>;
32
33/// The compute instance is shared across all [wgpu runtimes](WgpuRuntime).
34static RUNTIME: ComputeRuntime<WgpuDevice, Server, MutexComputeChannel<Server>> =
35    ComputeRuntime::new();
36
37impl Runtime for WgpuRuntime<WgslCompiler> {
38    type Compiler = WgslCompiler;
39    type Server = WgpuServer<WgslCompiler>;
40
41    type Channel = MutexComputeChannel<WgpuServer<WgslCompiler>>;
42    type Device = WgpuDevice;
43
44    fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel> {
45        RUNTIME.client(device, move || {
46            let setup = future::block_on(create_setup_for_device::<AutoGraphicsApi, WgslCompiler>(
47                device,
48            ));
49            create_client_on_setup(setup, RuntimeOptions::default())
50        })
51    }
52
53    fn name() -> &'static str {
54        "wgpu<wgsl>"
55    }
56
57    fn supported_line_sizes() -> &'static [u8] {
58        &[4, 2, 1]
59    }
60
61    fn max_cube_count() -> (u32, u32, u32) {
62        let max_dim = u16::MAX as u32;
63        (max_dim, max_dim, max_dim)
64    }
65
66    fn extension() -> &'static str {
67        "wgsl"
68    }
69}
70
71/// The values that control how a WGPU Runtime will perform its calculations.
72pub struct RuntimeOptions {
73    /// Control the amount of compute tasks to be aggregated into a single GPU command.
74    pub tasks_max: usize,
75    /// Configures the memory management.
76    pub memory_config: MemoryConfiguration,
77}
78
79impl Default for RuntimeOptions {
80    fn default() -> Self {
81        #[cfg(test)]
82        const DEFAULT_MAX_TASKS: usize = 1;
83        #[cfg(not(test))]
84        const DEFAULT_MAX_TASKS: usize = 32;
85
86        let tasks_max = match std::env::var("CUBECL_WGPU_MAX_TASKS") {
87            Ok(value) => value
88                .parse::<usize>()
89                .expect("CUBECL_WGPU_MAX_TASKS should be a positive integer."),
90            Err(_) => DEFAULT_MAX_TASKS,
91        };
92
93        Self {
94            tasks_max,
95            memory_config: MemoryConfiguration::default(),
96        }
97    }
98}
99
100/// A complete setup used to run wgpu.
101///
102/// These can either be created with [`init_setup`] or [`init_setup_async`].
103#[derive(Clone, Debug)]
104pub struct WgpuSetup {
105    /// The underlying wgpu instance.
106    pub instance: Arc<wgpu::Instance>,
107    /// The selected 'adapter'. This corresponds to a physical device.
108    pub adapter: Arc<wgpu::Adapter>,
109    /// The wgpu device Burn will use. Nb: There can only be one device per adapter.
110    pub device: Arc<wgpu::Device>,
111    /// The queue Burn commands will be submitted to.
112    pub queue: Arc<wgpu::Queue>,
113}
114
115/// Create a [`WgpuDevice`] on an existing [`WgpuSetup`].
116/// Useful when you want to share a device between CubeCL and other wgpu-dependent libraries.
117///
118/// # Note
119///
120/// Please **do not** to call on the same [`setup`](WgpuSetup) more than once.
121///
122/// This function generates a new, globally unique ID for the device every time it is called,
123/// even if called on the same device multiple times.
124pub fn init_device(setup: WgpuSetup, options: RuntimeOptions) -> WgpuDevice {
125    use core::sync::atomic::{AtomicU32, Ordering};
126
127    static COUNTER: AtomicU32 = AtomicU32::new(0);
128
129    let device_id = COUNTER.fetch_add(1, Ordering::Relaxed);
130    if device_id == u32::MAX {
131        core::panic!("Memory ID overflowed");
132    }
133
134    let device_id = WgpuDevice::Existing(device_id);
135    let client = create_client_on_setup(setup, options);
136    RUNTIME.register(&device_id, client);
137    device_id
138}
139
140/// Like [`init_setup_async`], but synchronous.
141/// On wasm, it is necessary to use [`init_setup_async`] instead.
142pub fn init_setup<G: GraphicsApi>(device: &WgpuDevice, options: RuntimeOptions) -> WgpuSetup {
143    cfg_if::cfg_if! {
144        if #[cfg(target_family = "wasm")] {
145            let _ = (device, options);
146            panic!("Creating a wgpu setup synchronously is unsupported on wasm. Use init_async instead");
147        } else {
148            future::block_on(init_setup_async::<G>(device, options))
149        }
150    }
151}
152
153/// Initialize a client on the given device with the given options.
154/// This function is useful to configure the runtime options
155/// or to pick a different graphics API.
156pub async fn init_setup_async<G: GraphicsApi>(
157    device: &WgpuDevice,
158    options: RuntimeOptions,
159) -> WgpuSetup {
160    let setup = create_setup_for_device::<G, WgslCompiler>(device).await;
161    let return_setup = setup.clone();
162    let client = create_client_on_setup(setup, options);
163    RUNTIME.register(device, client);
164    return_setup
165}
166
167pub(crate) fn create_client_on_setup<C: WgpuCompiler>(
168    setup: WgpuSetup,
169    options: RuntimeOptions,
170) -> ComputeClient<WgpuServer<C>, MutexComputeChannel<WgpuServer<C>>> {
171    let limits = setup.device.limits();
172    let adapter_limits = setup.adapter.limits();
173
174    let mem_props = MemoryDeviceProperties {
175        max_page_size: limits.max_storage_buffer_binding_size as u64,
176        alignment: WgpuStorage::ALIGNMENT.max(limits.min_storage_buffer_offset_alignment as u64),
177    };
178    let hardware_props = HardwareProperties {
179        plane_size_min: adapter_limits.min_subgroup_size,
180        plane_size_max: adapter_limits.max_subgroup_size,
181        max_bindings: limits.max_storage_buffers_per_shader_stage,
182    };
183
184    let memory_management = {
185        let device = setup.device.clone();
186        let mem_props = mem_props.clone();
187        let config = options.memory_config;
188        let storage = WgpuStorage::new(device.clone());
189        MemoryManagement::from_configuration(storage, mem_props, config)
190    };
191    let compilation_options = Default::default();
192    let server = WgpuServer::new(
193        memory_management,
194        compilation_options,
195        setup.device.clone(),
196        setup.queue,
197        options.tasks_max,
198    );
199    let channel = MutexComputeChannel::new(server);
200
201    let features = setup.adapter.features();
202    let mut device_props = DeviceProperties::new(&[], mem_props, hardware_props);
203
204    // Workaround: WebGPU does support subgroups and correctly reports this, but wgpu
205    // doesn't plumb through this info. Instead min/max are just reported as 0, which can cause issues.
206    // For now just disable subgroups on WebGPU, until this information is added.
207    let fake_plane_info =
208        adapter_limits.min_subgroup_size == 0 && adapter_limits.max_subgroup_size == 0;
209
210    if features.contains(wgpu::Features::SUBGROUP)
211        && setup.adapter.get_info().device_type != wgpu::DeviceType::Cpu
212        && !fake_plane_info
213    {
214        device_props.register_feature(Feature::Plane);
215    }
216    C::register_features(&setup.adapter, &setup.device, &mut device_props);
217    ComputeClient::new(channel, device_props)
218}
219
220/// Select the wgpu device and queue based on the provided [device](WgpuDevice).
221pub(crate) async fn create_setup_for_device<G: GraphicsApi, C: WgpuCompiler>(
222    device: &WgpuDevice,
223) -> WgpuSetup {
224    let (instance, adapter) = request_adapter::<G>(device).await;
225    let (device, queue) = C::request_device(&adapter).await;
226
227    log::info!(
228        "Created wgpu compute server on device {:?} => {:?}",
229        device,
230        adapter.get_info()
231    );
232
233    WgpuSetup {
234        instance: Arc::new(instance),
235        adapter: Arc::new(adapter),
236        device: Arc::new(device),
237        queue: Arc::new(queue),
238    }
239}
240
241async fn request_adapter<G: GraphicsApi>(device: &WgpuDevice) -> (wgpu::Instance, wgpu::Adapter) {
242    let debug = DebugLogger::default();
243    let instance_flags = match (debug.profile_level(), debug.is_activated()) {
244        (Some(ProfileLevel::Full), _) => InstanceFlags::advanced_debugging(),
245        (_, true) => InstanceFlags::debugging(),
246        (_, false) => InstanceFlags::default(),
247    };
248    log::debug!("{instance_flags:?}");
249    let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
250        backends: G::backend().into(),
251        flags: instance_flags,
252        ..Default::default()
253    });
254
255    #[allow(deprecated)]
256    let override_device = if matches!(
257        device,
258        WgpuDevice::DefaultDevice | WgpuDevice::BestAvailable
259    ) {
260        get_device_override()
261    } else {
262        None
263    };
264
265    let device = override_device.unwrap_or_else(|| device.clone());
266
267    let adapter = match device {
268        #[cfg(not(target_family = "wasm"))]
269        WgpuDevice::DiscreteGpu(num) => {
270            select_from_adapter_list::<G>(num, "No Discrete GPU device found", &instance, &device)
271        }
272        #[cfg(not(target_family = "wasm"))]
273        WgpuDevice::IntegratedGpu(num) => {
274            select_from_adapter_list::<G>(num, "No Integrated GPU device found", &instance, &device)
275        }
276        #[cfg(not(target_family = "wasm"))]
277        WgpuDevice::VirtualGpu(num) => {
278            select_from_adapter_list::<G>(num, "No Virtual GPU device found", &instance, &device)
279        }
280        #[cfg(not(target_family = "wasm"))]
281        WgpuDevice::Cpu => {
282            select_from_adapter_list::<G>(0, "No CPU device found", &instance, &device)
283        }
284        WgpuDevice::Existing(_) => {
285            unreachable!("Cannot select an adapter for an existing device.")
286        }
287        _ => instance
288            .request_adapter(&RequestAdapterOptions {
289                power_preference: wgpu::PowerPreference::HighPerformance,
290                force_fallback_adapter: false,
291                compatible_surface: None,
292            })
293            .await
294            .expect("No possible adapter available for backend. Falling back to first available."),
295    };
296
297    log::info!("Using adapter {:?}", adapter.get_info());
298
299    (instance, adapter)
300}
301
302#[cfg(not(target_family = "wasm"))]
303fn select_from_adapter_list<G: GraphicsApi>(
304    num: usize,
305    error: &str,
306    instance: &wgpu::Instance,
307    device: &WgpuDevice,
308) -> wgpu::Adapter {
309    let mut adapters_other = Vec::new();
310    let mut adapters = Vec::new();
311
312    instance
313        .enumerate_adapters(G::backend().into())
314        .into_iter()
315        .for_each(|adapter| {
316            let device_type = adapter.get_info().device_type;
317
318            if let wgpu::DeviceType::Other = device_type {
319                adapters_other.push(adapter);
320                return;
321            }
322
323            let is_same_type = match device {
324                WgpuDevice::DiscreteGpu(_) => device_type == wgpu::DeviceType::DiscreteGpu,
325                WgpuDevice::IntegratedGpu(_) => device_type == wgpu::DeviceType::IntegratedGpu,
326                WgpuDevice::VirtualGpu(_) => device_type == wgpu::DeviceType::VirtualGpu,
327                WgpuDevice::Cpu => device_type == wgpu::DeviceType::Cpu,
328                #[allow(deprecated)]
329                WgpuDevice::DefaultDevice | WgpuDevice::BestAvailable => true,
330                WgpuDevice::Existing(_) => {
331                    unreachable!("Cannot select an adapter for an existing device.")
332                }
333            };
334
335            if is_same_type {
336                adapters.push(adapter);
337            }
338        });
339
340    if adapters.len() <= num {
341        if adapters_other.len() <= num {
342            panic!(
343                "{}, adapters {:?}, other adapters {:?}",
344                error,
345                adapters
346                    .into_iter()
347                    .map(|adapter| adapter.get_info())
348                    .collect::<Vec<_>>(),
349                adapters_other
350                    .into_iter()
351                    .map(|adapter| adapter.get_info())
352                    .collect::<Vec<_>>(),
353            );
354        }
355
356        return adapters_other.remove(num);
357    }
358
359    adapters.remove(num)
360}
361
362fn get_device_override() -> Option<WgpuDevice> {
363    // If BestAvailable, check if we should instead construct as
364    // if a specific device was specified.
365    std::env::var("CUBECL_WGPU_DEFAULT_DEVICE")
366        .ok()
367        .and_then(|var| {
368            let override_device = if let Some(inner) = var.strip_prefix("DiscreteGpu(") {
369                inner
370                    .strip_suffix(")")
371                    .and_then(|s| s.parse().ok())
372                    .map(WgpuDevice::DiscreteGpu)
373            } else if let Some(inner) = var.strip_prefix("IntegratedGpu(") {
374                inner
375                    .strip_suffix(")")
376                    .and_then(|s| s.parse().ok())
377                    .map(WgpuDevice::IntegratedGpu)
378            } else if let Some(inner) = var.strip_prefix("VirtualGpu(") {
379                inner
380                    .strip_suffix(")")
381                    .and_then(|s| s.parse().ok())
382                    .map(WgpuDevice::VirtualGpu)
383            } else if var == "Cpu" {
384                Some(WgpuDevice::Cpu)
385            } else {
386                None
387            };
388
389            if override_device.is_none() {
390                log::warn!("Unknown CUBECL_WGPU_DEVICE override {var}");
391            }
392            override_device
393        })
394}