cubecl_spirv/
lookups.rs

1use std::collections::VecDeque;
2
3use cubecl_core::{
4    ir::{self, Builtin, Id, Type, VariableKind},
5    prelude::{Binding, KernelDefinition, Location, Visibility},
6};
7use cubecl_opt::{ConstArray, NodeIndex, SharedMemory};
8use hashbrown::{HashMap, HashSet};
9use rspirv::{
10    dr,
11    spirv::{
12        self, BuiltIn, CooperativeMatrixLayout, CooperativeMatrixUse, Scope, StorageClass, Word,
13    },
14};
15
16use crate::{
17    MAX_VECTORIZATION, SpirvCompiler, SpirvTarget,
18    item::{Elem, Item},
19    variable::{ConstVal, Variable},
20};
21
22#[derive(Clone, Debug, Default)]
23pub struct LookupTables {
24    pub buffers: Vec<Word>,
25    pub scalar_bindings: HashMap<ir::StorageType, Word>,
26    pub info: Word,
27    pub cube_dims: Vec<Word>,
28    pub cube_size: Word,
29
30    pub const_arrays: Vec<Array>,
31    pub shared_arrays: HashMap<Id, SharedArray>,
32    pub shared: HashMap<Id, SharedVar>,
33    pub local_arrays: HashMap<Id, Array>,
34    pub matrices: HashMap<Id, Matrix>,
35    pub globals: HashMap<Builtin, Word>,
36    pub loaded_builtins: HashMap<BuiltIn, Word>,
37
38    pub used_builtins: HashMap<BuiltIn, (Word, Item)>,
39
40    pub scalars: HashMap<(Id, ir::StorageType), Word>,
41    pub array_types: HashSet<Word>,
42    pub constants: HashMap<(ConstVal, Item), Word>,
43    pub bindings: HashMap<Id, Word>,
44    pub variables: HashMap<Id, Word>,
45    pub versioned: HashMap<(Id, u16), Word>,
46    pub labels: HashMap<NodeIndex, Word>,
47    pub end_labels: HashMap<NodeIndex, Word>,
48
49    pub atomic_scopes: HashMap<Word, Scope>,
50
51    pub slices: HashMap<Id, Slice>,
52
53    // For break, continue
54    pub loops: VecDeque<Loop>,
55
56    // Explicitly decorated types, to avoid double decorating
57    pub decorated_types: HashSet<Word>,
58    pub debug_types: HashSet<Word>,
59}
60
61#[derive(Clone, Debug)]
62pub struct Slice {
63    pub ptr: Variable,
64    pub offset: Word,
65    pub end: Word,
66    pub const_len: Option<u32>,
67    pub item: Item,
68}
69
70impl From<&Slice> for Variable {
71    fn from(value: &Slice) -> Self {
72        Variable::Slice {
73            ptr: Box::new(value.ptr.clone()),
74            offset: value.offset,
75            end: value.end,
76            const_len: value.const_len,
77            item: value.item.clone(),
78        }
79    }
80}
81
82#[derive(Clone, Debug)]
83pub struct Array {
84    pub id: Word,
85    pub item: Item,
86    pub len: u32,
87    pub var: ir::Variable,
88    pub alignment: Option<u32>,
89}
90
91#[derive(Clone, Debug)]
92pub struct SharedArray {
93    pub id: Word,
94    pub item: Item,
95    pub len: u32,
96    pub align: u32,
97    pub offset: u32,
98}
99
100#[derive(Clone, Debug)]
101pub struct SharedVar {
102    pub id: Word,
103    pub item: Item,
104    pub offset: u32,
105    pub align: u32,
106}
107
108impl SharedArray {
109    pub fn end(&self) -> u32 {
110        self.offset + self.len * self.item.size()
111    }
112}
113
114#[derive(Debug, Clone, Copy, PartialEq)]
115#[allow(missing_docs)]
116pub struct Matrix {
117    pub id: Word,
118    pub ident: CooperativeMatrixUse,
119    pub m: u32,
120    pub n: u32,
121    pub k: u32,
122    pub elem: Elem,
123    pub layout: Option<CooperativeMatrixLayout>,
124}
125
126#[derive(Clone, Debug)]
127pub struct Loop {
128    pub header: Word,
129    pub continue_target: Word,
130    pub post: Word,
131}
132
133impl<T: SpirvTarget> SpirvCompiler<T> {
134    pub fn init_state(&mut self, kernel: KernelDefinition) {
135        let mut target = self.target.clone();
136
137        self.state.buffers = kernel
138            .buffers
139            .into_iter()
140            .map(|mut binding| {
141                // This is safe when combined with the unroll transform that adjusts all indices.
142                // Must not be used alone
143                if binding.ty.line_size() > MAX_VECTORIZATION {
144                    binding.ty = binding.ty.line(MAX_VECTORIZATION);
145                }
146                let var = ir::Variable::new(VariableKind::GlobalInputArray(binding.id), binding.ty);
147                let name = self.name_of_var(var);
148                target.generate_binding(self, binding, name.into())
149            })
150            .collect();
151
152        let mut offset = self.state.buffers.len() as u32;
153        let info_binding = Binding {
154            id: offset,
155            location: Location::Storage,
156            visibility: Visibility::Read,
157            ty: self.addr_type.into(),
158            size: None,
159            has_extended_meta: false,
160        };
161        if self.metadata.static_len() > 0 {
162            self.state.info = target.generate_binding(self, info_binding, "info".to_string());
163            offset += 1;
164        }
165
166        self.state.scalar_bindings = kernel
167            .scalars
168            .into_iter()
169            .enumerate()
170            .map(|(i, binding)| {
171                let elem = binding.ty;
172                let binding = Binding {
173                    id: i as u32 + offset,
174                    location: Location::Storage,
175                    visibility: Visibility::Read,
176                    ty: ir::Type::new(elem),
177                    size: Some(binding.count),
178                    has_extended_meta: false,
179                };
180                let name = format!("scalars({elem})");
181                (elem, target.generate_binding(self, binding, name))
182            })
183            .collect();
184
185        let cube_dims = [kernel.cube_dim.x, kernel.cube_dim.y, kernel.cube_dim.z];
186        self.state.cube_dims = cube_dims.iter().map(|dim| self.const_u32(*dim)).collect();
187        self.state.cube_size = self.const_u32(cube_dims.iter().product());
188
189        let shared_liveness = self.shared_liveness.clone();
190        for alloc in shared_liveness.allocations.values() {
191            let smem_id = self.id();
192
193            match alloc.smem {
194                SharedMemory::Array {
195                    id,
196                    length,
197                    ty,
198                    align,
199                } => {
200                    let item = self.compile_type(ty);
201                    self.state.shared_arrays.insert(
202                        id,
203                        SharedArray {
204                            id: smem_id,
205                            item,
206                            len: length as u32,
207                            align: align as u32,
208                            offset: alloc.offset as u32,
209                        },
210                    );
211                }
212                SharedMemory::Value { id, ty, align } => {
213                    let item = self.compile_type(ty);
214                    self.state.shared.insert(
215                        id,
216                        SharedVar {
217                            id: smem_id,
218                            item,
219                            offset: alloc.offset as u32,
220                            align: align as u32,
221                        },
222                    );
223                }
224            }
225        }
226    }
227
228    fn dedup_const(&mut self, inst: &dr::Instruction) -> Option<Word> {
229        self.module_ref()
230            .types_global_values
231            .iter()
232            .find(|it| {
233                it.class == inst.class
234                    && it.result_type == inst.result_type
235                    && it.operands == inst.operands
236            })
237            .and_then(|it| it.result_id)
238    }
239
240    pub fn dedup_constant_bit32(&mut self, ty: Word, val: u32) -> Word {
241        let inst = dr::Instruction::new(
242            spirv::Op::Constant,
243            Some(ty),
244            None,
245            vec![dr::Operand::LiteralBit32(val)],
246        );
247        if let Some(id) = self.dedup_const(&inst) {
248            id
249        } else {
250            self.constant_bit32(ty, val)
251        }
252    }
253
254    pub fn dedup_constant_bit64(&mut self, ty: Word, val: u64) -> Word {
255        let inst = dr::Instruction::new(
256            spirv::Op::Constant,
257            Some(ty),
258            None,
259            vec![dr::Operand::LiteralBit64(val)],
260        );
261        if let Some(id) = self.dedup_const(&inst) {
262            id
263        } else {
264            self.constant_bit64(ty, val)
265        }
266    }
267
268    pub fn const_u32(&mut self, value: u32) -> Word {
269        let ty = Item::Scalar(Elem::Int(32, false));
270        let ty_id = ty.id(self);
271        self.dedup_constant_bit32(ty_id, value)
272    }
273
274    pub fn insert_builtin(
275        &mut self,
276        builtin: BuiltIn,
277        insert: impl FnOnce(&mut Self) -> Word,
278    ) -> Word {
279        if let Some(id) = self.state.loaded_builtins.get(&builtin) {
280            *id
281        } else {
282            let id = self.insert_in_setup(insert);
283            self.state.loaded_builtins.insert(builtin, id);
284            id
285        }
286    }
287
288    pub fn insert_global(
289        &mut self,
290        builtin: Builtin,
291        insert: impl FnOnce(&mut Self) -> Word,
292    ) -> Word {
293        if let Some(id) = self.state.globals.get(&builtin) {
294            *id
295        } else {
296            let id = self.insert_in_setup(insert);
297            self.state.globals.insert(builtin, id);
298            id
299        }
300    }
301
302    pub fn insert_in_setup(&mut self, insert: impl FnOnce(&mut Self) -> Word) -> Word {
303        let current_block = self.selected_block();
304        let setup = self.setup_block;
305        self.select_block(Some(setup)).unwrap();
306        let id = insert(self);
307        self.select_block(current_block).unwrap();
308        id
309    }
310
311    pub fn get_local(&mut self, id: Id, item: &Item, var: ir::Variable) -> Word {
312        if let Some(existing) = self.state.variables.get(&id) {
313            *existing
314        } else {
315            let ty = Item::Pointer(StorageClass::Function, Box::new(item.clone())).id(self);
316            let word = self.declare_function_variable(ty);
317            self.state.variables.insert(id, word);
318            self.debug_var_name(word, var);
319            word
320        }
321    }
322
323    pub fn get_binding(&mut self, id: Id, var: &ir::Variable) -> Word {
324        if let Some(existing) = self.state.bindings.get(&id) {
325            *existing
326        } else {
327            let word = self.id();
328            self.state.bindings.insert(id, word);
329            self.debug_var_name(word, *var);
330            word
331        }
332    }
333
334    pub fn merge_binding(&mut self, id: Id, word: Word) {
335        self.state.bindings.insert(id, word);
336    }
337
338    pub fn get_versioned(&mut self, id: (Id, u16), var: &ir::Variable) -> Word {
339        if let Some(existing) = self.state.versioned.get(&id) {
340            *existing
341        } else {
342            let word = self.id();
343            self.state.versioned.insert(id, word);
344            let mut debug_var = *var;
345            debug_var.kind = VariableKind::LocalMut { id: id.0 };
346            let name = self.name_of_var(debug_var);
347            self.debug_name(word, format!("{name}.v{}", id.1));
348            word
349        }
350    }
351
352    pub fn label(&mut self, block: NodeIndex) -> Word {
353        if let Some(existing) = self.state.labels.get(&block) {
354            *existing
355        } else {
356            let word = self.id();
357            self.debug_name(word, format!("bb{}", block.index()));
358            self.state.labels.insert(block, word);
359            word
360        }
361    }
362
363    pub fn end_label(&mut self, block: NodeIndex) -> Word {
364        if let Some(existing) = self.state.end_labels.get(&block) {
365            *existing
366        } else {
367            let word = self.label(block);
368            self.state.end_labels.insert(block, word);
369            word
370        }
371    }
372
373    pub fn global_scalar(&mut self, id: Id, ty: ir::StorageType) -> Variable {
374        if let Some(existing) = self.state.scalars.get(&(id, ty)).copied() {
375            let item = self.compile_type(ir::Type::new(ty));
376            Variable::GlobalScalar(existing, item.elem())
377        } else {
378            let ir_var = ir::Variable::new(VariableKind::GlobalScalar(id), Type::new(ty));
379            let current_block = self.selected_block();
380            let setup = self.setup_block;
381            self.select_block(Some(setup)).unwrap();
382            let arr_id = self.state.scalar_bindings[&ty];
383            let item = self.compile_type(ir::Type::new(ty));
384            let arr = Variable::GlobalInputArray(arr_id, item.clone(), 0);
385            let const_id = self.const_u32(id);
386            let index = Variable::Constant(const_id, id.into(), Item::Scalar(Elem::Int(32, false)));
387            let read_id = self.id();
388            let var = Variable::GlobalScalar(read_id, item.elem());
389            self.debug_var_name(read_id, ir_var);
390            self.read_indexed_unchecked(&var, &arr, &index);
391            self.select_block(current_block).unwrap();
392            self.state.scalars.insert((id, ty), read_id);
393            var
394        }
395    }
396
397    pub fn register_const_array(&mut self, arr: ConstArray) {
398        let var = ir::Variable::new(
399            VariableKind::ConstantArray {
400                id: arr.id,
401                length: arr.length,
402                unroll_factor: 1,
403            },
404            arr.item,
405        );
406        let item = self.compile_type(arr.item);
407        let array_ty = Item::Array(Box::new(item.clone()), arr.length as u32);
408        let pointer_ty = Item::Pointer(StorageClass::Function, Box::new(array_ty.clone())).id(self);
409        let array_ty = array_ty.id(self);
410        let values = arr
411            .values
412            .into_iter()
413            .map(|it| self.compile_variable(it))
414            .collect::<Vec<_>>()
415            .into_iter()
416            .map(|it| self.read_as(&it, &item))
417            .collect::<Vec<_>>();
418        let constant = self.constant_composite(array_ty, values);
419        let id = self.variable(pointer_ty, None, StorageClass::Function, Some(constant));
420        self.debug_var_name(id, var);
421        self.state.const_arrays.insert(
422            arr.id as usize,
423            Array {
424                id,
425                item,
426                len: arr.length as u32,
427                var,
428                alignment: None,
429            },
430        );
431    }
432}