cubecl_wgpu/backend/
base.rs

1use super::wgsl;
2use crate::{AutoCompiler, AutoRepresentation, WgpuServer};
3use cubecl_core::{ExecutionMode, WgpuCompilationOptions, prelude::CompiledKernel};
4use cubecl_runtime::{DeviceProperties, compiler::CompilationError};
5use std::{borrow::Cow, sync::Arc};
6use wgpu::{
7    Adapter, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType, BufferBindingType,
8    ComputePipeline, Device, PipelineLayoutDescriptor, Queue, ShaderModuleDescriptor, ShaderStages,
9};
10
11#[cfg(feature = "spirv")]
12use super::vulkan;
13
14#[cfg(all(feature = "msl", target_os = "macos"))]
15use super::metal;
16#[cfg(all(feature = "msl", target_os = "macos"))]
17use cubecl_cpp::metal as cpp_metal;
18
19impl WgpuServer {
20    pub fn create_pipeline(
21        &mut self,
22        kernel: CompiledKernel<AutoCompiler>,
23        mode: ExecutionMode,
24    ) -> Result<Arc<ComputePipeline>, CompilationError> {
25        let module = match &kernel.repr {
26            #[cfg(feature = "spirv")]
27            Some(AutoRepresentation::SpirV(repr)) => {
28                let spirv = repr.assemble();
29                unsafe {
30                    self.device.create_shader_module_passthrough(
31                        wgpu::ShaderModuleDescriptorPassthrough::SpirV(
32                            wgpu::ShaderModuleDescriptorSpirV {
33                                label: Some(&kernel.entrypoint_name),
34                                source: Cow::Borrowed(&spirv),
35                            },
36                        ),
37                    )
38                }
39            }
40            #[cfg(all(feature = "msl", target_os = "macos"))]
41            Some(AutoRepresentation::Msl(repr)) => {
42                let source = &kernel.source;
43                unsafe {
44                    self.device.create_shader_module_passthrough(
45                        wgpu::ShaderModuleDescriptorPassthrough::Msl(
46                            wgpu::ShaderModuleDescriptorMsl {
47                                entry_point: kernel.entrypoint_name.clone(),
48                                label: Some(&kernel.entrypoint_name),
49                                source: Cow::Borrowed(source),
50                                num_workgroups: (repr.cube_dim.x, repr.cube_dim.y, repr.cube_dim.z),
51                            },
52                        ),
53                    )
54                }
55            }
56            _ => {
57                let source = &kernel.source;
58
59                let checks = wgpu::ShaderRuntimeChecks {
60                    // Cube does not need wgpu bounds checks - OOB behaviour is instead
61                    // checked by cube (if enabled).
62                    // This is because the WebGPU specification only makes loose guarantees that Cube can't rely on.
63                    bounds_checks: false,
64                    // Loop bounds are only checked in checked mode.
65                    force_loop_bounding: mode == ExecutionMode::Checked,
66                };
67
68                // SAFETY: Cube guarantees OOB safety when launching in checked mode. Launching in unchecked mode
69                // is only available through the use of unsafe code.
70                unsafe {
71                    self.device.create_shader_module_trusted(
72                        ShaderModuleDescriptor {
73                            label: None,
74                            source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
75                        },
76                        checks,
77                    )
78                }
79            }
80        };
81
82        #[cfg(not(target_family = "wasm"))]
83        {
84            let info = cubecl_common::future::block_on(module.get_compilation_info());
85            let mut errors = info
86                .messages
87                .into_iter()
88                .filter(|msg| matches!(msg.message_type, wgpu::CompilationMessageType::Error))
89                .map(|msg| CompilationError::Generic {
90                    context: format!("{msg:?}"),
91                })
92                .collect::<Vec<_>>();
93
94            if !errors.is_empty() {
95                if errors.len() > 1 {
96                    return Err(CompilationError::Multiple {
97                        context: format!("Unable to compile kernel: {}", kernel.entrypoint_name),
98                        errors,
99                    });
100                } else {
101                    return Err(errors.remove(0));
102                }
103            }
104        }
105
106        let bindings_info = match &kernel.repr {
107            Some(AutoRepresentation::Wgsl(repr)) => Some(wgsl::bindings(repr)),
108            #[cfg(all(feature = "msl", target_os = "macos"))]
109            Some(AutoRepresentation::Msl(repr)) => Some(cpp_metal::bindings(repr)),
110            #[cfg(feature = "spirv")]
111            Some(AutoRepresentation::SpirV(repr)) => Some(vulkan::bindings(repr)),
112            _ => None,
113        };
114
115        let layout = bindings_info.map(|bindings| {
116            let (mut bindings, meta) = bindings;
117            // When slices are shared, it needs to be read-write if ANY of the slices is read-write,
118            // and since we can't be sure, we'll assume everything is read-write.
119            if !cfg!(exclusive_memory_only) {
120                bindings.fill(cubecl_runtime::kernel::Visibility::ReadWrite);
121            }
122
123            let bindings = bindings
124                .into_iter()
125                .chain(meta)
126                .enumerate()
127                .map(|(i, visibility)| BindGroupLayoutEntry {
128                    binding: i as u32,
129                    visibility: ShaderStages::COMPUTE,
130                    ty: BindingType::Buffer {
131                        ty: BufferBindingType::Storage {
132                            read_only: matches!(
133                                visibility,
134                                cubecl_runtime::kernel::Visibility::Read
135                            ),
136                        },
137                        has_dynamic_offset: false,
138                        min_binding_size: None,
139                    },
140                    count: None,
141                })
142                .collect::<Vec<_>>();
143            let layout = self
144                .device
145                .create_bind_group_layout(&BindGroupLayoutDescriptor {
146                    label: None,
147                    entries: &bindings,
148                });
149            self.device
150                .create_pipeline_layout(&PipelineLayoutDescriptor {
151                    label: None,
152                    bind_group_layouts: &[&layout],
153                    push_constant_ranges: &[],
154                })
155        });
156
157        let pipeline = self
158            .device
159            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
160                label: Some(&kernel.entrypoint_name),
161                layout: layout.as_ref(),
162                module: &module,
163                entry_point: Some(&kernel.entrypoint_name),
164                compilation_options: wgpu::PipelineCompilationOptions {
165                    zero_initialize_workgroup_memory: false,
166                    ..Default::default()
167                },
168                cache: None,
169            });
170        Ok(Arc::new(pipeline))
171    }
172}
173
174#[cfg(all(not(feature = "spirv"), not(feature = "msl")))]
175pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
176    wgsl::request_device(adapter).await
177}
178
179#[cfg(feature = "spirv")]
180pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
181    if is_vulkan(adapter) {
182        vulkan::request_vulkan_device(adapter).await
183    } else {
184        wgsl::request_device(adapter).await
185    }
186}
187
188#[cfg(all(feature = "msl", target_os = "macos"))]
189pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
190    use super::metal;
191
192    if is_metal(adapter) {
193        metal::request_metal_device(adapter).await
194    } else {
195        panic!("metal device not found!");
196    }
197}
198
199#[cfg(all(not(feature = "spirv"), not(feature = "msl")))]
200pub fn register_features(
201    adapter: &Adapter,
202    props: &mut DeviceProperties,
203    comp_options: &mut WgpuCompilationOptions,
204) {
205    wgsl::register_wgsl_features(adapter, props, comp_options);
206}
207
208#[cfg(feature = "spirv")]
209pub fn register_features(
210    adapter: &Adapter,
211    props: &mut DeviceProperties,
212    comp_options: &mut WgpuCompilationOptions,
213) {
214    if is_vulkan(adapter) {
215        vulkan::register_vulkan_features(adapter, props, comp_options);
216    } else {
217        wgsl::register_wgsl_features(adapter, props, comp_options);
218    }
219}
220
221#[cfg(all(feature = "msl", target_os = "macos"))]
222pub fn register_features(
223    adapter: &Adapter,
224    props: &mut DeviceProperties,
225    comp_options: &mut WgpuCompilationOptions,
226) {
227    if is_metal(adapter) {
228        metal::register_metal_features(adapter, props, comp_options);
229    } else {
230        panic!("metal device not found!");
231    }
232}
233
234#[cfg(feature = "spirv")]
235fn is_vulkan(adapter: &Adapter) -> bool {
236    unsafe { adapter.as_hal::<wgpu::hal::api::Vulkan>().is_some() }
237}
238
239#[cfg(all(feature = "msl", target_os = "macos"))]
240fn is_metal(adapter: &Adapter) -> bool {
241    unsafe { adapter.as_hal::<wgpu::hal::api::Metal>().is_some() }
242}