Skip to main content

cubecl_spirv/
globals.rs

1use cubecl_core::ir::{self, Builtin, ElemType, UIntKind};
2use rspirv::spirv::{BuiltIn, Word};
3
4use crate::{
5    SpirvCompiler, SpirvTarget,
6    item::{Elem, Item},
7    variable::Variable,
8};
9
10impl<T: SpirvTarget> SpirvCompiler<T> {
11    pub fn compile_builtin(&mut self, builtin: Builtin, ty: Item) -> Variable {
12        match builtin {
13            Builtin::UnitPos => Variable::Builtin(
14                self.insert_global(builtin, |b| {
15                    let id = b.load_builtin(BuiltIn::LocalInvocationIndex, &ty);
16                    b.debug_name(id, "UNIT_POS");
17                    id
18                }),
19                ty,
20            ),
21            Builtin::UnitPosX => Variable::Builtin(
22                self.insert_global(builtin, |b| {
23                    let id = b.extract(BuiltIn::LocalInvocationId, 0, &ty);
24                    b.debug_name(id, "UNIT_POS_X");
25                    id
26                }),
27                ty,
28            ),
29            Builtin::UnitPosY => Variable::Builtin(
30                self.insert_global(builtin, |b| {
31                    let id = b.extract(BuiltIn::LocalInvocationId, 1, &ty);
32                    b.debug_name(id, "UNIT_POS_Y");
33                    id
34                }),
35                ty,
36            ),
37            Builtin::UnitPosZ => Variable::Builtin(
38                self.insert_global(builtin, |b| {
39                    let id = b.extract(BuiltIn::LocalInvocationId, 2, &ty);
40                    b.debug_name(id, "UNIT_POS_Z");
41                    id
42                }),
43                ty,
44            ),
45            Builtin::CubePosX => Variable::Builtin(
46                self.insert_global(builtin, |b| {
47                    let id = b.extract(BuiltIn::WorkgroupId, 0, &ty);
48                    b.debug_name(id, "CUBE_POS_X");
49                    id
50                }),
51                ty,
52            ),
53            Builtin::CubePosY => Variable::Builtin(
54                self.insert_global(builtin, |b| {
55                    let id = b.extract(BuiltIn::WorkgroupId, 1, &ty);
56                    b.debug_name(id, "CUBE_POS_Y");
57                    id
58                }),
59                ty,
60            ),
61            Builtin::CubePosZ => Variable::Builtin(
62                self.insert_global(builtin, |b| {
63                    let id = b.extract(BuiltIn::WorkgroupId, 2, &ty);
64                    b.debug_name(id, "CUBE_POS_Z");
65                    id
66                }),
67                ty,
68            ),
69            Builtin::CubePosCluster
70            | Builtin::CubePosClusterX
71            | Builtin::CubePosClusterY
72            | Builtin::CubePosClusterZ => self.constant_var(0, ty),
73            Builtin::CubeDim => Variable::Builtin(self.state.cube_size, ty),
74            Builtin::CubeDimX => Variable::Builtin(self.state.cube_dims[0], ty),
75            Builtin::CubeDimY => Variable::Builtin(self.state.cube_dims[1], ty),
76            Builtin::CubeDimZ => Variable::Builtin(self.state.cube_dims[2], ty),
77            Builtin::CubeClusterDim
78            | Builtin::CubeClusterDimX
79            | Builtin::CubeClusterDimY
80            | Builtin::CubeClusterDimZ => self.constant_var(1, ty),
81            Builtin::CubeCount => Variable::Builtin(
82                self.insert_global(builtin, |b: &mut SpirvCompiler<T>| {
83                    let ty_id = ty.id(b);
84                    let x = b.compile_variable(builtin_u32(Builtin::CubeCountX)).id(b);
85                    let y = b.compile_variable(builtin_u32(Builtin::CubeCountY)).id(b);
86                    let z = b.compile_variable(builtin_u32(Builtin::CubeCountZ)).id(b);
87
88                    let x = Item::builtin_u32().cast_to(b, None, x, &ty);
89                    let y = Item::builtin_u32().cast_to(b, None, y, &ty);
90                    let z = Item::builtin_u32().cast_to(b, None, z, &ty);
91
92                    let count = b.i_mul(ty_id, None, x, y).unwrap();
93                    let count = b.i_mul(ty_id, None, count, z).unwrap();
94                    b.debug_name(count, "CUBE_COUNT");
95                    count
96                }),
97                ty,
98            ),
99            Builtin::CubeCountX => Variable::Builtin(
100                self.insert_global(builtin, |b| {
101                    let id = b.extract(BuiltIn::NumWorkgroups, 0, &ty);
102                    b.debug_name(id, "CUBE_COUNT_X");
103                    id
104                }),
105                ty,
106            ),
107            Builtin::CubeCountY => Variable::Builtin(
108                self.insert_global(builtin, |b| {
109                    let id = b.extract(BuiltIn::NumWorkgroups, 1, &ty);
110                    b.debug_name(id, "CUBE_COUNT_Y");
111                    id
112                }),
113                ty,
114            ),
115            Builtin::CubeCountZ => Variable::Builtin(
116                self.insert_global(builtin, |b| {
117                    let id = b.extract(BuiltIn::NumWorkgroups, 2, &ty);
118                    b.debug_name(id, "CUBE_COUNT_Z");
119                    id
120                }),
121                ty,
122            ),
123            Builtin::PlaneDim => {
124                let id = self.insert_global(builtin, |b| {
125                    let id = b.load_builtin(BuiltIn::SubgroupSize, &ty);
126                    b.debug_name(id, "PLANE_DIM");
127                    id
128                });
129                Variable::Builtin(id, ty)
130            }
131            Builtin::PlanePos => {
132                let id = self.insert_global(builtin, |b| {
133                    let id = b.load_builtin(BuiltIn::SubgroupId, &ty);
134                    b.debug_name(id, "PLANE_POS");
135                    id
136                });
137                Variable::Builtin(id, ty)
138            }
139            Builtin::UnitPosPlane => {
140                let id = self.insert_global(builtin, |b| {
141                    let id = b.load_builtin(BuiltIn::SubgroupLocalInvocationId, &ty);
142                    b.debug_name(id, "UNIT_POS_PLANE");
143                    id
144                });
145                Variable::Builtin(id, ty)
146            }
147            Builtin::CubePos => {
148                let id = self.insert_global(builtin, |b| {
149                    let x = b.compile_variable(builtin_u32(Builtin::CubePosX)).id(b);
150                    let y = b.compile_variable(builtin_u32(Builtin::CubePosY)).id(b);
151                    let z = b.compile_variable(builtin_u32(Builtin::CubePosZ)).id(b);
152
153                    let x = Item::builtin_u32().cast_to(b, None, x, &ty);
154                    let y = Item::builtin_u32().cast_to(b, None, y, &ty);
155                    let z = Item::builtin_u32().cast_to(b, None, z, &ty);
156
157                    let groups_x = b.compile_variable(builtin_u32(Builtin::CubeCountX)).id(b);
158                    let groups_y = b.compile_variable(builtin_u32(Builtin::CubeCountY)).id(b);
159
160                    let groups_x = Item::builtin_u32().cast_to(b, None, groups_x, &ty);
161                    let groups_y = Item::builtin_u32().cast_to(b, None, groups_y, &ty);
162
163                    let ty = ty.id(b);
164                    let id = b.i_mul(ty, None, z, groups_y).unwrap();
165                    let id = b.i_add(ty, None, id, y).unwrap();
166                    let id = b.i_mul(ty, None, id, groups_x).unwrap();
167                    let id = b.i_add(ty, None, id, x).unwrap();
168                    b.debug_name(id, "CUBE_POS");
169                    id
170                });
171                Variable::Builtin(id, ty)
172            }
173            Builtin::AbsolutePos => {
174                let id = self.insert_global(builtin, |b| {
175                    let x = b.compile_variable(builtin_u32(Builtin::AbsolutePosX)).id(b);
176                    let y = b.compile_variable(builtin_u32(Builtin::AbsolutePosY)).id(b);
177                    let z = b.compile_variable(builtin_u32(Builtin::AbsolutePosZ)).id(b);
178
179                    let x = Item::builtin_u32().cast_to(b, None, x, &ty);
180                    let y = Item::builtin_u32().cast_to(b, None, y, &ty);
181                    let z = Item::builtin_u32().cast_to(b, None, z, &ty);
182
183                    let groups_x = b.compile_variable(builtin_u32(Builtin::CubeCountX)).id(b);
184                    let groups_y = b.compile_variable(builtin_u32(Builtin::CubeCountY)).id(b);
185
186                    let groups_x = Item::builtin_u32().cast_to(b, None, groups_x, &ty);
187                    let groups_y = Item::builtin_u32().cast_to(b, None, groups_y, &ty);
188
189                    let size_x = ty.const_u32(b, b.cube_dim.x);
190                    let size_y = ty.const_u32(b, b.cube_dim.y);
191
192                    let ty = ty.id(b);
193                    let size_x = b.i_mul(ty, None, groups_x, size_x).unwrap();
194                    let size_y = b.i_mul(ty, None, groups_y, size_y).unwrap();
195                    let id = b.i_mul(ty, None, z, size_y).unwrap();
196                    let id = b.i_add(ty, None, id, y).unwrap();
197                    let id = b.i_mul(ty, None, id, size_x).unwrap();
198                    let id = b.i_add(ty, None, id, x).unwrap();
199                    b.debug_name(id, "ABSOLUTE_POS");
200                    id
201                });
202                Variable::Builtin(id, ty)
203            }
204            Builtin::AbsolutePosX => {
205                let id = self.insert_global(builtin, |b| {
206                    let id = b.extract(BuiltIn::GlobalInvocationId, 0, &ty);
207                    b.debug_name(id, "ABSOLUTE_POS_X");
208                    id
209                });
210
211                Variable::Builtin(id, ty)
212            }
213            Builtin::AbsolutePosY => {
214                let id = self.insert_global(builtin, |b| {
215                    let id = b.extract(BuiltIn::GlobalInvocationId, 1, &ty);
216                    b.debug_name(id, "ABSOLUTE_POS_Y");
217                    id
218                });
219
220                Variable::Builtin(id, ty)
221            }
222            Builtin::AbsolutePosZ => {
223                let id = self.insert_global(builtin, |b| {
224                    let id = b.extract(BuiltIn::GlobalInvocationId, 2, &ty);
225                    b.debug_name(id, "ABSOLUTE_POS_Z");
226                    id
227                });
228
229                Variable::Builtin(id, ty)
230            }
231        }
232    }
233
234    fn constant_var(&mut self, value: u32, ty: Item) -> Variable {
235        let id = ty.const_u32(self, value);
236        Variable::Builtin(id, ty.clone())
237    }
238
239    fn extract(&mut self, builtin: BuiltIn, idx: u32, ty: &Item) -> Word {
240        let composite_id = self.vec_global(builtin);
241        let ty = ty.id(self);
242        self.composite_extract(ty, None, composite_id, vec![idx])
243            .unwrap()
244    }
245
246    fn vec_global(&mut self, builtin: BuiltIn) -> Word {
247        let item = Item::Vector(Elem::Int(32, false), 3);
248
249        self.insert_builtin(builtin, |b| b.load_builtin(builtin, &item))
250    }
251
252    fn load_builtin(&mut self, builtin: BuiltIn, item: &Item) -> Word {
253        let item_id = item.id(self);
254        let id = self.builtin(builtin, item.clone());
255        self.load(item_id, None, id, None, vec![]).unwrap()
256    }
257}
258
259fn builtin_u32(builtin: Builtin) -> ir::Variable {
260    ir::Variable::builtin(builtin, ElemType::UInt(UIntKind::U32).into())
261}