cubecl_spirv/
globals.rs

1use cubecl_core::ir::{self, Builtin, UIntKind};
2use rspirv::spirv::{BuiltIn, Word};
3
4use crate::{
5    SpirvCompiler, SpirvTarget,
6    item::{Elem, Item},
7    variable::Variable,
8};
9
10impl<T: SpirvTarget> SpirvCompiler<T> {
11    pub fn compile_builtin(&mut self, builtin: Builtin) -> Variable {
12        match builtin {
13            Builtin::UnitPos => Variable::LocalInvocationIndex(self.insert_global(|b| {
14                let id = b.load_builtin(
15                    BuiltIn::LocalInvocationIndex,
16                    Item::Scalar(Elem::Int(32, false)),
17                );
18                b.debug_name(id, "UNIT_POS");
19                id
20            })),
21            Builtin::UnitPosX => Variable::LocalInvocationIdX(self.insert_global(|b| {
22                let id = b.extract(BuiltIn::LocalInvocationId, 0);
23                b.debug_name(id, "UNIT_POS_X");
24                id
25            })),
26            Builtin::UnitPosY => Variable::LocalInvocationIdY(self.insert_global(|b| {
27                let id = b.extract(BuiltIn::LocalInvocationId, 1);
28                b.debug_name(id, "UNIT_POS_Y");
29                id
30            })),
31            Builtin::UnitPosZ => Variable::LocalInvocationIdZ(self.insert_global(|b| {
32                let id = b.extract(BuiltIn::LocalInvocationId, 2);
33                b.debug_name(id, "UNIT_POS_Z");
34                id
35            })),
36            Builtin::CubePosX => Variable::WorkgroupIdX(self.insert_global(|b| {
37                let id = b.extract(BuiltIn::WorkgroupId, 0);
38                b.debug_name(id, "CUBE_POS_X");
39                id
40            })),
41            Builtin::CubePosY => Variable::WorkgroupIdY(self.insert_global(|b| {
42                let id = b.extract(BuiltIn::WorkgroupId, 1);
43                b.debug_name(id, "CUBE_POS_Y");
44                id
45            })),
46            Builtin::CubePosZ => Variable::WorkgroupIdZ(self.insert_global(|b| {
47                let id = b.extract(BuiltIn::WorkgroupId, 2);
48                b.debug_name(id, "CUBE_POS_Z");
49                id
50            })),
51            Builtin::CubePosCluster
52            | Builtin::CubePosClusterX
53            | Builtin::CubePosClusterY
54            | Builtin::CubePosClusterZ => self.constant_var(0),
55            Builtin::CubeDim => Variable::WorkgroupSize(self.state.cube_size),
56            Builtin::CubeDimX => Variable::WorkgroupSizeX(self.state.cube_dims[0]),
57            Builtin::CubeDimY => Variable::WorkgroupSizeY(self.state.cube_dims[1]),
58            Builtin::CubeDimZ => Variable::WorkgroupSizeZ(self.state.cube_dims[2]),
59            Builtin::CubeClusterDim
60            | Builtin::CubeClusterDimX
61            | Builtin::CubeClusterDimY
62            | Builtin::CubeClusterDimZ => self.constant_var(1),
63            Builtin::CubeCount => {
64                Variable::WorkgroupSize(self.insert_global(|b: &mut SpirvCompiler<T>| {
65                    let int = b.type_int(32, 0);
66                    let x = b.compile_variable(built_var(Builtin::CubeCountX)).id(b);
67                    let y = b.compile_variable(built_var(Builtin::CubeCountY)).id(b);
68                    let z = b.compile_variable(built_var(Builtin::CubeCountZ)).id(b);
69                    let count = b.i_mul(int, None, x, y).unwrap();
70                    let count = b.i_mul(int, None, count, z).unwrap();
71                    b.debug_name(count, "CUBE_COUNT");
72                    count
73                }))
74            }
75            Builtin::CubeCountX => Variable::NumWorkgroupsX(self.insert_global(|b| {
76                let id = b.extract(BuiltIn::NumWorkgroups, 0);
77                b.debug_name(id, "CUBE_COUNT_X");
78                id
79            })),
80            Builtin::CubeCountY => Variable::NumWorkgroupsY(self.insert_global(|b| {
81                let id = b.extract(BuiltIn::NumWorkgroups, 1);
82                b.debug_name(id, "CUBE_COUNT_Y");
83                id
84            })),
85            Builtin::CubeCountZ => Variable::NumWorkgroupsZ(self.insert_global(|b| {
86                let id = b.extract(BuiltIn::NumWorkgroups, 2);
87                b.debug_name(id, "CUBE_COUNT_Z");
88                id
89            })),
90            Builtin::PlaneDim => {
91                let id = self.insert_global(|b| {
92                    let id =
93                        b.load_builtin(BuiltIn::SubgroupSize, Item::Scalar(Elem::Int(32, false)));
94                    b.debug_name(id, "PLANE_DIM");
95                    id
96                });
97                Variable::SubgroupSize(id)
98            }
99            Builtin::UnitPosPlane => {
100                let id = self.insert_global(|b| {
101                    let id = b.load_builtin(
102                        BuiltIn::SubgroupLocalInvocationId,
103                        Item::Scalar(Elem::Int(32, false)),
104                    );
105                    b.debug_name(id, "UNIT_POS_PLANE");
106                    id
107                });
108                Variable::SubgroupSize(id)
109            }
110            Builtin::CubePos => {
111                let id = self.insert_global(|b| {
112                    let x = b.compile_variable(built_var(Builtin::CubePosX)).id(b);
113                    let y = b.compile_variable(built_var(Builtin::CubePosY)).id(b);
114                    let z = b.compile_variable(built_var(Builtin::CubePosZ)).id(b);
115
116                    let groups_x = b.compile_variable(built_var(Builtin::CubeCountX)).id(b);
117                    let groups_y = b.compile_variable(built_var(Builtin::CubeCountY)).id(b);
118                    let ty = Elem::Int(32, false).id(b);
119                    let id = b.i_mul(ty, None, z, groups_y).unwrap();
120                    let id = b.i_add(ty, None, id, y).unwrap();
121                    let id = b.i_mul(ty, None, id, groups_x).unwrap();
122                    let id = b.i_add(ty, None, id, x).unwrap();
123                    b.debug_name(id, "CUBE_POS");
124                    id
125                });
126                Variable::WorkgroupId(id)
127            }
128            Builtin::AbsolutePos => {
129                let id = self.insert_global(|b| {
130                    let x = b.compile_variable(built_var(Builtin::AbsolutePosX)).id(b);
131                    let y = b.compile_variable(built_var(Builtin::AbsolutePosY)).id(b);
132                    let z = b.compile_variable(built_var(Builtin::AbsolutePosZ)).id(b);
133
134                    let groups_x = b.compile_variable(built_var(Builtin::CubeCountX)).id(b);
135                    let groups_y = b.compile_variable(built_var(Builtin::CubeCountY)).id(b);
136                    let size_x = b.state.cube_dims[0];
137                    let size_y = b.state.cube_dims[1];
138                    let ty = Elem::Int(32, false).id(b);
139                    let size_x = b.i_mul(ty, None, groups_x, size_x).unwrap();
140                    let size_y = b.i_mul(ty, None, groups_y, size_y).unwrap();
141                    let id = b.i_mul(ty, None, z, size_y).unwrap();
142                    let id = b.i_add(ty, None, id, y).unwrap();
143                    let id = b.i_mul(ty, None, id, size_x).unwrap();
144                    let id = b.i_add(ty, None, id, x).unwrap();
145                    b.debug_name(id, "ABSOLUTE_POS");
146                    id
147                });
148                Variable::GlobalInvocationIndex(id)
149            }
150            Builtin::AbsolutePosX => {
151                let id = self.insert_global(|b| {
152                    let id = b.extract(BuiltIn::GlobalInvocationId, 0);
153                    b.debug_name(id, "ABSOLUTE_POS_X");
154                    id
155                });
156
157                Variable::GlobalInvocationIdX(id)
158            }
159            Builtin::AbsolutePosY => {
160                let id = self.insert_global(|b| {
161                    let id = b.extract(BuiltIn::GlobalInvocationId, 1);
162                    b.debug_name(id, "ABSOLUTE_POS_Y");
163                    id
164                });
165
166                Variable::GlobalInvocationIdY(id)
167            }
168            Builtin::AbsolutePosZ => {
169                let id = self.insert_global(|b| {
170                    let id = b.extract(BuiltIn::GlobalInvocationId, 2);
171                    b.debug_name(id, "ABSOLUTE_POS_Z");
172                    id
173                });
174
175                Variable::GlobalInvocationIdZ(id)
176            }
177        }
178    }
179
180    fn constant_var(&mut self, value: u32) -> Variable {
181        let var =
182            ir::Variable::constant(ir::ConstantScalarValue::UInt(value as u64, UIntKind::U32));
183        self.compile_variable(var)
184    }
185
186    fn extract(&mut self, builtin: BuiltIn, idx: u32) -> Word {
187        let composite_id = self.vec_global(builtin);
188        let ty = Elem::Int(32, false).id(self);
189        self.composite_extract(ty, None, composite_id, vec![idx])
190            .unwrap()
191    }
192
193    fn vec_global(&mut self, builtin: BuiltIn) -> Word {
194        let item = Item::Vector(Elem::Int(32, false), 3);
195
196        self.insert_global(|b| b.load_builtin(builtin, item))
197    }
198
199    fn load_builtin(&mut self, builtin: BuiltIn, item: Item) -> Word {
200        let item_id = item.id(self);
201        let id = self.builtin(builtin, item);
202        self.load(item_id, None, id, None, vec![]).unwrap()
203    }
204}
205
206fn built_var(builtin: Builtin) -> ir::Variable {
207    ir::Variable::builtin(builtin)
208}