Skip to main content

cubecl_wgpu/backend/
base.rs

1use super::wgsl;
2use crate::AutoRepresentationRef;
3use crate::WgpuServer;
4use cubecl_core::MemoryConfiguration;
5use cubecl_core::{
6    ExecutionMode, WgpuCompilationOptions, hash::StableHash, server::KernelArguments,
7};
8use cubecl_ir::DeviceProperties;
9use cubecl_runtime::{compiler::CompilationError, id::KernelId};
10use std::{borrow::Cow, sync::Arc};
11use wgpu::{
12    Adapter, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType, BufferBindingType,
13    ComputePipeline, Device, PipelineLayoutDescriptor, Queue, ShaderModule, ShaderModuleDescriptor,
14    ShaderStages,
15};
16
17#[cfg(feature = "spirv")]
18use super::vulkan;
19
20#[cfg(all(feature = "msl", target_os = "macos"))]
21use super::metal;
22#[cfg(all(feature = "msl", target_os = "macos"))]
23use cubecl_cpp::metal as cpp_metal;
24
25impl WgpuServer {
26    /// Loads a cached kernel if present and creates the pipeline for it.
27    /// Returns `None` if the cache isn't enabled, `Some(Ok(pipeline))` if a cache entry was found,
28    /// and `Some(Err(cache_key))` if the cache is enabled but doesn't contain this kernel.
29    #[allow(
30        clippy::type_complexity,
31        reason = "required because of error propagation"
32    )]
33    #[allow(unused_variables)]
34    pub fn load_cached_pipeline(
35        &self,
36        kernel_id: &KernelId,
37        bindings: &KernelArguments,
38        mode: ExecutionMode,
39    ) -> Result<Option<Result<Arc<ComputePipeline>, (u64, StableHash)>>, CompilationError> {
40        #[cfg(not(feature = "spirv"))]
41        let res = Ok(None);
42        #[cfg(feature = "spirv")]
43        let res = if let Some(cache) = &self.spirv_cache {
44            let key = (self.utilities.properties_hash, kernel_id.stable_hash());
45            if let Some(entry) = cache.get(&key) {
46                log::trace!("Using SPIR-V cache");
47
48                let repr = AutoRepresentationRef::SpirV(&entry.kernel);
49                let module = self.create_module(&entry.entrypoint_name, Some(repr), "", mode)?;
50                let pipeline =
51                    self.create_pipeline(&entry.entrypoint_name, Some(repr), module, bindings);
52                Ok(Some(Ok(pipeline)))
53            } else {
54                Ok(Some(Err(key)))
55            }
56        } else {
57            Ok(None)
58        };
59
60        res
61    }
62
63    pub fn create_module(
64        &self,
65        entrypoint_name: &str,
66        repr: Option<AutoRepresentationRef<'_>>,
67        source: &str,
68        mode: ExecutionMode,
69    ) -> Result<ShaderModule, CompilationError> {
70        match repr {
71            #[cfg(feature = "spirv")]
72            Some(AutoRepresentationRef::SpirV(repr)) => unsafe {
73                Ok(self.device.create_shader_module_passthrough(
74                    wgpu::ShaderModuleDescriptorPassthrough {
75                        label: Some(entrypoint_name),
76                        spirv: Some(Cow::Borrowed(&repr.assembled_module)),
77                        ..Default::default()
78                    },
79                ))
80            },
81            #[cfg(all(feature = "msl", target_os = "macos"))]
82            Some(AutoRepresentationRef::Msl(repr)) => unsafe {
83                Ok(self.device.create_shader_module_passthrough(
84                    wgpu::ShaderModuleDescriptorPassthrough {
85                        label: Some(entrypoint_name),
86                        msl: Some(Cow::Borrowed(source)),
87                        num_workgroups: (repr.cube_dim.x, repr.cube_dim.y, repr.cube_dim.z),
88                        ..Default::default()
89                    },
90                ))
91            },
92            _ => {
93                let checks = wgpu::ShaderRuntimeChecks {
94                    // Cube does not need wgpu bounds checks - OOB behaviour is instead
95                    // checked by cube (if enabled).
96                    // This is because the WebGPU specification only makes loose guarantees that Cube can't rely on.
97                    bounds_checks: false,
98                    // Loop bounds are only checked in checked mode.
99                    force_loop_bounding: mode == ExecutionMode::Checked,
100                    ..wgpu::ShaderRuntimeChecks::unchecked()
101                };
102
103                log::trace!("[cubecl-wgpu] compiling WGSL module `{entrypoint_name}`\n{source}");
104
105                let error_scope = self.device.push_error_scope(wgpu::ErrorFilter::Validation);
106
107                // SAFETY: Cube guarantees OOB safety when launching in checked mode. Launching in unchecked mode
108                // is only available through the use of unsafe code.
109                let module = unsafe {
110                    self.device.create_shader_module_trusted(
111                        ShaderModuleDescriptor {
112                            label: Some(entrypoint_name),
113                            source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
114                        },
115                        checks,
116                    )
117                };
118
119                // `pop()` detaches from the LIFO stack immediately; only the
120                // result is async. Safe to interleave with other push/pops.
121                let err_future = error_scope.pop();
122
123                #[cfg(not(target_family = "wasm"))]
124                if let Some(err) = cubecl_common::future::block_on(err_future) {
125                    log::error!(
126                        "[cubecl-wgpu] WGSL compilation failed for kernel `{entrypoint_name}`:\n{err}\n--- shader source ({} bytes) ---\n{source}\n--- end shader ---",
127                        source.len()
128                    );
129                    return Err(CompilationError::Generic {
130                        reason: format!(
131                            "WGSL compilation failed for kernel `{entrypoint_name}`: {err}"
132                        ),
133                        backtrace: cubecl_common::backtrace::BackTrace::capture(),
134                    });
135                }
136
137                // On wasm we can't block; spawn a task that awaits the pop
138                // future and logs.
139                #[cfg(target_family = "wasm")]
140                {
141                    let entrypoint_name = entrypoint_name.to_string();
142                    let source = source.to_string();
143                    wasm_bindgen_futures::spawn_local(async move {
144                        if let Some(err) = err_future.await {
145                            log::error!(
146                                "[cubecl-wgpu] WGSL compilation failed for kernel `{entrypoint_name}`:\n{err}\n--- shader source ({} bytes) ---\n{source}\n--- end shader ---",
147                                source.len()
148                            );
149                        }
150                    });
151                }
152
153                Ok(module)
154            }
155        }
156    }
157
158    #[allow(unused_variables)]
159    pub fn create_pipeline(
160        &self,
161        entrypoint_name: &str,
162        repr: Option<AutoRepresentationRef<'_>>,
163        module: ShaderModule,
164        bindings: &KernelArguments,
165    ) -> Arc<ComputePipeline> {
166        let bindings_info = match repr {
167            Some(AutoRepresentationRef::Wgsl(repr)) => Some(wgsl::bindings(repr, bindings)),
168            #[cfg(all(feature = "msl", target_os = "macos"))]
169            Some(AutoRepresentationRef::Msl(repr)) => Some(cpp_metal::bindings(repr, bindings)),
170            #[cfg(feature = "spirv")]
171            Some(AutoRepresentationRef::SpirV(repr)) => Some(vulkan::bindings(repr, bindings)),
172            _ => None,
173        };
174
175        let layout = bindings_info.map(|bindings| {
176            let (mut bindings, info, uniform_info) = bindings;
177            // When slices are shared, it needs to be read-write if ANY of the slices is read-write,
178            // and since we can't be sure, we'll assume everything is read-write.
179            if !cfg!(exclusive_memory_only) {
180                bindings.fill(cubecl_runtime::kernel::Visibility::ReadWrite);
181            }
182
183            let info = info.map(|_| match uniform_info {
184                true => BufferBindingType::Uniform,
185                false => BufferBindingType::Storage { read_only: true },
186            });
187
188            let bindings = bindings
189                .into_iter()
190                .map(|visibility| BufferBindingType::Storage {
191                    read_only: matches!(visibility, cubecl_runtime::kernel::Visibility::Read),
192                })
193                .chain(info)
194                .enumerate()
195                .map(|(i, ty)| BindGroupLayoutEntry {
196                    binding: i as u32,
197                    visibility: ShaderStages::COMPUTE,
198                    ty: BindingType::Buffer {
199                        ty,
200                        has_dynamic_offset: false,
201                        min_binding_size: None,
202                    },
203                    count: None,
204                })
205                .collect::<Vec<_>>();
206            let layout = self
207                .device
208                .create_bind_group_layout(&BindGroupLayoutDescriptor {
209                    label: None,
210                    entries: &bindings,
211                });
212            self.device
213                .create_pipeline_layout(&PipelineLayoutDescriptor {
214                    label: None,
215                    bind_group_layouts: &[Some(&layout)],
216                    immediate_size: 0,
217                })
218        });
219
220        let pipeline = self
221            .device
222            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
223                label: Some(entrypoint_name),
224                layout: layout.as_ref(),
225                module: &module,
226                entry_point: Some(entrypoint_name),
227                compilation_options: wgpu::PipelineCompilationOptions {
228                    zero_initialize_workgroup_memory: false,
229                    ..Default::default()
230                },
231                cache: None,
232            });
233        Arc::new(pipeline)
234    }
235}
236
237pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
238    if let Some(result) = request_vulkan_device(adapter).await {
239        return result;
240    }
241    if let Some(result) = request_metal_device(adapter).await {
242        return result;
243    }
244    wgsl::request_device(adapter).await
245}
246
247#[cfg(feature = "spirv")]
248async fn request_vulkan_device(adapter: &Adapter) -> Option<(Device, Queue)> {
249    if is_vulkan(adapter) {
250        vulkan::request_vulkan_device(adapter).await
251    } else {
252        None
253    }
254}
255
256#[cfg(not(feature = "spirv"))]
257async fn request_vulkan_device(_adapter: &Adapter) -> Option<(Device, Queue)> {
258    None
259}
260
261#[cfg(all(feature = "msl", target_os = "macos"))]
262async fn request_metal_device(adapter: &Adapter) -> Option<(Device, Queue)> {
263    if is_metal(adapter) {
264        Some(metal::request_metal_device(adapter).await)
265    } else {
266        None
267    }
268}
269
270#[cfg(not(all(feature = "msl", target_os = "macos")))]
271async fn request_metal_device(_adapter: &Adapter) -> Option<(Device, Queue)> {
272    None
273}
274
275pub fn register_features(
276    adapter: &Adapter,
277    props: &mut DeviceProperties,
278    comp_options: &mut WgpuCompilationOptions,
279    memory_config: &MemoryConfiguration,
280) {
281    if register_vulkan_features(adapter, props, comp_options, memory_config) {
282        return;
283    }
284    if register_metal_features(adapter, props, comp_options, memory_config) {
285        return;
286    }
287    wgsl::register_wgsl_features(adapter, props, comp_options);
288}
289
290#[cfg(feature = "spirv")]
291pub fn register_vulkan_features(
292    adapter: &Adapter,
293    props: &mut DeviceProperties,
294    comp_options: &mut WgpuCompilationOptions,
295    memory_config: &MemoryConfiguration,
296) -> bool {
297    if is_vulkan(adapter) {
298        vulkan::register_vulkan_features(adapter, props, comp_options, memory_config)
299    } else {
300        false
301    }
302}
303
304#[cfg(not(feature = "spirv"))]
305pub fn register_vulkan_features(
306    _adapter: &Adapter,
307    _props: &mut DeviceProperties,
308    _comp_options: &mut WgpuCompilationOptions,
309    _memory_config: &MemoryConfiguration,
310) -> bool {
311    false
312}
313
314#[cfg(all(feature = "msl", target_os = "macos"))]
315pub fn register_metal_features(
316    adapter: &Adapter,
317    props: &mut DeviceProperties,
318    comp_options: &mut WgpuCompilationOptions,
319    _memory_config: &MemoryConfiguration,
320) -> bool {
321    if is_metal(adapter) {
322        metal::register_metal_features(adapter, props, comp_options);
323        true
324    } else {
325        false
326    }
327}
328
329#[cfg(not(all(feature = "msl", target_os = "macos")))]
330pub fn register_metal_features(
331    _adapter: &Adapter,
332    _props: &mut DeviceProperties,
333    _comp_options: &mut WgpuCompilationOptions,
334    _memory_config: &MemoryConfiguration,
335) -> bool {
336    false
337}
338
339#[cfg(feature = "spirv")]
340fn is_vulkan(adapter: &Adapter) -> bool {
341    unsafe { adapter.as_hal::<wgpu::hal::api::Vulkan>().is_some() }
342}
343
344#[cfg(all(feature = "msl", target_os = "macos"))]
345fn is_metal(adapter: &Adapter) -> bool {
346    unsafe { adapter.as_hal::<wgpu::hal::api::Metal>().is_some() }
347}