cubecl_spirv/
globals.rs

1use cubecl_core::ir::{self, Builtin, ElemType, UIntKind};
2use rspirv::spirv::{BuiltIn, Word};
3
4use crate::{
5    SpirvCompiler, SpirvTarget,
6    item::{Elem, Item},
7    variable::Variable,
8};
9
10impl<T: SpirvTarget> SpirvCompiler<T> {
11    pub fn compile_builtin(&mut self, builtin: Builtin, ty: Item) -> Variable {
12        match builtin {
13            Builtin::UnitPos => Variable::Builtin(
14                self.insert_global(builtin, |b| {
15                    let id = b.load_builtin(BuiltIn::LocalInvocationIndex, &ty);
16                    b.debug_name(id, "UNIT_POS");
17                    id
18                }),
19                ty,
20            ),
21            Builtin::UnitPosX => Variable::Builtin(
22                self.insert_global(builtin, |b| {
23                    let id = b.extract(BuiltIn::LocalInvocationId, 0, &ty);
24                    b.debug_name(id, "UNIT_POS_X");
25                    id
26                }),
27                ty,
28            ),
29            Builtin::UnitPosY => Variable::Builtin(
30                self.insert_global(builtin, |b| {
31                    let id = b.extract(BuiltIn::LocalInvocationId, 1, &ty);
32                    b.debug_name(id, "UNIT_POS_Y");
33                    id
34                }),
35                ty,
36            ),
37            Builtin::UnitPosZ => Variable::Builtin(
38                self.insert_global(builtin, |b| {
39                    let id = b.extract(BuiltIn::LocalInvocationId, 2, &ty);
40                    b.debug_name(id, "UNIT_POS_Z");
41                    id
42                }),
43                ty,
44            ),
45            Builtin::CubePosX => Variable::Builtin(
46                self.insert_global(builtin, |b| {
47                    let id = b.extract(BuiltIn::WorkgroupId, 0, &ty);
48                    b.debug_name(id, "CUBE_POS_X");
49                    id
50                }),
51                ty,
52            ),
53            Builtin::CubePosY => Variable::Builtin(
54                self.insert_global(builtin, |b| {
55                    let id = b.extract(BuiltIn::WorkgroupId, 1, &ty);
56                    b.debug_name(id, "CUBE_POS_Y");
57                    id
58                }),
59                ty,
60            ),
61            Builtin::CubePosZ => Variable::Builtin(
62                self.insert_global(builtin, |b| {
63                    let id = b.extract(BuiltIn::WorkgroupId, 2, &ty);
64                    b.debug_name(id, "CUBE_POS_Z");
65                    id
66                }),
67                ty,
68            ),
69            Builtin::CubePosCluster
70            | Builtin::CubePosClusterX
71            | Builtin::CubePosClusterY
72            | Builtin::CubePosClusterZ => self.constant_var(0, ty),
73            Builtin::CubeDim => Variable::Builtin(self.state.cube_size, ty),
74            Builtin::CubeDimX => Variable::Builtin(self.state.cube_dims[0], ty),
75            Builtin::CubeDimY => Variable::Builtin(self.state.cube_dims[1], ty),
76            Builtin::CubeDimZ => Variable::Builtin(self.state.cube_dims[2], ty),
77            Builtin::CubeClusterDim
78            | Builtin::CubeClusterDimX
79            | Builtin::CubeClusterDimY
80            | Builtin::CubeClusterDimZ => self.constant_var(1, ty),
81            Builtin::CubeCount => Variable::Builtin(
82                self.insert_global(builtin, |b: &mut SpirvCompiler<T>| {
83                    let ty_id = ty.id(b);
84                    let x = b.compile_variable(builtin_u32(Builtin::CubeCountX)).id(b);
85                    let y = b.compile_variable(builtin_u32(Builtin::CubeCountY)).id(b);
86                    let z = b.compile_variable(builtin_u32(Builtin::CubeCountZ)).id(b);
87
88                    let x = Item::builtin_u32().cast_to(b, None, x, &ty);
89                    let y = Item::builtin_u32().cast_to(b, None, y, &ty);
90                    let z = Item::builtin_u32().cast_to(b, None, z, &ty);
91
92                    let count = b.i_mul(ty_id, None, x, y).unwrap();
93                    let count = b.i_mul(ty_id, None, count, z).unwrap();
94                    b.debug_name(count, "CUBE_COUNT");
95                    count
96                }),
97                ty,
98            ),
99            Builtin::CubeCountX => Variable::Builtin(
100                self.insert_global(builtin, |b| {
101                    let id = b.extract(BuiltIn::NumWorkgroups, 0, &ty);
102                    b.debug_name(id, "CUBE_COUNT_X");
103                    id
104                }),
105                ty,
106            ),
107            Builtin::CubeCountY => Variable::Builtin(
108                self.insert_global(builtin, |b| {
109                    let id = b.extract(BuiltIn::NumWorkgroups, 1, &ty);
110                    b.debug_name(id, "CUBE_COUNT_Y");
111                    id
112                }),
113                ty,
114            ),
115            Builtin::CubeCountZ => Variable::Builtin(
116                self.insert_global(builtin, |b| {
117                    let id = b.extract(BuiltIn::NumWorkgroups, 2, &ty);
118                    b.debug_name(id, "CUBE_COUNT_Z");
119                    id
120                }),
121                ty,
122            ),
123            Builtin::PlaneDim => {
124                let id = self.insert_global(builtin, |b| {
125                    let id = b.load_builtin(BuiltIn::SubgroupSize, &ty);
126                    b.debug_name(id, "PLANE_DIM");
127                    id
128                });
129                Variable::Builtin(id, ty)
130            }
131            Builtin::UnitPosPlane => {
132                let id = self.insert_global(builtin, |b| {
133                    let id = b.load_builtin(BuiltIn::SubgroupLocalInvocationId, &ty);
134                    b.debug_name(id, "UNIT_POS_PLANE");
135                    id
136                });
137                Variable::Builtin(id, ty)
138            }
139            Builtin::CubePos => {
140                let id = self.insert_global(builtin, |b| {
141                    let x = b.compile_variable(builtin_u32(Builtin::CubePosX)).id(b);
142                    let y = b.compile_variable(builtin_u32(Builtin::CubePosY)).id(b);
143                    let z = b.compile_variable(builtin_u32(Builtin::CubePosZ)).id(b);
144
145                    let x = Item::builtin_u32().cast_to(b, None, x, &ty);
146                    let y = Item::builtin_u32().cast_to(b, None, y, &ty);
147                    let z = Item::builtin_u32().cast_to(b, None, z, &ty);
148
149                    let groups_x = b.compile_variable(builtin_u32(Builtin::CubeCountX)).id(b);
150                    let groups_y = b.compile_variable(builtin_u32(Builtin::CubeCountY)).id(b);
151
152                    let groups_x = Item::builtin_u32().cast_to(b, None, groups_x, &ty);
153                    let groups_y = Item::builtin_u32().cast_to(b, None, groups_y, &ty);
154
155                    let ty = ty.id(b);
156                    let id = b.i_mul(ty, None, z, groups_y).unwrap();
157                    let id = b.i_add(ty, None, id, y).unwrap();
158                    let id = b.i_mul(ty, None, id, groups_x).unwrap();
159                    let id = b.i_add(ty, None, id, x).unwrap();
160                    b.debug_name(id, "CUBE_POS");
161                    id
162                });
163                Variable::Builtin(id, ty)
164            }
165            Builtin::AbsolutePos => {
166                let id = self.insert_global(builtin, |b| {
167                    let x = b.compile_variable(builtin_u32(Builtin::AbsolutePosX)).id(b);
168                    let y = b.compile_variable(builtin_u32(Builtin::AbsolutePosY)).id(b);
169                    let z = b.compile_variable(builtin_u32(Builtin::AbsolutePosZ)).id(b);
170
171                    let x = Item::builtin_u32().cast_to(b, None, x, &ty);
172                    let y = Item::builtin_u32().cast_to(b, None, y, &ty);
173                    let z = Item::builtin_u32().cast_to(b, None, z, &ty);
174
175                    let groups_x = b.compile_variable(builtin_u32(Builtin::CubeCountX)).id(b);
176                    let groups_y = b.compile_variable(builtin_u32(Builtin::CubeCountY)).id(b);
177
178                    let groups_x = Item::builtin_u32().cast_to(b, None, groups_x, &ty);
179                    let groups_y = Item::builtin_u32().cast_to(b, None, groups_y, &ty);
180
181                    let size_x = ty.const_u32(b, b.cube_dim.x);
182                    let size_y = ty.const_u32(b, b.cube_dim.y);
183
184                    let ty = ty.id(b);
185                    let size_x = b.i_mul(ty, None, groups_x, size_x).unwrap();
186                    let size_y = b.i_mul(ty, None, groups_y, size_y).unwrap();
187                    let id = b.i_mul(ty, None, z, size_y).unwrap();
188                    let id = b.i_add(ty, None, id, y).unwrap();
189                    let id = b.i_mul(ty, None, id, size_x).unwrap();
190                    let id = b.i_add(ty, None, id, x).unwrap();
191                    b.debug_name(id, "ABSOLUTE_POS");
192                    id
193                });
194                Variable::Builtin(id, ty)
195            }
196            Builtin::AbsolutePosX => {
197                let id = self.insert_global(builtin, |b| {
198                    let id = b.extract(BuiltIn::GlobalInvocationId, 0, &ty);
199                    b.debug_name(id, "ABSOLUTE_POS_X");
200                    id
201                });
202
203                Variable::Builtin(id, ty)
204            }
205            Builtin::AbsolutePosY => {
206                let id = self.insert_global(builtin, |b| {
207                    let id = b.extract(BuiltIn::GlobalInvocationId, 1, &ty);
208                    b.debug_name(id, "ABSOLUTE_POS_Y");
209                    id
210                });
211
212                Variable::Builtin(id, ty)
213            }
214            Builtin::AbsolutePosZ => {
215                let id = self.insert_global(builtin, |b| {
216                    let id = b.extract(BuiltIn::GlobalInvocationId, 2, &ty);
217                    b.debug_name(id, "ABSOLUTE_POS_Z");
218                    id
219                });
220
221                Variable::Builtin(id, ty)
222            }
223        }
224    }
225
226    fn constant_var(&mut self, value: u32, ty: Item) -> Variable {
227        let id = ty.const_u32(self, value);
228        Variable::Builtin(id, ty.clone())
229    }
230
231    fn extract(&mut self, builtin: BuiltIn, idx: u32, ty: &Item) -> Word {
232        let composite_id = self.vec_global(builtin);
233        let ty = ty.id(self);
234        self.composite_extract(ty, None, composite_id, vec![idx])
235            .unwrap()
236    }
237
238    fn vec_global(&mut self, builtin: BuiltIn) -> Word {
239        let item = Item::Vector(Elem::Int(32, false), 3);
240
241        self.insert_builtin(builtin, |b| b.load_builtin(builtin, &item))
242    }
243
244    fn load_builtin(&mut self, builtin: BuiltIn, item: &Item) -> Word {
245        let item_id = item.id(self);
246        let id = self.builtin(builtin, item.clone());
247        self.load(item_id, None, id, None, vec![]).unwrap()
248    }
249}
250
251fn builtin_u32(builtin: Builtin) -> ir::Variable {
252    ir::Variable::builtin(builtin, ElemType::UInt(UIntKind::U32).into())
253}