cubecl_wgpu/
runtime.rs

1use crate::{
2    AutoCompiler, AutoGraphicsApi, GraphicsApi, WgpuDevice, backend, compute::WgpuServer,
3    contiguous_strides,
4};
5use cubecl_common::{future, profile::TimingMethod};
6
7#[cfg(not(all(target_os = "macos", feature = "msl")))]
8use cubecl_core::{
9    AtomicFeature, Feature,
10    ir::{Elem, FloatKind},
11};
12use cubecl_core::{CubeCount, CubeDim, Runtime};
13pub use cubecl_runtime::memory_management::MemoryConfiguration;
14use cubecl_runtime::memory_management::MemoryDeviceProperties;
15use cubecl_runtime::{
16    ComputeRuntime,
17    channel::MutexComputeChannel,
18    client::ComputeClient,
19    id::DeviceId,
20    logging::{ProfileLevel, ServerLogger},
21};
22use cubecl_runtime::{DeviceProperties, memory_management::HardwareProperties};
23use wgpu::{InstanceFlags, RequestAdapterOptions};
24
25/// Runtime that uses the [wgpu] crate with the wgsl compiler. This is used in the Wgpu backend.
26/// For advanced configuration, use [`init_setup`] to pass in runtime options or to select a
27/// specific graphics API.
28#[derive(Debug)]
29pub struct WgpuRuntime;
30
31type Server = WgpuServer;
32
33/// The compute instance is shared across all [wgpu runtimes](WgpuRuntime).
34static RUNTIME: ComputeRuntime<WgpuDevice, Server, MutexComputeChannel<Server>> =
35    ComputeRuntime::new();
36
37impl Runtime for WgpuRuntime {
38    type Compiler = AutoCompiler;
39    type Server = WgpuServer;
40
41    type Channel = MutexComputeChannel<WgpuServer>;
42    type Device = WgpuDevice;
43
44    fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel> {
45        RUNTIME.client(device, move || {
46            let setup =
47                future::block_on(create_setup_for_device(device, AutoGraphicsApi::backend()));
48            create_client_on_setup(setup, RuntimeOptions::default())
49        })
50    }
51
52    fn name(client: &ComputeClient<Self::Server, Self::Channel>) -> &'static str {
53        match client.info() {
54            wgpu::Backend::Vulkan => {
55                #[cfg(feature = "spirv")]
56                return "wgpu<spirv>";
57
58                #[cfg(not(feature = "spirv"))]
59                return "wgpu<wgsl>";
60            }
61            wgpu::Backend::Metal => {
62                #[cfg(feature = "msl")]
63                return "wgpu<msl>";
64
65                #[cfg(not(feature = "msl"))]
66                return "wgpu<wgsl>";
67            }
68            _ => "wgpu<wgsl>",
69        }
70    }
71
72    fn supported_line_sizes() -> &'static [u8] {
73        #[cfg(feature = "msl")]
74        {
75            &[8, 4, 2, 1]
76        }
77        #[cfg(not(feature = "msl"))]
78        {
79            &[4, 2, 1]
80        }
81    }
82
83    fn max_cube_count() -> (u32, u32, u32) {
84        let max_dim = u16::MAX as u32;
85        (max_dim, max_dim, max_dim)
86    }
87
88    fn device_id(device: &Self::Device) -> DeviceId {
89        #[allow(deprecated)]
90        match device {
91            WgpuDevice::DiscreteGpu(index) => DeviceId::new(0, *index as u32),
92            WgpuDevice::IntegratedGpu(index) => DeviceId::new(1, *index as u32),
93            WgpuDevice::VirtualGpu(index) => DeviceId::new(2, *index as u32),
94            WgpuDevice::Cpu => DeviceId::new(3, 0),
95            WgpuDevice::BestAvailable | WgpuDevice::DefaultDevice => DeviceId::new(4, 0),
96            WgpuDevice::Existing(id) => DeviceId::new(5, *id),
97        }
98    }
99
100    fn can_read_tensor(shape: &[usize], strides: &[usize]) -> bool {
101        if shape.is_empty() {
102            return true;
103        }
104
105        for (expected, &stride) in contiguous_strides(shape).into_iter().zip(strides) {
106            if expected != stride {
107                return false;
108            }
109        }
110
111        true
112    }
113}
114
115/// The values that control how a WGPU Runtime will perform its calculations.
116pub struct RuntimeOptions {
117    /// Control the amount of compute tasks to be aggregated into a single GPU command.
118    pub tasks_max: usize,
119    /// Configures the memory management.
120    pub memory_config: MemoryConfiguration,
121}
122
123impl Default for RuntimeOptions {
124    fn default() -> Self {
125        #[cfg(test)]
126        const DEFAULT_MAX_TASKS: usize = 1;
127        #[cfg(not(test))]
128        const DEFAULT_MAX_TASKS: usize = 32;
129
130        let tasks_max = match std::env::var("CUBECL_WGPU_MAX_TASKS") {
131            Ok(value) => value
132                .parse::<usize>()
133                .expect("CUBECL_WGPU_MAX_TASKS should be a positive integer."),
134            Err(_) => DEFAULT_MAX_TASKS,
135        };
136
137        Self {
138            tasks_max,
139            memory_config: MemoryConfiguration::default(),
140        }
141    }
142}
143
144/// A complete setup used to run wgpu.
145///
146/// These can either be created with [`init_setup`] or [`init_setup_async`].
147#[derive(Clone, Debug)]
148pub struct WgpuSetup {
149    /// The underlying wgpu instance.
150    pub instance: wgpu::Instance,
151    /// The selected 'adapter'. This corresponds to a physical device.
152    pub adapter: wgpu::Adapter,
153    /// The wgpu device Burn will use. Nb: There can only be one device per adapter.
154    pub device: wgpu::Device,
155    /// The queue Burn commands will be submitted to.
156    pub queue: wgpu::Queue,
157    /// The backend used by the setup.
158    pub backend: wgpu::Backend,
159}
160
161/// Create a [`WgpuDevice`] on an existing [`WgpuSetup`].
162/// Useful when you want to share a device between CubeCL and other wgpu-dependent libraries.
163///
164/// # Note
165///
166/// Please **do not** to call on the same [`setup`](WgpuSetup) more than once.
167///
168/// This function generates a new, globally unique ID for the device every time it is called,
169/// even if called on the same device multiple times.
170pub fn init_device(setup: WgpuSetup, options: RuntimeOptions) -> WgpuDevice {
171    use core::sync::atomic::{AtomicU32, Ordering};
172
173    static COUNTER: AtomicU32 = AtomicU32::new(0);
174
175    let device_id = COUNTER.fetch_add(1, Ordering::Relaxed);
176    if device_id == u32::MAX {
177        core::panic!("Memory ID overflowed");
178    }
179
180    let device_id = WgpuDevice::Existing(device_id);
181    let client = create_client_on_setup(setup, options);
182    RUNTIME.register(&device_id, client);
183    device_id
184}
185
186/// Like [`init_setup_async`], but synchronous.
187/// On wasm, it is necessary to use [`init_setup_async`] instead.
188pub fn init_setup<G: GraphicsApi>(device: &WgpuDevice, options: RuntimeOptions) -> WgpuSetup {
189    cfg_if::cfg_if! {
190        if #[cfg(target_family = "wasm")] {
191            let _ = (device, options);
192            panic!("Creating a wgpu setup synchronously is unsupported on wasm. Use init_async instead");
193        } else {
194            future::block_on(init_setup_async::<G>(device, options))
195        }
196    }
197}
198
199/// Initialize a client on the given device with the given options.
200/// This function is useful to configure the runtime options
201/// or to pick a different graphics API.
202pub async fn init_setup_async<G: GraphicsApi>(
203    device: &WgpuDevice,
204    options: RuntimeOptions,
205) -> WgpuSetup {
206    let setup = create_setup_for_device(device, G::backend()).await;
207    let return_setup = setup.clone();
208    let client = create_client_on_setup(setup, options);
209    RUNTIME.register(device, client);
210    return_setup
211}
212
213pub(crate) fn create_client_on_setup(
214    setup: WgpuSetup,
215    options: RuntimeOptions,
216) -> ComputeClient<WgpuServer, MutexComputeChannel<WgpuServer>> {
217    let limits = setup.device.limits();
218    let adapter_limits = setup.adapter.limits();
219
220    let mem_props = MemoryDeviceProperties {
221        max_page_size: limits.max_storage_buffer_binding_size as u64,
222        alignment: limits.min_storage_buffer_offset_alignment as u64,
223    };
224    let max_count = adapter_limits.max_compute_workgroups_per_dimension;
225    let hardware_props = HardwareProperties {
226        // On Apple Silicon, the plane size is 32,
227        // though the minimum and maximum differ.
228        // https://github.com/gpuweb/gpuweb/issues/3950
229        #[cfg(apple_silicon)]
230        plane_size_min: 32,
231        #[cfg(not(apple_silicon))]
232        plane_size_min: adapter_limits.min_subgroup_size,
233        #[cfg(apple_silicon)]
234        plane_size_max: 32,
235        #[cfg(not(apple_silicon))]
236        plane_size_max: adapter_limits.max_subgroup_size,
237        max_bindings: limits.max_storage_buffers_per_shader_stage,
238        max_shared_memory_size: limits.max_compute_workgroup_storage_size as usize,
239        max_cube_count: CubeCount::new_3d(max_count, max_count, max_count),
240        max_units_per_cube: adapter_limits.max_compute_invocations_per_workgroup,
241        max_cube_dim: CubeDim::new_3d(
242            adapter_limits.max_compute_workgroup_size_x,
243            adapter_limits.max_compute_workgroup_size_y,
244            adapter_limits.max_compute_workgroup_size_z,
245        ),
246        num_streaming_multiprocessors: None,
247        num_tensor_cores: None,
248        min_tensor_cores_dim: None,
249    };
250
251    let mut compilation_options = Default::default();
252
253    let features = setup.adapter.features();
254
255    let time_measurement = if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
256        TimingMethod::Device
257    } else {
258        TimingMethod::System
259    };
260
261    let mut device_props =
262        DeviceProperties::new(&[], mem_props.clone(), hardware_props, time_measurement);
263
264    #[cfg(not(all(target_os = "macos", feature = "msl")))]
265    {
266        // Workaround: WebGPU does support subgroups and correctly reports this, but wgpu
267        // doesn't plumb through this info. Instead min/max are just reported as 0, which can cause issues.
268        // For now just disable subgroups on WebGPU, until this information is added.
269        let fake_plane_info =
270            adapter_limits.min_subgroup_size == 0 && adapter_limits.max_subgroup_size == 0;
271
272        if features.contains(wgpu::Features::SUBGROUP)
273            && setup.adapter.get_info().device_type != wgpu::DeviceType::Cpu
274            && !fake_plane_info
275        {
276            device_props.register_feature(Feature::Plane);
277        }
278    }
279
280    backend::register_features(&setup.adapter, &mut device_props, &mut compilation_options);
281
282    let server = WgpuServer::new(
283        mem_props,
284        options.memory_config,
285        compilation_options,
286        setup.device.clone(),
287        setup.queue,
288        options.tasks_max,
289        setup.backend,
290        time_measurement,
291    );
292    let channel = MutexComputeChannel::new(server);
293
294    #[cfg(not(all(target_os = "macos", feature = "msl")))]
295    if features.contains(wgpu::Features::SHADER_FLOAT32_ATOMIC) {
296        device_props.register_feature(Feature::Type(Elem::AtomicFloat(FloatKind::F32)));
297
298        device_props.register_feature(Feature::AtomicFloat(AtomicFeature::LoadStore));
299        device_props.register_feature(Feature::AtomicFloat(AtomicFeature::Add));
300    }
301
302    #[cfg(not(all(target_os = "macos", feature = "msl")))]
303    {
304        use cubecl_core::ir::{IntKind, UIntKind};
305
306        device_props.register_feature(Feature::Type(Elem::AtomicInt(IntKind::I32)));
307        device_props.register_feature(Feature::Type(Elem::AtomicUInt(UIntKind::U32)));
308        device_props.register_feature(Feature::AtomicInt(AtomicFeature::LoadStore));
309        device_props.register_feature(Feature::AtomicInt(AtomicFeature::Add));
310        device_props.register_feature(Feature::AtomicUInt(AtomicFeature::LoadStore));
311        device_props.register_feature(Feature::AtomicUInt(AtomicFeature::Add));
312    }
313
314    ComputeClient::new(channel, device_props, setup.backend)
315}
316
317/// Select the wgpu device and queue based on the provided [device](WgpuDevice) and
318/// [backend](wgpu::Backend).
319pub(crate) async fn create_setup_for_device(
320    device: &WgpuDevice,
321    backend: wgpu::Backend,
322) -> WgpuSetup {
323    let (instance, adapter) = request_adapter(device, backend).await;
324    let (device, queue) = backend::request_device(&adapter).await;
325
326    log::info!(
327        "Created wgpu compute server on device {:?} => {:?}",
328        device,
329        adapter.get_info()
330    );
331
332    WgpuSetup {
333        instance,
334        adapter,
335        device,
336        queue,
337        backend,
338    }
339}
340
341async fn request_adapter(
342    device: &WgpuDevice,
343    backend: wgpu::Backend,
344) -> (wgpu::Instance, wgpu::Adapter) {
345    let debug = ServerLogger::default();
346    let instance_flags = match (debug.profile_level(), debug.compilation_activated()) {
347        (Some(ProfileLevel::Full), _) => InstanceFlags::advanced_debugging(),
348        (_, true) => InstanceFlags::debugging(),
349        (_, false) => InstanceFlags::default(),
350    };
351    log::debug!("{instance_flags:?}");
352    let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
353        backends: backend.into(),
354        flags: instance_flags,
355        ..Default::default()
356    });
357
358    #[allow(deprecated)]
359    let override_device = if matches!(
360        device,
361        WgpuDevice::DefaultDevice | WgpuDevice::BestAvailable
362    ) {
363        get_device_override()
364    } else {
365        None
366    };
367
368    let device = override_device.unwrap_or_else(|| device.clone());
369
370    let adapter = match device {
371        #[cfg(not(target_family = "wasm"))]
372        WgpuDevice::DiscreteGpu(num) => select_from_adapter_list(
373            num,
374            "No Discrete GPU device found",
375            &instance,
376            &device,
377            backend,
378        ),
379        #[cfg(not(target_family = "wasm"))]
380        WgpuDevice::IntegratedGpu(num) => select_from_adapter_list(
381            num,
382            "No Integrated GPU device found",
383            &instance,
384            &device,
385            backend,
386        ),
387        #[cfg(not(target_family = "wasm"))]
388        WgpuDevice::VirtualGpu(num) => select_from_adapter_list(
389            num,
390            "No Virtual GPU device found",
391            &instance,
392            &device,
393            backend,
394        ),
395        #[cfg(not(target_family = "wasm"))]
396        WgpuDevice::Cpu => {
397            select_from_adapter_list(0, "No CPU device found", &instance, &device, backend)
398        }
399        WgpuDevice::Existing(_) => {
400            unreachable!("Cannot select an adapter for an existing device.")
401        }
402        _ => instance
403            .request_adapter(&RequestAdapterOptions {
404                power_preference: wgpu::PowerPreference::HighPerformance,
405                force_fallback_adapter: false,
406                compatible_surface: None,
407            })
408            .await
409            .expect("No possible adapter available for backend. Falling back to first available."),
410    };
411
412    log::info!("Using adapter {:?}", adapter.get_info());
413
414    (instance, adapter)
415}
416
417#[cfg(not(target_family = "wasm"))]
418fn select_from_adapter_list(
419    num: usize,
420    error: &str,
421    instance: &wgpu::Instance,
422    device: &WgpuDevice,
423    backend: wgpu::Backend,
424) -> wgpu::Adapter {
425    let mut adapters_other = Vec::new();
426    let mut adapters = Vec::new();
427
428    instance
429        .enumerate_adapters(backend.into())
430        .into_iter()
431        .for_each(|adapter| {
432            let device_type = adapter.get_info().device_type;
433
434            if let wgpu::DeviceType::Other = device_type {
435                adapters_other.push(adapter);
436                return;
437            }
438
439            let is_same_type = match device {
440                WgpuDevice::DiscreteGpu(_) => device_type == wgpu::DeviceType::DiscreteGpu,
441                WgpuDevice::IntegratedGpu(_) => device_type == wgpu::DeviceType::IntegratedGpu,
442                WgpuDevice::VirtualGpu(_) => device_type == wgpu::DeviceType::VirtualGpu,
443                WgpuDevice::Cpu => device_type == wgpu::DeviceType::Cpu,
444                #[allow(deprecated)]
445                WgpuDevice::DefaultDevice | WgpuDevice::BestAvailable => true,
446                WgpuDevice::Existing(_) => {
447                    unreachable!("Cannot select an adapter for an existing device.")
448                }
449            };
450
451            if is_same_type {
452                adapters.push(adapter);
453            }
454        });
455
456    if adapters.len() <= num {
457        if adapters_other.len() <= num {
458            panic!(
459                "{}, adapters {:?}, other adapters {:?}",
460                error,
461                adapters
462                    .into_iter()
463                    .map(|adapter| adapter.get_info())
464                    .collect::<Vec<_>>(),
465                adapters_other
466                    .into_iter()
467                    .map(|adapter| adapter.get_info())
468                    .collect::<Vec<_>>(),
469            );
470        }
471
472        return adapters_other.remove(num);
473    }
474
475    adapters.remove(num)
476}
477
478fn get_device_override() -> Option<WgpuDevice> {
479    // If BestAvailable, check if we should instead construct as
480    // if a specific device was specified.
481    std::env::var("CUBECL_WGPU_DEFAULT_DEVICE")
482        .ok()
483        .and_then(|var| {
484            let override_device = if let Some(inner) = var.strip_prefix("DiscreteGpu(") {
485                inner
486                    .strip_suffix(")")
487                    .and_then(|s| s.parse().ok())
488                    .map(WgpuDevice::DiscreteGpu)
489            } else if let Some(inner) = var.strip_prefix("IntegratedGpu(") {
490                inner
491                    .strip_suffix(")")
492                    .and_then(|s| s.parse().ok())
493                    .map(WgpuDevice::IntegratedGpu)
494            } else if let Some(inner) = var.strip_prefix("VirtualGpu(") {
495                inner
496                    .strip_suffix(")")
497                    .and_then(|s| s.parse().ok())
498                    .map(WgpuDevice::VirtualGpu)
499            } else if var == "Cpu" {
500                Some(WgpuDevice::Cpu)
501            } else {
502                None
503            };
504
505            if override_device.is_none() {
506                log::warn!("Unknown CUBECL_WGPU_DEVICE override {var}");
507            }
508            override_device
509        })
510}