cubecl_wgpu/backend/
base.rs

1use super::wgsl;
2use crate::{AutoCompiler, AutoRepresentation, WgpuServer};
3use cubecl_core::{ExecutionMode, WgpuCompilationOptions, prelude::CompiledKernel};
4use cubecl_ir::DeviceProperties;
5use cubecl_runtime::compiler::CompilationError;
6use std::{borrow::Cow, sync::Arc};
7use wgpu::{
8    Adapter, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType, BufferBindingType,
9    ComputePipeline, Device, PipelineLayoutDescriptor, Queue, ShaderModuleDescriptor, ShaderStages,
10};
11
12#[cfg(not(target_family = "wasm"))]
13use crate::errors::{fetch_error, track_error};
14
15#[cfg(feature = "spirv")]
16use super::vulkan;
17
18#[cfg(all(feature = "msl", target_os = "macos"))]
19use super::metal;
20#[cfg(all(feature = "msl", target_os = "macos"))]
21use cubecl_cpp::metal as cpp_metal;
22
23impl WgpuServer {
24    pub fn create_pipeline(
25        &mut self,
26        kernel: CompiledKernel<AutoCompiler>,
27        mode: ExecutionMode,
28    ) -> Result<Arc<ComputePipeline>, CompilationError> {
29        let module = match &kernel.repr {
30            #[cfg(feature = "spirv")]
31            Some(AutoRepresentation::SpirV(repr)) => {
32                let spirv = repr.assemble();
33                unsafe {
34                    self.device.create_shader_module_passthrough(
35                        wgpu::ShaderModuleDescriptorPassthrough::SpirV(
36                            wgpu::ShaderModuleDescriptorSpirV {
37                                label: Some(&kernel.entrypoint_name),
38                                source: Cow::Borrowed(&spirv),
39                            },
40                        ),
41                    )
42                }
43            }
44            #[cfg(all(feature = "msl", target_os = "macos"))]
45            Some(AutoRepresentation::Msl(repr)) => {
46                let source = &kernel.source;
47                unsafe {
48                    self.device.create_shader_module_passthrough(
49                        wgpu::ShaderModuleDescriptorPassthrough::Msl(
50                            wgpu::ShaderModuleDescriptorMsl {
51                                entry_point: kernel.entrypoint_name.clone(),
52                                label: Some(&kernel.entrypoint_name),
53                                source: Cow::Borrowed(source),
54                                num_workgroups: (repr.cube_dim.x, repr.cube_dim.y, repr.cube_dim.z),
55                            },
56                        ),
57                    )
58                }
59            }
60            _ => {
61                let source = &kernel.source;
62
63                let checks = wgpu::ShaderRuntimeChecks {
64                    // Cube does not need wgpu bounds checks - OOB behaviour is instead
65                    // checked by cube (if enabled).
66                    // This is because the WebGPU specification only makes loose guarantees that Cube can't rely on.
67                    bounds_checks: false,
68                    // Loop bounds are only checked in checked mode.
69                    force_loop_bounding: mode == ExecutionMode::Checked,
70                };
71
72                #[cfg(not(target_family = "wasm"))]
73                track_error(&self.device, wgpu::ErrorFilter::Validation);
74
75                // SAFETY: Cube guarantees OOB safety when launching in checked mode. Launching in unchecked mode
76                // is only available through the use of unsafe code.
77                unsafe {
78                    self.device.create_shader_module_trusted(
79                        ShaderModuleDescriptor {
80                            label: None,
81                            source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
82                        },
83                        checks,
84                    )
85                }
86            }
87        };
88
89        #[cfg(not(target_family = "wasm"))]
90        if let Some(err) = cubecl_common::future::block_on(fetch_error(&self.device)) {
91            return Err(CompilationError::Generic {
92                reason: format!("{err}"),
93                backtrace: cubecl_common::backtrace::BackTrace::capture(),
94            });
95        }
96
97        let bindings_info = match &kernel.repr {
98            Some(AutoRepresentation::Wgsl(repr)) => Some(wgsl::bindings(repr)),
99            #[cfg(all(feature = "msl", target_os = "macos"))]
100            Some(AutoRepresentation::Msl(repr)) => Some(cpp_metal::bindings(repr)),
101            #[cfg(feature = "spirv")]
102            Some(AutoRepresentation::SpirV(repr)) => Some(vulkan::bindings(repr)),
103            _ => None,
104        };
105
106        let layout = bindings_info.map(|bindings| {
107            let (mut bindings, meta) = bindings;
108            // When slices are shared, it needs to be read-write if ANY of the slices is read-write,
109            // and since we can't be sure, we'll assume everything is read-write.
110            if !cfg!(exclusive_memory_only) {
111                bindings.fill(cubecl_runtime::kernel::Visibility::ReadWrite);
112            }
113
114            let bindings = bindings
115                .into_iter()
116                .chain(meta)
117                .enumerate()
118                .map(|(i, visibility)| BindGroupLayoutEntry {
119                    binding: i as u32,
120                    visibility: ShaderStages::COMPUTE,
121                    ty: BindingType::Buffer {
122                        ty: BufferBindingType::Storage {
123                            read_only: matches!(
124                                visibility,
125                                cubecl_runtime::kernel::Visibility::Read
126                            ),
127                        },
128                        has_dynamic_offset: false,
129                        min_binding_size: None,
130                    },
131                    count: None,
132                })
133                .collect::<Vec<_>>();
134            let layout = self
135                .device
136                .create_bind_group_layout(&BindGroupLayoutDescriptor {
137                    label: None,
138                    entries: &bindings,
139                });
140            self.device
141                .create_pipeline_layout(&PipelineLayoutDescriptor {
142                    label: None,
143                    bind_group_layouts: &[&layout],
144                    push_constant_ranges: &[],
145                })
146        });
147
148        let pipeline = self
149            .device
150            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
151                label: Some(&kernel.entrypoint_name),
152                layout: layout.as_ref(),
153                module: &module,
154                entry_point: Some(&kernel.entrypoint_name),
155                compilation_options: wgpu::PipelineCompilationOptions {
156                    zero_initialize_workgroup_memory: false,
157                    ..Default::default()
158                },
159                cache: None,
160            });
161        Ok(Arc::new(pipeline))
162    }
163}
164
165#[cfg(all(not(feature = "spirv"), not(feature = "msl")))]
166pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
167    wgsl::request_device(adapter).await
168}
169
170#[cfg(feature = "spirv")]
171pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
172    if is_vulkan(adapter) {
173        vulkan::request_vulkan_device(adapter).await
174    } else {
175        wgsl::request_device(adapter).await
176    }
177}
178
179#[cfg(all(feature = "msl", target_os = "macos"))]
180pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
181    use super::metal;
182
183    if is_metal(adapter) {
184        metal::request_metal_device(adapter).await
185    } else {
186        panic!("metal device not found!");
187    }
188}
189
190#[cfg(all(not(feature = "spirv"), not(feature = "msl")))]
191pub fn register_features(
192    adapter: &Adapter,
193    props: &mut DeviceProperties,
194    comp_options: &mut WgpuCompilationOptions,
195) {
196    wgsl::register_wgsl_features(adapter, props, comp_options);
197}
198
199#[cfg(feature = "spirv")]
200pub fn register_features(
201    adapter: &Adapter,
202    props: &mut DeviceProperties,
203    comp_options: &mut WgpuCompilationOptions,
204) {
205    if is_vulkan(adapter) {
206        vulkan::register_vulkan_features(adapter, props, comp_options);
207    } else {
208        wgsl::register_wgsl_features(adapter, props, comp_options);
209    }
210}
211
212#[cfg(all(feature = "msl", target_os = "macos"))]
213pub fn register_features(
214    adapter: &Adapter,
215    props: &mut DeviceProperties,
216    comp_options: &mut WgpuCompilationOptions,
217) {
218    if is_metal(adapter) {
219        metal::register_metal_features(adapter, props, comp_options);
220    } else {
221        panic!("metal device not found!");
222    }
223}
224
225#[cfg(feature = "spirv")]
226fn is_vulkan(adapter: &Adapter) -> bool {
227    unsafe { adapter.as_hal::<wgpu::hal::api::Vulkan>().is_some() }
228}
229
230#[cfg(all(feature = "msl", target_os = "macos"))]
231fn is_metal(adapter: &Adapter) -> bool {
232    unsafe { adapter.as_hal::<wgpu::hal::api::Metal>().is_some() }
233}