1use super::wgsl;
2use crate::AutoRepresentationRef;
3use crate::WgpuServer;
4use cubecl_core::{ExecutionMode, WgpuCompilationOptions, hash::StableHash, server::Bindings};
5use cubecl_ir::DeviceProperties;
6use cubecl_runtime::{compiler::CompilationError, id::KernelId};
7use std::{borrow::Cow, sync::Arc};
8use wgpu::{
9 Adapter, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType, BufferBindingType,
10 ComputePipeline, Device, PipelineLayoutDescriptor, Queue, ShaderModule, ShaderModuleDescriptor,
11 ShaderStages,
12};
13
14#[cfg(feature = "spirv")]
15use super::vulkan;
16
17#[cfg(all(feature = "msl", target_os = "macos"))]
18use super::metal;
19#[cfg(all(feature = "msl", target_os = "macos"))]
20use cubecl_cpp::metal as cpp_metal;
21
22impl WgpuServer {
23 #[allow(
27 clippy::type_complexity,
28 reason = "required because of error propagation"
29 )]
30 #[allow(unused_variables)]
31 pub fn load_cached_pipeline(
32 &self,
33 kernel_id: &KernelId,
34 bindings: &Bindings,
35 mode: ExecutionMode,
36 ) -> Result<Option<Result<Arc<ComputePipeline>, (u64, StableHash)>>, CompilationError> {
37 #[cfg(not(feature = "spirv"))]
38 let res = Ok(None);
39 #[cfg(feature = "spirv")]
40 let res = if let Some(cache) = &self.spirv_cache {
41 let key = (self.utilities.properties_hash, kernel_id.stable_hash());
42 if let Some(entry) = cache.get(&key) {
43 log::trace!("Using SPIR-V cache");
44
45 let repr = AutoRepresentationRef::SpirV(&entry.kernel);
46 let module = self.create_module(&entry.entrypoint_name, Some(repr), "", mode)?;
47 let pipeline =
48 self.create_pipeline(&entry.entrypoint_name, Some(repr), module, bindings);
49 Ok(Some(Ok(pipeline)))
50 } else {
51 Ok(Some(Err(key)))
52 }
53 } else {
54 Ok(None)
55 };
56
57 res
58 }
59
60 pub fn create_module(
61 &self,
62 entrypoint_name: &str,
63 repr: Option<AutoRepresentationRef<'_>>,
64 source: &str,
65 mode: ExecutionMode,
66 ) -> Result<ShaderModule, CompilationError> {
67 #[allow(unused_assignments)]
68 #[cfg(not(target_family = "wasm"))]
69 let mut error_scope = None;
70
71 match repr {
72 #[cfg(feature = "spirv")]
73 Some(AutoRepresentationRef::SpirV(repr)) => unsafe {
74 Ok(self.device.create_shader_module_passthrough(
75 wgpu::ShaderModuleDescriptorPassthrough {
76 label: Some(entrypoint_name),
77 spirv: Some(Cow::Borrowed(&repr.assembled_module)),
78 ..Default::default()
79 },
80 ))
81 },
82 #[cfg(all(feature = "msl", target_os = "macos"))]
83 Some(AutoRepresentationRef::Msl(repr)) => unsafe {
84 Ok(self.device.create_shader_module_passthrough(
85 wgpu::ShaderModuleDescriptorPassthrough {
86 entry_point: entrypoint_name.to_string(),
87 label: Some(entrypoint_name),
88 msl: Some(Cow::Borrowed(source)),
89 num_workgroups: (repr.cube_dim.x, repr.cube_dim.y, repr.cube_dim.z),
90 ..Default::default()
91 },
92 ))
93 },
94 _ => {
95 let _ = entrypoint_name; let checks = wgpu::ShaderRuntimeChecks {
97 bounds_checks: false,
101 force_loop_bounding: mode == ExecutionMode::Checked,
103 ray_query_initialization_tracking: false,
104 };
105
106 #[cfg(not(target_family = "wasm"))]
107 {
108 error_scope = Some(self.device.push_error_scope(wgpu::ErrorFilter::Validation));
109 }
110
111 let module = unsafe {
114 self.device.create_shader_module_trusted(
115 ShaderModuleDescriptor {
116 label: None,
117 source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
118 },
119 checks,
120 )
121 };
122
123 #[cfg(not(target_family = "wasm"))]
124 if let Some(scope) = error_scope
125 && let Some(err) = cubecl_common::future::block_on(scope.pop())
126 {
127 return Err(CompilationError::Generic {
128 reason: format!("{err}"),
129 backtrace: cubecl_common::backtrace::BackTrace::capture(),
130 });
131 }
132
133 Ok(module)
134 }
135 }
136 }
137
138 #[allow(unused_variables)]
139 pub fn create_pipeline(
140 &self,
141 entrypoint_name: &str,
142 repr: Option<AutoRepresentationRef<'_>>,
143 module: ShaderModule,
144 bindings: &Bindings,
145 ) -> Arc<ComputePipeline> {
146 let bindings_info = match repr {
147 Some(AutoRepresentationRef::Wgsl(repr)) => Some(wgsl::bindings(repr)),
148 #[cfg(all(feature = "msl", target_os = "macos"))]
149 Some(AutoRepresentationRef::Msl(repr)) => Some(cpp_metal::bindings(repr)),
150 #[cfg(feature = "spirv")]
151 Some(AutoRepresentationRef::SpirV(repr)) => Some(vulkan::bindings(repr, bindings)),
152 _ => None,
153 };
154
155 let layout = bindings_info.map(|bindings| {
156 let (mut bindings, meta) = bindings;
157 if !cfg!(exclusive_memory_only) {
160 bindings.fill(cubecl_runtime::kernel::Visibility::ReadWrite);
161 }
162
163 let bindings = bindings
164 .into_iter()
165 .chain(meta)
166 .enumerate()
167 .map(|(i, visibility)| BindGroupLayoutEntry {
168 binding: i as u32,
169 visibility: ShaderStages::COMPUTE,
170 ty: BindingType::Buffer {
171 ty: BufferBindingType::Storage {
172 read_only: matches!(
173 visibility,
174 cubecl_runtime::kernel::Visibility::Read
175 ),
176 },
177 has_dynamic_offset: false,
178 min_binding_size: None,
179 },
180 count: None,
181 })
182 .collect::<Vec<_>>();
183 let layout = self
184 .device
185 .create_bind_group_layout(&BindGroupLayoutDescriptor {
186 label: None,
187 entries: &bindings,
188 });
189 self.device
190 .create_pipeline_layout(&PipelineLayoutDescriptor {
191 label: None,
192 bind_group_layouts: &[&layout],
193 immediate_size: 0,
194 })
195 });
196
197 let pipeline = self
198 .device
199 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
200 label: Some(entrypoint_name),
201 layout: layout.as_ref(),
202 module: &module,
203 entry_point: Some(entrypoint_name),
204 compilation_options: wgpu::PipelineCompilationOptions {
205 zero_initialize_workgroup_memory: false,
206 ..Default::default()
207 },
208 cache: None,
209 });
210 Arc::new(pipeline)
211 }
212}
213
214#[cfg(all(not(feature = "spirv"), not(feature = "msl")))]
215pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
216 wgsl::request_device(adapter).await
217}
218
219#[cfg(feature = "spirv")]
220pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
221 if is_vulkan(adapter) {
222 vulkan::request_vulkan_device(adapter).await
223 } else {
224 wgsl::request_device(adapter).await
225 }
226}
227
228#[cfg(all(feature = "msl", target_os = "macos"))]
229pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
230 use super::metal;
231
232 if is_metal(adapter) {
233 metal::request_metal_device(adapter).await
234 } else {
235 panic!("metal device not found!");
236 }
237}
238
239#[cfg(all(not(feature = "spirv"), not(feature = "msl")))]
240pub fn register_features(
241 adapter: &Adapter,
242 props: &mut DeviceProperties,
243 comp_options: &mut WgpuCompilationOptions,
244) {
245 wgsl::register_wgsl_features(adapter, props, comp_options);
246}
247
248#[cfg(feature = "spirv")]
249pub fn register_features(
250 adapter: &Adapter,
251 props: &mut DeviceProperties,
252 comp_options: &mut WgpuCompilationOptions,
253) {
254 if is_vulkan(adapter) {
255 vulkan::register_vulkan_features(adapter, props, comp_options);
256 } else {
257 wgsl::register_wgsl_features(adapter, props, comp_options);
258 }
259}
260
261#[cfg(all(feature = "msl", target_os = "macos"))]
262pub fn register_features(
263 adapter: &Adapter,
264 props: &mut DeviceProperties,
265 comp_options: &mut WgpuCompilationOptions,
266) {
267 if is_metal(adapter) {
268 metal::register_metal_features(adapter, props, comp_options);
269 } else {
270 panic!("metal device not found!");
271 }
272}
273
274#[cfg(feature = "spirv")]
275fn is_vulkan(adapter: &Adapter) -> bool {
276 unsafe { adapter.as_hal::<wgpu::hal::api::Vulkan>().is_some() }
277}
278
279#[cfg(all(feature = "msl", target_os = "macos"))]
280fn is_metal(adapter: &Adapter) -> bool {
281 unsafe { adapter.as_hal::<wgpu::hal::api::Metal>().is_some() }
282}