cubecl_wgpu/backend/
base.rs

1use std::{borrow::Cow, sync::Arc};
2
3use cubecl_core::{ExecutionMode, Feature, WgpuCompilationOptions, prelude::CompiledKernel};
4use cubecl_runtime::DeviceProperties;
5use wgpu::{
6    Adapter, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType, BufferBindingType,
7    ComputePipeline, Device, PipelineLayoutDescriptor, Queue, ShaderModuleDescriptor, ShaderStages,
8};
9
10use crate::{AutoCompiler, AutoRepresentation, WgpuServer};
11
12use super::wgsl;
13
14#[cfg(feature = "spirv")]
15use super::vulkan;
16
17#[cfg(all(feature = "msl", target_os = "macos"))]
18use super::metal;
19#[cfg(all(feature = "msl", target_os = "macos"))]
20use cubecl_cpp::metal as cpp_metal;
21
22impl WgpuServer {
23    pub fn create_pipeline(
24        &mut self,
25        kernel: CompiledKernel<AutoCompiler>,
26        mode: ExecutionMode,
27    ) -> Arc<ComputePipeline> {
28        let module = match &kernel.repr {
29            #[cfg(feature = "spirv")]
30            Some(AutoRepresentation::SpirV(repr)) => {
31                let spirv = repr.assemble();
32                unsafe {
33                    self.device.create_shader_module_passthrough(
34                        wgpu::ShaderModuleDescriptorPassthrough::SpirV(
35                            wgpu::ShaderModuleDescriptorSpirV {
36                                label: Some(&kernel.entrypoint_name),
37                                source: Cow::Borrowed(&spirv),
38                            },
39                        ),
40                    )
41                }
42            }
43            #[cfg(all(feature = "msl", target_os = "macos"))]
44            Some(AutoRepresentation::Msl(repr)) => {
45                let source = &kernel.source;
46                unsafe {
47                    self.device.create_shader_module_passthrough(
48                        wgpu::ShaderModuleDescriptorPassthrough::Msl(
49                            wgpu::ShaderModuleDescriptorMsl {
50                                entry_point: kernel.entrypoint_name.clone(),
51                                label: Some(&kernel.entrypoint_name),
52                                source: Cow::Borrowed(source),
53                                num_workgroups: (repr.cube_dim.x, repr.cube_dim.y, repr.cube_dim.z),
54                            },
55                        ),
56                    )
57                }
58            }
59            _ => {
60                let source = &kernel.source;
61
62                let checks = wgpu::ShaderRuntimeChecks {
63                    // Cube does not need wgpu bounds checks - OOB behaviour is instead
64                    // checked by cube (if enabled).
65                    // This is because the WebGPU specification only makes loose guarantees that Cube can't rely on.
66                    bounds_checks: false,
67                    // Loop bounds are only checked in checked mode.
68                    force_loop_bounding: mode == ExecutionMode::Checked,
69                };
70
71                // SAFETY: Cube guarantees OOB safety when launching in checked mode. Launching in unchecked mode
72                // is only available through the use of unsafe code.
73                unsafe {
74                    self.device.create_shader_module_trusted(
75                        ShaderModuleDescriptor {
76                            label: None,
77                            source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
78                        },
79                        checks,
80                    )
81                }
82            }
83        };
84        let bindings = match &kernel.repr {
85            Some(AutoRepresentation::Wgsl(repr)) => Some(wgsl::bindings(repr)),
86            #[cfg(all(feature = "msl", target_os = "macos"))]
87            Some(AutoRepresentation::Msl(repr)) => Some(cpp_metal::bindings(repr)),
88            #[cfg(feature = "spirv")]
89            Some(AutoRepresentation::SpirV(repr)) => Some(vulkan::bindings(repr)),
90            _ => None,
91        };
92        let layout = bindings.map(|bindings| {
93            let bindings = bindings
94                .into_iter()
95                .map(|(i, _visibility)| BindGroupLayoutEntry {
96                    binding: i as u32,
97                    visibility: ShaderStages::COMPUTE,
98                    ty: BindingType::Buffer {
99                        #[cfg(not(exclusive_memory_only))]
100                        ty: BufferBindingType::Storage { read_only: false },
101                        #[cfg(exclusive_memory_only)]
102                        ty: BufferBindingType::Storage {
103                            read_only: matches!(
104                                _visibility,
105                                cubecl_core::compute::Visibility::Read
106                            ),
107                        },
108                        has_dynamic_offset: false,
109                        min_binding_size: None,
110                    },
111                    count: None,
112                })
113                .collect::<Vec<_>>();
114            let layout = self
115                .device
116                .create_bind_group_layout(&BindGroupLayoutDescriptor {
117                    label: None,
118                    entries: &bindings,
119                });
120            self.device
121                .create_pipeline_layout(&PipelineLayoutDescriptor {
122                    label: None,
123                    bind_group_layouts: &[&layout],
124                    push_constant_ranges: &[],
125                })
126        });
127
128        Arc::new(
129            self.device
130                .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
131                    label: Some(&kernel.entrypoint_name),
132                    layout: layout.as_ref(),
133                    module: &module,
134                    entry_point: Some(&kernel.entrypoint_name),
135                    compilation_options: wgpu::PipelineCompilationOptions {
136                        zero_initialize_workgroup_memory: false,
137                        ..Default::default()
138                    },
139                    cache: None,
140                }),
141        )
142    }
143}
144
145#[cfg(all(not(feature = "spirv"), not(feature = "msl")))]
146pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
147    wgsl::request_device(adapter).await
148}
149
150#[cfg(feature = "spirv")]
151pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
152    if is_vulkan(adapter) {
153        vulkan::request_vulkan_device(adapter).await
154    } else {
155        wgsl::request_device(adapter).await
156    }
157}
158
159#[cfg(all(feature = "msl", target_os = "macos"))]
160pub async fn request_device(adapter: &Adapter) -> (Device, Queue) {
161    use super::metal;
162
163    if is_metal(adapter) {
164        metal::request_metal_device(adapter).await
165    } else {
166        panic!("metal device not found!");
167    }
168}
169
170#[cfg(all(not(feature = "spirv"), not(feature = "msl")))]
171pub fn register_features(
172    adapter: &Adapter,
173    props: &mut DeviceProperties<Feature>,
174    comp_options: &mut WgpuCompilationOptions,
175) {
176    wgsl::register_wgsl_features(adapter, props, comp_options);
177}
178
179#[cfg(feature = "spirv")]
180pub fn register_features(
181    adapter: &Adapter,
182    props: &mut DeviceProperties<Feature>,
183    comp_options: &mut WgpuCompilationOptions,
184) {
185    if is_vulkan(adapter) {
186        vulkan::register_vulkan_features(adapter, props, comp_options);
187    } else {
188        wgsl::register_wgsl_features(adapter, props, comp_options);
189    }
190}
191
192#[cfg(all(feature = "msl", target_os = "macos"))]
193pub fn register_features(
194    adapter: &Adapter,
195    props: &mut DeviceProperties<Feature>,
196    comp_options: &mut WgpuCompilationOptions,
197) {
198    if is_metal(adapter) {
199        metal::register_metal_features(adapter, props, comp_options);
200    } else {
201        panic!("metal device not found!");
202    }
203}
204
205#[cfg(feature = "spirv")]
206fn is_vulkan(adapter: &Adapter) -> bool {
207    unsafe { adapter.as_hal::<wgpu::hal::api::Vulkan, _, _>(|adapter| adapter.is_some()) }
208}
209
210#[cfg(all(feature = "msl", target_os = "macos"))]
211fn is_metal(adapter: &Adapter) -> bool {
212    unsafe { adapter.as_hal::<wgpu::hal::api::Metal, _, _>(|adapter| adapter.is_some()) }
213}