cubecl_wgpu/
runtime.rs

1use crate::{
2    AutoCompiler, AutoGraphicsApi, GraphicsApi, WgpuDevice, backend, compute::WgpuServer,
3    contiguous_strides,
4};
5use cubecl_common::device::{Device, DeviceState};
6use cubecl_common::{future, profile::TimingMethod};
7use cubecl_core::{Runtime, ir::TargetProperties};
8use cubecl_core::{ir::LineSize, server::ServerUtilities};
9use cubecl_ir::{DeviceProperties, HardwareProperties, MemoryDeviceProperties};
10pub use cubecl_runtime::memory_management::MemoryConfiguration;
11use cubecl_runtime::{
12    client::ComputeClient,
13    logging::{ProfileLevel, ServerLogger},
14};
15use wgpu::{InstanceFlags, RequestAdapterOptions};
16
17/// Runtime that uses the [wgpu] crate with the wgsl compiler. This is used in the Wgpu backend.
18/// For advanced configuration, use [`init_setup`] to pass in runtime options or to select a
19/// specific graphics API.
20#[derive(Debug)]
21pub struct WgpuRuntime;
22
23impl DeviceState for WgpuServer {
24    fn init(device_id: cubecl_common::device::DeviceId) -> Self {
25        let device = WgpuDevice::from_id(device_id);
26        let setup = future::block_on(create_setup_for_device(&device, AutoGraphicsApi::backend()));
27        create_server(setup, RuntimeOptions::default())
28    }
29}
30
31impl Runtime for WgpuRuntime {
32    type Compiler = AutoCompiler;
33    type Server = WgpuServer;
34    type Device = WgpuDevice;
35
36    fn client(device: &Self::Device) -> ComputeClient<Self> {
37        ComputeClient::load(device)
38    }
39
40    fn name(client: &ComputeClient<Self>) -> &'static str {
41        match client.info() {
42            wgpu::Backend::Vulkan => {
43                #[cfg(feature = "spirv")]
44                return "wgpu<spirv>";
45
46                #[cfg(not(feature = "spirv"))]
47                return "wgpu<wgsl>";
48            }
49            wgpu::Backend::Metal => {
50                #[cfg(feature = "msl")]
51                return "wgpu<msl>";
52
53                #[cfg(not(feature = "msl"))]
54                return "wgpu<wgsl>";
55            }
56            _ => "wgpu<wgsl>",
57        }
58    }
59
60    fn supported_line_sizes() -> &'static [LineSize] {
61        #[cfg(feature = "msl")]
62        {
63            &[8, 4, 2, 1]
64        }
65        #[cfg(not(feature = "msl"))]
66        {
67            &[4, 2, 1]
68        }
69    }
70
71    fn max_global_line_size() -> LineSize {
72        4
73    }
74
75    fn max_cube_count() -> (u32, u32, u32) {
76        let max_dim = u16::MAX as u32;
77        (max_dim, max_dim, max_dim)
78    }
79
80    fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool {
81        if shape.is_empty() {
82            return true;
83        }
84
85        for (expected, &stride) in contiguous_strides(shape).into_iter().zip(strides) {
86            if expected != stride {
87                return false;
88            }
89        }
90
91        true
92    }
93
94    fn target_properties() -> TargetProperties {
95        TargetProperties {
96            // Values are irrelevant, since no wgsl backends currently support manual mma
97            mma: Default::default(),
98        }
99    }
100}
101
102/// The values that control how a WGPU Runtime will perform its calculations.
103pub struct RuntimeOptions {
104    /// Control the amount of compute tasks to be aggregated into a single GPU command.
105    pub tasks_max: usize,
106    /// Configures the memory management.
107    pub memory_config: MemoryConfiguration,
108}
109
110impl Default for RuntimeOptions {
111    fn default() -> Self {
112        #[cfg(test)]
113        const DEFAULT_MAX_TASKS: usize = 1;
114        #[cfg(not(test))]
115        const DEFAULT_MAX_TASKS: usize = 32;
116
117        let tasks_max = match std::env::var("CUBECL_WGPU_MAX_TASKS") {
118            Ok(value) => value
119                .parse::<usize>()
120                .expect("CUBECL_WGPU_MAX_TASKS should be a positive integer."),
121            Err(_) => DEFAULT_MAX_TASKS,
122        };
123
124        Self {
125            tasks_max,
126            memory_config: MemoryConfiguration::default(),
127        }
128    }
129}
130
131/// A complete setup used to run wgpu.
132///
133/// These can either be created with [`init_setup`] or [`init_setup_async`].
134#[derive(Clone, Debug)]
135pub struct WgpuSetup {
136    /// The underlying wgpu instance.
137    pub instance: wgpu::Instance,
138    /// The selected 'adapter'. This corresponds to a physical device.
139    pub adapter: wgpu::Adapter,
140    /// The wgpu device Burn will use. Nb: There can only be one device per adapter.
141    pub device: wgpu::Device,
142    /// The queue Burn commands will be submitted to.
143    pub queue: wgpu::Queue,
144    /// The backend used by the setup.
145    pub backend: wgpu::Backend,
146}
147
148/// Create a [`WgpuDevice`] on an existing [`WgpuSetup`].
149/// Useful when you want to share a device between CubeCL and other wgpu-dependent libraries.
150///
151/// # Note
152///
153/// Please **do not** to call on the same [`setup`](WgpuSetup) more than once.
154///
155/// This function generates a new, globally unique ID for the device every time it is called,
156/// even if called on the same device multiple times.
157pub fn init_device(setup: WgpuSetup, options: RuntimeOptions) -> WgpuDevice {
158    use core::sync::atomic::{AtomicU32, Ordering};
159
160    static COUNTER: AtomicU32 = AtomicU32::new(0);
161
162    let device_id = COUNTER.fetch_add(1, Ordering::Relaxed);
163    if device_id == u32::MAX {
164        core::panic!("Memory ID overflowed");
165    }
166
167    let device_id = WgpuDevice::Existing(device_id);
168    let server = create_server(setup, options);
169    let _ = ComputeClient::<WgpuRuntime>::init(&device_id, server);
170    device_id
171}
172
173/// Like [`init_setup_async`], but synchronous.
174/// On wasm, it is necessary to use [`init_setup_async`] instead.
175pub fn init_setup<G: GraphicsApi>(device: &WgpuDevice, options: RuntimeOptions) -> WgpuSetup {
176    cfg_if::cfg_if! {
177        if #[cfg(target_family = "wasm")] {
178            let _ = (device, options);
179            panic!("Creating a wgpu setup synchronously is unsupported on wasm. Use init_async instead");
180        } else {
181            future::block_on(init_setup_async::<G>(device, options))
182        }
183    }
184}
185
186/// Initialize a client on the given device with the given options.
187/// This function is useful to configure the runtime options
188/// or to pick a different graphics API.
189pub async fn init_setup_async<G: GraphicsApi>(
190    device: &WgpuDevice,
191    options: RuntimeOptions,
192) -> WgpuSetup {
193    let setup = create_setup_for_device(device, G::backend()).await;
194    let return_setup = setup.clone();
195    let server = create_server(setup, options);
196    let _ = ComputeClient::<WgpuRuntime>::init(device, server);
197    return_setup
198}
199
200pub(crate) fn create_server(setup: WgpuSetup, options: RuntimeOptions) -> WgpuServer {
201    let limits = setup.device.limits();
202    let mut adapter_limits = setup.adapter.limits();
203
204    // Workaround: WebGPU reports some "fake" subgroup info atm, as it's not really supported yet.
205    // However, some algorithms do rely on having this information eg. cubecl-reduce uses max subgroup size _even_ when
206    // subgroups aren't used. For now, just override with the maximum range of subgroups possible.
207    if adapter_limits.min_subgroup_size == 0 && adapter_limits.max_subgroup_size == 0 {
208        // There is in theory nothing limiting the size to go below 8 but in practice 8 is the minimum found anywhere.
209        adapter_limits.min_subgroup_size = 8;
210        // This is a hard limit of GPU APIs (subgroup ballot returns 4 * 32 bits).
211        adapter_limits.max_subgroup_size = 128;
212    }
213
214    let mem_props = MemoryDeviceProperties {
215        max_page_size: limits.max_storage_buffer_binding_size as u64,
216        alignment: limits.min_storage_buffer_offset_alignment as u64,
217    };
218    let max_count = adapter_limits.max_compute_workgroups_per_dimension;
219    let hardware_props = HardwareProperties {
220        load_width: 128,
221        // On Apple Silicon, the plane size is 32,
222        // though the minimum and maximum differ.
223        // https://github.com/gpuweb/gpuweb/issues/3950
224        #[cfg(apple_silicon)]
225        plane_size_min: 32,
226        #[cfg(not(apple_silicon))]
227        plane_size_min: adapter_limits.min_subgroup_size,
228        #[cfg(apple_silicon)]
229        plane_size_max: 32,
230        #[cfg(not(apple_silicon))]
231        plane_size_max: adapter_limits.max_subgroup_size,
232        // wgpu uses an additional buffer for variable-length buffers,
233        // so we have to use one buffer less on our side to make room for that wgpu internal buffer.
234        // See: https://github.com/gfx-rs/wgpu/blob/a9638c8e3ac09ce4f27ac171f8175671e30365fd/wgpu-hal/src/metal/device.rs#L799
235        max_bindings: limits
236            .max_storage_buffers_per_shader_stage
237            .saturating_sub(1),
238        max_shared_memory_size: limits.max_compute_workgroup_storage_size as usize,
239        max_cube_count: (max_count, max_count, max_count),
240        max_units_per_cube: adapter_limits.max_compute_invocations_per_workgroup,
241        max_cube_dim: (
242            adapter_limits.max_compute_workgroup_size_x,
243            adapter_limits.max_compute_workgroup_size_y,
244            adapter_limits.max_compute_workgroup_size_z,
245        ),
246        num_streaming_multiprocessors: None,
247        num_tensor_cores: None,
248        min_tensor_cores_dim: None,
249        num_cpu_cores: None, // TODO: Check if device is CPU.
250    };
251
252    let mut compilation_options = Default::default();
253
254    let features = setup.adapter.features();
255
256    let time_measurement = if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
257        TimingMethod::Device
258    } else {
259        TimingMethod::System
260    };
261
262    let mut device_props = DeviceProperties::new(
263        Default::default(),
264        mem_props.clone(),
265        hardware_props,
266        time_measurement,
267    );
268
269    #[cfg(not(all(target_os = "macos", feature = "msl")))]
270    {
271        if features.contains(wgpu::Features::SUBGROUP)
272            && setup.adapter.get_info().device_type != wgpu::DeviceType::Cpu
273        {
274            use cubecl_ir::features::Plane;
275
276            device_props.features.plane.insert(Plane::Ops);
277        }
278    }
279
280    #[cfg(any(feature = "spirv", feature = "msl"))]
281    device_props
282        .features
283        .plane
284        .insert(cubecl_ir::features::Plane::NonUniformControlFlow);
285
286    backend::register_features(&setup.adapter, &mut device_props, &mut compilation_options);
287
288    let logger = alloc::sync::Arc::new(ServerLogger::default());
289
290    WgpuServer::new(
291        mem_props,
292        options.memory_config,
293        compilation_options,
294        setup.device.clone(),
295        setup.queue,
296        options.tasks_max,
297        setup.backend,
298        time_measurement,
299        ServerUtilities::new(device_props, logger, setup.backend),
300    )
301}
302
303/// Select the wgpu device and queue based on the provided [device](WgpuDevice) and
304/// [backend](wgpu::Backend).
305pub(crate) async fn create_setup_for_device(
306    device: &WgpuDevice,
307    backend: wgpu::Backend,
308) -> WgpuSetup {
309    let (instance, adapter) = request_adapter(device, backend).await;
310    let (device, queue) = backend::request_device(&adapter).await;
311
312    log::info!(
313        "Created wgpu compute server on device {:?} => {:?}",
314        device,
315        adapter.get_info()
316    );
317
318    WgpuSetup {
319        instance,
320        adapter,
321        device,
322        queue,
323        backend,
324    }
325}
326
327async fn request_adapter(
328    device: &WgpuDevice,
329    backend: wgpu::Backend,
330) -> (wgpu::Instance, wgpu::Adapter) {
331    let debug = ServerLogger::default();
332    let instance_flags = match (debug.profile_level(), debug.compilation_activated()) {
333        (Some(ProfileLevel::Full), _) => InstanceFlags::advanced_debugging(),
334        (_, true) => InstanceFlags::debugging(),
335        (_, false) => InstanceFlags::default(),
336    };
337    log::debug!("{instance_flags:?}");
338    let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
339        backends: backend.into(),
340        flags: instance_flags,
341        ..Default::default()
342    });
343
344    #[allow(deprecated)]
345    let override_device = if matches!(
346        device,
347        WgpuDevice::DefaultDevice | WgpuDevice::BestAvailable
348    ) {
349        get_device_override()
350    } else {
351        None
352    };
353
354    let device = override_device.unwrap_or_else(|| device.clone());
355
356    let adapter = match device {
357        #[cfg(not(target_family = "wasm"))]
358        WgpuDevice::DiscreteGpu(num) => select_from_adapter_list(
359            num,
360            "No Discrete GPU device found",
361            &instance,
362            &device,
363            backend,
364        ),
365        #[cfg(not(target_family = "wasm"))]
366        WgpuDevice::IntegratedGpu(num) => select_from_adapter_list(
367            num,
368            "No Integrated GPU device found",
369            &instance,
370            &device,
371            backend,
372        ),
373        #[cfg(not(target_family = "wasm"))]
374        WgpuDevice::VirtualGpu(num) => select_from_adapter_list(
375            num,
376            "No Virtual GPU device found",
377            &instance,
378            &device,
379            backend,
380        ),
381        #[cfg(not(target_family = "wasm"))]
382        WgpuDevice::Cpu => {
383            select_from_adapter_list(0, "No CPU device found", &instance, &device, backend)
384        }
385        WgpuDevice::Existing(_) => {
386            unreachable!("Cannot select an adapter for an existing device.")
387        }
388        _ => instance
389            .request_adapter(&RequestAdapterOptions {
390                power_preference: wgpu::PowerPreference::HighPerformance,
391                force_fallback_adapter: false,
392                compatible_surface: None,
393            })
394            .await
395            .expect("No possible adapter available for backend. Falling back to first available."),
396    };
397
398    log::info!("Using adapter {:?}", adapter.get_info());
399
400    (instance, adapter)
401}
402
403#[cfg(not(target_family = "wasm"))]
404fn select_from_adapter_list(
405    num: usize,
406    error: &str,
407    instance: &wgpu::Instance,
408    device: &WgpuDevice,
409    backend: wgpu::Backend,
410) -> wgpu::Adapter {
411    let mut adapters_other = Vec::new();
412    let mut adapters = Vec::new();
413
414    instance
415        .enumerate_adapters(backend.into())
416        .into_iter()
417        .for_each(|adapter| {
418            let device_type = adapter.get_info().device_type;
419
420            if let wgpu::DeviceType::Other = device_type {
421                adapters_other.push(adapter);
422                return;
423            }
424
425            let is_same_type = match device {
426                WgpuDevice::DiscreteGpu(_) => device_type == wgpu::DeviceType::DiscreteGpu,
427                WgpuDevice::IntegratedGpu(_) => device_type == wgpu::DeviceType::IntegratedGpu,
428                WgpuDevice::VirtualGpu(_) => device_type == wgpu::DeviceType::VirtualGpu,
429                WgpuDevice::Cpu => device_type == wgpu::DeviceType::Cpu,
430                #[allow(deprecated)]
431                WgpuDevice::DefaultDevice | WgpuDevice::BestAvailable => true,
432                WgpuDevice::Existing(_) => {
433                    unreachable!("Cannot select an adapter for an existing device.")
434                }
435            };
436
437            if is_same_type {
438                adapters.push(adapter);
439            }
440        });
441
442    if adapters.len() <= num {
443        if adapters_other.len() <= num {
444            panic!(
445                "{}, adapters {:?}, other adapters {:?}",
446                error,
447                adapters
448                    .into_iter()
449                    .map(|adapter| adapter.get_info())
450                    .collect::<Vec<_>>(),
451                adapters_other
452                    .into_iter()
453                    .map(|adapter| adapter.get_info())
454                    .collect::<Vec<_>>(),
455            );
456        }
457
458        return adapters_other.remove(num);
459    }
460
461    adapters.remove(num)
462}
463
464fn get_device_override() -> Option<WgpuDevice> {
465    // If BestAvailable, check if we should instead construct as
466    // if a specific device was specified.
467    std::env::var("CUBECL_WGPU_DEFAULT_DEVICE")
468        .ok()
469        .and_then(|var| {
470            let override_device = if let Some(inner) = var.strip_prefix("DiscreteGpu(") {
471                inner
472                    .strip_suffix(")")
473                    .and_then(|s| s.parse().ok())
474                    .map(WgpuDevice::DiscreteGpu)
475            } else if let Some(inner) = var.strip_prefix("IntegratedGpu(") {
476                inner
477                    .strip_suffix(")")
478                    .and_then(|s| s.parse().ok())
479                    .map(WgpuDevice::IntegratedGpu)
480            } else if let Some(inner) = var.strip_prefix("VirtualGpu(") {
481                inner
482                    .strip_suffix(")")
483                    .and_then(|s| s.parse().ok())
484                    .map(WgpuDevice::VirtualGpu)
485            } else if var == "Cpu" {
486                Some(WgpuDevice::Cpu)
487            } else {
488                None
489            };
490
491            if override_device.is_none() {
492                log::warn!("Unknown CUBECL_WGPU_DEVICE override {var}");
493            }
494            override_device
495        })
496}