1use std::{borrow::Cow, sync::Arc};
2
3use cubecl_core::{ExecutionMode, WgpuCompilationOptions, prelude::CompiledKernel};
4use cubecl_runtime::DeviceProperties;
5use wgpu::{
6 Adapter, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType, BufferBindingType,
7 ComputePipeline, Device, PipelineLayoutDescriptor, Queue, ShaderModuleDescriptor, ShaderStages,
8};
9
10use crate::{AutoCompiler, AutoRepresentation, WgpuServer};
11
12use super::wgsl;
13
14#[cfg(feature = "spirv")]
15use super::vulkan;
16
17#[cfg(all(feature = "msl", target_os = "macos"))]
18use super::metal;
19#[cfg(all(feature = "msl", target_os = "macos"))]
20use cubecl_cpp::metal as cpp_metal;
21
22impl WgpuServer {
23 pub fn create_pipeline(
24 &mut self,
25 kernel: CompiledKernel<AutoCompiler>,
26 mode: ExecutionMode,
27 ) -> Arc<ComputePipeline> {
28 let module = match &kernel.repr {
29 #[cfg(feature = "spirv")]
30 Some(AutoRepresentation::SpirV(repr)) => {
31 let spirv = repr.assemble();
32 unsafe {
33 self.device.create_shader_module_passthrough(
34 wgpu::ShaderModuleDescriptorPassthrough::SpirV(
35 wgpu::ShaderModuleDescriptorSpirV {
36 label: Some(&kernel.entrypoint_name),
37 source: Cow::Borrowed(&spirv),
38 },
39 ),
40 )
41 }
42 }
43 #[cfg(all(feature = "msl", target_os = "macos"))]
44 Some(AutoRepresentation::Msl(repr)) => {
45 let source = &kernel.source;
46 unsafe {
47 self.device.create_shader_module_passthrough(
48 wgpu::ShaderModuleDescriptorPassthrough::Msl(
49 wgpu::ShaderModuleDescriptorMsl {
50 entry_point: kernel.entrypoint_name.clone(),
51 label: Some(&kernel.entrypoint_name),
52 source: Cow::Borrowed(source),
53 num_workgroups: (repr.cube_dim.x, repr.cube_dim.y, repr.cube_dim.z),
54 },
55 ),
56 )
57 }
58 }
59 _ => {
60 let source = &kernel.source;
61
62 let checks = wgpu::ShaderRuntimeChecks {
63 bounds_checks: false,
67 force_loop_bounding: mode == ExecutionMode::Checked,
69 };
70
71 unsafe {
74 self.device.create_shader_module_trusted(
75 ShaderModuleDescriptor {
76 label: None,
77 source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
78 },
79 checks,
80 )
81 }
82 }
83 };
84 let bindings_info = match &kernel.repr {
85 Some(AutoRepresentation::Wgsl(repr)) => Some(wgsl::bindings(repr)),
86 #[cfg(all(feature = "msl", target_os = "macos"))]
87 Some(AutoRepresentation::Msl(repr)) => Some(cpp_metal::bindings(repr)),
88 #[cfg(feature = "spirv")]
89 Some(AutoRepresentation::SpirV(repr)) => Some(vulkan::bindings(repr)),
90 _ => None,
91 };
92
93 let layout = bindings_info.map(|bindings| {
94 let (mut bindings, meta) = bindings;
95 if !cfg!(exclusive_memory_only) {
98 bindings.fill(cubecl_core::compute::Visibility::ReadWrite);
99 }
100
101 let bindings = bindings
102 .into_iter()
103 .chain(meta)
104 .enumerate()
105 .map(|(i, visibility)| BindGroupLayoutEntry {
106 binding: i as u32,
107 visibility: ShaderStages::COMPUTE,
108 ty: BindingType::Buffer {
109 ty: BufferBindingType::Storage {
110 read_only: matches!(visibility, cubecl_core::compute::Visibility::Read),
111 },
112 has_dynamic_offset: false,
113 min_binding_size: None,
114 },
115 count: None,
116 })
117 .collect::<Vec<_>>();
118 let layout = self
119 .device
120 .create_bind_group_layout(&BindGroupLayoutDescriptor {
121 label: None,
122 entries: &bindings,
123 });
124 self.device
125 .create_pipeline_layout(&PipelineLayoutDescriptor {
126 label: None,
127 bind_group_layouts: &[&layout],
128 push_constant_ranges: &[],
129 })
130 });
131
132 Arc::new(
133 self.device
134 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
135 label: Some(&kernel.entrypoint_name),
136 layout: layout.as_ref(),
137 module: &module,
138 entry_point: Some(&kernel.entrypoint_name),
139 compilation_options: wgpu::PipelineCompilationOptions {
140 zero_initialize_workgroup_memory: false,
141 ..Default::default()
142 },
143 cache: None,
144 }),
145 )
146 }
147}
148
149#[cfg(all(not(feature = "spirv"), not(feature = "msl")))]
150pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
151 wgsl::request_device(adapter).await
152}
153
154#[cfg(feature = "spirv")]
155pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
156 if is_vulkan(adapter) {
157 vulkan::request_vulkan_device(adapter).await
158 } else {
159 wgsl::request_device(adapter).await
160 }
161}
162
163#[cfg(all(feature = "msl", target_os = "macos"))]
164pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
165 use super::metal;
166
167 if is_metal(adapter) {
168 metal::request_metal_device(adapter).await
169 } else {
170 panic!("metal device not found!");
171 }
172}
173
174#[cfg(all(not(feature = "spirv"), not(feature = "msl")))]
175pub fn register_features(
176 adapter: &Adapter,
177 props: &mut DeviceProperties,
178 comp_options: &mut WgpuCompilationOptions,
179) {
180 wgsl::register_wgsl_features(adapter, props, comp_options);
181}
182
183#[cfg(feature = "spirv")]
184pub fn register_features(
185 adapter: &Adapter,
186 props: &mut DeviceProperties,
187 comp_options: &mut WgpuCompilationOptions,
188) {
189 if is_vulkan(adapter) {
190 vulkan::register_vulkan_features(adapter, props, comp_options);
191 } else {
192 wgsl::register_wgsl_features(adapter, props, comp_options);
193 }
194}
195
196#[cfg(all(feature = "msl", target_os = "macos"))]
197pub fn register_features(
198 adapter: &Adapter,
199 props: &mut DeviceProperties,
200 comp_options: &mut WgpuCompilationOptions,
201) {
202 if is_metal(adapter) {
203 metal::register_metal_features(adapter, props, comp_options);
204 } else {
205 panic!("metal device not found!");
206 }
207}
208
209#[cfg(feature = "spirv")]
210fn is_vulkan(adapter: &Adapter) -> bool {
211 unsafe { adapter.as_hal::<wgpu::hal::api::Vulkan>().is_some() }
212}
213
214#[cfg(all(feature = "msl", target_os = "macos"))]
215fn is_metal(adapter: &Adapter) -> bool {
216 unsafe { adapter.as_hal::<wgpu::hal::api::Metal>().is_some() }
217}