1use cubecl_core::prelude::{Binding, Location, Visibility};
2use rspirv::spirv::{
3 self, AddressingModel, Capability, Decoration, ExecutionModel, MemoryModel, StorageClass, Word,
4};
5use std::{fmt::Debug, iter};
6
7use crate::{SpirvCompiler, extensions::TargetExtensions, item::Item};
8
9pub trait SpirvTarget:
10 TargetExtensions<Self> + Debug + Clone + Default + Send + Sync + 'static
11{
12 fn set_modes(
13 &mut self,
14 b: &mut SpirvCompiler<Self>,
15 main: Word,
16 builtins: Vec<Word>,
17 cube_dims: Vec<u32>,
18 );
19 fn generate_binding(
20 &mut self,
21 b: &mut SpirvCompiler<Self>,
22 binding: Binding,
23 name: String,
24 ) -> Word;
25
26 fn set_kernel_name(&mut self, name: impl Into<String>);
27}
28
29#[derive(Clone)]
30pub struct GLCompute {
31 kernel_name: String,
32}
33
34impl Default for GLCompute {
35 fn default() -> Self {
36 Self {
37 kernel_name: "main".into(),
38 }
39 }
40}
41
42impl Debug for GLCompute {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 f.write_str("gl_compute")
45 }
46}
47
48impl SpirvTarget for GLCompute {
49 fn set_modes(
50 &mut self,
51 b: &mut SpirvCompiler<Self>,
52 main: Word,
53 builtins: Vec<Word>,
54 cube_dims: Vec<u32>,
55 ) {
56 let interface: Vec<u32> = builtins
57 .into_iter()
58 .chain(b.state.buffers.iter().copied())
59 .chain(iter::once(b.state.info))
60 .chain(b.state.scalar_bindings.values().copied())
61 .chain(b.state.shared_memories.values().map(|it| it.id))
62 .collect();
63
64 b.capability(Capability::Shader);
65 b.capability(Capability::VulkanMemoryModel);
66 b.capability(Capability::VulkanMemoryModelDeviceScope);
67 b.capability(Capability::GroupNonUniform);
68
69 if b.compilation_options.supports_explicit_smem {
70 b.extension("SPV_KHR_workgroup_memory_explicit_layout");
71 }
72
73 let caps: Vec<_> = b.capabilities.iter().copied().collect();
74 for cap in caps.iter() {
75 b.capability(*cap);
76 }
77
78 if caps.contains(&Capability::CooperativeMatrixKHR) {
79 b.extension("SPV_KHR_cooperative_matrix");
80 }
81
82 if caps.contains(&Capability::AtomicFloat16AddEXT) {
83 b.extension("SPV_EXT_shader_atomic_float16_add");
84 }
85
86 if caps.contains(&Capability::AtomicFloat32AddEXT)
87 | caps.contains(&Capability::AtomicFloat64AddEXT)
88 {
89 b.extension("SPV_EXT_shader_atomic_float_add");
90 }
91
92 if caps.contains(&Capability::AtomicFloat16MinMaxEXT)
93 | caps.contains(&Capability::AtomicFloat32MinMaxEXT)
94 | caps.contains(&Capability::AtomicFloat64MinMaxEXT)
95 {
96 b.extension("SPV_EXT_shader_atomic_float_min_max");
97 }
98
99 if caps.contains(&Capability::BFloat16TypeKHR)
100 || caps.contains(&Capability::BFloat16CooperativeMatrixKHR)
101 || caps.contains(&Capability::BFloat16DotProductKHR)
102 {
103 b.extension("SPV_KHR_bfloat16");
104 }
105
106 if caps.contains(&Capability::Float8EXT)
107 || caps.contains(&Capability::Float8CooperativeMatrixEXT)
108 {
109 b.extension("SPV_EXT_float8");
110 }
111
112 if caps.contains(&Capability::FloatControls2) {
113 b.extension("SPV_KHR_float_controls2");
114 }
115
116 if b.debug_symbols {
117 b.extension("SPV_KHR_non_semantic_info");
118 }
119
120 b.memory_model(AddressingModel::Logical, MemoryModel::Vulkan);
121 b.entry_point(
122 ExecutionModel::GLCompute,
123 main,
124 &self.kernel_name,
125 interface,
126 );
127 b.execution_mode(main, spirv::ExecutionMode::LocalSize, cube_dims);
128 }
129
130 fn generate_binding(
131 &mut self,
132 b: &mut SpirvCompiler<Self>,
133 binding: Binding,
134 name: String,
135 ) -> Word {
136 let index = binding.id;
137 let item = b.compile_type(binding.ty);
138 let item_size = item.size();
139 let item = match binding.size {
140 Some(size) => Item::Array(Box::new(item), size as u32),
141 None => Item::RuntimeArray(Box::new(item)),
142 };
143 let arr = item.id(b); if !b.state.array_types.contains(&arr) {
146 b.decorate(arr, Decoration::ArrayStride, [item_size.into()]);
147 b.state.array_types.insert(arr);
148 }
149
150 let struct_ty = b.id();
151 b.type_struct_id(Some(struct_ty), vec![arr]);
152
153 let location = match binding.location {
154 Location::Cube => StorageClass::Workgroup,
155 Location::Storage => StorageClass::StorageBuffer,
156 };
157 let ptr_ty = b.type_pointer(None, location, struct_ty);
158 let var = b.variable(ptr_ty, None, location, None);
159
160 b.debug_name(var, name);
161
162 if matches!(binding.visibility, Visibility::Read) {
163 b.decorate(var, Decoration::NonWritable, vec![]);
164 }
165
166 b.decorate(var, Decoration::DescriptorSet, vec![0u32.into()]);
167 b.decorate(var, Decoration::Binding, vec![index.into()]);
168 b.decorate(struct_ty, Decoration::Block, vec![]);
169 b.member_decorate(struct_ty, 0, Decoration::Offset, vec![0u32.into()]);
170
171 var
172 }
173
174 fn set_kernel_name(&mut self, name: impl Into<String>) {
175 self.kernel_name = name.into();
176 }
177}