1use super::wgsl;
2use crate::{AutoCompiler, AutoRepresentation, WgpuServer};
3use cubecl_core::{ExecutionMode, WgpuCompilationOptions, prelude::CompiledKernel};
4use cubecl_runtime::{DeviceProperties, compiler::CompilationError};
5use std::{borrow::Cow, sync::Arc};
6use wgpu::{
7 Adapter, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType, BufferBindingType,
8 ComputePipeline, Device, PipelineLayoutDescriptor, Queue, ShaderModuleDescriptor, ShaderStages,
9};
10
11#[cfg(feature = "spirv")]
12use super::vulkan;
13
14#[cfg(all(feature = "msl", target_os = "macos"))]
15use super::metal;
16#[cfg(all(feature = "msl", target_os = "macos"))]
17use cubecl_cpp::metal as cpp_metal;
18
19impl WgpuServer {
20 pub fn create_pipeline(
21 &mut self,
22 kernel: CompiledKernel<AutoCompiler>,
23 mode: ExecutionMode,
24 ) -> Result<Arc<ComputePipeline>, CompilationError> {
25 let module = match &kernel.repr {
26 #[cfg(feature = "spirv")]
27 Some(AutoRepresentation::SpirV(repr)) => {
28 let spirv = repr.assemble();
29 unsafe {
30 self.device.create_shader_module_passthrough(
31 wgpu::ShaderModuleDescriptorPassthrough::SpirV(
32 wgpu::ShaderModuleDescriptorSpirV {
33 label: Some(&kernel.entrypoint_name),
34 source: Cow::Borrowed(&spirv),
35 },
36 ),
37 )
38 }
39 }
40 #[cfg(all(feature = "msl", target_os = "macos"))]
41 Some(AutoRepresentation::Msl(repr)) => {
42 let source = &kernel.source;
43 unsafe {
44 self.device.create_shader_module_passthrough(
45 wgpu::ShaderModuleDescriptorPassthrough::Msl(
46 wgpu::ShaderModuleDescriptorMsl {
47 entry_point: kernel.entrypoint_name.clone(),
48 label: Some(&kernel.entrypoint_name),
49 source: Cow::Borrowed(source),
50 num_workgroups: (repr.cube_dim.x, repr.cube_dim.y, repr.cube_dim.z),
51 },
52 ),
53 )
54 }
55 }
56 _ => {
57 let source = &kernel.source;
58
59 let checks = wgpu::ShaderRuntimeChecks {
60 bounds_checks: false,
64 force_loop_bounding: mode == ExecutionMode::Checked,
66 };
67
68 unsafe {
71 self.device.create_shader_module_trusted(
72 ShaderModuleDescriptor {
73 label: None,
74 source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
75 },
76 checks,
77 )
78 }
79 }
80 };
81
82 #[cfg(not(target_family = "wasm"))]
83 {
84 let info = cubecl_common::future::block_on(module.get_compilation_info());
85 let mut errors = info
86 .messages
87 .into_iter()
88 .filter(|msg| matches!(msg.message_type, wgpu::CompilationMessageType::Error))
89 .map(|msg| CompilationError::Generic {
90 context: format!("{msg:?}"),
91 })
92 .collect::<Vec<_>>();
93
94 if !errors.is_empty() {
95 if errors.len() > 1 {
96 return Err(CompilationError::Multiple {
97 context: format!("Unable to compile kernel: {}", kernel.entrypoint_name),
98 errors,
99 });
100 } else {
101 return Err(errors.remove(0));
102 }
103 }
104 }
105
106 let bindings_info = match &kernel.repr {
107 Some(AutoRepresentation::Wgsl(repr)) => Some(wgsl::bindings(repr)),
108 #[cfg(all(feature = "msl", target_os = "macos"))]
109 Some(AutoRepresentation::Msl(repr)) => Some(cpp_metal::bindings(repr)),
110 #[cfg(feature = "spirv")]
111 Some(AutoRepresentation::SpirV(repr)) => Some(vulkan::bindings(repr)),
112 _ => None,
113 };
114
115 let layout = bindings_info.map(|bindings| {
116 let (mut bindings, meta) = bindings;
117 if !cfg!(exclusive_memory_only) {
120 bindings.fill(cubecl_runtime::kernel::Visibility::ReadWrite);
121 }
122
123 let bindings = bindings
124 .into_iter()
125 .chain(meta)
126 .enumerate()
127 .map(|(i, visibility)| BindGroupLayoutEntry {
128 binding: i as u32,
129 visibility: ShaderStages::COMPUTE,
130 ty: BindingType::Buffer {
131 ty: BufferBindingType::Storage {
132 read_only: matches!(
133 visibility,
134 cubecl_runtime::kernel::Visibility::Read
135 ),
136 },
137 has_dynamic_offset: false,
138 min_binding_size: None,
139 },
140 count: None,
141 })
142 .collect::<Vec<_>>();
143 let layout = self
144 .device
145 .create_bind_group_layout(&BindGroupLayoutDescriptor {
146 label: None,
147 entries: &bindings,
148 });
149 self.device
150 .create_pipeline_layout(&PipelineLayoutDescriptor {
151 label: None,
152 bind_group_layouts: &[&layout],
153 push_constant_ranges: &[],
154 })
155 });
156
157 let pipeline = self
158 .device
159 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
160 label: Some(&kernel.entrypoint_name),
161 layout: layout.as_ref(),
162 module: &module,
163 entry_point: Some(&kernel.entrypoint_name),
164 compilation_options: wgpu::PipelineCompilationOptions {
165 zero_initialize_workgroup_memory: false,
166 ..Default::default()
167 },
168 cache: None,
169 });
170 Ok(Arc::new(pipeline))
171 }
172}
173
174#[cfg(all(not(feature = "spirv"), not(feature = "msl")))]
175pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
176 wgsl::request_device(adapter).await
177}
178
179#[cfg(feature = "spirv")]
180pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
181 if is_vulkan(adapter) {
182 vulkan::request_vulkan_device(adapter).await
183 } else {
184 wgsl::request_device(adapter).await
185 }
186}
187
188#[cfg(all(feature = "msl", target_os = "macos"))]
189pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
190 use super::metal;
191
192 if is_metal(adapter) {
193 metal::request_metal_device(adapter).await
194 } else {
195 panic!("metal device not found!");
196 }
197}
198
199#[cfg(all(not(feature = "spirv"), not(feature = "msl")))]
200pub fn register_features(
201 adapter: &Adapter,
202 props: &mut DeviceProperties,
203 comp_options: &mut WgpuCompilationOptions,
204) {
205 wgsl::register_wgsl_features(adapter, props, comp_options);
206}
207
208#[cfg(feature = "spirv")]
209pub fn register_features(
210 adapter: &Adapter,
211 props: &mut DeviceProperties,
212 comp_options: &mut WgpuCompilationOptions,
213) {
214 if is_vulkan(adapter) {
215 vulkan::register_vulkan_features(adapter, props, comp_options);
216 } else {
217 wgsl::register_wgsl_features(adapter, props, comp_options);
218 }
219}
220
221#[cfg(all(feature = "msl", target_os = "macos"))]
222pub fn register_features(
223 adapter: &Adapter,
224 props: &mut DeviceProperties,
225 comp_options: &mut WgpuCompilationOptions,
226) {
227 if is_metal(adapter) {
228 metal::register_metal_features(adapter, props, comp_options);
229 } else {
230 panic!("metal device not found!");
231 }
232}
233
234#[cfg(feature = "spirv")]
235fn is_vulkan(adapter: &Adapter) -> bool {
236 unsafe { adapter.as_hal::<wgpu::hal::api::Vulkan>().is_some() }
237}
238
239#[cfg(all(feature = "msl", target_os = "macos"))]
240fn is_metal(adapter: &Adapter) -> bool {
241 unsafe { adapter.as_hal::<wgpu::hal::api::Metal>().is_some() }
242}