cubecl_wgpu/backend/
base.rs

1use super::wgsl;
2use crate::{AutoCompiler, AutoRepresentation, WgpuServer};
3use cubecl_core::{ExecutionMode, WgpuCompilationOptions, prelude::CompiledKernel};
4use cubecl_runtime::{DeviceProperties, compiler::CompilationError};
5use std::{borrow::Cow, sync::Arc};
6use wgpu::{
7    Adapter, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType, BufferBindingType,
8    ComputePipeline, Device, PipelineLayoutDescriptor, Queue, ShaderModuleDescriptor, ShaderStages,
9};
10
11#[cfg(not(target_family = "wasm"))]
12use crate::errors::{fetch_error, track_error};
13
14#[cfg(feature = "spirv")]
15use super::vulkan;
16
17#[cfg(all(feature = "msl", target_os = "macos"))]
18use super::metal;
19#[cfg(all(feature = "msl", target_os = "macos"))]
20use cubecl_cpp::metal as cpp_metal;
21
22impl WgpuServer {
23    pub fn create_pipeline(
24        &mut self,
25        kernel: CompiledKernel<AutoCompiler>,
26        mode: ExecutionMode,
27    ) -> Result<Arc<ComputePipeline>, CompilationError> {
28        let module = match &kernel.repr {
29            #[cfg(feature = "spirv")]
30            Some(AutoRepresentation::SpirV(repr)) => {
31                let spirv = repr.assemble();
32                unsafe {
33                    self.device.create_shader_module_passthrough(
34                        wgpu::ShaderModuleDescriptorPassthrough::SpirV(
35                            wgpu::ShaderModuleDescriptorSpirV {
36                                label: Some(&kernel.entrypoint_name),
37                                source: Cow::Borrowed(&spirv),
38                            },
39                        ),
40                    )
41                }
42            }
43            #[cfg(all(feature = "msl", target_os = "macos"))]
44            Some(AutoRepresentation::Msl(repr)) => {
45                let source = &kernel.source;
46                unsafe {
47                    self.device.create_shader_module_passthrough(
48                        wgpu::ShaderModuleDescriptorPassthrough::Msl(
49                            wgpu::ShaderModuleDescriptorMsl {
50                                entry_point: kernel.entrypoint_name.clone(),
51                                label: Some(&kernel.entrypoint_name),
52                                source: Cow::Borrowed(source),
53                                num_workgroups: (repr.cube_dim.x, repr.cube_dim.y, repr.cube_dim.z),
54                            },
55                        ),
56                    )
57                }
58            }
59            _ => {
60                let source = &kernel.source;
61
62                let checks = wgpu::ShaderRuntimeChecks {
63                    // Cube does not need wgpu bounds checks - OOB behaviour is instead
64                    // checked by cube (if enabled).
65                    // This is because the WebGPU specification only makes loose guarantees that Cube can't rely on.
66                    bounds_checks: false,
67                    // Loop bounds are only checked in checked mode.
68                    force_loop_bounding: mode == ExecutionMode::Checked,
69                };
70
71                #[cfg(not(target_family = "wasm"))]
72                track_error(&self.device, wgpu::ErrorFilter::Validation);
73
74                // SAFETY: Cube guarantees OOB safety when launching in checked mode. Launching in unchecked mode
75                // is only available through the use of unsafe code.
76                unsafe {
77                    self.device.create_shader_module_trusted(
78                        ShaderModuleDescriptor {
79                            label: None,
80                            source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
81                        },
82                        checks,
83                    )
84                }
85            }
86        };
87
88        #[cfg(not(target_family = "wasm"))]
89        if let Some(err) = cubecl_common::future::block_on(fetch_error(&self.device)) {
90            return Err(CompilationError::Generic {
91                reason: format!("{err}"),
92                backtrace: cubecl_common::backtrace::BackTrace::capture(),
93            });
94        }
95
96        let bindings_info = match &kernel.repr {
97            Some(AutoRepresentation::Wgsl(repr)) => Some(wgsl::bindings(repr)),
98            #[cfg(all(feature = "msl", target_os = "macos"))]
99            Some(AutoRepresentation::Msl(repr)) => Some(cpp_metal::bindings(repr)),
100            #[cfg(feature = "spirv")]
101            Some(AutoRepresentation::SpirV(repr)) => Some(vulkan::bindings(repr)),
102            _ => None,
103        };
104
105        let layout = bindings_info.map(|bindings| {
106            let (mut bindings, meta) = bindings;
107            // When slices are shared, it needs to be read-write if ANY of the slices is read-write,
108            // and since we can't be sure, we'll assume everything is read-write.
109            if !cfg!(exclusive_memory_only) {
110                bindings.fill(cubecl_runtime::kernel::Visibility::ReadWrite);
111            }
112
113            let bindings = bindings
114                .into_iter()
115                .chain(meta)
116                .enumerate()
117                .map(|(i, visibility)| BindGroupLayoutEntry {
118                    binding: i as u32,
119                    visibility: ShaderStages::COMPUTE,
120                    ty: BindingType::Buffer {
121                        ty: BufferBindingType::Storage {
122                            read_only: matches!(
123                                visibility,
124                                cubecl_runtime::kernel::Visibility::Read
125                            ),
126                        },
127                        has_dynamic_offset: false,
128                        min_binding_size: None,
129                    },
130                    count: None,
131                })
132                .collect::<Vec<_>>();
133            let layout = self
134                .device
135                .create_bind_group_layout(&BindGroupLayoutDescriptor {
136                    label: None,
137                    entries: &bindings,
138                });
139            self.device
140                .create_pipeline_layout(&PipelineLayoutDescriptor {
141                    label: None,
142                    bind_group_layouts: &[&layout],
143                    push_constant_ranges: &[],
144                })
145        });
146
147        let pipeline = self
148            .device
149            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
150                label: Some(&kernel.entrypoint_name),
151                layout: layout.as_ref(),
152                module: &module,
153                entry_point: Some(&kernel.entrypoint_name),
154                compilation_options: wgpu::PipelineCompilationOptions {
155                    zero_initialize_workgroup_memory: false,
156                    ..Default::default()
157                },
158                cache: None,
159            });
160        Ok(Arc::new(pipeline))
161    }
162}
163
164#[cfg(all(not(feature = "spirv"), not(feature = "msl")))]
165pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
166    wgsl::request_device(adapter).await
167}
168
169#[cfg(feature = "spirv")]
170pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
171    if is_vulkan(adapter) {
172        vulkan::request_vulkan_device(adapter).await
173    } else {
174        wgsl::request_device(adapter).await
175    }
176}
177
178#[cfg(all(feature = "msl", target_os = "macos"))]
179pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
180    use super::metal;
181
182    if is_metal(adapter) {
183        metal::request_metal_device(adapter).await
184    } else {
185        panic!("metal device not found!");
186    }
187}
188
189#[cfg(all(not(feature = "spirv"), not(feature = "msl")))]
190pub fn register_features(
191    adapter: &Adapter,
192    props: &mut DeviceProperties,
193    comp_options: &mut WgpuCompilationOptions,
194) {
195    wgsl::register_wgsl_features(adapter, props, comp_options);
196}
197
198#[cfg(feature = "spirv")]
199pub fn register_features(
200    adapter: &Adapter,
201    props: &mut DeviceProperties,
202    comp_options: &mut WgpuCompilationOptions,
203) {
204    if is_vulkan(adapter) {
205        vulkan::register_vulkan_features(adapter, props, comp_options);
206    } else {
207        wgsl::register_wgsl_features(adapter, props, comp_options);
208    }
209}
210
211#[cfg(all(feature = "msl", target_os = "macos"))]
212pub fn register_features(
213    adapter: &Adapter,
214    props: &mut DeviceProperties,
215    comp_options: &mut WgpuCompilationOptions,
216) {
217    if is_metal(adapter) {
218        metal::register_metal_features(adapter, props, comp_options);
219    } else {
220        panic!("metal device not found!");
221    }
222}
223
224#[cfg(feature = "spirv")]
225fn is_vulkan(adapter: &Adapter) -> bool {
226    unsafe { adapter.as_hal::<wgpu::hal::api::Vulkan>().is_some() }
227}
228
229#[cfg(all(feature = "msl", target_os = "macos"))]
230fn is_metal(adapter: &Adapter) -> bool {
231    unsafe { adapter.as_hal::<wgpu::hal::api::Metal>().is_some() }
232}