Skip to main content

cubecl_wgpu/backend/
base.rs

1use super::wgsl;
2use crate::AutoRepresentationRef;
3use crate::WgpuServer;
4use cubecl_core::MemoryConfiguration;
5use cubecl_core::{
6    ExecutionMode, WgpuCompilationOptions, hash::StableHash, server::KernelArguments,
7};
8use cubecl_ir::DeviceProperties;
9use cubecl_runtime::{compiler::CompilationError, id::KernelId};
10use std::{borrow::Cow, sync::Arc};
11use wgpu::{
12    Adapter, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType, BufferBindingType,
13    ComputePipeline, Device, PipelineLayoutDescriptor, Queue, ShaderModule, ShaderModuleDescriptor,
14    ShaderStages,
15};
16
17#[cfg(feature = "spirv")]
18use super::vulkan;
19
20#[cfg(all(feature = "msl", target_os = "macos"))]
21use super::metal;
22#[cfg(all(feature = "msl", target_os = "macos"))]
23use cubecl_cpp::metal as cpp_metal;
24
25impl WgpuServer {
26    /// Loads a cached kernel if present and creates the pipeline for it.
27    /// Returns `None` if the cache isn't enabled, `Some(Ok(pipeline))` if a cache entry was found,
28    /// and `Some(Err(cache_key))` if the cache is enabled but doesn't contain this kernel.
29    #[allow(
30        clippy::type_complexity,
31        reason = "required because of error propagation"
32    )]
33    #[allow(unused_variables)]
34    pub fn load_cached_pipeline(
35        &self,
36        kernel_id: &KernelId,
37        bindings: &KernelArguments,
38        mode: ExecutionMode,
39    ) -> Result<Option<Result<Arc<ComputePipeline>, (u64, StableHash)>>, CompilationError> {
40        #[cfg(not(feature = "spirv"))]
41        let res = Ok(None);
42        #[cfg(feature = "spirv")]
43        let res = if let Some(cache) = &self.spirv_cache {
44            let key = (self.utilities.properties_hash, kernel_id.stable_hash());
45            if let Some(entry) = cache.get(&key) {
46                log::trace!("Using SPIR-V cache");
47
48                let repr = AutoRepresentationRef::SpirV(&entry.kernel);
49                let module = self.create_module(&entry.entrypoint_name, Some(repr), "", mode)?;
50                let pipeline =
51                    self.create_pipeline(&entry.entrypoint_name, Some(repr), module, bindings);
52                Ok(Some(Ok(pipeline)))
53            } else {
54                Ok(Some(Err(key)))
55            }
56        } else {
57            Ok(None)
58        };
59
60        res
61    }
62
63    pub fn create_module(
64        &self,
65        entrypoint_name: &str,
66        repr: Option<AutoRepresentationRef<'_>>,
67        source: &str,
68        mode: ExecutionMode,
69    ) -> Result<ShaderModule, CompilationError> {
70        #[allow(unused_assignments)]
71        #[cfg(not(target_family = "wasm"))]
72        let mut error_scope = None;
73
74        match repr {
75            #[cfg(feature = "spirv")]
76            Some(AutoRepresentationRef::SpirV(repr)) => unsafe {
77                Ok(self.device.create_shader_module_passthrough(
78                    wgpu::ShaderModuleDescriptorPassthrough {
79                        label: Some(entrypoint_name),
80                        spirv: Some(Cow::Borrowed(&repr.assembled_module)),
81                        ..Default::default()
82                    },
83                ))
84            },
85            #[cfg(all(feature = "msl", target_os = "macos"))]
86            Some(AutoRepresentationRef::Msl(repr)) => unsafe {
87                Ok(self.device.create_shader_module_passthrough(
88                    wgpu::ShaderModuleDescriptorPassthrough {
89                        label: Some(entrypoint_name),
90                        msl: Some(Cow::Borrowed(source)),
91                        num_workgroups: (repr.cube_dim.x, repr.cube_dim.y, repr.cube_dim.z),
92                        ..Default::default()
93                    },
94                ))
95            },
96            _ => {
97                let _ = entrypoint_name; // otherwise unused
98                let checks = wgpu::ShaderRuntimeChecks {
99                    // Cube does not need wgpu bounds checks - OOB behaviour is instead
100                    // checked by cube (if enabled).
101                    // This is because the WebGPU specification only makes loose guarantees that Cube can't rely on.
102                    bounds_checks: false,
103                    // Loop bounds are only checked in checked mode.
104                    force_loop_bounding: mode == ExecutionMode::Checked,
105                    ..wgpu::ShaderRuntimeChecks::unchecked()
106                };
107
108                #[cfg(not(target_family = "wasm"))]
109                {
110                    error_scope = Some(self.device.push_error_scope(wgpu::ErrorFilter::Validation));
111                }
112
113                // SAFETY: Cube guarantees OOB safety when launching in checked mode. Launching in unchecked mode
114                // is only available through the use of unsafe code.
115                let module = unsafe {
116                    self.device.create_shader_module_trusted(
117                        ShaderModuleDescriptor {
118                            label: None,
119                            source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
120                        },
121                        checks,
122                    )
123                };
124
125                #[cfg(not(target_family = "wasm"))]
126                if let Some(scope) = error_scope
127                    && let Some(err) = cubecl_common::future::block_on(scope.pop())
128                {
129                    return Err(CompilationError::Generic {
130                        reason: format!("{err}"),
131                        backtrace: cubecl_common::backtrace::BackTrace::capture(),
132                    });
133                }
134
135                Ok(module)
136            }
137        }
138    }
139
140    #[allow(unused_variables)]
141    pub fn create_pipeline(
142        &self,
143        entrypoint_name: &str,
144        repr: Option<AutoRepresentationRef<'_>>,
145        module: ShaderModule,
146        bindings: &KernelArguments,
147    ) -> Arc<ComputePipeline> {
148        let bindings_info = match repr {
149            Some(AutoRepresentationRef::Wgsl(repr)) => Some(wgsl::bindings(repr, bindings)),
150            #[cfg(all(feature = "msl", target_os = "macos"))]
151            Some(AutoRepresentationRef::Msl(repr)) => Some(cpp_metal::bindings(repr, bindings)),
152            #[cfg(feature = "spirv")]
153            Some(AutoRepresentationRef::SpirV(repr)) => Some(vulkan::bindings(repr, bindings)),
154            _ => None,
155        };
156
157        let layout = bindings_info.map(|bindings| {
158            let (mut bindings, info, uniform_info) = bindings;
159            // When slices are shared, it needs to be read-write if ANY of the slices is read-write,
160            // and since we can't be sure, we'll assume everything is read-write.
161            if !cfg!(exclusive_memory_only) {
162                bindings.fill(cubecl_runtime::kernel::Visibility::ReadWrite);
163            }
164
165            let info = info.map(|_| match uniform_info {
166                true => BufferBindingType::Uniform,
167                false => BufferBindingType::Storage { read_only: true },
168            });
169
170            let bindings = bindings
171                .into_iter()
172                .map(|visibility| BufferBindingType::Storage {
173                    read_only: matches!(visibility, cubecl_runtime::kernel::Visibility::Read),
174                })
175                .chain(info)
176                .enumerate()
177                .map(|(i, ty)| BindGroupLayoutEntry {
178                    binding: i as u32,
179                    visibility: ShaderStages::COMPUTE,
180                    ty: BindingType::Buffer {
181                        ty,
182                        has_dynamic_offset: false,
183                        min_binding_size: None,
184                    },
185                    count: None,
186                })
187                .collect::<Vec<_>>();
188            let layout = self
189                .device
190                .create_bind_group_layout(&BindGroupLayoutDescriptor {
191                    label: None,
192                    entries: &bindings,
193                });
194            self.device
195                .create_pipeline_layout(&PipelineLayoutDescriptor {
196                    label: None,
197                    bind_group_layouts: &[Some(&layout)],
198                    immediate_size: 0,
199                })
200        });
201
202        let pipeline = self
203            .device
204            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
205                label: Some(entrypoint_name),
206                layout: layout.as_ref(),
207                module: &module,
208                entry_point: Some(entrypoint_name),
209                compilation_options: wgpu::PipelineCompilationOptions {
210                    zero_initialize_workgroup_memory: false,
211                    ..Default::default()
212                },
213                cache: None,
214            });
215        Arc::new(pipeline)
216    }
217}
218
219pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
220    if let Some(result) = request_vulkan_device(adapter).await {
221        return result;
222    }
223    if let Some(result) = request_metal_device(adapter).await {
224        return result;
225    }
226    wgsl::request_device(adapter).await
227}
228
229#[cfg(feature = "spirv")]
230async fn request_vulkan_device(adapter: &Adapter) -> Option<(Device, Queue)> {
231    if is_vulkan(adapter) {
232        vulkan::request_vulkan_device(adapter).await
233    } else {
234        None
235    }
236}
237
238#[cfg(not(feature = "spirv"))]
239async fn request_vulkan_device(_adapter: &Adapter) -> Option<(Device, Queue)> {
240    None
241}
242
243#[cfg(all(feature = "msl", target_os = "macos"))]
244async fn request_metal_device(adapter: &Adapter) -> Option<(Device, Queue)> {
245    if is_metal(adapter) {
246        Some(metal::request_metal_device(adapter).await)
247    } else {
248        None
249    }
250}
251
252#[cfg(not(all(feature = "msl", target_os = "macos")))]
253async fn request_metal_device(_adapter: &Adapter) -> Option<(Device, Queue)> {
254    None
255}
256
257pub fn register_features(
258    adapter: &Adapter,
259    props: &mut DeviceProperties,
260    comp_options: &mut WgpuCompilationOptions,
261    memory_config: &MemoryConfiguration,
262) {
263    if register_vulkan_features(adapter, props, comp_options, memory_config) {
264        return;
265    }
266    if register_metal_features(adapter, props, comp_options, memory_config) {
267        return;
268    }
269    wgsl::register_wgsl_features(adapter, props, comp_options);
270}
271
272#[cfg(feature = "spirv")]
273pub fn register_vulkan_features(
274    adapter: &Adapter,
275    props: &mut DeviceProperties,
276    comp_options: &mut WgpuCompilationOptions,
277    memory_config: &MemoryConfiguration,
278) -> bool {
279    if is_vulkan(adapter) {
280        vulkan::register_vulkan_features(adapter, props, comp_options, memory_config)
281    } else {
282        false
283    }
284}
285
286#[cfg(not(feature = "spirv"))]
287pub fn register_vulkan_features(
288    _adapter: &Adapter,
289    _props: &mut DeviceProperties,
290    _comp_options: &mut WgpuCompilationOptions,
291    _memory_config: &MemoryConfiguration,
292) -> bool {
293    false
294}
295
296#[cfg(all(feature = "msl", target_os = "macos"))]
297pub fn register_metal_features(
298    adapter: &Adapter,
299    props: &mut DeviceProperties,
300    comp_options: &mut WgpuCompilationOptions,
301    _memory_config: &MemoryConfiguration,
302) -> bool {
303    if is_metal(adapter) {
304        metal::register_metal_features(adapter, props, comp_options);
305        true
306    } else {
307        false
308    }
309}
310
311#[cfg(not(all(feature = "msl", target_os = "macos")))]
312pub fn register_metal_features(
313    _adapter: &Adapter,
314    _props: &mut DeviceProperties,
315    _comp_options: &mut WgpuCompilationOptions,
316    _memory_config: &MemoryConfiguration,
317) -> bool {
318    false
319}
320
321#[cfg(feature = "spirv")]
322fn is_vulkan(adapter: &Adapter) -> bool {
323    unsafe { adapter.as_hal::<wgpu::hal::api::Vulkan>().is_some() }
324}
325
326#[cfg(all(feature = "msl", target_os = "macos"))]
327fn is_metal(adapter: &Adapter) -> bool {
328    unsafe { adapter.as_hal::<wgpu::hal::api::Metal>().is_some() }
329}