cubecl_spirv/
lookups.rs

1use std::collections::VecDeque;
2
3use cubecl_core::{
4    ir::{self, Id, Type, VariableKind},
5    prelude::{Binding, KernelDefinition, Location, Visibility},
6};
7use cubecl_opt::{ConstArray, NodeIndex, SharedMemory};
8use hashbrown::{HashMap, HashSet};
9use rspirv::{
10    dr,
11    spirv::{self, BuiltIn, CooperativeMatrixLayout, CooperativeMatrixUse, StorageClass, Word},
12};
13
14use crate::{
15    MAX_VECTORIZATION, SpirvCompiler, SpirvTarget,
16    item::{Elem, Item},
17    variable::{ConstVal, Variable},
18};
19
20#[derive(Clone, Debug, Default)]
21pub struct LookupTables {
22    pub buffers: Vec<Word>,
23    pub scalar_bindings: HashMap<ir::StorageType, Word>,
24    pub info: Word,
25    pub cube_dims: Vec<Word>,
26    pub cube_size: Word,
27
28    pub const_arrays: Vec<Array>,
29    pub shared_arrays: HashMap<Id, SharedArray>,
30    pub shared: HashMap<Id, SharedVar>,
31    pub local_arrays: HashMap<Id, Array>,
32    pub matrices: HashMap<Id, Matrix>,
33
34    pub used_builtins: HashMap<BuiltIn, (Word, Item)>,
35
36    pub scalars: HashMap<(Id, ir::StorageType), Word>,
37    pub array_types: HashSet<Word>,
38    pub constants: HashMap<(ConstVal, Item), Word>,
39    pub bindings: HashMap<Id, Word>,
40    pub variables: HashMap<Id, Word>,
41    pub versioned: HashMap<(Id, u16), Word>,
42    pub labels: HashMap<NodeIndex, Word>,
43    pub end_labels: HashMap<NodeIndex, Word>,
44
45    pub slices: HashMap<Id, Slice>,
46
47    // For break, continue
48    pub loops: VecDeque<Loop>,
49
50    // Explicitly decorated types, to avoid double decorating
51    pub decorated_types: HashSet<Word>,
52    pub debug_types: HashSet<Word>,
53}
54
55#[derive(Clone, Debug)]
56pub struct Slice {
57    pub ptr: Variable,
58    pub offset: Word,
59    pub end: Word,
60    pub const_len: Option<u32>,
61    pub item: Item,
62}
63
64impl From<&Slice> for Variable {
65    fn from(value: &Slice) -> Self {
66        Variable::Slice {
67            ptr: Box::new(value.ptr.clone()),
68            offset: value.offset,
69            end: value.end,
70            const_len: value.const_len,
71            item: value.item.clone(),
72        }
73    }
74}
75
76#[derive(Clone, Debug)]
77pub struct Array {
78    pub id: Word,
79    pub item: Item,
80    pub len: u32,
81    pub var: ir::Variable,
82    pub alignment: Option<u32>,
83}
84
85#[derive(Clone, Debug)]
86pub struct SharedArray {
87    pub id: Word,
88    pub item: Item,
89    pub len: u32,
90    pub align: u32,
91    pub offset: u32,
92}
93
94#[derive(Clone, Debug)]
95pub struct SharedVar {
96    pub id: Word,
97    pub item: Item,
98    pub offset: u32,
99    pub align: u32,
100}
101
102impl SharedArray {
103    pub fn end(&self) -> u32 {
104        self.offset + self.len * self.item.size()
105    }
106}
107
108#[derive(Debug, Clone, Copy, PartialEq)]
109#[allow(missing_docs)]
110pub struct Matrix {
111    pub id: Word,
112    pub ident: CooperativeMatrixUse,
113    pub m: u32,
114    pub n: u32,
115    pub k: u32,
116    pub elem: Elem,
117    pub layout: Option<CooperativeMatrixLayout>,
118}
119
120#[derive(Clone, Debug)]
121pub struct Loop {
122    pub header: Word,
123    pub continue_target: Word,
124    pub post: Word,
125}
126
127impl<T: SpirvTarget> SpirvCompiler<T> {
128    pub fn init_state(&mut self, kernel: KernelDefinition) {
129        let mut target = self.target.clone();
130
131        self.state.buffers = kernel
132            .buffers
133            .into_iter()
134            .map(|mut binding| {
135                // This is safe when combined with the unroll transform that adjusts all indices.
136                // Must not be used alone
137                if binding.ty.line_size() > MAX_VECTORIZATION {
138                    binding.ty = binding.ty.line(MAX_VECTORIZATION);
139                }
140                let var = ir::Variable::new(VariableKind::GlobalInputArray(binding.id), binding.ty);
141                let name = self.name_of_var(var);
142                target.generate_binding(self, binding, name.into())
143            })
144            .collect();
145
146        let mut offset = self.state.buffers.len() as u32;
147        let info_binding = Binding {
148            id: offset,
149            location: Location::Storage,
150            visibility: Visibility::Read,
151            ty: ir::Type::scalar(ir::ElemType::UInt(ir::UIntKind::U32)),
152            size: None,
153            has_extended_meta: false,
154        };
155        if self.metadata.static_len() > 0 {
156            self.state.info = target.generate_binding(self, info_binding, "info".to_string());
157            offset += 1;
158        }
159
160        self.state.scalar_bindings = kernel
161            .scalars
162            .into_iter()
163            .enumerate()
164            .map(|(i, binding)| {
165                let elem = binding.ty;
166                let binding = Binding {
167                    id: i as u32 + offset,
168                    location: Location::Storage,
169                    visibility: Visibility::Read,
170                    ty: ir::Type::new(elem),
171                    size: Some(binding.count),
172                    has_extended_meta: false,
173                };
174                let name = format!("scalars({elem})");
175                (elem, target.generate_binding(self, binding, name))
176            })
177            .collect();
178
179        let cube_dims = [kernel.cube_dim.x, kernel.cube_dim.y, kernel.cube_dim.z];
180        self.state.cube_dims = cube_dims.iter().map(|dim| self.const_u32(*dim)).collect();
181        self.state.cube_size = self.const_u32(cube_dims.iter().product());
182
183        let shared_liveness = self.shared_liveness.clone();
184        for alloc in shared_liveness.allocations.values() {
185            let smem_id = self.id();
186
187            match alloc.smem {
188                SharedMemory::Array {
189                    id,
190                    length,
191                    ty,
192                    align,
193                } => {
194                    let item = self.compile_type(ty);
195                    self.state.shared_arrays.insert(
196                        id,
197                        SharedArray {
198                            id: smem_id,
199                            item,
200                            len: length,
201                            align,
202                            offset,
203                        },
204                    );
205                }
206                SharedMemory::Value { id, ty, align } => {
207                    let item = self.compile_type(ty);
208                    self.state.shared.insert(
209                        id,
210                        SharedVar {
211                            id: smem_id,
212                            item,
213                            offset,
214                            align,
215                        },
216                    );
217                }
218            }
219        }
220    }
221
222    fn dedup_const(&mut self, inst: &dr::Instruction) -> Option<Word> {
223        self.module_ref()
224            .types_global_values
225            .iter()
226            .find(|it| {
227                it.class == inst.class
228                    && it.result_type == inst.result_type
229                    && it.operands == inst.operands
230            })
231            .and_then(|it| it.result_id)
232    }
233
234    pub fn dedup_constant_bit32(&mut self, ty: Word, val: u32) -> Word {
235        let inst = dr::Instruction::new(
236            spirv::Op::Constant,
237            Some(ty),
238            None,
239            vec![dr::Operand::LiteralBit32(val)],
240        );
241        if let Some(id) = self.dedup_const(&inst) {
242            id
243        } else {
244            self.constant_bit32(ty, val)
245        }
246    }
247
248    pub fn dedup_constant_bit64(&mut self, ty: Word, val: u64) -> Word {
249        let inst = dr::Instruction::new(
250            spirv::Op::Constant,
251            Some(ty),
252            None,
253            vec![dr::Operand::LiteralBit64(val)],
254        );
255        if let Some(id) = self.dedup_const(&inst) {
256            id
257        } else {
258            self.constant_bit64(ty, val)
259        }
260    }
261
262    pub fn const_u32(&mut self, value: u32) -> Word {
263        let ty = Item::Scalar(Elem::Int(32, false));
264        let ty_id = ty.id(self);
265        self.dedup_constant_bit32(ty_id, value)
266    }
267
268    pub fn insert_global(&mut self, insert: impl FnOnce(&mut Self) -> Word) -> Word {
269        let current_block = self.selected_block();
270        let setup = self.setup_block;
271        self.select_block(Some(setup)).unwrap();
272        let id = insert(self);
273        self.select_block(current_block).unwrap();
274        id
275    }
276
277    pub fn get_local(&mut self, id: Id, item: &Item, var: ir::Variable) -> Word {
278        if let Some(existing) = self.state.variables.get(&id) {
279            *existing
280        } else {
281            let ty = Item::Pointer(StorageClass::Function, Box::new(item.clone())).id(self);
282            let word = self.declare_function_variable(ty);
283            self.state.variables.insert(id, word);
284            self.debug_var_name(word, var);
285            word
286        }
287    }
288
289    pub fn get_binding(&mut self, id: Id, var: &ir::Variable) -> Word {
290        if let Some(existing) = self.state.bindings.get(&id) {
291            *existing
292        } else {
293            let word = self.id();
294            self.state.bindings.insert(id, word);
295            self.debug_var_name(word, *var);
296            word
297        }
298    }
299
300    pub fn merge_binding(&mut self, id: Id, word: Word) {
301        self.state.bindings.insert(id, word);
302    }
303
304    pub fn get_versioned(&mut self, id: (Id, u16), var: &ir::Variable) -> Word {
305        if let Some(existing) = self.state.versioned.get(&id) {
306            *existing
307        } else {
308            let word = self.id();
309            self.state.versioned.insert(id, word);
310            let mut debug_var = *var;
311            debug_var.kind = VariableKind::LocalMut { id: id.0 };
312            let name = self.name_of_var(debug_var);
313            self.debug_name(word, format!("{name}.v{}", id.1));
314            word
315        }
316    }
317
318    pub fn label(&mut self, block: NodeIndex) -> Word {
319        if let Some(existing) = self.state.labels.get(&block) {
320            *existing
321        } else {
322            let word = self.id();
323            self.debug_name(word, format!("bb{}", block.index()));
324            self.state.labels.insert(block, word);
325            word
326        }
327    }
328
329    pub fn end_label(&mut self, block: NodeIndex) -> Word {
330        if let Some(existing) = self.state.end_labels.get(&block) {
331            *existing
332        } else {
333            let word = self.label(block);
334            self.state.end_labels.insert(block, word);
335            word
336        }
337    }
338
339    pub fn global_scalar(&mut self, id: Id, ty: ir::StorageType) -> Variable {
340        if let Some(existing) = self.state.scalars.get(&(id, ty)).copied() {
341            let item = self.compile_type(ir::Type::new(ty));
342            Variable::GlobalScalar(existing, item.elem())
343        } else {
344            let ir_var = ir::Variable::new(VariableKind::GlobalScalar(id), Type::new(ty));
345            let current_block = self.selected_block();
346            let setup = self.setup_block;
347            self.select_block(Some(setup)).unwrap();
348            let arr_id = self.state.scalar_bindings[&ty];
349            let item = self.compile_type(ir::Type::new(ty));
350            let arr = Variable::GlobalInputArray(arr_id, item.clone(), 0);
351            let const_id = self.const_u32(id);
352            let index = Variable::ConstantScalar(const_id, id.into(), Elem::Int(32, false));
353            let read_id = self.id();
354            let var = Variable::GlobalScalar(read_id, item.elem());
355            self.debug_var_name(read_id, ir_var);
356            self.read_indexed_unchecked(&var, &arr, &index);
357            self.select_block(current_block).unwrap();
358            self.state.scalars.insert((id, ty), read_id);
359            var
360        }
361    }
362
363    pub fn register_const_array(&mut self, arr: ConstArray) {
364        let var = ir::Variable::new(
365            VariableKind::ConstantArray {
366                id: arr.id,
367                length: arr.length,
368                unroll_factor: 1,
369            },
370            arr.item,
371        );
372        let item = self.compile_type(arr.item);
373        let array_ty = Item::Array(Box::new(item.clone()), arr.length);
374        let pointer_ty = Item::Pointer(StorageClass::Function, Box::new(array_ty.clone())).id(self);
375        let array_ty = array_ty.id(self);
376        let values = arr
377            .values
378            .into_iter()
379            .map(|it| self.compile_variable(it))
380            .collect::<Vec<_>>()
381            .into_iter()
382            .map(|it| self.read_as(&it, &item))
383            .collect::<Vec<_>>();
384        let constant = self.constant_composite(array_ty, values);
385        let id = self.variable(pointer_ty, None, StorageClass::Function, Some(constant));
386        self.debug_var_name(id, var);
387        self.state.const_arrays.insert(
388            arr.id as usize,
389            Array {
390                id,
391                item,
392                len: arr.length,
393                var,
394                alignment: None,
395            },
396        );
397    }
398}