Skip to main content

cubecl_wgpu/backend/
base.rs

1use super::wgsl;
2use crate::AutoRepresentationRef;
3use crate::WgpuServer;
4use cubecl_core::{ExecutionMode, WgpuCompilationOptions, hash::StableHash, server::Bindings};
5use cubecl_ir::DeviceProperties;
6use cubecl_runtime::{compiler::CompilationError, id::KernelId};
7use std::{borrow::Cow, sync::Arc};
8use wgpu::{
9    Adapter, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType, BufferBindingType,
10    ComputePipeline, Device, PipelineLayoutDescriptor, Queue, ShaderModule, ShaderModuleDescriptor,
11    ShaderStages,
12};
13
14#[cfg(feature = "spirv")]
15use super::vulkan;
16
17#[cfg(all(feature = "msl", target_os = "macos"))]
18use super::metal;
19#[cfg(all(feature = "msl", target_os = "macos"))]
20use cubecl_cpp::metal as cpp_metal;
21
22impl WgpuServer {
23    /// Loads a cached kernel if present and creates the pipeline for it.
24    /// Returns `None` if the cache isn't enabled, `Some(Ok(pipeline))` if a cache entry was found,
25    /// and `Some(Err(cache_key))` if the cache is enabled but doesn't contain this kernel.
26    #[allow(
27        clippy::type_complexity,
28        reason = "required because of error propagation"
29    )]
30    #[allow(unused_variables)]
31    pub fn load_cached_pipeline(
32        &self,
33        kernel_id: &KernelId,
34        bindings: &Bindings,
35        mode: ExecutionMode,
36    ) -> Result<Option<Result<Arc<ComputePipeline>, (u64, StableHash)>>, CompilationError> {
37        #[cfg(not(feature = "spirv"))]
38        let res = Ok(None);
39        #[cfg(feature = "spirv")]
40        let res = if let Some(cache) = &self.spirv_cache {
41            let key = (self.utilities.properties_hash, kernel_id.stable_hash());
42            if let Some(entry) = cache.get(&key) {
43                log::trace!("Using SPIR-V cache");
44
45                let repr = AutoRepresentationRef::SpirV(&entry.kernel);
46                let module = self.create_module(&entry.entrypoint_name, Some(repr), "", mode)?;
47                let pipeline =
48                    self.create_pipeline(&entry.entrypoint_name, Some(repr), module, bindings);
49                Ok(Some(Ok(pipeline)))
50            } else {
51                Ok(Some(Err(key)))
52            }
53        } else {
54            Ok(None)
55        };
56
57        res
58    }
59
60    pub fn create_module(
61        &self,
62        entrypoint_name: &str,
63        repr: Option<AutoRepresentationRef<'_>>,
64        source: &str,
65        mode: ExecutionMode,
66    ) -> Result<ShaderModule, CompilationError> {
67        #[allow(unused_assignments)]
68        #[cfg(not(target_family = "wasm"))]
69        let mut error_scope = None;
70
71        match repr {
72            #[cfg(feature = "spirv")]
73            Some(AutoRepresentationRef::SpirV(repr)) => unsafe {
74                Ok(self.device.create_shader_module_passthrough(
75                    wgpu::ShaderModuleDescriptorPassthrough {
76                        label: Some(entrypoint_name),
77                        spirv: Some(Cow::Borrowed(&repr.assembled_module)),
78                        ..Default::default()
79                    },
80                ))
81            },
82            #[cfg(all(feature = "msl", target_os = "macos"))]
83            Some(AutoRepresentationRef::Msl(repr)) => unsafe {
84                Ok(self.device.create_shader_module_passthrough(
85                    wgpu::ShaderModuleDescriptorPassthrough {
86                        entry_point: entrypoint_name.to_string(),
87                        label: Some(entrypoint_name),
88                        msl: Some(Cow::Borrowed(source)),
89                        num_workgroups: (repr.cube_dim.x, repr.cube_dim.y, repr.cube_dim.z),
90                        ..Default::default()
91                    },
92                ))
93            },
94            _ => {
95                let _ = entrypoint_name; // otherwise unused
96                let checks = wgpu::ShaderRuntimeChecks {
97                    // Cube does not need wgpu bounds checks - OOB behaviour is instead
98                    // checked by cube (if enabled).
99                    // This is because the WebGPU specification only makes loose guarantees that Cube can't rely on.
100                    bounds_checks: false,
101                    // Loop bounds are only checked in checked mode.
102                    force_loop_bounding: mode == ExecutionMode::Checked,
103                    ray_query_initialization_tracking: false,
104                };
105
106                #[cfg(not(target_family = "wasm"))]
107                {
108                    error_scope = Some(self.device.push_error_scope(wgpu::ErrorFilter::Validation));
109                }
110
111                // SAFETY: Cube guarantees OOB safety when launching in checked mode. Launching in unchecked mode
112                // is only available through the use of unsafe code.
113                let module = unsafe {
114                    self.device.create_shader_module_trusted(
115                        ShaderModuleDescriptor {
116                            label: None,
117                            source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
118                        },
119                        checks,
120                    )
121                };
122
123                #[cfg(not(target_family = "wasm"))]
124                if let Some(scope) = error_scope
125                    && let Some(err) = cubecl_common::future::block_on(scope.pop())
126                {
127                    return Err(CompilationError::Generic {
128                        reason: format!("{err}"),
129                        backtrace: cubecl_common::backtrace::BackTrace::capture(),
130                    });
131                }
132
133                Ok(module)
134            }
135        }
136    }
137
138    #[allow(unused_variables)]
139    pub fn create_pipeline(
140        &self,
141        entrypoint_name: &str,
142        repr: Option<AutoRepresentationRef<'_>>,
143        module: ShaderModule,
144        bindings: &Bindings,
145    ) -> Arc<ComputePipeline> {
146        let bindings_info = match repr {
147            Some(AutoRepresentationRef::Wgsl(repr)) => Some(wgsl::bindings(repr)),
148            #[cfg(all(feature = "msl", target_os = "macos"))]
149            Some(AutoRepresentationRef::Msl(repr)) => Some(cpp_metal::bindings(repr)),
150            #[cfg(feature = "spirv")]
151            Some(AutoRepresentationRef::SpirV(repr)) => Some(vulkan::bindings(repr, bindings)),
152            _ => None,
153        };
154
155        let layout = bindings_info.map(|bindings| {
156            let (mut bindings, meta) = bindings;
157            // When slices are shared, it needs to be read-write if ANY of the slices is read-write,
158            // and since we can't be sure, we'll assume everything is read-write.
159            if !cfg!(exclusive_memory_only) {
160                bindings.fill(cubecl_runtime::kernel::Visibility::ReadWrite);
161            }
162
163            let bindings = bindings
164                .into_iter()
165                .chain(meta)
166                .enumerate()
167                .map(|(i, visibility)| BindGroupLayoutEntry {
168                    binding: i as u32,
169                    visibility: ShaderStages::COMPUTE,
170                    ty: BindingType::Buffer {
171                        ty: BufferBindingType::Storage {
172                            read_only: matches!(
173                                visibility,
174                                cubecl_runtime::kernel::Visibility::Read
175                            ),
176                        },
177                        has_dynamic_offset: false,
178                        min_binding_size: None,
179                    },
180                    count: None,
181                })
182                .collect::<Vec<_>>();
183            let layout = self
184                .device
185                .create_bind_group_layout(&BindGroupLayoutDescriptor {
186                    label: None,
187                    entries: &bindings,
188                });
189            self.device
190                .create_pipeline_layout(&PipelineLayoutDescriptor {
191                    label: None,
192                    bind_group_layouts: &[&layout],
193                    immediate_size: 0,
194                })
195        });
196
197        let pipeline = self
198            .device
199            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
200                label: Some(entrypoint_name),
201                layout: layout.as_ref(),
202                module: &module,
203                entry_point: Some(entrypoint_name),
204                compilation_options: wgpu::PipelineCompilationOptions {
205                    zero_initialize_workgroup_memory: false,
206                    ..Default::default()
207                },
208                cache: None,
209            });
210        Arc::new(pipeline)
211    }
212}
213
214#[cfg(all(not(feature = "spirv"), not(feature = "msl")))]
215pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
216    wgsl::request_device(adapter).await
217}
218
219#[cfg(feature = "spirv")]
220pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
221    if is_vulkan(adapter) {
222        vulkan::request_vulkan_device(adapter).await
223    } else {
224        wgsl::request_device(adapter).await
225    }
226}
227
228#[cfg(all(feature = "msl", target_os = "macos"))]
229pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
230    use super::metal;
231
232    if is_metal(adapter) {
233        metal::request_metal_device(adapter).await
234    } else {
235        panic!("metal device not found!");
236    }
237}
238
239#[cfg(all(not(feature = "spirv"), not(feature = "msl")))]
240pub fn register_features(
241    adapter: &Adapter,
242    props: &mut DeviceProperties,
243    comp_options: &mut WgpuCompilationOptions,
244) {
245    wgsl::register_wgsl_features(adapter, props, comp_options);
246}
247
248#[cfg(feature = "spirv")]
249pub fn register_features(
250    adapter: &Adapter,
251    props: &mut DeviceProperties,
252    comp_options: &mut WgpuCompilationOptions,
253) {
254    if is_vulkan(adapter) {
255        vulkan::register_vulkan_features(adapter, props, comp_options);
256    } else {
257        wgsl::register_wgsl_features(adapter, props, comp_options);
258    }
259}
260
261#[cfg(all(feature = "msl", target_os = "macos"))]
262pub fn register_features(
263    adapter: &Adapter,
264    props: &mut DeviceProperties,
265    comp_options: &mut WgpuCompilationOptions,
266) {
267    if is_metal(adapter) {
268        metal::register_metal_features(adapter, props, comp_options);
269    } else {
270        panic!("metal device not found!");
271    }
272}
273
274#[cfg(feature = "spirv")]
275fn is_vulkan(adapter: &Adapter) -> bool {
276    unsafe { adapter.as_hal::<wgpu::hal::api::Vulkan>().is_some() }
277}
278
279#[cfg(all(feature = "msl", target_os = "macos"))]
280fn is_metal(adapter: &Adapter) -> bool {
281    unsafe { adapter.as_hal::<wgpu::hal::api::Metal>().is_some() }
282}