1use cubecl_core::ir::{self, Builtin, ElemType, UIntKind};
2use rspirv::spirv::{BuiltIn, Word};
3
4use crate::{
5 SpirvCompiler, SpirvTarget,
6 item::{Elem, Item},
7 variable::Variable,
8};
9
10impl<T: SpirvTarget> SpirvCompiler<T> {
11 pub fn compile_builtin(&mut self, builtin: Builtin, ty: Item) -> Variable {
12 match builtin {
13 Builtin::UnitPos => Variable::Builtin(
14 self.insert_global(builtin, |b| {
15 let id = b.load_builtin(BuiltIn::LocalInvocationIndex, &ty);
16 b.debug_name(id, "UNIT_POS");
17 id
18 }),
19 ty,
20 ),
21 Builtin::UnitPosX => Variable::Builtin(
22 self.insert_global(builtin, |b| {
23 let id = b.extract(BuiltIn::LocalInvocationId, 0, &ty);
24 b.debug_name(id, "UNIT_POS_X");
25 id
26 }),
27 ty,
28 ),
29 Builtin::UnitPosY => Variable::Builtin(
30 self.insert_global(builtin, |b| {
31 let id = b.extract(BuiltIn::LocalInvocationId, 1, &ty);
32 b.debug_name(id, "UNIT_POS_Y");
33 id
34 }),
35 ty,
36 ),
37 Builtin::UnitPosZ => Variable::Builtin(
38 self.insert_global(builtin, |b| {
39 let id = b.extract(BuiltIn::LocalInvocationId, 2, &ty);
40 b.debug_name(id, "UNIT_POS_Z");
41 id
42 }),
43 ty,
44 ),
45 Builtin::CubePosX => Variable::Builtin(
46 self.insert_global(builtin, |b| {
47 let id = b.extract(BuiltIn::WorkgroupId, 0, &ty);
48 b.debug_name(id, "CUBE_POS_X");
49 id
50 }),
51 ty,
52 ),
53 Builtin::CubePosY => Variable::Builtin(
54 self.insert_global(builtin, |b| {
55 let id = b.extract(BuiltIn::WorkgroupId, 1, &ty);
56 b.debug_name(id, "CUBE_POS_Y");
57 id
58 }),
59 ty,
60 ),
61 Builtin::CubePosZ => Variable::Builtin(
62 self.insert_global(builtin, |b| {
63 let id = b.extract(BuiltIn::WorkgroupId, 2, &ty);
64 b.debug_name(id, "CUBE_POS_Z");
65 id
66 }),
67 ty,
68 ),
69 Builtin::CubePosCluster
70 | Builtin::CubePosClusterX
71 | Builtin::CubePosClusterY
72 | Builtin::CubePosClusterZ => self.constant_var(0, ty),
73 Builtin::CubeDim => Variable::Builtin(self.state.cube_size, ty),
74 Builtin::CubeDimX => Variable::Builtin(self.state.cube_dims[0], ty),
75 Builtin::CubeDimY => Variable::Builtin(self.state.cube_dims[1], ty),
76 Builtin::CubeDimZ => Variable::Builtin(self.state.cube_dims[2], ty),
77 Builtin::CubeClusterDim
78 | Builtin::CubeClusterDimX
79 | Builtin::CubeClusterDimY
80 | Builtin::CubeClusterDimZ => self.constant_var(1, ty),
81 Builtin::CubeCount => Variable::Builtin(
82 self.insert_global(builtin, |b: &mut SpirvCompiler<T>| {
83 let ty_id = ty.id(b);
84 let x = b.compile_variable(builtin_u32(Builtin::CubeCountX)).id(b);
85 let y = b.compile_variable(builtin_u32(Builtin::CubeCountY)).id(b);
86 let z = b.compile_variable(builtin_u32(Builtin::CubeCountZ)).id(b);
87
88 let x = Item::builtin_u32().cast_to(b, None, x, &ty);
89 let y = Item::builtin_u32().cast_to(b, None, y, &ty);
90 let z = Item::builtin_u32().cast_to(b, None, z, &ty);
91
92 let count = b.i_mul(ty_id, None, x, y).unwrap();
93 let count = b.i_mul(ty_id, None, count, z).unwrap();
94 b.debug_name(count, "CUBE_COUNT");
95 count
96 }),
97 ty,
98 ),
99 Builtin::CubeCountX => Variable::Builtin(
100 self.insert_global(builtin, |b| {
101 let id = b.extract(BuiltIn::NumWorkgroups, 0, &ty);
102 b.debug_name(id, "CUBE_COUNT_X");
103 id
104 }),
105 ty,
106 ),
107 Builtin::CubeCountY => Variable::Builtin(
108 self.insert_global(builtin, |b| {
109 let id = b.extract(BuiltIn::NumWorkgroups, 1, &ty);
110 b.debug_name(id, "CUBE_COUNT_Y");
111 id
112 }),
113 ty,
114 ),
115 Builtin::CubeCountZ => Variable::Builtin(
116 self.insert_global(builtin, |b| {
117 let id = b.extract(BuiltIn::NumWorkgroups, 2, &ty);
118 b.debug_name(id, "CUBE_COUNT_Z");
119 id
120 }),
121 ty,
122 ),
123 Builtin::PlaneDim => {
124 let id = self.insert_global(builtin, |b| {
125 let id = b.load_builtin(BuiltIn::SubgroupSize, &ty);
126 b.debug_name(id, "PLANE_DIM");
127 id
128 });
129 Variable::Builtin(id, ty)
130 }
131 Builtin::PlanePos => {
132 let id = self.insert_global(builtin, |b| {
133 let id = b.load_builtin(BuiltIn::SubgroupId, &ty);
134 b.debug_name(id, "PLANE_POS");
135 id
136 });
137 Variable::Builtin(id, ty)
138 }
139 Builtin::UnitPosPlane => {
140 let id = self.insert_global(builtin, |b| {
141 let id = b.load_builtin(BuiltIn::SubgroupLocalInvocationId, &ty);
142 b.debug_name(id, "UNIT_POS_PLANE");
143 id
144 });
145 Variable::Builtin(id, ty)
146 }
147 Builtin::CubePos => {
148 let id = self.insert_global(builtin, |b| {
149 let x = b.compile_variable(builtin_u32(Builtin::CubePosX)).id(b);
150 let y = b.compile_variable(builtin_u32(Builtin::CubePosY)).id(b);
151 let z = b.compile_variable(builtin_u32(Builtin::CubePosZ)).id(b);
152
153 let x = Item::builtin_u32().cast_to(b, None, x, &ty);
154 let y = Item::builtin_u32().cast_to(b, None, y, &ty);
155 let z = Item::builtin_u32().cast_to(b, None, z, &ty);
156
157 let groups_x = b.compile_variable(builtin_u32(Builtin::CubeCountX)).id(b);
158 let groups_y = b.compile_variable(builtin_u32(Builtin::CubeCountY)).id(b);
159
160 let groups_x = Item::builtin_u32().cast_to(b, None, groups_x, &ty);
161 let groups_y = Item::builtin_u32().cast_to(b, None, groups_y, &ty);
162
163 let ty = ty.id(b);
164 let id = b.i_mul(ty, None, z, groups_y).unwrap();
165 let id = b.i_add(ty, None, id, y).unwrap();
166 let id = b.i_mul(ty, None, id, groups_x).unwrap();
167 let id = b.i_add(ty, None, id, x).unwrap();
168 b.debug_name(id, "CUBE_POS");
169 id
170 });
171 Variable::Builtin(id, ty)
172 }
173 Builtin::AbsolutePos => {
174 let id = self.insert_global(builtin, |b| {
175 let x = b.compile_variable(builtin_u32(Builtin::AbsolutePosX)).id(b);
176 let y = b.compile_variable(builtin_u32(Builtin::AbsolutePosY)).id(b);
177 let z = b.compile_variable(builtin_u32(Builtin::AbsolutePosZ)).id(b);
178
179 let x = Item::builtin_u32().cast_to(b, None, x, &ty);
180 let y = Item::builtin_u32().cast_to(b, None, y, &ty);
181 let z = Item::builtin_u32().cast_to(b, None, z, &ty);
182
183 let groups_x = b.compile_variable(builtin_u32(Builtin::CubeCountX)).id(b);
184 let groups_y = b.compile_variable(builtin_u32(Builtin::CubeCountY)).id(b);
185
186 let groups_x = Item::builtin_u32().cast_to(b, None, groups_x, &ty);
187 let groups_y = Item::builtin_u32().cast_to(b, None, groups_y, &ty);
188
189 let size_x = ty.const_u32(b, b.cube_dim.x);
190 let size_y = ty.const_u32(b, b.cube_dim.y);
191
192 let ty = ty.id(b);
193 let size_x = b.i_mul(ty, None, groups_x, size_x).unwrap();
194 let size_y = b.i_mul(ty, None, groups_y, size_y).unwrap();
195 let id = b.i_mul(ty, None, z, size_y).unwrap();
196 let id = b.i_add(ty, None, id, y).unwrap();
197 let id = b.i_mul(ty, None, id, size_x).unwrap();
198 let id = b.i_add(ty, None, id, x).unwrap();
199 b.debug_name(id, "ABSOLUTE_POS");
200 id
201 });
202 Variable::Builtin(id, ty)
203 }
204 Builtin::AbsolutePosX => {
205 let id = self.insert_global(builtin, |b| {
206 let id = b.extract(BuiltIn::GlobalInvocationId, 0, &ty);
207 b.debug_name(id, "ABSOLUTE_POS_X");
208 id
209 });
210
211 Variable::Builtin(id, ty)
212 }
213 Builtin::AbsolutePosY => {
214 let id = self.insert_global(builtin, |b| {
215 let id = b.extract(BuiltIn::GlobalInvocationId, 1, &ty);
216 b.debug_name(id, "ABSOLUTE_POS_Y");
217 id
218 });
219
220 Variable::Builtin(id, ty)
221 }
222 Builtin::AbsolutePosZ => {
223 let id = self.insert_global(builtin, |b| {
224 let id = b.extract(BuiltIn::GlobalInvocationId, 2, &ty);
225 b.debug_name(id, "ABSOLUTE_POS_Z");
226 id
227 });
228
229 Variable::Builtin(id, ty)
230 }
231 }
232 }
233
234 fn constant_var(&mut self, value: u32, ty: Item) -> Variable {
235 let id = ty.const_u32(self, value);
236 Variable::Builtin(id, ty.clone())
237 }
238
239 fn extract(&mut self, builtin: BuiltIn, idx: u32, ty: &Item) -> Word {
240 let composite_id = self.vec_global(builtin);
241 let ty = ty.id(self);
242 self.composite_extract(ty, None, composite_id, vec![idx])
243 .unwrap()
244 }
245
246 fn vec_global(&mut self, builtin: BuiltIn) -> Word {
247 let item = Item::Vector(Elem::Int(32, false), 3);
248
249 self.insert_builtin(builtin, |b| b.load_builtin(builtin, &item))
250 }
251
252 fn load_builtin(&mut self, builtin: BuiltIn, item: &Item) -> Word {
253 let item_id = item.id(self);
254 let id = self.builtin(builtin, item.clone());
255 self.load(item_id, None, id, None, vec![]).unwrap()
256 }
257}
258
259fn builtin_u32(builtin: Builtin) -> ir::Variable {
260 ir::Variable::builtin(builtin, ElemType::UInt(UIntKind::U32).into())
261}