1use super::wgsl;
2use crate::AutoRepresentationRef;
3use crate::WgpuServer;
4use cubecl_core::MemoryConfiguration;
5use cubecl_core::{
6 ExecutionMode, WgpuCompilationOptions, hash::StableHash, server::KernelArguments,
7};
8use cubecl_ir::DeviceProperties;
9use cubecl_runtime::{compiler::CompilationError, id::KernelId};
10use std::{borrow::Cow, sync::Arc};
11use wgpu::{
12 Adapter, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType, BufferBindingType,
13 ComputePipeline, Device, PipelineLayoutDescriptor, Queue, ShaderModule, ShaderModuleDescriptor,
14 ShaderStages,
15};
16
17#[cfg(feature = "spirv")]
18use super::vulkan;
19
20#[cfg(all(feature = "msl", target_os = "macos"))]
21use super::metal;
22#[cfg(all(feature = "msl", target_os = "macos"))]
23use cubecl_cpp::metal as cpp_metal;
24
25impl WgpuServer {
26 #[allow(
30 clippy::type_complexity,
31 reason = "required because of error propagation"
32 )]
33 #[allow(unused_variables)]
34 pub fn load_cached_pipeline(
35 &self,
36 kernel_id: &KernelId,
37 bindings: &KernelArguments,
38 mode: ExecutionMode,
39 ) -> Result<Option<Result<Arc<ComputePipeline>, (u64, StableHash)>>, CompilationError> {
40 #[cfg(not(feature = "spirv"))]
41 let res = Ok(None);
42 #[cfg(feature = "spirv")]
43 let res = if let Some(cache) = &self.spirv_cache {
44 let key = (self.utilities.properties_hash, kernel_id.stable_hash());
45 if let Some(entry) = cache.get(&key) {
46 log::trace!("Using SPIR-V cache");
47
48 let repr = AutoRepresentationRef::SpirV(&entry.kernel);
49 let module = self.create_module(&entry.entrypoint_name, Some(repr), "", mode)?;
50 let pipeline =
51 self.create_pipeline(&entry.entrypoint_name, Some(repr), module, bindings);
52 Ok(Some(Ok(pipeline)))
53 } else {
54 Ok(Some(Err(key)))
55 }
56 } else {
57 Ok(None)
58 };
59
60 res
61 }
62
63 pub fn create_module(
64 &self,
65 entrypoint_name: &str,
66 repr: Option<AutoRepresentationRef<'_>>,
67 source: &str,
68 mode: ExecutionMode,
69 ) -> Result<ShaderModule, CompilationError> {
70 match repr {
71 #[cfg(feature = "spirv")]
72 Some(AutoRepresentationRef::SpirV(repr)) => unsafe {
73 Ok(self.device.create_shader_module_passthrough(
74 wgpu::ShaderModuleDescriptorPassthrough {
75 label: Some(entrypoint_name),
76 spirv: Some(Cow::Borrowed(&repr.assembled_module)),
77 ..Default::default()
78 },
79 ))
80 },
81 #[cfg(all(feature = "msl", target_os = "macos"))]
82 Some(AutoRepresentationRef::Msl(repr)) => unsafe {
83 Ok(self.device.create_shader_module_passthrough(
84 wgpu::ShaderModuleDescriptorPassthrough {
85 label: Some(entrypoint_name),
86 msl: Some(Cow::Borrowed(source)),
87 num_workgroups: (repr.cube_dim.x, repr.cube_dim.y, repr.cube_dim.z),
88 ..Default::default()
89 },
90 ))
91 },
92 _ => {
93 let checks = wgpu::ShaderRuntimeChecks {
94 bounds_checks: false,
98 force_loop_bounding: mode == ExecutionMode::Checked,
100 ..wgpu::ShaderRuntimeChecks::unchecked()
101 };
102
103 log::trace!("[cubecl-wgpu] compiling WGSL module `{entrypoint_name}`\n{source}");
104
105 let error_scope = self.device.push_error_scope(wgpu::ErrorFilter::Validation);
106
107 let module = unsafe {
110 self.device.create_shader_module_trusted(
111 ShaderModuleDescriptor {
112 label: Some(entrypoint_name),
113 source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
114 },
115 checks,
116 )
117 };
118
119 let err_future = error_scope.pop();
122
123 #[cfg(not(target_family = "wasm"))]
124 if let Some(err) = cubecl_common::future::block_on(err_future) {
125 log::error!(
126 "[cubecl-wgpu] WGSL compilation failed for kernel `{entrypoint_name}`:\n{err}\n--- shader source ({} bytes) ---\n{source}\n--- end shader ---",
127 source.len()
128 );
129 return Err(CompilationError::Generic {
130 reason: format!(
131 "WGSL compilation failed for kernel `{entrypoint_name}`: {err}"
132 ),
133 backtrace: cubecl_common::backtrace::BackTrace::capture(),
134 });
135 }
136
137 #[cfg(target_family = "wasm")]
140 {
141 let entrypoint_name = entrypoint_name.to_string();
142 let source = source.to_string();
143 wasm_bindgen_futures::spawn_local(async move {
144 if let Some(err) = err_future.await {
145 log::error!(
146 "[cubecl-wgpu] WGSL compilation failed for kernel `{entrypoint_name}`:\n{err}\n--- shader source ({} bytes) ---\n{source}\n--- end shader ---",
147 source.len()
148 );
149 }
150 });
151 }
152
153 Ok(module)
154 }
155 }
156 }
157
158 #[allow(unused_variables)]
159 pub fn create_pipeline(
160 &self,
161 entrypoint_name: &str,
162 repr: Option<AutoRepresentationRef<'_>>,
163 module: ShaderModule,
164 bindings: &KernelArguments,
165 ) -> Arc<ComputePipeline> {
166 let bindings_info = match repr {
167 Some(AutoRepresentationRef::Wgsl(repr)) => Some(wgsl::bindings(repr, bindings)),
168 #[cfg(all(feature = "msl", target_os = "macos"))]
169 Some(AutoRepresentationRef::Msl(repr)) => Some(cpp_metal::bindings(repr, bindings)),
170 #[cfg(feature = "spirv")]
171 Some(AutoRepresentationRef::SpirV(repr)) => Some(vulkan::bindings(repr, bindings)),
172 _ => None,
173 };
174
175 let layout = bindings_info.map(|bindings| {
176 let (mut bindings, info, uniform_info) = bindings;
177 if !cfg!(exclusive_memory_only) {
180 bindings.fill(cubecl_runtime::kernel::Visibility::ReadWrite);
181 }
182
183 let info = info.map(|_| match uniform_info {
184 true => BufferBindingType::Uniform,
185 false => BufferBindingType::Storage { read_only: true },
186 });
187
188 let bindings = bindings
189 .into_iter()
190 .map(|visibility| BufferBindingType::Storage {
191 read_only: matches!(visibility, cubecl_runtime::kernel::Visibility::Read),
192 })
193 .chain(info)
194 .enumerate()
195 .map(|(i, ty)| BindGroupLayoutEntry {
196 binding: i as u32,
197 visibility: ShaderStages::COMPUTE,
198 ty: BindingType::Buffer {
199 ty,
200 has_dynamic_offset: false,
201 min_binding_size: None,
202 },
203 count: None,
204 })
205 .collect::<Vec<_>>();
206 let layout = self
207 .device
208 .create_bind_group_layout(&BindGroupLayoutDescriptor {
209 label: None,
210 entries: &bindings,
211 });
212 self.device
213 .create_pipeline_layout(&PipelineLayoutDescriptor {
214 label: None,
215 bind_group_layouts: &[Some(&layout)],
216 immediate_size: 0,
217 })
218 });
219
220 let pipeline = self
221 .device
222 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
223 label: Some(entrypoint_name),
224 layout: layout.as_ref(),
225 module: &module,
226 entry_point: Some(entrypoint_name),
227 compilation_options: wgpu::PipelineCompilationOptions {
228 zero_initialize_workgroup_memory: false,
229 ..Default::default()
230 },
231 cache: None,
232 });
233 Arc::new(pipeline)
234 }
235}
236
237pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
238 if let Some(result) = request_vulkan_device(adapter).await {
239 return result;
240 }
241 if let Some(result) = request_metal_device(adapter).await {
242 return result;
243 }
244 wgsl::request_device(adapter).await
245}
246
247#[cfg(feature = "spirv")]
248async fn request_vulkan_device(adapter: &Adapter) -> Option<(Device, Queue)> {
249 if is_vulkan(adapter) {
250 vulkan::request_vulkan_device(adapter).await
251 } else {
252 None
253 }
254}
255
256#[cfg(not(feature = "spirv"))]
257async fn request_vulkan_device(_adapter: &Adapter) -> Option<(Device, Queue)> {
258 None
259}
260
261#[cfg(all(feature = "msl", target_os = "macos"))]
262async fn request_metal_device(adapter: &Adapter) -> Option<(Device, Queue)> {
263 if is_metal(adapter) {
264 Some(metal::request_metal_device(adapter).await)
265 } else {
266 None
267 }
268}
269
270#[cfg(not(all(feature = "msl", target_os = "macos")))]
271async fn request_metal_device(_adapter: &Adapter) -> Option<(Device, Queue)> {
272 None
273}
274
275pub fn register_features(
276 adapter: &Adapter,
277 props: &mut DeviceProperties,
278 comp_options: &mut WgpuCompilationOptions,
279 memory_config: &MemoryConfiguration,
280) {
281 if register_vulkan_features(adapter, props, comp_options, memory_config) {
282 return;
283 }
284 if register_metal_features(adapter, props, comp_options, memory_config) {
285 return;
286 }
287 wgsl::register_wgsl_features(adapter, props, comp_options);
288}
289
290#[cfg(feature = "spirv")]
291pub fn register_vulkan_features(
292 adapter: &Adapter,
293 props: &mut DeviceProperties,
294 comp_options: &mut WgpuCompilationOptions,
295 memory_config: &MemoryConfiguration,
296) -> bool {
297 if is_vulkan(adapter) {
298 vulkan::register_vulkan_features(adapter, props, comp_options, memory_config)
299 } else {
300 false
301 }
302}
303
304#[cfg(not(feature = "spirv"))]
305pub fn register_vulkan_features(
306 _adapter: &Adapter,
307 _props: &mut DeviceProperties,
308 _comp_options: &mut WgpuCompilationOptions,
309 _memory_config: &MemoryConfiguration,
310) -> bool {
311 false
312}
313
314#[cfg(all(feature = "msl", target_os = "macos"))]
315pub fn register_metal_features(
316 adapter: &Adapter,
317 props: &mut DeviceProperties,
318 comp_options: &mut WgpuCompilationOptions,
319 _memory_config: &MemoryConfiguration,
320) -> bool {
321 if is_metal(adapter) {
322 metal::register_metal_features(adapter, props, comp_options);
323 true
324 } else {
325 false
326 }
327}
328
329#[cfg(not(all(feature = "msl", target_os = "macos")))]
330pub fn register_metal_features(
331 _adapter: &Adapter,
332 _props: &mut DeviceProperties,
333 _comp_options: &mut WgpuCompilationOptions,
334 _memory_config: &MemoryConfiguration,
335) -> bool {
336 false
337}
338
339#[cfg(feature = "spirv")]
340fn is_vulkan(adapter: &Adapter) -> bool {
341 unsafe { adapter.as_hal::<wgpu::hal::api::Vulkan>().is_some() }
342}
343
344#[cfg(all(feature = "msl", target_os = "macos"))]
345fn is_metal(adapter: &Adapter) -> bool {
346 unsafe { adapter.as_hal::<wgpu::hal::api::Metal>().is_some() }
347}