1use super::wgsl;
2use crate::AutoRepresentationRef;
3use crate::WgpuServer;
4use cubecl_core::MemoryConfiguration;
5use cubecl_core::{
6 ExecutionMode, WgpuCompilationOptions, hash::StableHash, server::KernelArguments,
7};
8use cubecl_ir::DeviceProperties;
9use cubecl_runtime::{compiler::CompilationError, id::KernelId};
10use std::{borrow::Cow, sync::Arc};
11use wgpu::{
12 Adapter, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType, BufferBindingType,
13 ComputePipeline, Device, PipelineLayoutDescriptor, Queue, ShaderModule, ShaderModuleDescriptor,
14 ShaderStages,
15};
16
17#[cfg(feature = "spirv")]
18use super::vulkan;
19
20#[cfg(all(feature = "msl", target_os = "macos"))]
21use super::metal;
22#[cfg(all(feature = "msl", target_os = "macos"))]
23use cubecl_cpp::metal as cpp_metal;
24
25impl WgpuServer {
26 #[allow(
30 clippy::type_complexity,
31 reason = "required because of error propagation"
32 )]
33 #[allow(unused_variables)]
34 pub fn load_cached_pipeline(
35 &self,
36 kernel_id: &KernelId,
37 bindings: &KernelArguments,
38 mode: ExecutionMode,
39 ) -> Result<Option<Result<Arc<ComputePipeline>, (u64, StableHash)>>, CompilationError> {
40 #[cfg(not(feature = "spirv"))]
41 let res = Ok(None);
42 #[cfg(feature = "spirv")]
43 let res = if let Some(cache) = &self.spirv_cache {
44 let key = (self.utilities.properties_hash, kernel_id.stable_hash());
45 if let Some(entry) = cache.get(&key) {
46 log::trace!("Using SPIR-V cache");
47
48 let repr = AutoRepresentationRef::SpirV(&entry.kernel);
49 let module = self.create_module(&entry.entrypoint_name, Some(repr), "", mode)?;
50 let pipeline =
51 self.create_pipeline(&entry.entrypoint_name, Some(repr), module, bindings);
52 Ok(Some(Ok(pipeline)))
53 } else {
54 Ok(Some(Err(key)))
55 }
56 } else {
57 Ok(None)
58 };
59
60 res
61 }
62
63 pub fn create_module(
64 &self,
65 entrypoint_name: &str,
66 repr: Option<AutoRepresentationRef<'_>>,
67 source: &str,
68 mode: ExecutionMode,
69 ) -> Result<ShaderModule, CompilationError> {
70 #[allow(unused_assignments)]
71 #[cfg(not(target_family = "wasm"))]
72 let mut error_scope = None;
73
74 match repr {
75 #[cfg(feature = "spirv")]
76 Some(AutoRepresentationRef::SpirV(repr)) => unsafe {
77 Ok(self.device.create_shader_module_passthrough(
78 wgpu::ShaderModuleDescriptorPassthrough {
79 label: Some(entrypoint_name),
80 spirv: Some(Cow::Borrowed(&repr.assembled_module)),
81 ..Default::default()
82 },
83 ))
84 },
85 #[cfg(all(feature = "msl", target_os = "macos"))]
86 Some(AutoRepresentationRef::Msl(repr)) => unsafe {
87 Ok(self.device.create_shader_module_passthrough(
88 wgpu::ShaderModuleDescriptorPassthrough {
89 label: Some(entrypoint_name),
90 msl: Some(Cow::Borrowed(source)),
91 num_workgroups: (repr.cube_dim.x, repr.cube_dim.y, repr.cube_dim.z),
92 ..Default::default()
93 },
94 ))
95 },
96 _ => {
97 let _ = entrypoint_name; let checks = wgpu::ShaderRuntimeChecks {
99 bounds_checks: false,
103 force_loop_bounding: mode == ExecutionMode::Checked,
105 ..wgpu::ShaderRuntimeChecks::unchecked()
106 };
107
108 #[cfg(not(target_family = "wasm"))]
109 {
110 error_scope = Some(self.device.push_error_scope(wgpu::ErrorFilter::Validation));
111 }
112
113 let module = unsafe {
116 self.device.create_shader_module_trusted(
117 ShaderModuleDescriptor {
118 label: None,
119 source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
120 },
121 checks,
122 )
123 };
124
125 #[cfg(not(target_family = "wasm"))]
126 if let Some(scope) = error_scope
127 && let Some(err) = cubecl_common::future::block_on(scope.pop())
128 {
129 return Err(CompilationError::Generic {
130 reason: format!("{err}"),
131 backtrace: cubecl_common::backtrace::BackTrace::capture(),
132 });
133 }
134
135 Ok(module)
136 }
137 }
138 }
139
140 #[allow(unused_variables)]
141 pub fn create_pipeline(
142 &self,
143 entrypoint_name: &str,
144 repr: Option<AutoRepresentationRef<'_>>,
145 module: ShaderModule,
146 bindings: &KernelArguments,
147 ) -> Arc<ComputePipeline> {
148 let bindings_info = match repr {
149 Some(AutoRepresentationRef::Wgsl(repr)) => Some(wgsl::bindings(repr, bindings)),
150 #[cfg(all(feature = "msl", target_os = "macos"))]
151 Some(AutoRepresentationRef::Msl(repr)) => Some(cpp_metal::bindings(repr, bindings)),
152 #[cfg(feature = "spirv")]
153 Some(AutoRepresentationRef::SpirV(repr)) => Some(vulkan::bindings(repr, bindings)),
154 _ => None,
155 };
156
157 let layout = bindings_info.map(|bindings| {
158 let (mut bindings, info, uniform_info) = bindings;
159 if !cfg!(exclusive_memory_only) {
162 bindings.fill(cubecl_runtime::kernel::Visibility::ReadWrite);
163 }
164
165 let info = info.map(|_| match uniform_info {
166 true => BufferBindingType::Uniform,
167 false => BufferBindingType::Storage { read_only: true },
168 });
169
170 let bindings = bindings
171 .into_iter()
172 .map(|visibility| BufferBindingType::Storage {
173 read_only: matches!(visibility, cubecl_runtime::kernel::Visibility::Read),
174 })
175 .chain(info)
176 .enumerate()
177 .map(|(i, ty)| BindGroupLayoutEntry {
178 binding: i as u32,
179 visibility: ShaderStages::COMPUTE,
180 ty: BindingType::Buffer {
181 ty,
182 has_dynamic_offset: false,
183 min_binding_size: None,
184 },
185 count: None,
186 })
187 .collect::<Vec<_>>();
188 let layout = self
189 .device
190 .create_bind_group_layout(&BindGroupLayoutDescriptor {
191 label: None,
192 entries: &bindings,
193 });
194 self.device
195 .create_pipeline_layout(&PipelineLayoutDescriptor {
196 label: None,
197 bind_group_layouts: &[Some(&layout)],
198 immediate_size: 0,
199 })
200 });
201
202 let pipeline = self
203 .device
204 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
205 label: Some(entrypoint_name),
206 layout: layout.as_ref(),
207 module: &module,
208 entry_point: Some(entrypoint_name),
209 compilation_options: wgpu::PipelineCompilationOptions {
210 zero_initialize_workgroup_memory: false,
211 ..Default::default()
212 },
213 cache: None,
214 });
215 Arc::new(pipeline)
216 }
217}
218
219pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
220 if let Some(result) = request_vulkan_device(adapter).await {
221 return result;
222 }
223 if let Some(result) = request_metal_device(adapter).await {
224 return result;
225 }
226 wgsl::request_device(adapter).await
227}
228
229#[cfg(feature = "spirv")]
230async fn request_vulkan_device(adapter: &Adapter) -> Option<(Device, Queue)> {
231 if is_vulkan(adapter) {
232 vulkan::request_vulkan_device(adapter).await
233 } else {
234 None
235 }
236}
237
238#[cfg(not(feature = "spirv"))]
239async fn request_vulkan_device(_adapter: &Adapter) -> Option<(Device, Queue)> {
240 None
241}
242
243#[cfg(all(feature = "msl", target_os = "macos"))]
244async fn request_metal_device(adapter: &Adapter) -> Option<(Device, Queue)> {
245 if is_metal(adapter) {
246 Some(metal::request_metal_device(adapter).await)
247 } else {
248 None
249 }
250}
251
252#[cfg(not(all(feature = "msl", target_os = "macos")))]
253async fn request_metal_device(_adapter: &Adapter) -> Option<(Device, Queue)> {
254 None
255}
256
257pub fn register_features(
258 adapter: &Adapter,
259 props: &mut DeviceProperties,
260 comp_options: &mut WgpuCompilationOptions,
261 memory_config: &MemoryConfiguration,
262) {
263 if register_vulkan_features(adapter, props, comp_options, memory_config) {
264 return;
265 }
266 if register_metal_features(adapter, props, comp_options, memory_config) {
267 return;
268 }
269 wgsl::register_wgsl_features(adapter, props, comp_options);
270}
271
272#[cfg(feature = "spirv")]
273pub fn register_vulkan_features(
274 adapter: &Adapter,
275 props: &mut DeviceProperties,
276 comp_options: &mut WgpuCompilationOptions,
277 memory_config: &MemoryConfiguration,
278) -> bool {
279 if is_vulkan(adapter) {
280 vulkan::register_vulkan_features(adapter, props, comp_options, memory_config)
281 } else {
282 false
283 }
284}
285
286#[cfg(not(feature = "spirv"))]
287pub fn register_vulkan_features(
288 _adapter: &Adapter,
289 _props: &mut DeviceProperties,
290 _comp_options: &mut WgpuCompilationOptions,
291 _memory_config: &MemoryConfiguration,
292) -> bool {
293 false
294}
295
296#[cfg(all(feature = "msl", target_os = "macos"))]
297pub fn register_metal_features(
298 adapter: &Adapter,
299 props: &mut DeviceProperties,
300 comp_options: &mut WgpuCompilationOptions,
301 _memory_config: &MemoryConfiguration,
302) -> bool {
303 if is_metal(adapter) {
304 metal::register_metal_features(adapter, props, comp_options);
305 true
306 } else {
307 false
308 }
309}
310
311#[cfg(not(all(feature = "msl", target_os = "macos")))]
312pub fn register_metal_features(
313 _adapter: &Adapter,
314 _props: &mut DeviceProperties,
315 _comp_options: &mut WgpuCompilationOptions,
316 _memory_config: &MemoryConfiguration,
317) -> bool {
318 false
319}
320
321#[cfg(feature = "spirv")]
322fn is_vulkan(adapter: &Adapter) -> bool {
323 unsafe { adapter.as_hal::<wgpu::hal::api::Vulkan>().is_some() }
324}
325
326#[cfg(all(feature = "msl", target_os = "macos"))]
327fn is_metal(adapter: &Adapter) -> bool {
328 unsafe { adapter.as_hal::<wgpu::hal::api::Metal>().is_some() }
329}