1use cubecl_core::prelude::{KernelArg, Visibility};
2use rspirv::{
3 dr::Operand,
4 spirv::{
5 self, AddressingModel, Capability, Decoration, ExecutionMode, ExecutionModel, MemoryModel,
6 StorageClass, Word,
7 },
8};
9use std::{fmt::Debug, iter};
10
11use crate::{SpirvCompiler, extensions::TargetExtensions, item::Item};
12
13pub trait SpirvTarget:
14 TargetExtensions<Self> + Debug + Clone + Default + Send + Sync + 'static
15{
16 fn set_modes(
17 &mut self,
18 b: &mut SpirvCompiler<Self>,
19 main: Word,
20 builtins: Vec<Word>,
21 cube_dims: Vec<u32>,
22 );
23 fn generate_binding(
24 &mut self,
25 b: &mut SpirvCompiler<Self>,
26 binding: KernelArg,
27 name: String,
28 ) -> Word;
29 fn generate_info_binding(&mut self, b: &mut SpirvCompiler<Self>, offset: u32) -> Word;
30 fn info_storage_class(b: &mut SpirvCompiler<Self>) -> StorageClass;
31
32 fn set_kernel_name(&mut self, name: impl Into<String>);
33}
34
35#[derive(Clone)]
36pub struct GLCompute {
37 kernel_name: String,
38}
39
40impl Default for GLCompute {
41 fn default() -> Self {
42 Self {
43 kernel_name: "main".into(),
44 }
45 }
46}
47
48impl Debug for GLCompute {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 f.write_str("gl_compute")
51 }
52}
53
54impl SpirvTarget for GLCompute {
55 fn set_modes(
56 &mut self,
57 b: &mut SpirvCompiler<Self>,
58 main: Word,
59 builtins: Vec<Word>,
60 cube_dims: Vec<u32>,
61 ) {
62 let interface: Vec<u32> = builtins
63 .into_iter()
64 .chain(b.state.buffers.iter().copied())
65 .chain(iter::once(b.state.info))
66 .chain(b.state.shared_arrays.values().map(|it| it.id))
67 .chain(b.state.shared.values().map(|it| it.id))
68 .collect();
69
70 let version = b.compilation_options.vulkan.max_spirv_version;
71
72 b.capability(Capability::Shader);
73 b.capability(Capability::VulkanMemoryModel);
74 b.capability(Capability::VulkanMemoryModelDeviceScope);
75 b.capability(Capability::GroupNonUniform);
76
77 if b.compilation_options.vulkan.supports_explicit_smem {
78 b.extension("SPV_KHR_workgroup_memory_explicit_layout");
79 }
80
81 if b.addr_type.size_bits() == 64 {
82 b.extension("SPV_EXT_shader_64bit_indexing");
83 b.capability(Capability::Shader64BitIndexingEXT);
84 b.execution_mode(main, ExecutionMode::Shader64BitIndexingEXT, []);
85 }
86
87 let caps: Vec<_> = b.capabilities.iter().copied().collect();
88 for cap in caps.iter() {
89 b.capability(*cap);
90 }
91
92 if caps.contains(&Capability::CooperativeMatrixKHR) {
93 b.extension("SPV_KHR_cooperative_matrix");
94 }
95
96 if caps.contains(&Capability::AtomicFloat16AddEXT) {
97 b.extension("SPV_EXT_shader_atomic_float16_add");
98 }
99
100 if caps.contains(&Capability::AtomicFloat32AddEXT)
101 | caps.contains(&Capability::AtomicFloat64AddEXT)
102 {
103 b.extension("SPV_EXT_shader_atomic_float_add");
104 }
105
106 if caps.contains(&Capability::AtomicFloat16MinMaxEXT)
107 | caps.contains(&Capability::AtomicFloat32MinMaxEXT)
108 | caps.contains(&Capability::AtomicFloat64MinMaxEXT)
109 {
110 b.extension("SPV_EXT_shader_atomic_float_min_max");
111 }
112
113 if caps.contains(&Capability::AtomicFloat16VectorNV) {
114 b.extension("SPV_NV_shader_atomic_fp16_vector");
115 }
116
117 if caps.contains(&Capability::BFloat16TypeKHR)
118 || caps.contains(&Capability::BFloat16CooperativeMatrixKHR)
119 || caps.contains(&Capability::BFloat16DotProductKHR)
120 {
121 b.extension("SPV_KHR_bfloat16");
122 }
123
124 if caps.contains(&Capability::Float8EXT)
125 || caps.contains(&Capability::Float8CooperativeMatrixEXT)
126 {
127 b.extension("SPV_EXT_float8");
128 }
129
130 if caps.contains(&Capability::FloatControls2) {
131 b.extension("SPV_KHR_float_controls2");
132 }
133
134 if b.debug_symbols {
135 b.extension("SPV_KHR_non_semantic_info");
136 }
137
138 if version < (1, 5) {
139 b.extension("SPV_KHR_vulkan_memory_model");
140 if caps.contains(&Capability::StorageBuffer8BitAccess) {
141 b.extension("SPV_KHR_8bit_storage");
142 }
143 }
144
145 if version < (1, 3) {
146 b.extension("SPV_KHR_storage_buffer_storage_class");
147
148 if caps.contains(&Capability::StorageBuffer16BitAccess) {
149 b.extension("SPV_KHR_16bit_storage");
150 }
151 }
152
153 b.memory_model(AddressingModel::Logical, MemoryModel::Vulkan);
154 b.entry_point(
155 ExecutionModel::GLCompute,
156 main,
157 &self.kernel_name,
158 interface,
159 );
160 b.execution_mode(main, spirv::ExecutionMode::LocalSize, cube_dims);
161 }
162
163 fn generate_binding(
164 &mut self,
165 b: &mut SpirvCompiler<Self>,
166 binding: KernelArg,
167 name: String,
168 ) -> Word {
169 let index = binding.id;
170 let item = b.compile_type(binding.ty);
171 let item_size = item.size();
172 match item.elem().size() {
173 1 => {
174 b.capabilities.insert(Capability::StorageBuffer8BitAccess);
175 }
176 2 => {
177 b.capabilities.insert(Capability::StorageBuffer16BitAccess);
178 }
179 _ => {}
180 }
181
182 let item = match binding.size {
183 Some(size) => Item::Array(Box::new(item), size as u32),
184 None => Item::RuntimeArray(Box::new(item)),
185 };
186 let arr = item.id(b); if !b.state.array_types.contains(&arr) {
189 b.decorate(arr, Decoration::ArrayStride, [item_size.into()]);
190 b.state.array_types.insert(arr);
191 }
192
193 let struct_ty = b.id();
194 b.type_struct_id(Some(struct_ty), vec![arr]);
195
196 let storage_class = StorageClass::StorageBuffer;
197 let ptr_ty = b.type_pointer(None, storage_class, struct_ty);
198 let var = b.variable(ptr_ty, None, storage_class, None);
199
200 b.debug_name(var, name);
201
202 if matches!(binding.visibility, Visibility::Read) {
203 b.decorate(var, Decoration::NonWritable, vec![]);
204 }
205
206 b.decorate(var, Decoration::DescriptorSet, vec![0u32.into()]);
207 b.decorate(var, Decoration::Binding, vec![index.into()]);
208 b.decorate(struct_ty, Decoration::Block, vec![]);
209 b.member_decorate(struct_ty, 0, Decoration::Offset, vec![0u32.into()]);
210
211 var
212 }
213
214 fn generate_info_binding(&mut self, b: &mut SpirvCompiler<Self>, index: u32) -> Word {
217 let address_type = b.addr_type;
218 let struct_ty = b.id();
219
220 let mut fields = Vec::new();
221
222 let scalars = b.info.scalars.clone();
223
224 for scalar in scalars {
225 let scalar_ty = b.compile_storage_type(scalar.ty);
226 match scalar_ty.size() {
227 1 => {
228 b.capabilities.insert(Capability::StorageBuffer8BitAccess);
229 b.capabilities
230 .insert(Capability::UniformAndStorageBuffer8BitAccess);
231 }
232 2 => {
233 b.capabilities.insert(Capability::StorageBuffer16BitAccess);
234 b.capabilities
235 .insert(Capability::UniformAndStorageBuffer16BitAccess);
236 }
237 _ => {}
238 }
239
240 let arr_ty = Item::Array(
241 Box::new(Item::Scalar(scalar_ty)),
242 scalar.padded_size() as u32,
243 );
244 let arr_ty_id = arr_ty.id(b);
245
246 if !b.state.array_types.contains(&arr_ty_id) {
247 b.decorate(
248 arr_ty_id,
249 Decoration::ArrayStride,
250 [(scalar.ty.size() as u32).into()],
251 );
252 b.state.array_types.insert(arr_ty_id);
253 }
254
255 b.member_decorate(
256 struct_ty,
257 fields.len() as u32,
258 Decoration::Offset,
259 [(scalar.offset as u32).into()],
260 );
261 fields.push(arr_ty_id);
262 }
263
264 if let Some(field) = b.info.sized_meta {
265 let scalar_ty = b.compile_storage_type(field.ty);
266 let arr_ty = Item::Array(Box::new(Item::Scalar(scalar_ty)), field.size as u32);
267 let arr_ty_id = arr_ty.id(b);
268
269 if !b.state.array_types.contains(&arr_ty_id) {
270 b.decorate(
271 arr_ty_id,
272 Decoration::ArrayStride,
273 [(address_type.size() as u32).into()],
274 );
275 b.state.array_types.insert(arr_ty_id);
276 }
277
278 b.member_decorate(
279 struct_ty,
280 fields.len() as u32,
281 Decoration::Offset,
282 [(field.offset as u32).into()],
283 );
284 fields.push(arr_ty_id);
285 }
286
287 if b.info.has_dynamic_meta {
288 let offset = b.info.dynamic_meta_offset;
289 let scalar_ty = b.compile_storage_type(address_type);
290 let arr_ty = Item::RuntimeArray(Box::new(Item::Scalar(scalar_ty)));
291 let arr_ty_id = arr_ty.id(b);
292
293 if !b.state.array_types.contains(&arr_ty_id) {
294 b.decorate(
295 arr_ty_id,
296 Decoration::ArrayStride,
297 [(address_type.size() as u32).into()],
298 );
299 b.state.array_types.insert(arr_ty_id);
300 }
301
302 b.member_decorate(
303 struct_ty,
304 fields.len() as u32,
305 Decoration::Offset,
306 [Operand::LiteralBit32(offset as u32)],
307 );
308 fields.push(arr_ty_id);
309 }
310
311 b.type_struct_id(Some(struct_ty), fields);
312
313 let location = Self::info_storage_class(b);
314 let ptr_ty = b.type_pointer(None, location, struct_ty);
315 let var = b.variable(ptr_ty, None, location, None);
316
317 b.debug_name(var, "info");
318 b.decorate(var, Decoration::NonWritable, vec![]);
319
320 b.decorate(var, Decoration::DescriptorSet, vec![0u32.into()]);
321 b.decorate(var, Decoration::Binding, vec![index.into()]);
322 b.decorate(struct_ty, Decoration::Block, vec![]);
323
324 var
325 }
326
327 fn info_storage_class(b: &mut SpirvCompiler<Self>) -> StorageClass {
328 if !b
329 .compilation_options
330 .vulkan
331 .supports_uniform_standard_layout
332 {
333 return StorageClass::StorageBuffer;
334 }
335 let is_dynamic = b.info.metadata.num_extended_meta() > 0;
336 if b.compilation_options.vulkan.supports_uniform_unsized_array || !is_dynamic {
337 StorageClass::Uniform
338 } else {
339 StorageClass::StorageBuffer
340 }
341 }
342
343 fn set_kernel_name(&mut self, name: impl Into<String>) {
344 self.kernel_name = name.into();
345 }
346}