Skip to main content

cubecl_wgpu/
runtime.rs

1use crate::{
2    AutoCompiler, AutoGraphicsApi, GraphicsApi, WgpuDevice, backend, compute::WgpuServer,
3    contiguous_strides,
4};
5use cubecl_common::device::{Device, DeviceState};
6use cubecl_common::{future, profile::TimingMethod};
7use cubecl_core::server::ServerUtilities;
8use cubecl_core::{Runtime, ir::TargetProperties};
9use cubecl_ir::{DeviceProperties, HardwareProperties, MemoryDeviceProperties};
10pub use cubecl_runtime::memory_management::MemoryConfiguration;
11use cubecl_runtime::{
12    client::ComputeClient,
13    logging::{ProfileLevel, ServerLogger},
14};
15use wgpu::{InstanceFlags, RequestAdapterOptions};
16
17/// Runtime that uses the [wgpu] crate with the wgsl compiler. This is used in the Wgpu backend.
18/// For advanced configuration, use [`init_setup`] to pass in runtime options or to select a
19/// specific graphics API.
20#[derive(Debug)]
21pub struct WgpuRuntime;
22
23impl DeviceState for WgpuServer {
24    fn init(device_id: cubecl_common::device::DeviceId) -> Self {
25        let device = WgpuDevice::from_id(device_id);
26        let setup = future::block_on(create_setup_for_device(&device, AutoGraphicsApi::backend()));
27        create_server(setup, RuntimeOptions::default())
28    }
29}
30
31impl Runtime for WgpuRuntime {
32    type Compiler = AutoCompiler;
33    type Server = WgpuServer;
34    type Device = WgpuDevice;
35
36    fn client(device: &Self::Device) -> ComputeClient<Self> {
37        ComputeClient::load(device)
38    }
39
40    fn name(client: &ComputeClient<Self>) -> &'static str {
41        match client.info() {
42            wgpu::Backend::Vulkan => {
43                #[cfg(feature = "spirv")]
44                return "wgpu<spirv>";
45
46                #[cfg(not(feature = "spirv"))]
47                return "wgpu<wgsl>";
48            }
49            wgpu::Backend::Metal => {
50                #[cfg(feature = "msl")]
51                return "wgpu<msl>";
52
53                #[cfg(not(feature = "msl"))]
54                return "wgpu<wgsl>";
55            }
56            _ => "wgpu<wgsl>",
57        }
58    }
59
60    fn max_cube_count() -> (u32, u32, u32) {
61        let max_dim = u16::MAX as u32;
62        (max_dim, max_dim, max_dim)
63    }
64
65    fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool {
66        if shape.is_empty() {
67            return true;
68        }
69
70        for (&expected, &stride) in contiguous_strides(shape).iter().zip(strides) {
71            if expected != stride {
72                return false;
73            }
74        }
75
76        true
77    }
78
79    fn target_properties() -> TargetProperties {
80        TargetProperties {
81            // Values are irrelevant, since no wgsl backends currently support manual mma
82            mma: Default::default(),
83        }
84    }
85}
86
87/// The values that control how a WGPU Runtime will perform its calculations.
88pub struct RuntimeOptions {
89    /// Control the amount of compute tasks to be aggregated into a single GPU command.
90    pub tasks_max: usize,
91    /// Configures the memory management.
92    pub memory_config: MemoryConfiguration,
93}
94
95impl Default for RuntimeOptions {
96    fn default() -> Self {
97        #[cfg(test)]
98        const DEFAULT_MAX_TASKS: usize = 1;
99        #[cfg(not(test))]
100        const DEFAULT_MAX_TASKS: usize = 32;
101
102        let tasks_max = match std::env::var("CUBECL_WGPU_MAX_TASKS") {
103            Ok(value) => value
104                .parse::<usize>()
105                .expect("CUBECL_WGPU_MAX_TASKS should be a positive integer."),
106            Err(_) => DEFAULT_MAX_TASKS,
107        };
108
109        Self {
110            tasks_max,
111            memory_config: MemoryConfiguration::default(),
112        }
113    }
114}
115
116/// A complete setup used to run wgpu.
117///
118/// These can either be created with [`init_setup`] or [`init_setup_async`].
119#[derive(Clone, Debug)]
120pub struct WgpuSetup {
121    /// The underlying wgpu instance.
122    pub instance: wgpu::Instance,
123    /// The selected 'adapter'. This corresponds to a physical device.
124    pub adapter: wgpu::Adapter,
125    /// The wgpu device Burn will use. Nb: There can only be one device per adapter.
126    pub device: wgpu::Device,
127    /// The queue Burn commands will be submitted to.
128    pub queue: wgpu::Queue,
129    /// The backend used by the setup.
130    pub backend: wgpu::Backend,
131}
132
133/// Create a [`WgpuDevice`] on an existing [`WgpuSetup`].
134/// Useful when you want to share a device between `CubeCL` and other wgpu-dependent libraries.
135///
136/// # Note
137///
138/// Please **do not** to call on the same [`setup`](WgpuSetup) more than once.
139///
140/// This function generates a new, globally unique ID for the device every time it is called,
141/// even if called on the same device multiple times.
142pub fn init_device(setup: WgpuSetup, options: RuntimeOptions) -> WgpuDevice {
143    use core::sync::atomic::{AtomicU32, Ordering};
144
145    static COUNTER: AtomicU32 = AtomicU32::new(0);
146
147    let device_id = COUNTER.fetch_add(1, Ordering::Relaxed);
148    if device_id == u32::MAX {
149        core::panic!("Memory ID overflowed");
150    }
151
152    let device_id = WgpuDevice::Existing(device_id);
153    let server = create_server(setup, options);
154    let _ = ComputeClient::<WgpuRuntime>::init(&device_id, server);
155    device_id
156}
157
158/// Like [`init_setup_async`], but synchronous.
159/// On wasm, it is necessary to use [`init_setup_async`] instead.
160pub fn init_setup<G: GraphicsApi>(device: &WgpuDevice, options: RuntimeOptions) -> WgpuSetup {
161    cfg_if::cfg_if! {
162        if #[cfg(target_family = "wasm")] {
163            let _ = (device, options);
164            panic!("Creating a wgpu setup synchronously is unsupported on wasm. Use init_async instead");
165        } else {
166            future::block_on(init_setup_async::<G>(device, options))
167        }
168    }
169}
170
171/// Initialize a client on the given device with the given options.
172/// This function is useful to configure the runtime options
173/// or to pick a different graphics API.
174pub async fn init_setup_async<G: GraphicsApi>(
175    device: &WgpuDevice,
176    options: RuntimeOptions,
177) -> WgpuSetup {
178    let setup = create_setup_for_device(device, G::backend()).await;
179    let return_setup = setup.clone();
180    let server = create_server(setup, options);
181    let _ = ComputeClient::<WgpuRuntime>::init(device, server);
182    return_setup
183}
184
185pub(crate) fn create_server(setup: WgpuSetup, options: RuntimeOptions) -> WgpuServer {
186    let limits = setup.device.limits();
187    let adapter_limits = setup.adapter.limits();
188    let mut adapter_info = setup.adapter.get_info();
189
190    // Workaround: WebGPU reports some "fake" subgroup info atm, as it's not really supported yet.
191    // However, some algorithms do rely on having this information eg. cubecl-reduce uses max subgroup size _even_ when
192    // subgroups aren't used. For now, just override with the maximum range of subgroups possible.
193    if adapter_info.subgroup_min_size == 0 && adapter_info.subgroup_max_size == 0 {
194        // There is in theory nothing limiting the size to go below 8 but in practice 8 is the minimum found anywhere.
195        adapter_info.subgroup_min_size = 8;
196        // This is a hard limit of GPU APIs (subgroup ballot returns 4 * 32 bits).
197        adapter_info.subgroup_max_size = 128;
198    }
199
200    let mem_props = MemoryDeviceProperties {
201        max_page_size: limits.max_storage_buffer_binding_size as u64,
202        alignment: limits.min_storage_buffer_offset_alignment as u64,
203    };
204    let max_count = adapter_limits.max_compute_workgroups_per_dimension;
205    let hardware_props = HardwareProperties {
206        load_width: 128,
207        // On Apple Silicon, the plane size is 32,
208        // though the minimum and maximum differ.
209        // https://github.com/gpuweb/gpuweb/issues/3950
210        #[cfg(apple_silicon)]
211        plane_size_min: 32,
212        #[cfg(not(apple_silicon))]
213        plane_size_min: adapter_info.subgroup_min_size,
214        #[cfg(apple_silicon)]
215        plane_size_max: 32,
216        #[cfg(not(apple_silicon))]
217        plane_size_max: adapter_info.subgroup_max_size,
218        // wgpu uses an additional buffer for variable-length buffers,
219        // so we have to use one buffer less on our side to make room for that wgpu internal buffer.
220        // See: https://github.com/gfx-rs/wgpu/blob/a9638c8e3ac09ce4f27ac171f8175671e30365fd/wgpu-hal/src/metal/device.rs#L799
221        max_bindings: limits
222            .max_storage_buffers_per_shader_stage
223            .saturating_sub(1),
224        max_shared_memory_size: limits.max_compute_workgroup_storage_size as usize,
225        max_cube_count: (max_count, max_count, max_count),
226        max_units_per_cube: adapter_limits.max_compute_invocations_per_workgroup,
227        max_cube_dim: (
228            adapter_limits.max_compute_workgroup_size_x,
229            adapter_limits.max_compute_workgroup_size_y,
230            adapter_limits.max_compute_workgroup_size_z,
231        ),
232        num_streaming_multiprocessors: None,
233        num_tensor_cores: None,
234        min_tensor_cores_dim: None,
235        num_cpu_cores: None, // TODO: Check if device is CPU.
236        max_line_size: 4,
237    };
238
239    let mut compilation_options = Default::default();
240
241    let features = setup.adapter.features();
242
243    let time_measurement = if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
244        TimingMethod::Device
245    } else {
246        TimingMethod::System
247    };
248
249    let mut device_props = DeviceProperties::new(
250        Default::default(),
251        mem_props.clone(),
252        hardware_props,
253        time_measurement,
254    );
255
256    #[cfg(not(all(target_os = "macos", feature = "msl")))]
257    {
258        if features.contains(wgpu::Features::SUBGROUP)
259            && setup.adapter.get_info().device_type != wgpu::DeviceType::Cpu
260        {
261            use cubecl_ir::features::Plane;
262
263            device_props.features.plane.insert(Plane::Ops);
264        }
265    }
266
267    #[cfg(any(feature = "spirv", feature = "msl"))]
268    device_props
269        .features
270        .plane
271        .insert(cubecl_ir::features::Plane::NonUniformControlFlow);
272
273    backend::register_features(&setup.adapter, &mut device_props, &mut compilation_options);
274
275    let logger = alloc::sync::Arc::new(ServerLogger::default());
276
277    WgpuServer::new(
278        mem_props,
279        options.memory_config,
280        compilation_options,
281        setup.device.clone(),
282        setup.queue,
283        options.tasks_max,
284        setup.backend,
285        time_measurement,
286        ServerUtilities::new(device_props, logger, setup.backend),
287    )
288}
289
290/// Select the wgpu device and queue based on the provided [device](WgpuDevice) and
291/// [backend](wgpu::Backend).
292pub(crate) async fn create_setup_for_device(
293    device: &WgpuDevice,
294    backend: wgpu::Backend,
295) -> WgpuSetup {
296    let (instance, adapter) = request_adapter(device, backend).await;
297    let (device, queue) = backend::request_device(&adapter).await;
298
299    log::info!(
300        "Created wgpu compute server on device {:?} => {:?}",
301        device,
302        adapter.get_info()
303    );
304
305    WgpuSetup {
306        instance,
307        adapter,
308        device,
309        queue,
310        backend,
311    }
312}
313
314async fn request_adapter(
315    device: &WgpuDevice,
316    backend: wgpu::Backend,
317) -> (wgpu::Instance, wgpu::Adapter) {
318    let debug = ServerLogger::default();
319    let instance_flags = match (debug.profile_level(), debug.compilation_activated()) {
320        (Some(ProfileLevel::Full), _) => InstanceFlags::advanced_debugging(),
321        (_, true) => InstanceFlags::debugging(),
322        (_, false) => InstanceFlags::default(),
323    };
324    log::debug!("{instance_flags:?}");
325    let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
326        backends: backend.into(),
327        flags: instance_flags,
328        ..Default::default()
329    });
330
331    #[allow(deprecated)]
332    let override_device = if matches!(
333        device,
334        WgpuDevice::DefaultDevice | WgpuDevice::BestAvailable
335    ) {
336        get_device_override()
337    } else {
338        None
339    };
340
341    let device = override_device.unwrap_or_else(|| device.clone());
342
343    let adapter = match device {
344        #[cfg(not(target_family = "wasm"))]
345        WgpuDevice::DiscreteGpu(num) => {
346            select_from_adapter_list(
347                num,
348                "No Discrete GPU device found",
349                &instance,
350                &device,
351                backend,
352            )
353            .await
354        }
355        #[cfg(not(target_family = "wasm"))]
356        WgpuDevice::IntegratedGpu(num) => {
357            select_from_adapter_list(
358                num,
359                "No Integrated GPU device found",
360                &instance,
361                &device,
362                backend,
363            )
364            .await
365        }
366        #[cfg(not(target_family = "wasm"))]
367        WgpuDevice::VirtualGpu(num) => {
368            select_from_adapter_list(
369                num,
370                "No Virtual GPU device found",
371                &instance,
372                &device,
373                backend,
374            )
375            .await
376        }
377        #[cfg(not(target_family = "wasm"))]
378        WgpuDevice::Cpu => {
379            select_from_adapter_list(0, "No CPU device found", &instance, &device, backend).await
380        }
381        #[cfg(target_family = "wasm")]
382        WgpuDevice::IntegratedGpu(_) => {
383            request_adapter_with_preference(&instance, wgpu::PowerPreference::LowPower).await
384        }
385        WgpuDevice::Existing(_) => {
386            unreachable!("Cannot select an adapter for an existing device.")
387        }
388        _ => {
389            request_adapter_with_preference(&instance, wgpu::PowerPreference::HighPerformance).await
390        }
391    };
392
393    log::info!("Using adapter {:?}", adapter.get_info());
394
395    (instance, adapter)
396}
397
398async fn request_adapter_with_preference(
399    instance: &wgpu::Instance,
400    power_preference: wgpu::PowerPreference,
401) -> wgpu::Adapter {
402    instance
403        .request_adapter(&RequestAdapterOptions {
404            power_preference,
405            force_fallback_adapter: false,
406            compatible_surface: None,
407        })
408        .await
409        .expect("No possible adapter available for backend. Falling back to first available.")
410}
411
412#[cfg(not(target_family = "wasm"))]
413async fn select_from_adapter_list(
414    num: usize,
415    error: &str,
416    instance: &wgpu::Instance,
417    device: &WgpuDevice,
418    backend: wgpu::Backend,
419) -> wgpu::Adapter {
420    let mut adapters_other = Vec::new();
421    let mut adapters = Vec::new();
422
423    instance
424        .enumerate_adapters(backend.into())
425        .await
426        .into_iter()
427        .for_each(|adapter| {
428            let device_type = adapter.get_info().device_type;
429
430            if let wgpu::DeviceType::Other = device_type {
431                adapters_other.push(adapter);
432                return;
433            }
434
435            let is_same_type = match device {
436                WgpuDevice::DiscreteGpu(_) => device_type == wgpu::DeviceType::DiscreteGpu,
437                WgpuDevice::IntegratedGpu(_) => device_type == wgpu::DeviceType::IntegratedGpu,
438                WgpuDevice::VirtualGpu(_) => device_type == wgpu::DeviceType::VirtualGpu,
439                WgpuDevice::Cpu => device_type == wgpu::DeviceType::Cpu,
440                #[allow(deprecated)]
441                WgpuDevice::DefaultDevice | WgpuDevice::BestAvailable => true,
442                WgpuDevice::Existing(_) => {
443                    unreachable!("Cannot select an adapter for an existing device.")
444                }
445            };
446
447            if is_same_type {
448                adapters.push(adapter);
449            }
450        });
451
452    if adapters.len() <= num {
453        if adapters_other.len() <= num {
454            panic!(
455                "{}, adapters {:?}, other adapters {:?}",
456                error,
457                adapters
458                    .into_iter()
459                    .map(|adapter| adapter.get_info())
460                    .collect::<Vec<_>>(),
461                adapters_other
462                    .into_iter()
463                    .map(|adapter| adapter.get_info())
464                    .collect::<Vec<_>>(),
465            );
466        }
467
468        return adapters_other.remove(num);
469    }
470
471    adapters.remove(num)
472}
473
474fn get_device_override() -> Option<WgpuDevice> {
475    // If BestAvailable, check if we should instead construct as
476    // if a specific device was specified.
477    std::env::var("CUBECL_WGPU_DEFAULT_DEVICE")
478        .ok()
479        .and_then(|var| {
480            let override_device = if let Some(inner) = var.strip_prefix("DiscreteGpu(") {
481                inner
482                    .strip_suffix(")")
483                    .and_then(|s| s.parse().ok())
484                    .map(WgpuDevice::DiscreteGpu)
485            } else if let Some(inner) = var.strip_prefix("IntegratedGpu(") {
486                inner
487                    .strip_suffix(")")
488                    .and_then(|s| s.parse().ok())
489                    .map(WgpuDevice::IntegratedGpu)
490            } else if let Some(inner) = var.strip_prefix("VirtualGpu(") {
491                inner
492                    .strip_suffix(")")
493                    .and_then(|s| s.parse().ok())
494                    .map(WgpuDevice::VirtualGpu)
495            } else if var == "Cpu" {
496                Some(WgpuDevice::Cpu)
497            } else {
498                None
499            };
500
501            if override_device.is_none() {
502                log::warn!("Unknown CUBECL_WGPU_DEVICE override {var}");
503            }
504            override_device
505        })
506}