1use cubecl_core::prelude::{Binding, Location, Visibility};
2use rspirv::spirv::{
3 self, AddressingModel, Capability, Decoration, ExecutionMode, ExecutionModel, MemoryModel,
4 StorageClass, Word,
5};
6use std::{fmt::Debug, iter};
7
8use crate::{SpirvCompiler, extensions::TargetExtensions, item::Item};
9
10pub trait SpirvTarget:
11 TargetExtensions<Self> + Debug + Clone + Default + Send + Sync + 'static
12{
13 fn set_modes(
14 &mut self,
15 b: &mut SpirvCompiler<Self>,
16 main: Word,
17 builtins: Vec<Word>,
18 cube_dims: Vec<u32>,
19 );
20 fn generate_binding(
21 &mut self,
22 b: &mut SpirvCompiler<Self>,
23 binding: Binding,
24 name: String,
25 ) -> Word;
26
27 fn set_kernel_name(&mut self, name: impl Into<String>);
28}
29
30#[derive(Clone)]
31pub struct GLCompute {
32 kernel_name: String,
33}
34
35impl Default for GLCompute {
36 fn default() -> Self {
37 Self {
38 kernel_name: "main".into(),
39 }
40 }
41}
42
43impl Debug for GLCompute {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 f.write_str("gl_compute")
46 }
47}
48
49impl SpirvTarget for GLCompute {
50 fn set_modes(
51 &mut self,
52 b: &mut SpirvCompiler<Self>,
53 main: Word,
54 builtins: Vec<Word>,
55 cube_dims: Vec<u32>,
56 ) {
57 let interface: Vec<u32> = builtins
58 .into_iter()
59 .chain(b.state.buffers.iter().copied())
60 .chain(iter::once(b.state.info))
61 .chain(b.state.scalar_bindings.values().copied())
62 .chain(b.state.shared_arrays.values().map(|it| it.id))
63 .chain(b.state.shared.values().map(|it| it.id))
64 .collect();
65
66 b.capability(Capability::Shader);
67 b.capability(Capability::VulkanMemoryModel);
68 b.capability(Capability::VulkanMemoryModelDeviceScope);
69 b.capability(Capability::GroupNonUniform);
70
71 if b.compilation_options.supports_explicit_smem {
72 b.extension("SPV_KHR_workgroup_memory_explicit_layout");
73 }
74
75 if b.addr_type.size_bits() == 64 {
76 b.extension("SPV_EXT_shader_64bit_indexing");
77 b.capability(Capability::Shader64BitIndexingEXT);
78 b.execution_mode(main, ExecutionMode::Shader64BitIndexingEXT, []);
79 }
80
81 let caps: Vec<_> = b.capabilities.iter().copied().collect();
82 for cap in caps.iter() {
83 b.capability(*cap);
84 }
85
86 if caps.contains(&Capability::CooperativeMatrixKHR) {
87 b.extension("SPV_KHR_cooperative_matrix");
88 }
89
90 if caps.contains(&Capability::AtomicFloat16AddEXT) {
91 b.extension("SPV_EXT_shader_atomic_float16_add");
92 }
93
94 if caps.contains(&Capability::AtomicFloat32AddEXT)
95 | caps.contains(&Capability::AtomicFloat64AddEXT)
96 {
97 b.extension("SPV_EXT_shader_atomic_float_add");
98 }
99
100 if caps.contains(&Capability::AtomicFloat16MinMaxEXT)
101 | caps.contains(&Capability::AtomicFloat32MinMaxEXT)
102 | caps.contains(&Capability::AtomicFloat64MinMaxEXT)
103 {
104 b.extension("SPV_EXT_shader_atomic_float_min_max");
105 }
106
107 if caps.contains(&Capability::BFloat16TypeKHR)
108 || caps.contains(&Capability::BFloat16CooperativeMatrixKHR)
109 || caps.contains(&Capability::BFloat16DotProductKHR)
110 {
111 b.extension("SPV_KHR_bfloat16");
112 }
113
114 if caps.contains(&Capability::Float8EXT)
115 || caps.contains(&Capability::Float8CooperativeMatrixEXT)
116 {
117 b.extension("SPV_EXT_float8");
118 }
119
120 if caps.contains(&Capability::FloatControls2) {
121 b.extension("SPV_KHR_float_controls2");
122 }
123
124 if b.debug_symbols {
125 b.extension("SPV_KHR_non_semantic_info");
126 }
127
128 b.memory_model(AddressingModel::Logical, MemoryModel::Vulkan);
129 b.entry_point(
130 ExecutionModel::GLCompute,
131 main,
132 &self.kernel_name,
133 interface,
134 );
135 b.execution_mode(main, spirv::ExecutionMode::LocalSize, cube_dims);
136 }
137
138 fn generate_binding(
139 &mut self,
140 b: &mut SpirvCompiler<Self>,
141 binding: Binding,
142 name: String,
143 ) -> Word {
144 let index = binding.id;
145 let item = b.compile_type(binding.ty);
146 let item_size = item.size();
147 let item = match binding.size {
148 Some(size) => Item::Array(Box::new(item), size as u32),
149 None => Item::RuntimeArray(Box::new(item)),
150 };
151 let arr = item.id(b); if !b.state.array_types.contains(&arr) {
154 b.decorate(arr, Decoration::ArrayStride, [item_size.into()]);
155 b.state.array_types.insert(arr);
156 }
157
158 let struct_ty = b.id();
159 b.type_struct_id(Some(struct_ty), vec![arr]);
160
161 let location = match binding.location {
162 Location::Cube => StorageClass::Workgroup,
163 Location::Storage => StorageClass::StorageBuffer,
164 };
165 let ptr_ty = b.type_pointer(None, location, struct_ty);
166 let var = b.variable(ptr_ty, None, location, None);
167
168 b.debug_name(var, name);
169
170 if matches!(binding.visibility, Visibility::Read) {
171 b.decorate(var, Decoration::NonWritable, vec![]);
172 }
173
174 b.decorate(var, Decoration::DescriptorSet, vec![0u32.into()]);
175 b.decorate(var, Decoration::Binding, vec![index.into()]);
176 b.decorate(struct_ty, Decoration::Block, vec![]);
177 b.member_decorate(struct_ty, 0, Decoration::Offset, vec![0u32.into()]);
178
179 var
180 }
181
182 fn set_kernel_name(&mut self, name: impl Into<String>) {
183 self.kernel_name = name.into();
184 }
185}