cubecl_wgpu/backend/
base.rs

1use std::{borrow::Cow, sync::Arc};
2
3use cubecl_core::{ExecutionMode, WgpuCompilationOptions, prelude::CompiledKernel};
4use cubecl_runtime::DeviceProperties;
5use wgpu::{
6    Adapter, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType, BufferBindingType,
7    ComputePipeline, Device, PipelineLayoutDescriptor, Queue, ShaderModuleDescriptor, ShaderStages,
8};
9
10use crate::{AutoCompiler, AutoRepresentation, WgpuServer};
11
12use super::wgsl;
13
14#[cfg(feature = "spirv")]
15use super::vulkan;
16
17#[cfg(all(feature = "msl", target_os = "macos"))]
18use super::metal;
19#[cfg(all(feature = "msl", target_os = "macos"))]
20use cubecl_cpp::metal as cpp_metal;
21
22impl WgpuServer {
23    pub fn create_pipeline(
24        &mut self,
25        kernel: CompiledKernel<AutoCompiler>,
26        mode: ExecutionMode,
27    ) -> Arc<ComputePipeline> {
28        let module = match &kernel.repr {
29            #[cfg(feature = "spirv")]
30            Some(AutoRepresentation::SpirV(repr)) => {
31                let spirv = repr.assemble();
32                unsafe {
33                    self.device.create_shader_module_passthrough(
34                        wgpu::ShaderModuleDescriptorPassthrough::SpirV(
35                            wgpu::ShaderModuleDescriptorSpirV {
36                                label: Some(&kernel.entrypoint_name),
37                                source: Cow::Borrowed(&spirv),
38                            },
39                        ),
40                    )
41                }
42            }
43            #[cfg(all(feature = "msl", target_os = "macos"))]
44            Some(AutoRepresentation::Msl(repr)) => {
45                let source = &kernel.source;
46                unsafe {
47                    self.device.create_shader_module_passthrough(
48                        wgpu::ShaderModuleDescriptorPassthrough::Msl(
49                            wgpu::ShaderModuleDescriptorMsl {
50                                entry_point: kernel.entrypoint_name.clone(),
51                                label: Some(&kernel.entrypoint_name),
52                                source: Cow::Borrowed(source),
53                                num_workgroups: (repr.cube_dim.x, repr.cube_dim.y, repr.cube_dim.z),
54                            },
55                        ),
56                    )
57                }
58            }
59            _ => {
60                let source = &kernel.source;
61
62                let checks = wgpu::ShaderRuntimeChecks {
63                    // Cube does not need wgpu bounds checks - OOB behaviour is instead
64                    // checked by cube (if enabled).
65                    // This is because the WebGPU specification only makes loose guarantees that Cube can't rely on.
66                    bounds_checks: false,
67                    // Loop bounds are only checked in checked mode.
68                    force_loop_bounding: mode == ExecutionMode::Checked,
69                };
70
71                // SAFETY: Cube guarantees OOB safety when launching in checked mode. Launching in unchecked mode
72                // is only available through the use of unsafe code.
73                unsafe {
74                    self.device.create_shader_module_trusted(
75                        ShaderModuleDescriptor {
76                            label: None,
77                            source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
78                        },
79                        checks,
80                    )
81                }
82            }
83        };
84        let bindings_info = match &kernel.repr {
85            Some(AutoRepresentation::Wgsl(repr)) => Some(wgsl::bindings(repr)),
86            #[cfg(all(feature = "msl", target_os = "macos"))]
87            Some(AutoRepresentation::Msl(repr)) => Some(cpp_metal::bindings(repr)),
88            #[cfg(feature = "spirv")]
89            Some(AutoRepresentation::SpirV(repr)) => Some(vulkan::bindings(repr)),
90            _ => None,
91        };
92
93        let layout = bindings_info.map(|bindings| {
94            let (mut bindings, meta) = bindings;
95            // When slices are shared, it needs to be read-write if ANY of the slices is read-write,
96            // and since we can't be sure, we'll assume everything is read-write.
97            if !cfg!(exclusive_memory_only) {
98                bindings.fill(cubecl_core::compute::Visibility::ReadWrite);
99            }
100
101            let bindings = bindings
102                .into_iter()
103                .chain(meta)
104                .enumerate()
105                .map(|(i, visibility)| BindGroupLayoutEntry {
106                    binding: i as u32,
107                    visibility: ShaderStages::COMPUTE,
108                    ty: BindingType::Buffer {
109                        ty: BufferBindingType::Storage {
110                            read_only: matches!(visibility, cubecl_core::compute::Visibility::Read),
111                        },
112                        has_dynamic_offset: false,
113                        min_binding_size: None,
114                    },
115                    count: None,
116                })
117                .collect::<Vec<_>>();
118            let layout = self
119                .device
120                .create_bind_group_layout(&BindGroupLayoutDescriptor {
121                    label: None,
122                    entries: &bindings,
123                });
124            self.device
125                .create_pipeline_layout(&PipelineLayoutDescriptor {
126                    label: None,
127                    bind_group_layouts: &[&layout],
128                    push_constant_ranges: &[],
129                })
130        });
131
132        Arc::new(
133            self.device
134                .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
135                    label: Some(&kernel.entrypoint_name),
136                    layout: layout.as_ref(),
137                    module: &module,
138                    entry_point: Some(&kernel.entrypoint_name),
139                    compilation_options: wgpu::PipelineCompilationOptions {
140                        zero_initialize_workgroup_memory: false,
141                        ..Default::default()
142                    },
143                    cache: None,
144                }),
145        )
146    }
147}
148
149#[cfg(all(not(feature = "spirv"), not(feature = "msl")))]
150pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
151    wgsl::request_device(adapter).await
152}
153
154#[cfg(feature = "spirv")]
155pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
156    if is_vulkan(adapter) {
157        vulkan::request_vulkan_device(adapter).await
158    } else {
159        wgsl::request_device(adapter).await
160    }
161}
162
163#[cfg(all(feature = "msl", target_os = "macos"))]
164pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
165    use super::metal;
166
167    if is_metal(adapter) {
168        metal::request_metal_device(adapter).await
169    } else {
170        panic!("metal device not found!");
171    }
172}
173
174#[cfg(all(not(feature = "spirv"), not(feature = "msl")))]
175pub fn register_features(
176    adapter: &Adapter,
177    props: &mut DeviceProperties,
178    comp_options: &mut WgpuCompilationOptions,
179) {
180    wgsl::register_wgsl_features(adapter, props, comp_options);
181}
182
183#[cfg(feature = "spirv")]
184pub fn register_features(
185    adapter: &Adapter,
186    props: &mut DeviceProperties,
187    comp_options: &mut WgpuCompilationOptions,
188) {
189    if is_vulkan(adapter) {
190        vulkan::register_vulkan_features(adapter, props, comp_options);
191    } else {
192        wgsl::register_wgsl_features(adapter, props, comp_options);
193    }
194}
195
196#[cfg(all(feature = "msl", target_os = "macos"))]
197pub fn register_features(
198    adapter: &Adapter,
199    props: &mut DeviceProperties,
200    comp_options: &mut WgpuCompilationOptions,
201) {
202    if is_metal(adapter) {
203        metal::register_metal_features(adapter, props, comp_options);
204    } else {
205        panic!("metal device not found!");
206    }
207}
208
209#[cfg(feature = "spirv")]
210fn is_vulkan(adapter: &Adapter) -> bool {
211    unsafe { adapter.as_hal::<wgpu::hal::api::Vulkan>().is_some() }
212}
213
214#[cfg(all(feature = "msl", target_os = "macos"))]
215fn is_metal(adapter: &Adapter) -> bool {
216    unsafe { adapter.as_hal::<wgpu::hal::api::Metal>().is_some() }
217}