1use super::wgsl;
2use crate::{AutoCompiler, AutoRepresentation, WgpuServer};
3use cubecl_core::{ExecutionMode, WgpuCompilationOptions, prelude::CompiledKernel};
4use cubecl_ir::DeviceProperties;
5use cubecl_runtime::compiler::CompilationError;
6use std::{borrow::Cow, sync::Arc};
7use wgpu::{
8 Adapter, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType, BufferBindingType,
9 ComputePipeline, Device, PipelineLayoutDescriptor, Queue, ShaderModuleDescriptor, ShaderStages,
10};
11
12#[cfg(not(target_family = "wasm"))]
13use crate::errors::{fetch_error, track_error};
14
15#[cfg(feature = "spirv")]
16use super::vulkan;
17
18#[cfg(all(feature = "msl", target_os = "macos"))]
19use super::metal;
20#[cfg(all(feature = "msl", target_os = "macos"))]
21use cubecl_cpp::metal as cpp_metal;
22
23impl WgpuServer {
24 pub fn create_pipeline(
25 &mut self,
26 kernel: CompiledKernel<AutoCompiler>,
27 mode: ExecutionMode,
28 ) -> Result<Arc<ComputePipeline>, CompilationError> {
29 let module = match &kernel.repr {
30 #[cfg(feature = "spirv")]
31 Some(AutoRepresentation::SpirV(repr)) => {
32 let spirv = repr.assemble();
33 unsafe {
34 self.device.create_shader_module_passthrough(
35 wgpu::ShaderModuleDescriptorPassthrough::SpirV(
36 wgpu::ShaderModuleDescriptorSpirV {
37 label: Some(&kernel.entrypoint_name),
38 source: Cow::Borrowed(&spirv),
39 },
40 ),
41 )
42 }
43 }
44 #[cfg(all(feature = "msl", target_os = "macos"))]
45 Some(AutoRepresentation::Msl(repr)) => {
46 let source = &kernel.source;
47 unsafe {
48 self.device.create_shader_module_passthrough(
49 wgpu::ShaderModuleDescriptorPassthrough::Msl(
50 wgpu::ShaderModuleDescriptorMsl {
51 entry_point: kernel.entrypoint_name.clone(),
52 label: Some(&kernel.entrypoint_name),
53 source: Cow::Borrowed(source),
54 num_workgroups: (repr.cube_dim.x, repr.cube_dim.y, repr.cube_dim.z),
55 },
56 ),
57 )
58 }
59 }
60 _ => {
61 let source = &kernel.source;
62
63 let checks = wgpu::ShaderRuntimeChecks {
64 bounds_checks: false,
68 force_loop_bounding: mode == ExecutionMode::Checked,
70 };
71
72 #[cfg(not(target_family = "wasm"))]
73 track_error(&self.device, wgpu::ErrorFilter::Validation);
74
75 unsafe {
78 self.device.create_shader_module_trusted(
79 ShaderModuleDescriptor {
80 label: None,
81 source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
82 },
83 checks,
84 )
85 }
86 }
87 };
88
89 #[cfg(not(target_family = "wasm"))]
90 if let Some(err) = cubecl_common::future::block_on(fetch_error(&self.device)) {
91 return Err(CompilationError::Generic {
92 reason: format!("{err}"),
93 backtrace: cubecl_common::backtrace::BackTrace::capture(),
94 });
95 }
96
97 let bindings_info = match &kernel.repr {
98 Some(AutoRepresentation::Wgsl(repr)) => Some(wgsl::bindings(repr)),
99 #[cfg(all(feature = "msl", target_os = "macos"))]
100 Some(AutoRepresentation::Msl(repr)) => Some(cpp_metal::bindings(repr)),
101 #[cfg(feature = "spirv")]
102 Some(AutoRepresentation::SpirV(repr)) => Some(vulkan::bindings(repr)),
103 _ => None,
104 };
105
106 let layout = bindings_info.map(|bindings| {
107 let (mut bindings, meta) = bindings;
108 if !cfg!(exclusive_memory_only) {
111 bindings.fill(cubecl_runtime::kernel::Visibility::ReadWrite);
112 }
113
114 let bindings = bindings
115 .into_iter()
116 .chain(meta)
117 .enumerate()
118 .map(|(i, visibility)| BindGroupLayoutEntry {
119 binding: i as u32,
120 visibility: ShaderStages::COMPUTE,
121 ty: BindingType::Buffer {
122 ty: BufferBindingType::Storage {
123 read_only: matches!(
124 visibility,
125 cubecl_runtime::kernel::Visibility::Read
126 ),
127 },
128 has_dynamic_offset: false,
129 min_binding_size: None,
130 },
131 count: None,
132 })
133 .collect::<Vec<_>>();
134 let layout = self
135 .device
136 .create_bind_group_layout(&BindGroupLayoutDescriptor {
137 label: None,
138 entries: &bindings,
139 });
140 self.device
141 .create_pipeline_layout(&PipelineLayoutDescriptor {
142 label: None,
143 bind_group_layouts: &[&layout],
144 push_constant_ranges: &[],
145 })
146 });
147
148 let pipeline = self
149 .device
150 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
151 label: Some(&kernel.entrypoint_name),
152 layout: layout.as_ref(),
153 module: &module,
154 entry_point: Some(&kernel.entrypoint_name),
155 compilation_options: wgpu::PipelineCompilationOptions {
156 zero_initialize_workgroup_memory: false,
157 ..Default::default()
158 },
159 cache: None,
160 });
161 Ok(Arc::new(pipeline))
162 }
163}
164
165#[cfg(all(not(feature = "spirv"), not(feature = "msl")))]
166pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
167 wgsl::request_device(adapter).await
168}
169
170#[cfg(feature = "spirv")]
171pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
172 if is_vulkan(adapter) {
173 vulkan::request_vulkan_device(adapter).await
174 } else {
175 wgsl::request_device(adapter).await
176 }
177}
178
179#[cfg(all(feature = "msl", target_os = "macos"))]
180pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
181 use super::metal;
182
183 if is_metal(adapter) {
184 metal::request_metal_device(adapter).await
185 } else {
186 panic!("metal device not found!");
187 }
188}
189
190#[cfg(all(not(feature = "spirv"), not(feature = "msl")))]
191pub fn register_features(
192 adapter: &Adapter,
193 props: &mut DeviceProperties,
194 comp_options: &mut WgpuCompilationOptions,
195) {
196 wgsl::register_wgsl_features(adapter, props, comp_options);
197}
198
199#[cfg(feature = "spirv")]
200pub fn register_features(
201 adapter: &Adapter,
202 props: &mut DeviceProperties,
203 comp_options: &mut WgpuCompilationOptions,
204) {
205 if is_vulkan(adapter) {
206 vulkan::register_vulkan_features(adapter, props, comp_options);
207 } else {
208 wgsl::register_wgsl_features(adapter, props, comp_options);
209 }
210}
211
212#[cfg(all(feature = "msl", target_os = "macos"))]
213pub fn register_features(
214 adapter: &Adapter,
215 props: &mut DeviceProperties,
216 comp_options: &mut WgpuCompilationOptions,
217) {
218 if is_metal(adapter) {
219 metal::register_metal_features(adapter, props, comp_options);
220 } else {
221 panic!("metal device not found!");
222 }
223}
224
225#[cfg(feature = "spirv")]
226fn is_vulkan(adapter: &Adapter) -> bool {
227 unsafe { adapter.as_hal::<wgpu::hal::api::Vulkan>().is_some() }
228}
229
230#[cfg(all(feature = "msl", target_os = "macos"))]
231fn is_metal(adapter: &Adapter) -> bool {
232 unsafe { adapter.as_hal::<wgpu::hal::api::Metal>().is_some() }
233}