Skip to main content

cubecl_spirv/
target.rs

1use cubecl_core::prelude::{KernelArg, Visibility};
2use rspirv::{
3    dr::Operand,
4    spirv::{
5        self, AddressingModel, Capability, Decoration, ExecutionMode, ExecutionModel, MemoryModel,
6        StorageClass, Word,
7    },
8};
9use std::{fmt::Debug, iter};
10
11use crate::{SpirvCompiler, extensions::TargetExtensions, item::Item};
12
13pub trait SpirvTarget:
14    TargetExtensions<Self> + Debug + Clone + Default + Send + Sync + 'static
15{
16    fn set_modes(
17        &mut self,
18        b: &mut SpirvCompiler<Self>,
19        main: Word,
20        builtins: Vec<Word>,
21        cube_dims: Vec<u32>,
22    );
23    fn generate_binding(
24        &mut self,
25        b: &mut SpirvCompiler<Self>,
26        binding: KernelArg,
27        name: String,
28    ) -> Word;
29    fn generate_info_binding(&mut self, b: &mut SpirvCompiler<Self>, offset: u32) -> Word;
30    fn info_storage_class(b: &mut SpirvCompiler<Self>) -> StorageClass;
31
32    fn set_kernel_name(&mut self, name: impl Into<String>);
33}
34
35#[derive(Clone)]
36pub struct GLCompute {
37    kernel_name: String,
38}
39
40impl Default for GLCompute {
41    fn default() -> Self {
42        Self {
43            kernel_name: "main".into(),
44        }
45    }
46}
47
48impl Debug for GLCompute {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        f.write_str("gl_compute")
51    }
52}
53
54impl SpirvTarget for GLCompute {
55    fn set_modes(
56        &mut self,
57        b: &mut SpirvCompiler<Self>,
58        main: Word,
59        builtins: Vec<Word>,
60        cube_dims: Vec<u32>,
61    ) {
62        let interface: Vec<u32> = builtins
63            .into_iter()
64            .chain(b.state.buffers.iter().copied())
65            .chain(iter::once(b.state.info))
66            .chain(b.state.shared_arrays.values().map(|it| it.id))
67            .chain(b.state.shared.values().map(|it| it.id))
68            .collect();
69
70        let version = b.compilation_options.vulkan.max_spirv_version;
71
72        b.capability(Capability::Shader);
73        b.capability(Capability::VulkanMemoryModel);
74        b.capability(Capability::VulkanMemoryModelDeviceScope);
75        b.capability(Capability::GroupNonUniform);
76
77        if b.compilation_options.vulkan.supports_explicit_smem {
78            b.extension("SPV_KHR_workgroup_memory_explicit_layout");
79        }
80
81        if b.addr_type.size_bits() == 64 {
82            b.extension("SPV_EXT_shader_64bit_indexing");
83            b.capability(Capability::Shader64BitIndexingEXT);
84            b.execution_mode(main, ExecutionMode::Shader64BitIndexingEXT, []);
85        }
86
87        let caps: Vec<_> = b.capabilities.iter().copied().collect();
88        for cap in caps.iter() {
89            b.capability(*cap);
90        }
91
92        if caps.contains(&Capability::CooperativeMatrixKHR) {
93            b.extension("SPV_KHR_cooperative_matrix");
94        }
95
96        if caps.contains(&Capability::AtomicFloat16AddEXT) {
97            b.extension("SPV_EXT_shader_atomic_float16_add");
98        }
99
100        if caps.contains(&Capability::AtomicFloat32AddEXT)
101            | caps.contains(&Capability::AtomicFloat64AddEXT)
102        {
103            b.extension("SPV_EXT_shader_atomic_float_add");
104        }
105
106        if caps.contains(&Capability::AtomicFloat16MinMaxEXT)
107            | caps.contains(&Capability::AtomicFloat32MinMaxEXT)
108            | caps.contains(&Capability::AtomicFloat64MinMaxEXT)
109        {
110            b.extension("SPV_EXT_shader_atomic_float_min_max");
111        }
112
113        if caps.contains(&Capability::AtomicFloat16VectorNV) {
114            b.extension("SPV_NV_shader_atomic_fp16_vector");
115        }
116
117        if caps.contains(&Capability::BFloat16TypeKHR)
118            || caps.contains(&Capability::BFloat16CooperativeMatrixKHR)
119            || caps.contains(&Capability::BFloat16DotProductKHR)
120        {
121            b.extension("SPV_KHR_bfloat16");
122        }
123
124        if caps.contains(&Capability::Float8EXT)
125            || caps.contains(&Capability::Float8CooperativeMatrixEXT)
126        {
127            b.extension("SPV_EXT_float8");
128        }
129
130        if caps.contains(&Capability::FloatControls2) {
131            b.extension("SPV_KHR_float_controls2");
132        }
133
134        if b.debug_symbols {
135            b.extension("SPV_KHR_non_semantic_info");
136        }
137
138        if version < (1, 5) {
139            b.extension("SPV_KHR_vulkan_memory_model");
140            if caps.contains(&Capability::StorageBuffer8BitAccess) {
141                b.extension("SPV_KHR_8bit_storage");
142            }
143        }
144
145        if version < (1, 3) {
146            b.extension("SPV_KHR_storage_buffer_storage_class");
147
148            if caps.contains(&Capability::StorageBuffer16BitAccess) {
149                b.extension("SPV_KHR_16bit_storage");
150            }
151        }
152
153        b.memory_model(AddressingModel::Logical, MemoryModel::Vulkan);
154        b.entry_point(
155            ExecutionModel::GLCompute,
156            main,
157            &self.kernel_name,
158            interface,
159        );
160        b.execution_mode(main, spirv::ExecutionMode::LocalSize, cube_dims);
161    }
162
163    fn generate_binding(
164        &mut self,
165        b: &mut SpirvCompiler<Self>,
166        binding: KernelArg,
167        name: String,
168    ) -> Word {
169        let index = binding.id;
170        let item = b.compile_type(binding.ty);
171        let item_size = item.size();
172        match item.elem().size() {
173            1 => {
174                b.capabilities.insert(Capability::StorageBuffer8BitAccess);
175            }
176            2 => {
177                b.capabilities.insert(Capability::StorageBuffer16BitAccess);
178            }
179            _ => {}
180        }
181
182        let item = match binding.size {
183            Some(size) => Item::Array(Box::new(item), size as u32),
184            None => Item::RuntimeArray(Box::new(item)),
185        };
186        let arr = item.id(b); // pre-generate type
187
188        if !b.state.array_types.contains(&arr) {
189            b.decorate(arr, Decoration::ArrayStride, [item_size.into()]);
190            b.state.array_types.insert(arr);
191        }
192
193        let struct_ty = b.id();
194        b.type_struct_id(Some(struct_ty), vec![arr]);
195
196        let storage_class = StorageClass::StorageBuffer;
197        let ptr_ty = b.type_pointer(None, storage_class, struct_ty);
198        let var = b.variable(ptr_ty, None, storage_class, None);
199
200        b.debug_name(var, name);
201
202        if matches!(binding.visibility, Visibility::Read) {
203            b.decorate(var, Decoration::NonWritable, vec![]);
204        }
205
206        b.decorate(var, Decoration::DescriptorSet, vec![0u32.into()]);
207        b.decorate(var, Decoration::Binding, vec![index.into()]);
208        b.decorate(struct_ty, Decoration::Block, vec![]);
209        b.member_decorate(struct_ty, 0, Decoration::Offset, vec![0u32.into()]);
210
211        var
212    }
213
214    /// Generate info binding struct and variable.
215    /// SPIR-V structs have explicit offsets so unlike other targets we don't need to pad the length.
216    fn generate_info_binding(&mut self, b: &mut SpirvCompiler<Self>, index: u32) -> Word {
217        let address_type = b.addr_type;
218        let struct_ty = b.id();
219
220        let mut fields = Vec::new();
221
222        let scalars = b.info.scalars.clone();
223
224        for scalar in scalars {
225            let scalar_ty = b.compile_storage_type(scalar.ty);
226            match scalar_ty.size() {
227                1 => {
228                    b.capabilities.insert(Capability::StorageBuffer8BitAccess);
229                    b.capabilities
230                        .insert(Capability::UniformAndStorageBuffer8BitAccess);
231                }
232                2 => {
233                    b.capabilities.insert(Capability::StorageBuffer16BitAccess);
234                    b.capabilities
235                        .insert(Capability::UniformAndStorageBuffer16BitAccess);
236                }
237                _ => {}
238            }
239
240            let arr_ty = Item::Array(
241                Box::new(Item::Scalar(scalar_ty)),
242                scalar.padded_size() as u32,
243            );
244            let arr_ty_id = arr_ty.id(b);
245
246            if !b.state.array_types.contains(&arr_ty_id) {
247                b.decorate(
248                    arr_ty_id,
249                    Decoration::ArrayStride,
250                    [(scalar.ty.size() as u32).into()],
251                );
252                b.state.array_types.insert(arr_ty_id);
253            }
254
255            b.member_decorate(
256                struct_ty,
257                fields.len() as u32,
258                Decoration::Offset,
259                [(scalar.offset as u32).into()],
260            );
261            fields.push(arr_ty_id);
262        }
263
264        if let Some(field) = b.info.sized_meta {
265            let scalar_ty = b.compile_storage_type(field.ty);
266            let arr_ty = Item::Array(Box::new(Item::Scalar(scalar_ty)), field.size as u32);
267            let arr_ty_id = arr_ty.id(b);
268
269            if !b.state.array_types.contains(&arr_ty_id) {
270                b.decorate(
271                    arr_ty_id,
272                    Decoration::ArrayStride,
273                    [(address_type.size() as u32).into()],
274                );
275                b.state.array_types.insert(arr_ty_id);
276            }
277
278            b.member_decorate(
279                struct_ty,
280                fields.len() as u32,
281                Decoration::Offset,
282                [(field.offset as u32).into()],
283            );
284            fields.push(arr_ty_id);
285        }
286
287        if b.info.has_dynamic_meta {
288            let offset = b.info.dynamic_meta_offset;
289            let scalar_ty = b.compile_storage_type(address_type);
290            let arr_ty = Item::RuntimeArray(Box::new(Item::Scalar(scalar_ty)));
291            let arr_ty_id = arr_ty.id(b);
292
293            if !b.state.array_types.contains(&arr_ty_id) {
294                b.decorate(
295                    arr_ty_id,
296                    Decoration::ArrayStride,
297                    [(address_type.size() as u32).into()],
298                );
299                b.state.array_types.insert(arr_ty_id);
300            }
301
302            b.member_decorate(
303                struct_ty,
304                fields.len() as u32,
305                Decoration::Offset,
306                [Operand::LiteralBit32(offset as u32)],
307            );
308            fields.push(arr_ty_id);
309        }
310
311        b.type_struct_id(Some(struct_ty), fields);
312
313        let location = Self::info_storage_class(b);
314        let ptr_ty = b.type_pointer(None, location, struct_ty);
315        let var = b.variable(ptr_ty, None, location, None);
316
317        b.debug_name(var, "info");
318        b.decorate(var, Decoration::NonWritable, vec![]);
319
320        b.decorate(var, Decoration::DescriptorSet, vec![0u32.into()]);
321        b.decorate(var, Decoration::Binding, vec![index.into()]);
322        b.decorate(struct_ty, Decoration::Block, vec![]);
323
324        var
325    }
326
327    fn info_storage_class(b: &mut SpirvCompiler<Self>) -> StorageClass {
328        if !b
329            .compilation_options
330            .vulkan
331            .supports_uniform_standard_layout
332        {
333            return StorageClass::StorageBuffer;
334        }
335        let is_dynamic = b.info.metadata.num_extended_meta() > 0;
336        if b.compilation_options.vulkan.supports_uniform_unsized_array || !is_dynamic {
337            StorageClass::Uniform
338        } else {
339            StorageClass::StorageBuffer
340        }
341    }
342
343    fn set_kernel_name(&mut self, name: impl Into<String>) {
344        self.kernel_name = name.into();
345    }
346}