cubecl_wgpu/
runtime.rs

1use crate::{
2    AutoCompiler, AutoGraphicsApi, GraphicsApi, WgpuDevice, backend, compute::WgpuServer,
3    contiguous_strides,
4};
5use cubecl_common::device::{Device, DeviceState};
6use cubecl_common::{future, profile::TimingMethod};
7use cubecl_core::server::ServerUtilities;
8use cubecl_core::{CubeCount, CubeDim, Runtime, ir::TargetProperties};
9pub use cubecl_runtime::memory_management::MemoryConfiguration;
10use cubecl_runtime::memory_management::MemoryDeviceProperties;
11use cubecl_runtime::{DeviceProperties, memory_management::HardwareProperties};
12use cubecl_runtime::{
13    client::ComputeClient,
14    logging::{ProfileLevel, ServerLogger},
15};
16use wgpu::{InstanceFlags, RequestAdapterOptions};
17
18/// Runtime that uses the [wgpu] crate with the wgsl compiler. This is used in the Wgpu backend.
19/// For advanced configuration, use [`init_setup`] to pass in runtime options or to select a
20/// specific graphics API.
21#[derive(Debug)]
22pub struct WgpuRuntime;
23
24impl DeviceState for WgpuServer {
25    fn init(device_id: cubecl_common::device::DeviceId) -> Self {
26        let device = WgpuDevice::from_id(device_id);
27        let setup = future::block_on(create_setup_for_device(&device, AutoGraphicsApi::backend()));
28        create_server(setup, RuntimeOptions::default())
29    }
30}
31
32impl Runtime for WgpuRuntime {
33    type Compiler = AutoCompiler;
34    type Server = WgpuServer;
35    type Device = WgpuDevice;
36
37    fn client(device: &Self::Device) -> ComputeClient<Self> {
38        ComputeClient::load(device)
39    }
40
41    fn name(client: &ComputeClient<Self>) -> &'static str {
42        match client.info() {
43            wgpu::Backend::Vulkan => {
44                #[cfg(feature = "spirv")]
45                return "wgpu<spirv>";
46
47                #[cfg(not(feature = "spirv"))]
48                return "wgpu<wgsl>";
49            }
50            wgpu::Backend::Metal => {
51                #[cfg(feature = "msl")]
52                return "wgpu<msl>";
53
54                #[cfg(not(feature = "msl"))]
55                return "wgpu<wgsl>";
56            }
57            _ => "wgpu<wgsl>",
58        }
59    }
60
61    fn supported_line_sizes() -> &'static [u8] {
62        #[cfg(feature = "msl")]
63        {
64            &[8, 4, 2, 1]
65        }
66        #[cfg(not(feature = "msl"))]
67        {
68            &[4, 2, 1]
69        }
70    }
71
72    fn max_global_line_size() -> u8 {
73        4
74    }
75
76    fn max_cube_count() -> (u32, u32, u32) {
77        let max_dim = u16::MAX as u32;
78        (max_dim, max_dim, max_dim)
79    }
80
81    fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool {
82        if shape.is_empty() {
83            return true;
84        }
85
86        for (expected, &stride) in contiguous_strides(shape).into_iter().zip(strides) {
87            if expected != stride {
88                return false;
89            }
90        }
91
92        true
93    }
94
95    fn target_properties() -> TargetProperties {
96        TargetProperties {
97            // Values are irrelevant, since no wgsl backends currently support manual mma
98            mma: Default::default(),
99        }
100    }
101}
102
103/// The values that control how a WGPU Runtime will perform its calculations.
104pub struct RuntimeOptions {
105    /// Control the amount of compute tasks to be aggregated into a single GPU command.
106    pub tasks_max: usize,
107    /// Configures the memory management.
108    pub memory_config: MemoryConfiguration,
109}
110
111impl Default for RuntimeOptions {
112    fn default() -> Self {
113        #[cfg(test)]
114        const DEFAULT_MAX_TASKS: usize = 1;
115        #[cfg(not(test))]
116        const DEFAULT_MAX_TASKS: usize = 32;
117
118        let tasks_max = match std::env::var("CUBECL_WGPU_MAX_TASKS") {
119            Ok(value) => value
120                .parse::<usize>()
121                .expect("CUBECL_WGPU_MAX_TASKS should be a positive integer."),
122            Err(_) => DEFAULT_MAX_TASKS,
123        };
124
125        Self {
126            tasks_max,
127            memory_config: MemoryConfiguration::default(),
128        }
129    }
130}
131
132/// A complete setup used to run wgpu.
133///
134/// These can either be created with [`init_setup`] or [`init_setup_async`].
135#[derive(Clone, Debug)]
136pub struct WgpuSetup {
137    /// The underlying wgpu instance.
138    pub instance: wgpu::Instance,
139    /// The selected 'adapter'. This corresponds to a physical device.
140    pub adapter: wgpu::Adapter,
141    /// The wgpu device Burn will use. Nb: There can only be one device per adapter.
142    pub device: wgpu::Device,
143    /// The queue Burn commands will be submitted to.
144    pub queue: wgpu::Queue,
145    /// The backend used by the setup.
146    pub backend: wgpu::Backend,
147}
148
149/// Create a [`WgpuDevice`] on an existing [`WgpuSetup`].
150/// Useful when you want to share a device between CubeCL and other wgpu-dependent libraries.
151///
152/// # Note
153///
154/// Please **do not** to call on the same [`setup`](WgpuSetup) more than once.
155///
156/// This function generates a new, globally unique ID for the device every time it is called,
157/// even if called on the same device multiple times.
158pub fn init_device(setup: WgpuSetup, options: RuntimeOptions) -> WgpuDevice {
159    use core::sync::atomic::{AtomicU32, Ordering};
160
161    static COUNTER: AtomicU32 = AtomicU32::new(0);
162
163    let device_id = COUNTER.fetch_add(1, Ordering::Relaxed);
164    if device_id == u32::MAX {
165        core::panic!("Memory ID overflowed");
166    }
167
168    let device_id = WgpuDevice::Existing(device_id);
169    let server = create_server(setup, options);
170    let _ = ComputeClient::<WgpuRuntime>::init(&device_id, server);
171    device_id
172}
173
174/// Like [`init_setup_async`], but synchronous.
175/// On wasm, it is necessary to use [`init_setup_async`] instead.
176pub fn init_setup<G: GraphicsApi>(device: &WgpuDevice, options: RuntimeOptions) -> WgpuSetup {
177    cfg_if::cfg_if! {
178        if #[cfg(target_family = "wasm")] {
179            let _ = (device, options);
180            panic!("Creating a wgpu setup synchronously is unsupported on wasm. Use init_async instead");
181        } else {
182            future::block_on(init_setup_async::<G>(device, options))
183        }
184    }
185}
186
187/// Initialize a client on the given device with the given options.
188/// This function is useful to configure the runtime options
189/// or to pick a different graphics API.
190pub async fn init_setup_async<G: GraphicsApi>(
191    device: &WgpuDevice,
192    options: RuntimeOptions,
193) -> WgpuSetup {
194    let setup = create_setup_for_device(device, G::backend()).await;
195    let return_setup = setup.clone();
196    let server = create_server(setup, options);
197    let _ = ComputeClient::<WgpuRuntime>::init(device, server);
198    return_setup
199}
200
201pub(crate) fn create_server(setup: WgpuSetup, options: RuntimeOptions) -> WgpuServer {
202    let limits = setup.device.limits();
203    let mut adapter_limits = setup.adapter.limits();
204
205    // Workaround: WebGPU reports some "fake" subgroup info atm, as it's not really supported yet.
206    // However, some algorithms do rely on having this information eg. cubecl-reduce uses max subgroup size _even_ when
207    // subgroups aren't used. For now, just override with the maximum range of subgroups possible.
208    if adapter_limits.min_subgroup_size == 0 && adapter_limits.max_subgroup_size == 0 {
209        // There is in theory nothing limiting the size to go below 8 but in practice 8 is the minimum found anywhere.
210        adapter_limits.min_subgroup_size = 8;
211        // This is a hard limit of GPU APIs (subgroup ballot returns 4 * 32 bits).
212        adapter_limits.max_subgroup_size = 128;
213    }
214
215    let mem_props = MemoryDeviceProperties {
216        max_page_size: limits.max_storage_buffer_binding_size as u64,
217        alignment: limits.min_storage_buffer_offset_alignment as u64,
218    };
219    let max_count = adapter_limits.max_compute_workgroups_per_dimension;
220    let hardware_props = HardwareProperties {
221        load_width: 128,
222        // On Apple Silicon, the plane size is 32,
223        // though the minimum and maximum differ.
224        // https://github.com/gpuweb/gpuweb/issues/3950
225        #[cfg(apple_silicon)]
226        plane_size_min: 32,
227        #[cfg(not(apple_silicon))]
228        plane_size_min: adapter_limits.min_subgroup_size,
229        #[cfg(apple_silicon)]
230        plane_size_max: 32,
231        #[cfg(not(apple_silicon))]
232        plane_size_max: adapter_limits.max_subgroup_size,
233        // wgpu uses an additional buffer for variable-length buffers,
234        // so we have to use one buffer less on our side to make room for that wgpu internal buffer.
235        // See: https://github.com/gfx-rs/wgpu/blob/a9638c8e3ac09ce4f27ac171f8175671e30365fd/wgpu-hal/src/metal/device.rs#L799
236        max_bindings: limits
237            .max_storage_buffers_per_shader_stage
238            .saturating_sub(1),
239        max_shared_memory_size: limits.max_compute_workgroup_storage_size as usize,
240        max_cube_count: CubeCount::new_3d(max_count, max_count, max_count),
241        max_units_per_cube: adapter_limits.max_compute_invocations_per_workgroup,
242        max_cube_dim: CubeDim::new_3d(
243            adapter_limits.max_compute_workgroup_size_x,
244            adapter_limits.max_compute_workgroup_size_y,
245            adapter_limits.max_compute_workgroup_size_z,
246        ),
247        num_streaming_multiprocessors: None,
248        num_tensor_cores: None,
249        min_tensor_cores_dim: None,
250    };
251
252    let mut compilation_options = Default::default();
253
254    let features = setup.adapter.features();
255
256    let time_measurement = if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
257        TimingMethod::Device
258    } else {
259        TimingMethod::System
260    };
261
262    let mut device_props = DeviceProperties::new(
263        Default::default(),
264        mem_props.clone(),
265        hardware_props,
266        time_measurement,
267    );
268
269    #[cfg(not(all(target_os = "macos", feature = "msl")))]
270    {
271        if features.contains(wgpu::Features::SUBGROUP)
272            && setup.adapter.get_info().device_type != wgpu::DeviceType::Cpu
273        {
274            use cubecl_runtime::Plane;
275
276            device_props.features.plane.insert(Plane::Ops);
277        }
278    }
279
280    backend::register_features(&setup.adapter, &mut device_props, &mut compilation_options);
281
282    let logger = alloc::sync::Arc::new(ServerLogger::default());
283
284    WgpuServer::new(
285        mem_props,
286        options.memory_config,
287        compilation_options,
288        setup.device.clone(),
289        setup.queue,
290        options.tasks_max,
291        setup.backend,
292        time_measurement,
293        ServerUtilities::new(device_props, logger, setup.backend),
294    )
295}
296
297/// Select the wgpu device and queue based on the provided [device](WgpuDevice) and
298/// [backend](wgpu::Backend).
299pub(crate) async fn create_setup_for_device(
300    device: &WgpuDevice,
301    backend: wgpu::Backend,
302) -> WgpuSetup {
303    let (instance, adapter) = request_adapter(device, backend).await;
304    let (device, queue) = backend::request_device(&adapter).await;
305
306    log::info!(
307        "Created wgpu compute server on device {:?} => {:?}",
308        device,
309        adapter.get_info()
310    );
311
312    WgpuSetup {
313        instance,
314        adapter,
315        device,
316        queue,
317        backend,
318    }
319}
320
321async fn request_adapter(
322    device: &WgpuDevice,
323    backend: wgpu::Backend,
324) -> (wgpu::Instance, wgpu::Adapter) {
325    let debug = ServerLogger::default();
326    let instance_flags = match (debug.profile_level(), debug.compilation_activated()) {
327        (Some(ProfileLevel::Full), _) => InstanceFlags::advanced_debugging(),
328        (_, true) => InstanceFlags::debugging(),
329        (_, false) => InstanceFlags::default(),
330    };
331    log::debug!("{instance_flags:?}");
332    let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
333        backends: backend.into(),
334        flags: instance_flags,
335        ..Default::default()
336    });
337
338    #[allow(deprecated)]
339    let override_device = if matches!(
340        device,
341        WgpuDevice::DefaultDevice | WgpuDevice::BestAvailable
342    ) {
343        get_device_override()
344    } else {
345        None
346    };
347
348    let device = override_device.unwrap_or_else(|| device.clone());
349
350    let adapter = match device {
351        #[cfg(not(target_family = "wasm"))]
352        WgpuDevice::DiscreteGpu(num) => select_from_adapter_list(
353            num,
354            "No Discrete GPU device found",
355            &instance,
356            &device,
357            backend,
358        ),
359        #[cfg(not(target_family = "wasm"))]
360        WgpuDevice::IntegratedGpu(num) => select_from_adapter_list(
361            num,
362            "No Integrated GPU device found",
363            &instance,
364            &device,
365            backend,
366        ),
367        #[cfg(not(target_family = "wasm"))]
368        WgpuDevice::VirtualGpu(num) => select_from_adapter_list(
369            num,
370            "No Virtual GPU device found",
371            &instance,
372            &device,
373            backend,
374        ),
375        #[cfg(not(target_family = "wasm"))]
376        WgpuDevice::Cpu => {
377            select_from_adapter_list(0, "No CPU device found", &instance, &device, backend)
378        }
379        WgpuDevice::Existing(_) => {
380            unreachable!("Cannot select an adapter for an existing device.")
381        }
382        _ => instance
383            .request_adapter(&RequestAdapterOptions {
384                power_preference: wgpu::PowerPreference::HighPerformance,
385                force_fallback_adapter: false,
386                compatible_surface: None,
387            })
388            .await
389            .expect("No possible adapter available for backend. Falling back to first available."),
390    };
391
392    log::info!("Using adapter {:?}", adapter.get_info());
393
394    (instance, adapter)
395}
396
397#[cfg(not(target_family = "wasm"))]
398fn select_from_adapter_list(
399    num: usize,
400    error: &str,
401    instance: &wgpu::Instance,
402    device: &WgpuDevice,
403    backend: wgpu::Backend,
404) -> wgpu::Adapter {
405    let mut adapters_other = Vec::new();
406    let mut adapters = Vec::new();
407
408    instance
409        .enumerate_adapters(backend.into())
410        .into_iter()
411        .for_each(|adapter| {
412            let device_type = adapter.get_info().device_type;
413
414            if let wgpu::DeviceType::Other = device_type {
415                adapters_other.push(adapter);
416                return;
417            }
418
419            let is_same_type = match device {
420                WgpuDevice::DiscreteGpu(_) => device_type == wgpu::DeviceType::DiscreteGpu,
421                WgpuDevice::IntegratedGpu(_) => device_type == wgpu::DeviceType::IntegratedGpu,
422                WgpuDevice::VirtualGpu(_) => device_type == wgpu::DeviceType::VirtualGpu,
423                WgpuDevice::Cpu => device_type == wgpu::DeviceType::Cpu,
424                #[allow(deprecated)]
425                WgpuDevice::DefaultDevice | WgpuDevice::BestAvailable => true,
426                WgpuDevice::Existing(_) => {
427                    unreachable!("Cannot select an adapter for an existing device.")
428                }
429            };
430
431            if is_same_type {
432                adapters.push(adapter);
433            }
434        });
435
436    if adapters.len() <= num {
437        if adapters_other.len() <= num {
438            panic!(
439                "{}, adapters {:?}, other adapters {:?}",
440                error,
441                adapters
442                    .into_iter()
443                    .map(|adapter| adapter.get_info())
444                    .collect::<Vec<_>>(),
445                adapters_other
446                    .into_iter()
447                    .map(|adapter| adapter.get_info())
448                    .collect::<Vec<_>>(),
449            );
450        }
451
452        return adapters_other.remove(num);
453    }
454
455    adapters.remove(num)
456}
457
458fn get_device_override() -> Option<WgpuDevice> {
459    // If BestAvailable, check if we should instead construct as
460    // if a specific device was specified.
461    std::env::var("CUBECL_WGPU_DEFAULT_DEVICE")
462        .ok()
463        .and_then(|var| {
464            let override_device = if let Some(inner) = var.strip_prefix("DiscreteGpu(") {
465                inner
466                    .strip_suffix(")")
467                    .and_then(|s| s.parse().ok())
468                    .map(WgpuDevice::DiscreteGpu)
469            } else if let Some(inner) = var.strip_prefix("IntegratedGpu(") {
470                inner
471                    .strip_suffix(")")
472                    .and_then(|s| s.parse().ok())
473                    .map(WgpuDevice::IntegratedGpu)
474            } else if let Some(inner) = var.strip_prefix("VirtualGpu(") {
475                inner
476                    .strip_suffix(")")
477                    .and_then(|s| s.parse().ok())
478                    .map(WgpuDevice::VirtualGpu)
479            } else if var == "Cpu" {
480                Some(WgpuDevice::Cpu)
481            } else {
482                None
483            };
484
485            if override_device.is_none() {
486                log::warn!("Unknown CUBECL_WGPU_DEVICE override {var}");
487            }
488            override_device
489        })
490}