1use cubecl_core::ir::{self, Builtin, ElemType, UIntKind};
2use rspirv::spirv::{BuiltIn, Word};
3
4use crate::{
5 SpirvCompiler, SpirvTarget,
6 item::{Elem, Item},
7 variable::Variable,
8};
9
10impl<T: SpirvTarget> SpirvCompiler<T> {
11 pub fn compile_builtin(&mut self, builtin: Builtin, ty: Item) -> Variable {
12 match builtin {
13 Builtin::UnitPos => Variable::Builtin(
14 self.insert_global(builtin, |b| {
15 let id = b.load_builtin(BuiltIn::LocalInvocationIndex, &ty);
16 b.debug_name(id, "UNIT_POS");
17 id
18 }),
19 ty,
20 ),
21 Builtin::UnitPosX => Variable::Builtin(
22 self.insert_global(builtin, |b| {
23 let id = b.extract(BuiltIn::LocalInvocationId, 0, &ty);
24 b.debug_name(id, "UNIT_POS_X");
25 id
26 }),
27 ty,
28 ),
29 Builtin::UnitPosY => Variable::Builtin(
30 self.insert_global(builtin, |b| {
31 let id = b.extract(BuiltIn::LocalInvocationId, 1, &ty);
32 b.debug_name(id, "UNIT_POS_Y");
33 id
34 }),
35 ty,
36 ),
37 Builtin::UnitPosZ => Variable::Builtin(
38 self.insert_global(builtin, |b| {
39 let id = b.extract(BuiltIn::LocalInvocationId, 2, &ty);
40 b.debug_name(id, "UNIT_POS_Z");
41 id
42 }),
43 ty,
44 ),
45 Builtin::CubePosX => Variable::Builtin(
46 self.insert_global(builtin, |b| {
47 let id = b.extract(BuiltIn::WorkgroupId, 0, &ty);
48 b.debug_name(id, "CUBE_POS_X");
49 id
50 }),
51 ty,
52 ),
53 Builtin::CubePosY => Variable::Builtin(
54 self.insert_global(builtin, |b| {
55 let id = b.extract(BuiltIn::WorkgroupId, 1, &ty);
56 b.debug_name(id, "CUBE_POS_Y");
57 id
58 }),
59 ty,
60 ),
61 Builtin::CubePosZ => Variable::Builtin(
62 self.insert_global(builtin, |b| {
63 let id = b.extract(BuiltIn::WorkgroupId, 2, &ty);
64 b.debug_name(id, "CUBE_POS_Z");
65 id
66 }),
67 ty,
68 ),
69 Builtin::CubePosCluster
70 | Builtin::CubePosClusterX
71 | Builtin::CubePosClusterY
72 | Builtin::CubePosClusterZ => self.constant_var(0, ty),
73 Builtin::CubeDim => Variable::Builtin(self.state.cube_size, ty),
74 Builtin::CubeDimX => Variable::Builtin(self.state.cube_dims[0], ty),
75 Builtin::CubeDimY => Variable::Builtin(self.state.cube_dims[1], ty),
76 Builtin::CubeDimZ => Variable::Builtin(self.state.cube_dims[2], ty),
77 Builtin::CubeClusterDim
78 | Builtin::CubeClusterDimX
79 | Builtin::CubeClusterDimY
80 | Builtin::CubeClusterDimZ => self.constant_var(1, ty),
81 Builtin::CubeCount => Variable::Builtin(
82 self.insert_global(builtin, |b: &mut SpirvCompiler<T>| {
83 let ty_id = ty.id(b);
84 let x = b.compile_variable(builtin_u32(Builtin::CubeCountX)).id(b);
85 let y = b.compile_variable(builtin_u32(Builtin::CubeCountY)).id(b);
86 let z = b.compile_variable(builtin_u32(Builtin::CubeCountZ)).id(b);
87
88 let x = Item::builtin_u32().cast_to(b, None, x, &ty);
89 let y = Item::builtin_u32().cast_to(b, None, y, &ty);
90 let z = Item::builtin_u32().cast_to(b, None, z, &ty);
91
92 let count = b.i_mul(ty_id, None, x, y).unwrap();
93 let count = b.i_mul(ty_id, None, count, z).unwrap();
94 b.debug_name(count, "CUBE_COUNT");
95 count
96 }),
97 ty,
98 ),
99 Builtin::CubeCountX => Variable::Builtin(
100 self.insert_global(builtin, |b| {
101 let id = b.extract(BuiltIn::NumWorkgroups, 0, &ty);
102 b.debug_name(id, "CUBE_COUNT_X");
103 id
104 }),
105 ty,
106 ),
107 Builtin::CubeCountY => Variable::Builtin(
108 self.insert_global(builtin, |b| {
109 let id = b.extract(BuiltIn::NumWorkgroups, 1, &ty);
110 b.debug_name(id, "CUBE_COUNT_Y");
111 id
112 }),
113 ty,
114 ),
115 Builtin::CubeCountZ => Variable::Builtin(
116 self.insert_global(builtin, |b| {
117 let id = b.extract(BuiltIn::NumWorkgroups, 2, &ty);
118 b.debug_name(id, "CUBE_COUNT_Z");
119 id
120 }),
121 ty,
122 ),
123 Builtin::PlaneDim => {
124 let id = self.insert_global(builtin, |b| {
125 let id = b.load_builtin(BuiltIn::SubgroupSize, &ty);
126 b.debug_name(id, "PLANE_DIM");
127 id
128 });
129 Variable::Builtin(id, ty)
130 }
131 Builtin::UnitPosPlane => {
132 let id = self.insert_global(builtin, |b| {
133 let id = b.load_builtin(BuiltIn::SubgroupLocalInvocationId, &ty);
134 b.debug_name(id, "UNIT_POS_PLANE");
135 id
136 });
137 Variable::Builtin(id, ty)
138 }
139 Builtin::CubePos => {
140 let id = self.insert_global(builtin, |b| {
141 let x = b.compile_variable(builtin_u32(Builtin::CubePosX)).id(b);
142 let y = b.compile_variable(builtin_u32(Builtin::CubePosY)).id(b);
143 let z = b.compile_variable(builtin_u32(Builtin::CubePosZ)).id(b);
144
145 let x = Item::builtin_u32().cast_to(b, None, x, &ty);
146 let y = Item::builtin_u32().cast_to(b, None, y, &ty);
147 let z = Item::builtin_u32().cast_to(b, None, z, &ty);
148
149 let groups_x = b.compile_variable(builtin_u32(Builtin::CubeCountX)).id(b);
150 let groups_y = b.compile_variable(builtin_u32(Builtin::CubeCountY)).id(b);
151
152 let groups_x = Item::builtin_u32().cast_to(b, None, groups_x, &ty);
153 let groups_y = Item::builtin_u32().cast_to(b, None, groups_y, &ty);
154
155 let ty = ty.id(b);
156 let id = b.i_mul(ty, None, z, groups_y).unwrap();
157 let id = b.i_add(ty, None, id, y).unwrap();
158 let id = b.i_mul(ty, None, id, groups_x).unwrap();
159 let id = b.i_add(ty, None, id, x).unwrap();
160 b.debug_name(id, "CUBE_POS");
161 id
162 });
163 Variable::Builtin(id, ty)
164 }
165 Builtin::AbsolutePos => {
166 let id = self.insert_global(builtin, |b| {
167 let x = b.compile_variable(builtin_u32(Builtin::AbsolutePosX)).id(b);
168 let y = b.compile_variable(builtin_u32(Builtin::AbsolutePosY)).id(b);
169 let z = b.compile_variable(builtin_u32(Builtin::AbsolutePosZ)).id(b);
170
171 let x = Item::builtin_u32().cast_to(b, None, x, &ty);
172 let y = Item::builtin_u32().cast_to(b, None, y, &ty);
173 let z = Item::builtin_u32().cast_to(b, None, z, &ty);
174
175 let groups_x = b.compile_variable(builtin_u32(Builtin::CubeCountX)).id(b);
176 let groups_y = b.compile_variable(builtin_u32(Builtin::CubeCountY)).id(b);
177
178 let groups_x = Item::builtin_u32().cast_to(b, None, groups_x, &ty);
179 let groups_y = Item::builtin_u32().cast_to(b, None, groups_y, &ty);
180
181 let size_x = ty.const_u32(b, b.cube_dim.x);
182 let size_y = ty.const_u32(b, b.cube_dim.y);
183
184 let ty = ty.id(b);
185 let size_x = b.i_mul(ty, None, groups_x, size_x).unwrap();
186 let size_y = b.i_mul(ty, None, groups_y, size_y).unwrap();
187 let id = b.i_mul(ty, None, z, size_y).unwrap();
188 let id = b.i_add(ty, None, id, y).unwrap();
189 let id = b.i_mul(ty, None, id, size_x).unwrap();
190 let id = b.i_add(ty, None, id, x).unwrap();
191 b.debug_name(id, "ABSOLUTE_POS");
192 id
193 });
194 Variable::Builtin(id, ty)
195 }
196 Builtin::AbsolutePosX => {
197 let id = self.insert_global(builtin, |b| {
198 let id = b.extract(BuiltIn::GlobalInvocationId, 0, &ty);
199 b.debug_name(id, "ABSOLUTE_POS_X");
200 id
201 });
202
203 Variable::Builtin(id, ty)
204 }
205 Builtin::AbsolutePosY => {
206 let id = self.insert_global(builtin, |b| {
207 let id = b.extract(BuiltIn::GlobalInvocationId, 1, &ty);
208 b.debug_name(id, "ABSOLUTE_POS_Y");
209 id
210 });
211
212 Variable::Builtin(id, ty)
213 }
214 Builtin::AbsolutePosZ => {
215 let id = self.insert_global(builtin, |b| {
216 let id = b.extract(BuiltIn::GlobalInvocationId, 2, &ty);
217 b.debug_name(id, "ABSOLUTE_POS_Z");
218 id
219 });
220
221 Variable::Builtin(id, ty)
222 }
223 }
224 }
225
226 fn constant_var(&mut self, value: u32, ty: Item) -> Variable {
227 let id = ty.const_u32(self, value);
228 Variable::Builtin(id, ty.clone())
229 }
230
231 fn extract(&mut self, builtin: BuiltIn, idx: u32, ty: &Item) -> Word {
232 let composite_id = self.vec_global(builtin);
233 let ty = ty.id(self);
234 self.composite_extract(ty, None, composite_id, vec![idx])
235 .unwrap()
236 }
237
238 fn vec_global(&mut self, builtin: BuiltIn) -> Word {
239 let item = Item::Vector(Elem::Int(32, false), 3);
240
241 self.insert_builtin(builtin, |b| b.load_builtin(builtin, &item))
242 }
243
244 fn load_builtin(&mut self, builtin: BuiltIn, item: &Item) -> Word {
245 let item_id = item.id(self);
246 let id = self.builtin(builtin, item.clone());
247 self.load(item_id, None, id, None, vec![]).unwrap()
248 }
249}
250
251fn builtin_u32(builtin: Builtin) -> ir::Variable {
252 ir::Variable::builtin(builtin, ElemType::UInt(UIntKind::U32).into())
253}