1use std::{borrow::Cow, sync::Arc};
2
3use cubecl_core::{ExecutionMode, Feature, WgpuCompilationOptions, prelude::CompiledKernel};
4use cubecl_runtime::DeviceProperties;
5use wgpu::{
6 Adapter, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType, BufferBindingType,
7 ComputePipeline, Device, PipelineLayoutDescriptor, Queue, ShaderModuleDescriptor, ShaderStages,
8};
9
10use crate::{AutoCompiler, AutoRepresentation, WgpuServer};
11
12use super::wgsl;
13
14#[cfg(feature = "spirv")]
15use super::vulkan;
16
17#[cfg(all(feature = "msl", target_os = "macos"))]
18use super::metal;
19#[cfg(all(feature = "msl", target_os = "macos"))]
20use cubecl_cpp::metal as cpp_metal;
21
22impl WgpuServer {
23 pub fn create_pipeline(
24 &mut self,
25 kernel: CompiledKernel<AutoCompiler>,
26 mode: ExecutionMode,
27 ) -> Arc<ComputePipeline> {
28 let module = match &kernel.repr {
29 #[cfg(feature = "spirv")]
30 Some(AutoRepresentation::SpirV(repr)) => {
31 let spirv = repr.assemble();
32 unsafe {
33 self.device.create_shader_module_passthrough(
34 wgpu::ShaderModuleDescriptorPassthrough::SpirV(
35 wgpu::ShaderModuleDescriptorSpirV {
36 label: Some(&kernel.entrypoint_name),
37 source: Cow::Borrowed(&spirv),
38 },
39 ),
40 )
41 }
42 }
43 #[cfg(all(feature = "msl", target_os = "macos"))]
44 Some(AutoRepresentation::Msl(repr)) => {
45 let source = &kernel.source;
46 unsafe {
47 self.device.create_shader_module_passthrough(
48 wgpu::ShaderModuleDescriptorPassthrough::Msl(
49 wgpu::ShaderModuleDescriptorMsl {
50 entry_point: kernel.entrypoint_name.clone(),
51 label: Some(&kernel.entrypoint_name),
52 source: Cow::Borrowed(source),
53 num_workgroups: (repr.cube_dim.x, repr.cube_dim.y, repr.cube_dim.z),
54 },
55 ),
56 )
57 }
58 }
59 _ => {
60 let source = &kernel.source;
61
62 let checks = wgpu::ShaderRuntimeChecks {
63 bounds_checks: false,
67 force_loop_bounding: mode == ExecutionMode::Checked,
69 };
70
71 unsafe {
74 self.device.create_shader_module_trusted(
75 ShaderModuleDescriptor {
76 label: None,
77 source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
78 },
79 checks,
80 )
81 }
82 }
83 };
84 let bindings = match &kernel.repr {
85 Some(AutoRepresentation::Wgsl(repr)) => Some(wgsl::bindings(repr)),
86 #[cfg(all(feature = "msl", target_os = "macos"))]
87 Some(AutoRepresentation::Msl(repr)) => Some(cpp_metal::bindings(repr)),
88 #[cfg(feature = "spirv")]
89 Some(AutoRepresentation::SpirV(repr)) => Some(vulkan::bindings(repr)),
90 _ => None,
91 };
92 let layout = bindings.map(|bindings| {
93 let bindings = bindings
94 .into_iter()
95 .map(|(i, _visibility)| BindGroupLayoutEntry {
96 binding: i as u32,
97 visibility: ShaderStages::COMPUTE,
98 ty: BindingType::Buffer {
99 #[cfg(not(exclusive_memory_only))]
100 ty: BufferBindingType::Storage { read_only: false },
101 #[cfg(exclusive_memory_only)]
102 ty: BufferBindingType::Storage {
103 read_only: matches!(
104 _visibility,
105 cubecl_core::compute::Visibility::Read
106 ),
107 },
108 has_dynamic_offset: false,
109 min_binding_size: None,
110 },
111 count: None,
112 })
113 .collect::<Vec<_>>();
114 let layout = self
115 .device
116 .create_bind_group_layout(&BindGroupLayoutDescriptor {
117 label: None,
118 entries: &bindings,
119 });
120 self.device
121 .create_pipeline_layout(&PipelineLayoutDescriptor {
122 label: None,
123 bind_group_layouts: &[&layout],
124 push_constant_ranges: &[],
125 })
126 });
127
128 Arc::new(
129 self.device
130 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
131 label: Some(&kernel.entrypoint_name),
132 layout: layout.as_ref(),
133 module: &module,
134 entry_point: Some(&kernel.entrypoint_name),
135 compilation_options: wgpu::PipelineCompilationOptions {
136 zero_initialize_workgroup_memory: false,
137 ..Default::default()
138 },
139 cache: None,
140 }),
141 )
142 }
143}
144
145#[cfg(all(not(feature = "spirv"), not(feature = "msl")))]
146pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
147 wgsl::request_device(adapter).await
148}
149
150#[cfg(feature = "spirv")]
151pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
152 if is_vulkan(adapter) {
153 vulkan::request_vulkan_device(adapter).await
154 } else {
155 wgsl::request_device(adapter).await
156 }
157}
158
159#[cfg(all(feature = "msl", target_os = "macos"))]
160pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
161 use super::metal;
162
163 if is_metal(adapter) {
164 metal::request_metal_device(adapter).await
165 } else {
166 panic!("metal device not found!");
167 }
168}
169
170#[cfg(all(not(feature = "spirv"), not(feature = "msl")))]
171pub fn register_features(
172 adapter: &Adapter,
173 props: &mut DeviceProperties<Feature>,
174 comp_options: &mut WgpuCompilationOptions,
175) {
176 wgsl::register_wgsl_features(adapter, props, comp_options);
177}
178
179#[cfg(feature = "spirv")]
180pub fn register_features(
181 adapter: &Adapter,
182 props: &mut DeviceProperties<Feature>,
183 comp_options: &mut WgpuCompilationOptions,
184) {
185 if is_vulkan(adapter) {
186 vulkan::register_vulkan_features(adapter, props, comp_options);
187 } else {
188 wgsl::register_wgsl_features(adapter, props, comp_options);
189 }
190}
191
192#[cfg(all(feature = "msl", target_os = "macos"))]
193pub fn register_features(
194 adapter: &Adapter,
195 props: &mut DeviceProperties<Feature>,
196 comp_options: &mut WgpuCompilationOptions,
197) {
198 if is_metal(adapter) {
199 metal::register_metal_features(adapter, props, comp_options);
200 } else {
201 panic!("metal device not found!");
202 }
203}
204
205#[cfg(feature = "spirv")]
206fn is_vulkan(adapter: &Adapter) -> bool {
207 unsafe { adapter.as_hal::<wgpu::hal::api::Vulkan, _, _>(|adapter| adapter.is_some()) }
208}
209
210#[cfg(all(feature = "msl", target_os = "macos"))]
211fn is_metal(adapter: &Adapter) -> bool {
212 unsafe { adapter.as_hal::<wgpu::hal::api::Metal, _, _>(|adapter| adapter.is_some()) }
213}