Skip to main content

cubecl_wgpu/
runtime.rs

1use crate::{
2    AutoCompiler, AutoGraphicsApi, GraphicsApi, WgpuDevice, backend, compute::WgpuServer,
3    contiguous_strides,
4};
5use cubecl_common::device::{Device, DeviceService};
6use cubecl_common::{future, profile::TimingMethod};
7use cubecl_core::device::{DeviceId, ServerUtilitiesHandle};
8use cubecl_core::server::ServerUtilities;
9use cubecl_core::zspace::{Shape, Strides};
10use cubecl_core::{Runtime, ir::TargetProperties};
11use cubecl_ir::{DeviceProperties, HardwareProperties, MemoryDeviceProperties};
12use cubecl_runtime::allocator::ContiguousMemoryLayoutPolicy;
13#[cfg(not(feature = "vulkan-validate"))]
14use cubecl_runtime::logging::ProfileLevel;
15pub use cubecl_runtime::memory_management::MemoryConfiguration;
16use cubecl_runtime::{client::ComputeClient, logging::ServerLogger};
17use wgpu::{InstanceFlags, RequestAdapterOptions};
18
19/// Runtime that uses the [wgpu] crate with the wgsl compiler. This is used in the Wgpu backend.
20/// For advanced configuration, use [`init_setup`] to pass in runtime options or to select a
21/// specific graphics API.
22#[derive(Debug, Clone)]
23pub struct WgpuRuntime;
24
25impl DeviceService for WgpuServer {
26    fn init(device_id: cubecl_common::device::DeviceId) -> Self {
27        let device = WgpuDevice::from_id(device_id);
28        let setup = future::block_on(create_setup_for_device(&device, AutoGraphicsApi::backend()));
29        create_server(setup, RuntimeOptions::default())
30    }
31
32    fn utilities(&self) -> ServerUtilitiesHandle {
33        self.utilities.clone() as ServerUtilitiesHandle
34    }
35}
36
37impl Runtime for WgpuRuntime {
38    type Compiler = AutoCompiler;
39    type Server = WgpuServer;
40    type Device = WgpuDevice;
41
42    fn client(device: &Self::Device) -> ComputeClient<Self> {
43        ComputeClient::load(device)
44    }
45
46    fn name(client: &ComputeClient<Self>) -> &'static str {
47        match client.info() {
48            wgpu::Backend::Vulkan => {
49                #[cfg(feature = "spirv")]
50                return "wgpu<spirv>";
51
52                #[cfg(not(feature = "spirv"))]
53                return "wgpu<wgsl>";
54            }
55            wgpu::Backend::Metal => {
56                #[cfg(feature = "msl")]
57                return "wgpu<msl>";
58
59                #[cfg(not(feature = "msl"))]
60                return "wgpu<wgsl>";
61            }
62            _ => "wgpu<wgsl>",
63        }
64    }
65
66    fn max_cube_count() -> (u32, u32, u32) {
67        let max_dim = u16::MAX as u32;
68        (max_dim, max_dim, max_dim)
69    }
70
71    fn can_read_tensor(shape: &Shape, strides: &Strides) -> bool {
72        if shape.is_empty() {
73            return true;
74        }
75
76        for (&expected, &stride) in contiguous_strides(shape).iter().zip(strides.iter()) {
77            if expected != stride {
78                return false;
79            }
80        }
81
82        true
83    }
84
85    fn target_properties() -> TargetProperties {
86        TargetProperties {
87            // Values are irrelevant, since no wgsl backends currently support manual mma
88            mma: Default::default(),
89        }
90    }
91
92    fn enumerate_devices(type_id: u16, info: &wgpu::Backend) -> Vec<DeviceId> {
93        #[cfg(target_family = "wasm")]
94        {
95            let _ = type_id;
96            let _ = info;
97            // WebGPU only supports a single device currently.
98            vec![DeviceId::new(0, 0)]
99        }
100
101        #[cfg(not(target_family = "wasm"))]
102        {
103            let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
104                backends: wgpu::Backends::all(),
105                ..wgpu::InstanceDescriptor::new_without_display_handle()
106            });
107
108            let adapters = enumerate_all_adapters(instance, *info);
109            adapters
110                .into_iter()
111                .filter(|adapter| {
112                    // Default doesn't filter device types.
113                    if type_id == 4 {
114                        return true;
115                    }
116
117                    let device_type = adapter.get_info().device_type;
118
119                    let adapter_type_id = match device_type {
120                        wgpu::DeviceType::Other => 4,
121                        wgpu::DeviceType::IntegratedGpu => 1,
122                        wgpu::DeviceType::DiscreteGpu => 0,
123                        wgpu::DeviceType::VirtualGpu => 2,
124                        wgpu::DeviceType::Cpu => 3,
125                    };
126
127                    adapter_type_id == type_id
128                })
129                .enumerate()
130                .map(|(index, adapter)| match adapter.get_info().device_type {
131                    wgpu::DeviceType::DiscreteGpu => DeviceId::new(0, index as u16),
132                    wgpu::DeviceType::IntegratedGpu => DeviceId::new(1, index as u16),
133                    wgpu::DeviceType::VirtualGpu => DeviceId::new(2, index as u16),
134                    wgpu::DeviceType::Cpu => DeviceId::new(3, 0),
135                    wgpu::DeviceType::Other => DeviceId::new(4, 0),
136                })
137                .collect()
138        }
139    }
140
141    fn enumerate_all_devices(info: &wgpu::Backend) -> Vec<DeviceId> {
142        #[cfg(target_family = "wasm")]
143        {
144            let _ = info;
145            // WebGPU only supports a single device currently.
146            vec![DeviceId::new(0, 0)]
147        }
148
149        #[cfg(not(target_family = "wasm"))]
150        {
151            let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
152                backends: wgpu::Backends::all(),
153                ..wgpu::InstanceDescriptor::new_without_display_handle()
154            });
155            let adapters = enumerate_all_adapters(instance, *info);
156            adapters
157                .into_iter()
158                .enumerate()
159                .map(|(index, adapter)| match adapter.get_info().device_type {
160                    wgpu::DeviceType::DiscreteGpu => DeviceId::new(0, index as u16),
161                    wgpu::DeviceType::IntegratedGpu => DeviceId::new(1, index as u16),
162                    wgpu::DeviceType::VirtualGpu => DeviceId::new(2, index as u16),
163                    wgpu::DeviceType::Cpu => DeviceId::new(3, 0),
164                    wgpu::DeviceType::Other => DeviceId::new(4, 0),
165                })
166                .collect()
167        }
168    }
169}
170
171#[cfg(not(target_family = "wasm"))]
172fn enumerate_all_adapters(instance: wgpu::Instance, backend: wgpu::Backend) -> Vec<wgpu::Adapter> {
173    // `enumerate_adapters` is now async & available on WebGPU
174    cubecl_common::future::block_on(instance.enumerate_adapters(backend.into()))
175}
176
177/// The values that control how a WGPU Runtime will perform its calculations.
178pub struct RuntimeOptions {
179    /// Control the amount of compute tasks to be aggregated into a single GPU command.
180    pub tasks_max: usize,
181    /// Configures the memory management.
182    pub memory_config: MemoryConfiguration,
183}
184
185impl Default for RuntimeOptions {
186    fn default() -> Self {
187        #[cfg(test)]
188        const DEFAULT_MAX_TASKS: usize = 32;
189        #[cfg(not(test))]
190        const DEFAULT_MAX_TASKS: usize = 32;
191
192        let tasks_max = match std::env::var("CUBECL_WGPU_MAX_TASKS") {
193            Ok(value) => value
194                .parse::<usize>()
195                .expect("CUBECL_WGPU_MAX_TASKS should be a positive integer."),
196            Err(_) => DEFAULT_MAX_TASKS,
197        };
198
199        Self {
200            tasks_max,
201            memory_config: MemoryConfiguration::default(),
202        }
203    }
204}
205
206/// A complete setup used to run wgpu.
207///
208/// These can either be created with [`init_setup`] or [`init_setup_async`].
209#[derive(Clone, Debug)]
210pub struct WgpuSetup {
211    /// The underlying wgpu instance.
212    pub instance: wgpu::Instance,
213    /// The selected 'adapter'. This corresponds to a physical device.
214    pub adapter: wgpu::Adapter,
215    /// The wgpu device Burn will use. Nb: There can only be one device per adapter.
216    pub device: wgpu::Device,
217    /// The queue Burn commands will be submitted to.
218    pub queue: wgpu::Queue,
219    /// The backend used by the setup.
220    pub backend: wgpu::Backend,
221}
222
223/// Create a [`WgpuDevice`] on an existing [`WgpuSetup`].
224/// Useful when you want to share a device between `CubeCL` and other wgpu-dependent libraries.
225///
226/// # Note
227///
228/// Please **do not** to call on the same [`setup`](WgpuSetup) more than once.
229///
230/// This function generates a new, globally unique ID for the device every time it is called,
231/// even if called on the same device multiple times.
232pub fn init_device(setup: WgpuSetup, options: RuntimeOptions) -> WgpuDevice {
233    use core::sync::atomic::{AtomicU32, Ordering};
234
235    static COUNTER: AtomicU32 = AtomicU32::new(0);
236
237    let device_id = COUNTER.fetch_add(1, Ordering::Relaxed);
238    if device_id == u32::MAX {
239        core::panic!("Memory ID overflowed");
240    }
241
242    let device_id = WgpuDevice::Existing(device_id);
243    let server = create_server(setup, options);
244    let _ = ComputeClient::<WgpuRuntime>::init(&device_id, server);
245    device_id
246}
247
248/// Like [`init_setup_async`], but synchronous.
249/// On wasm, it is necessary to use [`init_setup_async`] instead.
250pub fn init_setup<G: GraphicsApi>(device: &WgpuDevice, options: RuntimeOptions) -> WgpuSetup {
251    cfg_if::cfg_if! {
252        if #[cfg(target_family = "wasm")] {
253            let _ = (device, options);
254            panic!("Creating a wgpu setup synchronously is unsupported on wasm. Use init_async instead");
255        } else {
256            future::block_on(init_setup_async::<G>(device, options))
257        }
258    }
259}
260
261/// Initialize a client on the given device with the given options.
262/// This function is useful to configure the runtime options
263/// or to pick a different graphics API.
264pub async fn init_setup_async<G: GraphicsApi>(
265    device: &WgpuDevice,
266    options: RuntimeOptions,
267) -> WgpuSetup {
268    let setup = create_setup_for_device(device, G::backend()).await;
269    let return_setup = setup.clone();
270    let server = create_server(setup, options);
271    let _ = ComputeClient::<WgpuRuntime>::init(device, server);
272    return_setup
273}
274
275pub(crate) fn create_server(setup: WgpuSetup, options: RuntimeOptions) -> WgpuServer {
276    let limits = setup.device.limits();
277    let adapter_limits = setup.adapter.limits();
278    let mut adapter_info = setup.adapter.get_info();
279
280    // Workaround: WebGPU reports some "fake" subgroup info atm, as it's not really supported yet.
281    // However, some algorithms do rely on having this information eg. cubecl-reduce uses max subgroup size _even_ when
282    // subgroups aren't used. For now, just override with the maximum range of subgroups possible.
283    if adapter_info.subgroup_min_size == 0 && adapter_info.subgroup_max_size == 0 {
284        // There is in theory nothing limiting the size to go below 8 but in practice 8 is the minimum found anywhere.
285        adapter_info.subgroup_min_size = 8;
286        // This is a hard limit of GPU APIs (subgroup ballot returns 4 * 32 bits).
287        adapter_info.subgroup_max_size = 128;
288    }
289
290    let mem_props = MemoryDeviceProperties {
291        max_page_size: limits.max_storage_buffer_binding_size,
292        alignment: limits.min_uniform_buffer_offset_alignment as u64,
293    };
294    let max_count = adapter_limits.max_compute_workgroups_per_dimension;
295    let hardware_props = HardwareProperties {
296        load_width: 128,
297        // On Apple Silicon, the plane size is 32,
298        // though the minimum and maximum differ.
299        // https://github.com/gpuweb/gpuweb/issues/3950
300        #[cfg(apple_silicon)]
301        plane_size_min: 32,
302        #[cfg(not(apple_silicon))]
303        plane_size_min: adapter_info.subgroup_min_size,
304        #[cfg(apple_silicon)]
305        plane_size_max: 32,
306        #[cfg(not(apple_silicon))]
307        plane_size_max: adapter_info.subgroup_max_size,
308        // wgpu uses an additional buffer for variable-length buffers,
309        // so we have to use one buffer less on our side to make room for that wgpu internal buffer.
310        // See: https://github.com/gfx-rs/wgpu/blob/a9638c8e3ac09ce4f27ac171f8175671e30365fd/wgpu-hal/src/metal/device.rs#L799
311        max_bindings: limits
312            .max_storage_buffers_per_shader_stage
313            .saturating_sub(1),
314        max_shared_memory_size: limits.max_compute_workgroup_storage_size as usize,
315        max_cube_count: (max_count, max_count, max_count),
316        max_units_per_cube: adapter_limits.max_compute_invocations_per_workgroup,
317        max_cube_dim: (
318            adapter_limits.max_compute_workgroup_size_x,
319            adapter_limits.max_compute_workgroup_size_y,
320            adapter_limits.max_compute_workgroup_size_z,
321        ),
322        num_streaming_multiprocessors: None,
323        num_tensor_cores: None,
324        min_tensor_cores_dim: None,
325        num_cpu_cores: None, // TODO: Check if device is CPU.
326        max_vector_size: 4,
327    };
328
329    let mut compilation_options = Default::default();
330
331    let features = setup.adapter.features();
332
333    let time_measurement = if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
334        TimingMethod::Device
335    } else {
336        TimingMethod::System
337    };
338
339    let mut device_props = DeviceProperties::new(
340        Default::default(),
341        mem_props,
342        hardware_props,
343        time_measurement,
344    );
345
346    #[cfg(not(all(target_os = "macos", feature = "msl")))]
347    {
348        if features.contains(wgpu::Features::SUBGROUP)
349            && setup.adapter.get_info().device_type != wgpu::DeviceType::Cpu
350        {
351            use cubecl_ir::features::Plane;
352
353            device_props.features.plane.insert(Plane::Ops);
354        }
355    }
356
357    #[cfg(any(feature = "spirv", feature = "msl"))]
358    device_props
359        .features
360        .plane
361        .insert(cubecl_ir::features::Plane::NonUniformControlFlow);
362
363    backend::register_features(
364        &setup.adapter,
365        &mut device_props,
366        &mut compilation_options,
367        &options.memory_config,
368    );
369
370    let logger = alloc::sync::Arc::new(ServerLogger::default());
371
372    let allocator = ContiguousMemoryLayoutPolicy::new(device_props.memory.alignment as usize);
373    WgpuServer::new(
374        device_props.memory.clone(),
375        options.memory_config,
376        compilation_options,
377        setup.device.clone(),
378        setup.queue,
379        options.tasks_max,
380        setup.backend,
381        time_measurement,
382        ServerUtilities::new(device_props, logger, setup.backend, allocator),
383    )
384}
385
386/// Select the wgpu device and queue based on the provided [device](WgpuDevice) and
387/// [backend](wgpu::Backend).
388pub(crate) async fn create_setup_for_device(
389    device: &WgpuDevice,
390    backend: wgpu::Backend,
391) -> WgpuSetup {
392    let (instance, adapter) = request_adapter(device, backend).await;
393    let (device, queue) = backend::request_device(&adapter).await;
394
395    log::info!(
396        "Created wgpu compute server on device {:?} => {:?}",
397        device,
398        adapter.get_info()
399    );
400
401    WgpuSetup {
402        instance,
403        adapter,
404        device,
405        queue,
406        backend,
407    }
408}
409
410async fn request_adapter(
411    device: &WgpuDevice,
412    backend: wgpu::Backend,
413) -> (wgpu::Instance, wgpu::Adapter) {
414    #[cfg(not(feature = "vulkan-validate"))]
415    let instance_flags = {
416        let debug = ServerLogger::default();
417        match (debug.profile_level(), debug.compilation_activated()) {
418            (Some(ProfileLevel::Full), _) => InstanceFlags::advanced_debugging(),
419            (_, true) => InstanceFlags::debugging(),
420            (_, false) => InstanceFlags::default(),
421        }
422    };
423    #[cfg(feature = "vulkan-validate")]
424    let instance_flags = InstanceFlags::advanced_debugging();
425    log::debug!("{instance_flags:?}");
426    let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
427        backends: backend.into(),
428        flags: instance_flags,
429        ..wgpu::InstanceDescriptor::new_without_display_handle()
430    });
431
432    #[allow(deprecated)]
433    let override_device = if matches!(
434        device,
435        WgpuDevice::DefaultDevice | WgpuDevice::BestAvailable
436    ) {
437        get_device_override()
438    } else {
439        None
440    };
441
442    let device = override_device.unwrap_or_else(|| device.clone());
443
444    let adapter = match device {
445        #[cfg(not(target_family = "wasm"))]
446        WgpuDevice::DiscreteGpu(num) => {
447            select_from_adapter_list(
448                num,
449                "No Discrete GPU device found",
450                &instance,
451                &device,
452                backend,
453            )
454            .await
455        }
456        #[cfg(not(target_family = "wasm"))]
457        WgpuDevice::IntegratedGpu(num) => {
458            select_from_adapter_list(
459                num,
460                "No Integrated GPU device found",
461                &instance,
462                &device,
463                backend,
464            )
465            .await
466        }
467        #[cfg(not(target_family = "wasm"))]
468        WgpuDevice::VirtualGpu(num) => {
469            select_from_adapter_list(
470                num,
471                "No Virtual GPU device found",
472                &instance,
473                &device,
474                backend,
475            )
476            .await
477        }
478        #[cfg(not(target_family = "wasm"))]
479        WgpuDevice::Cpu => {
480            select_from_adapter_list(0, "No CPU device found", &instance, &device, backend).await
481        }
482        #[cfg(target_family = "wasm")]
483        WgpuDevice::IntegratedGpu(_) => {
484            request_adapter_with_preference(&instance, wgpu::PowerPreference::LowPower).await
485        }
486        WgpuDevice::Existing(_) => {
487            unreachable!("Cannot select an adapter for an existing device.")
488        }
489        _ => {
490            request_adapter_with_preference(&instance, wgpu::PowerPreference::HighPerformance).await
491        }
492    };
493
494    log::info!("Using adapter {:?}", adapter.get_info());
495
496    (instance, adapter)
497}
498
499async fn request_adapter_with_preference(
500    instance: &wgpu::Instance,
501    power_preference: wgpu::PowerPreference,
502) -> wgpu::Adapter {
503    instance
504        .request_adapter(&RequestAdapterOptions {
505            power_preference,
506            force_fallback_adapter: false,
507            compatible_surface: None,
508        })
509        .await
510        .expect("No possible adapter available for backend. Falling back to first available.")
511}
512
513#[cfg(not(target_family = "wasm"))]
514async fn select_from_adapter_list(
515    num: usize,
516    error: &str,
517    instance: &wgpu::Instance,
518    device: &WgpuDevice,
519    backend: wgpu::Backend,
520) -> wgpu::Adapter {
521    let mut adapters_other = Vec::new();
522    let mut adapters = Vec::new();
523
524    instance
525        .enumerate_adapters(backend.into())
526        .await
527        .into_iter()
528        .for_each(|adapter| {
529            let device_type = adapter.get_info().device_type;
530
531            if let wgpu::DeviceType::Other = device_type {
532                adapters_other.push(adapter);
533                return;
534            }
535
536            let is_same_type = match device {
537                WgpuDevice::DiscreteGpu(_) => device_type == wgpu::DeviceType::DiscreteGpu,
538                WgpuDevice::IntegratedGpu(_) => device_type == wgpu::DeviceType::IntegratedGpu,
539                WgpuDevice::VirtualGpu(_) => device_type == wgpu::DeviceType::VirtualGpu,
540                WgpuDevice::Cpu => device_type == wgpu::DeviceType::Cpu,
541                #[allow(deprecated)]
542                WgpuDevice::DefaultDevice | WgpuDevice::BestAvailable => true,
543                WgpuDevice::Existing(_) => {
544                    unreachable!("Cannot select an adapter for an existing device.")
545                }
546            };
547
548            if is_same_type {
549                adapters.push(adapter);
550            }
551        });
552
553    if adapters.len() <= num {
554        if adapters_other.len() <= num {
555            panic!(
556                "{}, adapters {:?}, other adapters {:?}",
557                error,
558                adapters
559                    .into_iter()
560                    .map(|adapter| adapter.get_info())
561                    .collect::<Vec<_>>(),
562                adapters_other
563                    .into_iter()
564                    .map(|adapter| adapter.get_info())
565                    .collect::<Vec<_>>(),
566            );
567        }
568
569        return adapters_other.remove(num);
570    }
571
572    adapters.remove(num)
573}
574
575fn get_device_override() -> Option<WgpuDevice> {
576    // If BestAvailable, check if we should instead construct as
577    // if a specific device was specified.
578    std::env::var("CUBECL_WGPU_DEFAULT_DEVICE")
579        .ok()
580        .and_then(|var| {
581            let override_device = if let Some(inner) = var.strip_prefix("DiscreteGpu(") {
582                inner
583                    .strip_suffix(")")
584                    .and_then(|s| s.parse().ok())
585                    .map(WgpuDevice::DiscreteGpu)
586            } else if let Some(inner) = var.strip_prefix("IntegratedGpu(") {
587                inner
588                    .strip_suffix(")")
589                    .and_then(|s| s.parse().ok())
590                    .map(WgpuDevice::IntegratedGpu)
591            } else if let Some(inner) = var.strip_prefix("VirtualGpu(") {
592                inner
593                    .strip_suffix(")")
594                    .and_then(|s| s.parse().ok())
595                    .map(WgpuDevice::VirtualGpu)
596            } else if var == "Cpu" {
597                Some(WgpuDevice::Cpu)
598            } else {
599                None
600            };
601
602            if override_device.is_none() {
603                log::warn!("Unknown CUBECL_WGPU_DEVICE override {var}");
604            }
605            override_device
606        })
607}