1use cubecl_core::ir::{self, Builtin, UIntKind};
2use rspirv::spirv::{BuiltIn, Word};
3
4use crate::{
5 SpirvCompiler, SpirvTarget,
6 item::{Elem, Item},
7 variable::Variable,
8};
9
10impl<T: SpirvTarget> SpirvCompiler<T> {
11 pub fn compile_builtin(&mut self, builtin: Builtin) -> Variable {
12 match builtin {
13 Builtin::UnitPos => Variable::LocalInvocationIndex(self.insert_global(|b| {
14 let id = b.load_builtin(
15 BuiltIn::LocalInvocationIndex,
16 Item::Scalar(Elem::Int(32, false)),
17 );
18 b.debug_name(id, "UNIT_POS");
19 id
20 })),
21 Builtin::UnitPosX => Variable::LocalInvocationIdX(self.insert_global(|b| {
22 let id = b.extract(BuiltIn::LocalInvocationId, 0);
23 b.debug_name(id, "UNIT_POS_X");
24 id
25 })),
26 Builtin::UnitPosY => Variable::LocalInvocationIdY(self.insert_global(|b| {
27 let id = b.extract(BuiltIn::LocalInvocationId, 1);
28 b.debug_name(id, "UNIT_POS_Y");
29 id
30 })),
31 Builtin::UnitPosZ => Variable::LocalInvocationIdZ(self.insert_global(|b| {
32 let id = b.extract(BuiltIn::LocalInvocationId, 2);
33 b.debug_name(id, "UNIT_POS_Z");
34 id
35 })),
36 Builtin::CubePosX => Variable::WorkgroupIdX(self.insert_global(|b| {
37 let id = b.extract(BuiltIn::WorkgroupId, 0);
38 b.debug_name(id, "CUBE_POS_X");
39 id
40 })),
41 Builtin::CubePosY => Variable::WorkgroupIdY(self.insert_global(|b| {
42 let id = b.extract(BuiltIn::WorkgroupId, 1);
43 b.debug_name(id, "CUBE_POS_Y");
44 id
45 })),
46 Builtin::CubePosZ => Variable::WorkgroupIdZ(self.insert_global(|b| {
47 let id = b.extract(BuiltIn::WorkgroupId, 2);
48 b.debug_name(id, "CUBE_POS_Z");
49 id
50 })),
51 Builtin::CubePosCluster
52 | Builtin::CubePosClusterX
53 | Builtin::CubePosClusterY
54 | Builtin::CubePosClusterZ => self.constant_var(0),
55 Builtin::CubeDim => Variable::WorkgroupSize(self.state.cube_size),
56 Builtin::CubeDimX => Variable::WorkgroupSizeX(self.state.cube_dims[0]),
57 Builtin::CubeDimY => Variable::WorkgroupSizeY(self.state.cube_dims[1]),
58 Builtin::CubeDimZ => Variable::WorkgroupSizeZ(self.state.cube_dims[2]),
59 Builtin::CubeClusterDim
60 | Builtin::CubeClusterDimX
61 | Builtin::CubeClusterDimY
62 | Builtin::CubeClusterDimZ => self.constant_var(1),
63 Builtin::CubeCount => {
64 Variable::WorkgroupSize(self.insert_global(|b: &mut SpirvCompiler<T>| {
65 let int = b.type_int(32, 0);
66 let x = b.compile_variable(built_var(Builtin::CubeCountX)).id(b);
67 let y = b.compile_variable(built_var(Builtin::CubeCountY)).id(b);
68 let z = b.compile_variable(built_var(Builtin::CubeCountZ)).id(b);
69 let count = b.i_mul(int, None, x, y).unwrap();
70 let count = b.i_mul(int, None, count, z).unwrap();
71 b.debug_name(count, "CUBE_COUNT");
72 count
73 }))
74 }
75 Builtin::CubeCountX => Variable::NumWorkgroupsX(self.insert_global(|b| {
76 let id = b.extract(BuiltIn::NumWorkgroups, 0);
77 b.debug_name(id, "CUBE_COUNT_X");
78 id
79 })),
80 Builtin::CubeCountY => Variable::NumWorkgroupsY(self.insert_global(|b| {
81 let id = b.extract(BuiltIn::NumWorkgroups, 1);
82 b.debug_name(id, "CUBE_COUNT_Y");
83 id
84 })),
85 Builtin::CubeCountZ => Variable::NumWorkgroupsZ(self.insert_global(|b| {
86 let id = b.extract(BuiltIn::NumWorkgroups, 2);
87 b.debug_name(id, "CUBE_COUNT_Z");
88 id
89 })),
90 Builtin::PlaneDim => {
91 let id = self.insert_global(|b| {
92 let id =
93 b.load_builtin(BuiltIn::SubgroupSize, Item::Scalar(Elem::Int(32, false)));
94 b.debug_name(id, "PLANE_DIM");
95 id
96 });
97 Variable::SubgroupSize(id)
98 }
99 Builtin::UnitPosPlane => {
100 let id = self.insert_global(|b| {
101 let id = b.load_builtin(
102 BuiltIn::SubgroupLocalInvocationId,
103 Item::Scalar(Elem::Int(32, false)),
104 );
105 b.debug_name(id, "UNIT_POS_PLANE");
106 id
107 });
108 Variable::SubgroupSize(id)
109 }
110 Builtin::CubePos => {
111 let id = self.insert_global(|b| {
112 let x = b.compile_variable(built_var(Builtin::CubePosX)).id(b);
113 let y = b.compile_variable(built_var(Builtin::CubePosY)).id(b);
114 let z = b.compile_variable(built_var(Builtin::CubePosZ)).id(b);
115
116 let groups_x = b.compile_variable(built_var(Builtin::CubeCountX)).id(b);
117 let groups_y = b.compile_variable(built_var(Builtin::CubeCountY)).id(b);
118 let ty = Elem::Int(32, false).id(b);
119 let id = b.i_mul(ty, None, z, groups_y).unwrap();
120 let id = b.i_add(ty, None, id, y).unwrap();
121 let id = b.i_mul(ty, None, id, groups_x).unwrap();
122 let id = b.i_add(ty, None, id, x).unwrap();
123 b.debug_name(id, "CUBE_POS");
124 id
125 });
126 Variable::WorkgroupId(id)
127 }
128 Builtin::AbsolutePos => {
129 let id = self.insert_global(|b| {
130 let x = b.compile_variable(built_var(Builtin::AbsolutePosX)).id(b);
131 let y = b.compile_variable(built_var(Builtin::AbsolutePosY)).id(b);
132 let z = b.compile_variable(built_var(Builtin::AbsolutePosZ)).id(b);
133
134 let groups_x = b.compile_variable(built_var(Builtin::CubeCountX)).id(b);
135 let groups_y = b.compile_variable(built_var(Builtin::CubeCountY)).id(b);
136 let size_x = b.state.cube_dims[0];
137 let size_y = b.state.cube_dims[1];
138 let ty = Elem::Int(32, false).id(b);
139 let size_x = b.i_mul(ty, None, groups_x, size_x).unwrap();
140 let size_y = b.i_mul(ty, None, groups_y, size_y).unwrap();
141 let id = b.i_mul(ty, None, z, size_y).unwrap();
142 let id = b.i_add(ty, None, id, y).unwrap();
143 let id = b.i_mul(ty, None, id, size_x).unwrap();
144 let id = b.i_add(ty, None, id, x).unwrap();
145 b.debug_name(id, "ABSOLUTE_POS");
146 id
147 });
148 Variable::GlobalInvocationIndex(id)
149 }
150 Builtin::AbsolutePosX => {
151 let id = self.insert_global(|b| {
152 let id = b.extract(BuiltIn::GlobalInvocationId, 0);
153 b.debug_name(id, "ABSOLUTE_POS_X");
154 id
155 });
156
157 Variable::GlobalInvocationIdX(id)
158 }
159 Builtin::AbsolutePosY => {
160 let id = self.insert_global(|b| {
161 let id = b.extract(BuiltIn::GlobalInvocationId, 1);
162 b.debug_name(id, "ABSOLUTE_POS_Y");
163 id
164 });
165
166 Variable::GlobalInvocationIdY(id)
167 }
168 Builtin::AbsolutePosZ => {
169 let id = self.insert_global(|b| {
170 let id = b.extract(BuiltIn::GlobalInvocationId, 2);
171 b.debug_name(id, "ABSOLUTE_POS_Z");
172 id
173 });
174
175 Variable::GlobalInvocationIdZ(id)
176 }
177 }
178 }
179
180 fn constant_var(&mut self, value: u32) -> Variable {
181 let var =
182 ir::Variable::constant(ir::ConstantScalarValue::UInt(value as u64, UIntKind::U32));
183 self.compile_variable(var)
184 }
185
186 fn extract(&mut self, builtin: BuiltIn, idx: u32) -> Word {
187 let composite_id = self.vec_global(builtin);
188 let ty = Elem::Int(32, false).id(self);
189 self.composite_extract(ty, None, composite_id, vec![idx])
190 .unwrap()
191 }
192
193 fn vec_global(&mut self, builtin: BuiltIn) -> Word {
194 let item = Item::Vector(Elem::Int(32, false), 3);
195
196 self.insert_global(|b| b.load_builtin(builtin, item))
197 }
198
199 fn load_builtin(&mut self, builtin: BuiltIn, item: Item) -> Word {
200 let item_id = item.id(self);
201 let id = self.builtin(builtin, item);
202 self.load(item_id, None, id, None, vec![]).unwrap()
203 }
204}
205
206fn built_var(builtin: Builtin) -> ir::Variable {
207 ir::Variable::builtin(builtin)
208}