1use cubecl_core::prelude::{Binding, Location, Visibility};
2use rspirv::spirv::{
3 self, AddressingModel, Capability, Decoration, ExecutionModel, MemoryModel, StorageClass, Word,
4};
5use std::{fmt::Debug, iter};
6
7use crate::{SpirvCompiler, extensions::TargetExtensions, item::Item};
8
9pub trait SpirvTarget:
10 TargetExtensions<Self> + Debug + Clone + Default + Send + Sync + 'static
11{
12 fn set_modes(
13 &mut self,
14 b: &mut SpirvCompiler<Self>,
15 main: Word,
16 builtins: Vec<Word>,
17 cube_dims: Vec<u32>,
18 );
19 fn generate_binding(
20 &mut self,
21 b: &mut SpirvCompiler<Self>,
22 binding: Binding,
23 name: String,
24 ) -> Word;
25
26 fn set_kernel_name(&mut self, name: impl Into<String>);
27}
28
29#[derive(Clone)]
30pub struct GLCompute {
31 kernel_name: String,
32}
33
34impl Default for GLCompute {
35 fn default() -> Self {
36 Self {
37 kernel_name: "main".into(),
38 }
39 }
40}
41
42impl Debug for GLCompute {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 f.write_str("gl_compute")
45 }
46}
47
48impl SpirvTarget for GLCompute {
49 fn set_modes(
50 &mut self,
51 b: &mut SpirvCompiler<Self>,
52 main: Word,
53 builtins: Vec<Word>,
54 cube_dims: Vec<u32>,
55 ) {
56 let interface: Vec<u32> = builtins
57 .into_iter()
58 .chain(b.state.buffers.iter().copied())
59 .chain(iter::once(b.state.info))
60 .chain(b.state.scalar_bindings.values().copied())
61 .chain(b.state.shared_arrays.values().map(|it| it.id))
62 .chain(b.state.shared.values().map(|it| it.id))
63 .collect();
64
65 b.capability(Capability::Shader);
66 b.capability(Capability::VulkanMemoryModel);
67 b.capability(Capability::VulkanMemoryModelDeviceScope);
68 b.capability(Capability::GroupNonUniform);
69
70 if b.compilation_options.supports_explicit_smem {
71 b.extension("SPV_KHR_workgroup_memory_explicit_layout");
72 }
73
74 let caps: Vec<_> = b.capabilities.iter().copied().collect();
75 for cap in caps.iter() {
76 b.capability(*cap);
77 }
78
79 if caps.contains(&Capability::CooperativeMatrixKHR) {
80 b.extension("SPV_KHR_cooperative_matrix");
81 }
82
83 if caps.contains(&Capability::AtomicFloat16AddEXT) {
84 b.extension("SPV_EXT_shader_atomic_float16_add");
85 }
86
87 if caps.contains(&Capability::AtomicFloat32AddEXT)
88 | caps.contains(&Capability::AtomicFloat64AddEXT)
89 {
90 b.extension("SPV_EXT_shader_atomic_float_add");
91 }
92
93 if caps.contains(&Capability::AtomicFloat16MinMaxEXT)
94 | caps.contains(&Capability::AtomicFloat32MinMaxEXT)
95 | caps.contains(&Capability::AtomicFloat64MinMaxEXT)
96 {
97 b.extension("SPV_EXT_shader_atomic_float_min_max");
98 }
99
100 if caps.contains(&Capability::BFloat16TypeKHR)
101 || caps.contains(&Capability::BFloat16CooperativeMatrixKHR)
102 || caps.contains(&Capability::BFloat16DotProductKHR)
103 {
104 b.extension("SPV_KHR_bfloat16");
105 }
106
107 if caps.contains(&Capability::Float8EXT)
108 || caps.contains(&Capability::Float8CooperativeMatrixEXT)
109 {
110 b.extension("SPV_EXT_float8");
111 }
112
113 if caps.contains(&Capability::FloatControls2) {
114 b.extension("SPV_KHR_float_controls2");
115 }
116
117 if b.debug_symbols {
118 b.extension("SPV_KHR_non_semantic_info");
119 }
120
121 b.memory_model(AddressingModel::Logical, MemoryModel::Vulkan);
122 b.entry_point(
123 ExecutionModel::GLCompute,
124 main,
125 &self.kernel_name,
126 interface,
127 );
128 b.execution_mode(main, spirv::ExecutionMode::LocalSize, cube_dims);
129 }
130
131 fn generate_binding(
132 &mut self,
133 b: &mut SpirvCompiler<Self>,
134 binding: Binding,
135 name: String,
136 ) -> Word {
137 let index = binding.id;
138 let item = b.compile_type(binding.ty);
139 let item_size = item.size();
140 let item = match binding.size {
141 Some(size) => Item::Array(Box::new(item), size as u32),
142 None => Item::RuntimeArray(Box::new(item)),
143 };
144 let arr = item.id(b); if !b.state.array_types.contains(&arr) {
147 b.decorate(arr, Decoration::ArrayStride, [item_size.into()]);
148 b.state.array_types.insert(arr);
149 }
150
151 let struct_ty = b.id();
152 b.type_struct_id(Some(struct_ty), vec![arr]);
153
154 let location = match binding.location {
155 Location::Cube => StorageClass::Workgroup,
156 Location::Storage => StorageClass::StorageBuffer,
157 };
158 let ptr_ty = b.type_pointer(None, location, struct_ty);
159 let var = b.variable(ptr_ty, None, location, None);
160
161 b.debug_name(var, name);
162
163 if matches!(binding.visibility, Visibility::Read) {
164 b.decorate(var, Decoration::NonWritable, vec![]);
165 }
166
167 b.decorate(var, Decoration::DescriptorSet, vec![0u32.into()]);
168 b.decorate(var, Decoration::Binding, vec![index.into()]);
169 b.decorate(struct_ty, Decoration::Block, vec![]);
170 b.member_decorate(struct_ty, 0, Decoration::Offset, vec![0u32.into()]);
171
172 var
173 }
174
175 fn set_kernel_name(&mut self, name: impl Into<String>) {
176 self.kernel_name = name.into();
177 }
178}