1use super::wgsl;
2use crate::{AutoCompiler, AutoRepresentation, WgpuServer};
3use cubecl_core::{ExecutionMode, WgpuCompilationOptions, prelude::CompiledKernel};
4use cubecl_runtime::{DeviceProperties, compiler::CompilationError};
5use std::{borrow::Cow, sync::Arc};
6use wgpu::{
7 Adapter, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType, BufferBindingType,
8 ComputePipeline, Device, PipelineLayoutDescriptor, Queue, ShaderModuleDescriptor, ShaderStages,
9};
10
11#[cfg(not(target_family = "wasm"))]
12use crate::errors::{fetch_error, track_error};
13
14#[cfg(feature = "spirv")]
15use super::vulkan;
16
17#[cfg(all(feature = "msl", target_os = "macos"))]
18use super::metal;
19#[cfg(all(feature = "msl", target_os = "macos"))]
20use cubecl_cpp::metal as cpp_metal;
21
22impl WgpuServer {
23 pub fn create_pipeline(
24 &mut self,
25 kernel: CompiledKernel<AutoCompiler>,
26 mode: ExecutionMode,
27 ) -> Result<Arc<ComputePipeline>, CompilationError> {
28 let module = match &kernel.repr {
29 #[cfg(feature = "spirv")]
30 Some(AutoRepresentation::SpirV(repr)) => {
31 let spirv = repr.assemble();
32 unsafe {
33 self.device.create_shader_module_passthrough(
34 wgpu::ShaderModuleDescriptorPassthrough::SpirV(
35 wgpu::ShaderModuleDescriptorSpirV {
36 label: Some(&kernel.entrypoint_name),
37 source: Cow::Borrowed(&spirv),
38 },
39 ),
40 )
41 }
42 }
43 #[cfg(all(feature = "msl", target_os = "macos"))]
44 Some(AutoRepresentation::Msl(repr)) => {
45 let source = &kernel.source;
46 unsafe {
47 self.device.create_shader_module_passthrough(
48 wgpu::ShaderModuleDescriptorPassthrough::Msl(
49 wgpu::ShaderModuleDescriptorMsl {
50 entry_point: kernel.entrypoint_name.clone(),
51 label: Some(&kernel.entrypoint_name),
52 source: Cow::Borrowed(source),
53 num_workgroups: (repr.cube_dim.x, repr.cube_dim.y, repr.cube_dim.z),
54 },
55 ),
56 )
57 }
58 }
59 _ => {
60 let source = &kernel.source;
61
62 let checks = wgpu::ShaderRuntimeChecks {
63 bounds_checks: false,
67 force_loop_bounding: mode == ExecutionMode::Checked,
69 };
70
71 #[cfg(not(target_family = "wasm"))]
72 track_error(&self.device, wgpu::ErrorFilter::Validation);
73
74 unsafe {
77 self.device.create_shader_module_trusted(
78 ShaderModuleDescriptor {
79 label: None,
80 source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
81 },
82 checks,
83 )
84 }
85 }
86 };
87
88 #[cfg(not(target_family = "wasm"))]
89 if let Some(err) = cubecl_common::future::block_on(fetch_error(&self.device)) {
90 return Err(CompilationError::Generic {
91 reason: format!("{err}"),
92 backtrace: cubecl_common::backtrace::BackTrace::capture(),
93 });
94 }
95
96 let bindings_info = match &kernel.repr {
97 Some(AutoRepresentation::Wgsl(repr)) => Some(wgsl::bindings(repr)),
98 #[cfg(all(feature = "msl", target_os = "macos"))]
99 Some(AutoRepresentation::Msl(repr)) => Some(cpp_metal::bindings(repr)),
100 #[cfg(feature = "spirv")]
101 Some(AutoRepresentation::SpirV(repr)) => Some(vulkan::bindings(repr)),
102 _ => None,
103 };
104
105 let layout = bindings_info.map(|bindings| {
106 let (mut bindings, meta) = bindings;
107 if !cfg!(exclusive_memory_only) {
110 bindings.fill(cubecl_runtime::kernel::Visibility::ReadWrite);
111 }
112
113 let bindings = bindings
114 .into_iter()
115 .chain(meta)
116 .enumerate()
117 .map(|(i, visibility)| BindGroupLayoutEntry {
118 binding: i as u32,
119 visibility: ShaderStages::COMPUTE,
120 ty: BindingType::Buffer {
121 ty: BufferBindingType::Storage {
122 read_only: matches!(
123 visibility,
124 cubecl_runtime::kernel::Visibility::Read
125 ),
126 },
127 has_dynamic_offset: false,
128 min_binding_size: None,
129 },
130 count: None,
131 })
132 .collect::<Vec<_>>();
133 let layout = self
134 .device
135 .create_bind_group_layout(&BindGroupLayoutDescriptor {
136 label: None,
137 entries: &bindings,
138 });
139 self.device
140 .create_pipeline_layout(&PipelineLayoutDescriptor {
141 label: None,
142 bind_group_layouts: &[&layout],
143 push_constant_ranges: &[],
144 })
145 });
146
147 let pipeline = self
148 .device
149 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
150 label: Some(&kernel.entrypoint_name),
151 layout: layout.as_ref(),
152 module: &module,
153 entry_point: Some(&kernel.entrypoint_name),
154 compilation_options: wgpu::PipelineCompilationOptions {
155 zero_initialize_workgroup_memory: false,
156 ..Default::default()
157 },
158 cache: None,
159 });
160 Ok(Arc::new(pipeline))
161 }
162}
163
164#[cfg(all(not(feature = "spirv"), not(feature = "msl")))]
165pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
166 wgsl::request_device(adapter).await
167}
168
169#[cfg(feature = "spirv")]
170pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
171 if is_vulkan(adapter) {
172 vulkan::request_vulkan_device(adapter).await
173 } else {
174 wgsl::request_device(adapter).await
175 }
176}
177
178#[cfg(all(feature = "msl", target_os = "macos"))]
179pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
180 use super::metal;
181
182 if is_metal(adapter) {
183 metal::request_metal_device(adapter).await
184 } else {
185 panic!("metal device not found!");
186 }
187}
188
189#[cfg(all(not(feature = "spirv"), not(feature = "msl")))]
190pub fn register_features(
191 adapter: &Adapter,
192 props: &mut DeviceProperties,
193 comp_options: &mut WgpuCompilationOptions,
194) {
195 wgsl::register_wgsl_features(adapter, props, comp_options);
196}
197
198#[cfg(feature = "spirv")]
199pub fn register_features(
200 adapter: &Adapter,
201 props: &mut DeviceProperties,
202 comp_options: &mut WgpuCompilationOptions,
203) {
204 if is_vulkan(adapter) {
205 vulkan::register_vulkan_features(adapter, props, comp_options);
206 } else {
207 wgsl::register_wgsl_features(adapter, props, comp_options);
208 }
209}
210
211#[cfg(all(feature = "msl", target_os = "macos"))]
212pub fn register_features(
213 adapter: &Adapter,
214 props: &mut DeviceProperties,
215 comp_options: &mut WgpuCompilationOptions,
216) {
217 if is_metal(adapter) {
218 metal::register_metal_features(adapter, props, comp_options);
219 } else {
220 panic!("metal device not found!");
221 }
222}
223
224#[cfg(feature = "spirv")]
225fn is_vulkan(adapter: &Adapter) -> bool {
226 unsafe { adapter.as_hal::<wgpu::hal::api::Vulkan>().is_some() }
227}
228
229#[cfg(all(feature = "msl", target_os = "macos"))]
230fn is_metal(adapter: &Adapter) -> bool {
231 unsafe { adapter.as_hal::<wgpu::hal::api::Metal>().is_some() }
232}