1use cubecl_core::compute::{Binding, Location, Visibility};
2use hashbrown::HashMap;
3use rspirv::{
4 dr,
5 spirv::{
6 self, AddressingModel, Capability, Decoration, ExecutionModel, MemoryModel, StorageClass,
7 Word,
8 },
9};
10use std::{fmt::Debug, iter};
11
12use crate::{
13 SpirvCompiler, debug,
14 extensions::{TargetExtensions, glcompute},
15 item::Item,
16};
17
18pub trait SpirvTarget:
19 TargetExtensions<Self> + Debug + Clone + Default + Send + Sync + 'static
20{
21 fn extensions(&mut self, b: &mut SpirvCompiler<Self>) -> HashMap<String, Word>;
22 fn set_modes(
23 &mut self,
24 b: &mut SpirvCompiler<Self>,
25 main: Word,
26 builtins: Vec<Word>,
27 cube_dims: Vec<u32>,
28 );
29 fn generate_binding(
30 &mut self,
31 b: &mut SpirvCompiler<Self>,
32 binding: Binding,
33 name: String,
34 ) -> Word;
35
36 fn set_kernel_name(&mut self, name: impl Into<String>);
37}
38
39#[derive(Clone)]
40pub struct GLCompute {
41 kernel_name: String,
42}
43
44impl Default for GLCompute {
45 fn default() -> Self {
46 Self {
47 kernel_name: "main".into(),
48 }
49 }
50}
51
52impl Debug for GLCompute {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 f.write_str("gl_compute")
55 }
56}
57
58impl SpirvTarget for GLCompute {
59 fn set_modes(
60 &mut self,
61 b: &mut SpirvCompiler<Self>,
62 main: Word,
63 builtins: Vec<Word>,
64 cube_dims: Vec<u32>,
65 ) {
66 let interface: Vec<u32> = builtins
67 .into_iter()
68 .chain(b.state.buffers.iter().copied())
69 .chain(iter::once(b.state.info))
70 .chain(b.state.scalar_bindings.values().copied())
71 .chain(b.state.shared_memories.values().map(|it| it.id))
72 .collect();
73
74 b.capability(Capability::Shader);
75 b.capability(Capability::VulkanMemoryModel);
76 b.capability(Capability::VulkanMemoryModelDeviceScope);
77
78 let caps: Vec<_> = b.capabilities.iter().copied().collect();
79 for cap in caps.iter() {
80 b.capability(*cap);
81 }
82 if b.float_controls {
83 let inst = dr::Instruction::new(
84 spirv::Op::Capability,
85 None,
86 None,
87 vec![dr::Operand::LiteralBit32(6029)],
88 );
89 b.module_mut().capabilities.push(inst);
90 }
91
92 if caps.contains(&Capability::CooperativeMatrixKHR) {
93 b.extension("SPV_KHR_cooperative_matrix");
94 }
95
96 if caps.contains(&Capability::AtomicFloat16AddEXT) {
97 b.extension("SPV_EXT_shader_atomic_float16_add");
98 }
99
100 if caps.contains(&Capability::AtomicFloat32AddEXT)
101 | caps.contains(&Capability::AtomicFloat64AddEXT)
102 {
103 b.extension("SPV_EXT_shader_atomic_float_add");
104 }
105
106 if caps.contains(&Capability::AtomicFloat16MinMaxEXT)
107 | caps.contains(&Capability::AtomicFloat32MinMaxEXT)
108 | caps.contains(&Capability::AtomicFloat64MinMaxEXT)
109 {
110 b.extension("SPV_EXT_shader_atomic_float_min_max");
111 }
112
113 if b.float_controls {
114 b.extension("SPV_KHR_float_controls2");
115 }
116
117 if b.debug_symbols {
118 b.extension("SPV_KHR_non_semantic_info");
119 }
120
121 b.memory_model(AddressingModel::Logical, MemoryModel::Vulkan);
122 b.entry_point(
123 ExecutionModel::GLCompute,
124 main,
125 &self.kernel_name,
126 interface,
127 );
128 b.execution_mode(main, spirv::ExecutionMode::LocalSize, cube_dims);
129
130 if b.float_controls {
131 b.declare_float_execution_modes(main);
132 }
133 }
134
135 fn generate_binding(
136 &mut self,
137 b: &mut SpirvCompiler<Self>,
138 binding: Binding,
139 name: String,
140 ) -> Word {
141 let index = binding.id;
142 let item = b.compile_item(binding.item);
143 let item = match binding.size {
144 Some(size) => Item::Array(Box::new(item), size as u32),
145 None => Item::RuntimeArray(Box::new(item)),
146 };
147 let arr = item.id(b); let struct_ty = b.id();
149 b.type_struct_id(Some(struct_ty), vec![arr]);
150
151 let location = match binding.location {
152 Location::Cube => StorageClass::Workgroup,
153 Location::Storage => StorageClass::StorageBuffer,
154 };
155 let ptr_ty = b.type_pointer(None, location, struct_ty);
156 let var = b.variable(ptr_ty, None, location, None);
157
158 b.debug_name(var, name);
159
160 if matches!(binding.visibility, Visibility::Read) {
161 b.decorate(var, Decoration::NonWritable, vec![]);
162 }
163
164 b.decorate(var, Decoration::DescriptorSet, vec![0u32.into()]);
165 b.decorate(var, Decoration::Binding, vec![index.into()]);
166 b.decorate(struct_ty, Decoration::Block, vec![]);
167 b.member_decorate(struct_ty, 0, Decoration::Offset, vec![0u32.into()]);
168
169 var
170 }
171
172 fn extensions(&mut self, b: &mut SpirvCompiler<Self>) -> HashMap<String, Word> {
173 let mut extensions = HashMap::new();
174 extensions.insert(
175 glcompute::STD_NAME.to_string(),
176 b.ext_inst_import(glcompute::STD_NAME),
177 );
178 if b.debug_symbols {
179 extensions.insert(
180 debug::DEBUG_EXT_NAME.to_string(),
181 b.ext_inst_import(debug::DEBUG_EXT_NAME),
182 );
183 extensions.insert(
184 debug::PRINT_EXT_NAME.to_string(),
185 b.ext_inst_import(debug::PRINT_EXT_NAME),
186 );
187 }
188 extensions
189 }
190
191 fn set_kernel_name(&mut self, name: impl Into<String>) {
192 self.kernel_name = name.into();
193 }
194}