cubecl_spirv/
lookups.rs

1use std::collections::VecDeque;
2
3use cubecl_core::{
4    ir::{self, Id, Type, VariableKind},
5    prelude::{Binding, KernelDefinition, Location, Visibility},
6};
7use cubecl_opt::{ConstArray, NodeIndex};
8use hashbrown::{HashMap, HashSet};
9use rspirv::{
10    dr,
11    spirv::{self, BuiltIn, CooperativeMatrixLayout, CooperativeMatrixUse, StorageClass, Word},
12};
13
14use crate::{
15    MAX_VECTORIZATION, SpirvCompiler, SpirvTarget,
16    item::{Elem, Item},
17    variable::{ConstVal, Variable},
18};
19
20#[derive(Clone, Debug, Default)]
21pub struct LookupTables {
22    pub buffers: Vec<Word>,
23    pub scalar_bindings: HashMap<ir::StorageType, Word>,
24    pub info: Word,
25    pub cube_dims: Vec<Word>,
26    pub cube_size: Word,
27
28    pub const_arrays: Vec<Array>,
29    pub shared_memories: HashMap<Id, SharedMemory>,
30    pub local_arrays: HashMap<Id, Array>,
31    pub matrices: HashMap<Id, Matrix>,
32
33    pub used_builtins: HashMap<BuiltIn, (Word, Item)>,
34
35    pub scalars: HashMap<(Id, ir::StorageType), Word>,
36    pub array_types: HashSet<Word>,
37    pub constants: HashMap<(ConstVal, Item), Word>,
38    pub bindings: HashMap<Id, Word>,
39    pub variables: HashMap<Id, Word>,
40    pub versioned: HashMap<(Id, u16), Word>,
41    pub labels: HashMap<NodeIndex, Word>,
42    pub end_labels: HashMap<NodeIndex, Word>,
43
44    pub slices: HashMap<Id, Slice>,
45
46    // For break, continue
47    pub loops: VecDeque<Loop>,
48
49    // Explicitly decorated types, to avoid double decorating
50    pub decorated_types: HashSet<Word>,
51    pub debug_types: HashSet<Word>,
52}
53
54#[derive(Clone, Debug)]
55pub struct Slice {
56    pub ptr: Variable,
57    pub offset: Word,
58    pub end: Word,
59    pub const_len: Option<u32>,
60    pub item: Item,
61}
62
63impl From<&Slice> for Variable {
64    fn from(value: &Slice) -> Self {
65        Variable::Slice {
66            ptr: Box::new(value.ptr.clone()),
67            offset: value.offset,
68            end: value.end,
69            const_len: value.const_len,
70            item: value.item.clone(),
71        }
72    }
73}
74
75#[derive(Clone, Debug)]
76pub struct Array {
77    pub id: Word,
78    pub item: Item,
79    pub len: u32,
80    pub var: ir::Variable,
81    pub alignment: Option<u32>,
82}
83
84#[derive(Clone, Debug)]
85pub struct SharedMemory {
86    pub id: Word,
87    pub item: Item,
88    pub len: u32,
89    pub align: u32,
90    pub offset: u32,
91}
92
93impl SharedMemory {
94    pub fn end(&self) -> u32 {
95        self.offset + self.len * self.item.size()
96    }
97}
98
99#[derive(Debug, Clone, Copy, PartialEq)]
100#[allow(missing_docs)]
101pub struct Matrix {
102    pub id: Word,
103    pub ident: CooperativeMatrixUse,
104    pub m: u32,
105    pub n: u32,
106    pub k: u32,
107    pub elem: Elem,
108    pub layout: Option<CooperativeMatrixLayout>,
109}
110
111#[derive(Clone, Debug)]
112pub struct Loop {
113    pub header: Word,
114    pub continue_target: Word,
115    pub post: Word,
116}
117
118impl<T: SpirvTarget> SpirvCompiler<T> {
119    pub fn init_state(&mut self, kernel: KernelDefinition) {
120        let mut target = self.target.clone();
121
122        self.state.buffers = kernel
123            .buffers
124            .into_iter()
125            .map(|mut binding| {
126                // This is safe when combined with the unroll transform that adjusts all indices.
127                // Must not be used alone
128                if binding.ty.line_size() > MAX_VECTORIZATION {
129                    binding.ty = binding.ty.line(MAX_VECTORIZATION);
130                }
131                let var = ir::Variable::new(VariableKind::GlobalInputArray(binding.id), binding.ty);
132                let name = self.name_of_var(var);
133                target.generate_binding(self, binding, name.into())
134            })
135            .collect();
136
137        let mut offset = self.state.buffers.len() as u32;
138        let info_binding = Binding {
139            id: offset,
140            location: Location::Storage,
141            visibility: Visibility::Read,
142            ty: ir::Type::scalar(ir::ElemType::UInt(ir::UIntKind::U32)),
143            size: None,
144            has_extended_meta: false,
145        };
146        if self.metadata.static_len() > 0 {
147            self.state.info = target.generate_binding(self, info_binding, "info".to_string());
148            offset += 1;
149        }
150
151        self.state.scalar_bindings = kernel
152            .scalars
153            .into_iter()
154            .enumerate()
155            .map(|(i, binding)| {
156                let elem = binding.ty;
157                let binding = Binding {
158                    id: i as u32 + offset,
159                    location: Location::Storage,
160                    visibility: Visibility::Read,
161                    ty: ir::Type::new(elem),
162                    size: Some(binding.count),
163                    has_extended_meta: false,
164                };
165                let name = format!("scalars({elem})");
166                (elem, target.generate_binding(self, binding, name))
167            })
168            .collect();
169
170        let cube_dims = [kernel.cube_dim.x, kernel.cube_dim.y, kernel.cube_dim.z];
171        self.state.cube_dims = cube_dims.iter().map(|dim| self.const_u32(*dim)).collect();
172        self.state.cube_size = self.const_u32(cube_dims.iter().product());
173
174        let shared_liveness = self.shared_liveness.clone();
175        let shared_memories = shared_liveness.allocations.values().map(|alloc| {
176            let smem_id = self.id();
177            let smem = SharedMemory {
178                id: smem_id,
179                item: self.compile_type(alloc.smem.ty),
180                len: alloc.smem.length,
181                align: alloc.smem.align,
182                offset: alloc.offset,
183            };
184            (alloc.smem.id, smem)
185        });
186        self.state.shared_memories = shared_memories.collect();
187    }
188
189    fn dedup_const(&mut self, inst: &dr::Instruction) -> Option<Word> {
190        self.module_ref()
191            .types_global_values
192            .iter()
193            .find(|it| {
194                it.class == inst.class
195                    && it.result_type == inst.result_type
196                    && it.operands == inst.operands
197            })
198            .and_then(|it| it.result_id)
199    }
200
201    pub fn dedup_constant_bit32(&mut self, ty: Word, val: u32) -> Word {
202        let inst = dr::Instruction::new(
203            spirv::Op::Constant,
204            Some(ty),
205            None,
206            vec![dr::Operand::LiteralBit32(val)],
207        );
208        if let Some(id) = self.dedup_const(&inst) {
209            id
210        } else {
211            self.constant_bit32(ty, val)
212        }
213    }
214
215    pub fn dedup_constant_bit64(&mut self, ty: Word, val: u64) -> Word {
216        let inst = dr::Instruction::new(
217            spirv::Op::Constant,
218            Some(ty),
219            None,
220            vec![dr::Operand::LiteralBit64(val)],
221        );
222        if let Some(id) = self.dedup_const(&inst) {
223            id
224        } else {
225            self.constant_bit64(ty, val)
226        }
227    }
228
229    pub fn const_u32(&mut self, value: u32) -> Word {
230        let ty = Item::Scalar(Elem::Int(32, false));
231        let ty_id = ty.id(self);
232        self.dedup_constant_bit32(ty_id, value)
233    }
234
235    pub fn insert_global(&mut self, insert: impl FnOnce(&mut Self) -> Word) -> Word {
236        let current_block = self.selected_block();
237        let setup = self.setup_block;
238        self.select_block(Some(setup)).unwrap();
239        let id = insert(self);
240        self.select_block(current_block).unwrap();
241        id
242    }
243
244    pub fn get_local(&mut self, id: Id, item: &Item, var: ir::Variable) -> Word {
245        if let Some(existing) = self.state.variables.get(&id) {
246            *existing
247        } else {
248            let ty = Item::Pointer(StorageClass::Function, Box::new(item.clone())).id(self);
249            let word = self.declare_function_variable(ty);
250            self.state.variables.insert(id, word);
251            self.debug_var_name(word, var);
252            word
253        }
254    }
255
256    pub fn get_binding(&mut self, id: Id, var: &ir::Variable) -> Word {
257        if let Some(existing) = self.state.bindings.get(&id) {
258            *existing
259        } else {
260            let word = self.id();
261            self.state.bindings.insert(id, word);
262            self.debug_var_name(word, *var);
263            word
264        }
265    }
266
267    pub fn merge_binding(&mut self, id: Id, word: Word) {
268        self.state.bindings.insert(id, word);
269    }
270
271    pub fn get_versioned(&mut self, id: (Id, u16), var: &ir::Variable) -> Word {
272        if let Some(existing) = self.state.versioned.get(&id) {
273            *existing
274        } else {
275            let word = self.id();
276            self.state.versioned.insert(id, word);
277            let mut debug_var = *var;
278            debug_var.kind = VariableKind::LocalMut { id: id.0 };
279            let name = self.name_of_var(debug_var);
280            self.debug_name(word, format!("{name}.v{}", id.1));
281            word
282        }
283    }
284
285    pub fn label(&mut self, block: NodeIndex) -> Word {
286        if let Some(existing) = self.state.labels.get(&block) {
287            *existing
288        } else {
289            let word = self.id();
290            self.debug_name(word, format!("bb{}", block.index()));
291            self.state.labels.insert(block, word);
292            word
293        }
294    }
295
296    pub fn end_label(&mut self, block: NodeIndex) -> Word {
297        if let Some(existing) = self.state.end_labels.get(&block) {
298            *existing
299        } else {
300            let word = self.label(block);
301            self.state.end_labels.insert(block, word);
302            word
303        }
304    }
305
306    pub fn global_scalar(&mut self, id: Id, ty: ir::StorageType) -> Variable {
307        if let Some(existing) = self.state.scalars.get(&(id, ty)).copied() {
308            let item = self.compile_type(ir::Type::new(ty));
309            Variable::GlobalScalar(existing, item.elem())
310        } else {
311            let ir_var = ir::Variable::new(VariableKind::GlobalScalar(id), Type::new(ty));
312            let current_block = self.selected_block();
313            let setup = self.setup_block;
314            self.select_block(Some(setup)).unwrap();
315            let arr_id = self.state.scalar_bindings[&ty];
316            let item = self.compile_type(ir::Type::new(ty));
317            let arr = Variable::GlobalInputArray(arr_id, item.clone(), 0);
318            let const_id = self.const_u32(id);
319            let index = Variable::ConstantScalar(const_id, id.into(), Elem::Int(32, false));
320            let read_id = self.id();
321            let var = Variable::GlobalScalar(read_id, item.elem());
322            self.debug_var_name(read_id, ir_var);
323            self.read_indexed_unchecked(&var, &arr, &index);
324            self.select_block(current_block).unwrap();
325            self.state.scalars.insert((id, ty), read_id);
326            var
327        }
328    }
329
330    pub fn register_const_array(&mut self, arr: ConstArray) {
331        let var = ir::Variable::new(
332            VariableKind::ConstantArray {
333                id: arr.id,
334                length: arr.length,
335                unroll_factor: 1,
336            },
337            arr.item,
338        );
339        let item = self.compile_type(arr.item);
340        let array_ty = Item::Array(Box::new(item.clone()), arr.length);
341        let pointer_ty = Item::Pointer(StorageClass::Function, Box::new(array_ty.clone())).id(self);
342        let array_ty = array_ty.id(self);
343        let values = arr
344            .values
345            .into_iter()
346            .map(|it| self.compile_variable(it))
347            .collect::<Vec<_>>()
348            .into_iter()
349            .map(|it| self.read_as(&it, &item))
350            .collect::<Vec<_>>();
351        let constant = self.constant_composite(array_ty, values);
352        let id = self.variable(pointer_ty, None, StorageClass::Function, Some(constant));
353        self.debug_var_name(id, var);
354        self.state.const_arrays.insert(
355            arr.id as usize,
356            Array {
357                id,
358                item,
359                len: arr.length,
360                var,
361                alignment: None,
362            },
363        );
364    }
365}