cubecl_spirv/
lookups.rs

1use std::collections::VecDeque;
2
3use cubecl_core::{
4    compute::{Binding, Location, Visibility},
5    ir::{self, Id, Type, VariableKind},
6    prelude::KernelDefinition,
7};
8use cubecl_opt::{ConstArray, NodeIndex};
9use hashbrown::{HashMap, HashSet};
10use rspirv::{
11    dr,
12    spirv::{self, BuiltIn, CooperativeMatrixLayout, CooperativeMatrixUse, StorageClass, Word},
13};
14
15use crate::{
16    MAX_VECTORIZATION, SpirvCompiler, SpirvTarget,
17    item::{Elem, Item},
18    variable::{ConstVal, Variable},
19};
20
21#[derive(Clone, Debug, Default)]
22pub struct LookupTables {
23    pub buffers: Vec<Word>,
24    pub scalar_bindings: HashMap<ir::StorageType, Word>,
25    pub info: Word,
26    pub cube_dims: Vec<Word>,
27    pub cube_size: Word,
28
29    pub const_arrays: Vec<Array>,
30    pub shared_memories: HashMap<Id, SharedMemory>,
31    pub local_arrays: HashMap<Id, Array>,
32    pub matrices: HashMap<Id, Matrix>,
33
34    pub used_builtins: HashMap<BuiltIn, (Word, Item)>,
35
36    pub scalars: HashMap<(Id, ir::StorageType), Word>,
37    pub array_types: HashSet<Word>,
38    pub constants: HashMap<(ConstVal, Item), Word>,
39    pub bindings: HashMap<Id, Word>,
40    pub variables: HashMap<Id, Word>,
41    pub versioned: HashMap<(Id, u16), Word>,
42    pub labels: HashMap<NodeIndex, Word>,
43    pub end_labels: HashMap<NodeIndex, Word>,
44
45    pub slices: HashMap<Id, Slice>,
46
47    // For break, continue
48    pub loops: VecDeque<Loop>,
49
50    // Explicitly decorated types, to avoid double decorating
51    pub decorated_types: HashSet<Word>,
52    pub debug_types: HashSet<Word>,
53}
54
55#[derive(Clone, Debug)]
56pub struct Slice {
57    pub ptr: Variable,
58    pub offset: Word,
59    pub end: Word,
60    pub const_len: Option<u32>,
61    pub item: Item,
62}
63
64impl From<&Slice> for Variable {
65    fn from(value: &Slice) -> Self {
66        Variable::Slice {
67            ptr: Box::new(value.ptr.clone()),
68            offset: value.offset,
69            end: value.end,
70            const_len: value.const_len,
71            item: value.item.clone(),
72        }
73    }
74}
75
76#[derive(Clone, Debug)]
77pub struct Array {
78    pub id: Word,
79    pub item: Item,
80    pub len: u32,
81    pub var: ir::Variable,
82    pub alignment: Option<u32>,
83}
84
85#[derive(Clone, Debug)]
86pub struct SharedMemory {
87    pub id: Word,
88    pub item: Item,
89    pub len: u32,
90    pub align: u32,
91    pub offset: u32,
92}
93
94impl SharedMemory {
95    pub fn end(&self) -> u32 {
96        self.offset + self.len * self.item.size()
97    }
98}
99
100#[derive(Debug, Clone, Copy, PartialEq)]
101#[allow(missing_docs)]
102pub struct Matrix {
103    pub id: Word,
104    pub ident: CooperativeMatrixUse,
105    pub m: u32,
106    pub n: u32,
107    pub k: u32,
108    pub elem: Elem,
109    pub layout: Option<CooperativeMatrixLayout>,
110}
111
112#[derive(Clone, Debug)]
113pub struct Loop {
114    pub header: Word,
115    pub continue_target: Word,
116    pub post: Word,
117}
118
119impl<T: SpirvTarget> SpirvCompiler<T> {
120    pub fn init_state(&mut self, kernel: KernelDefinition) {
121        let mut target = self.target.clone();
122
123        self.state.buffers = kernel
124            .buffers
125            .into_iter()
126            .map(|mut binding| {
127                // This is safe when combined with the unroll transform that adjusts all indices.
128                // Must not be used alone
129                if binding.ty.line_size() > MAX_VECTORIZATION {
130                    binding.ty = binding.ty.line(MAX_VECTORIZATION);
131                }
132                let var = ir::Variable::new(VariableKind::GlobalInputArray(binding.id), binding.ty);
133                let name = self.name_of_var(var);
134                target.generate_binding(self, binding, name.into())
135            })
136            .collect();
137
138        let mut offset = self.state.buffers.len() as u32;
139        let info_binding = Binding {
140            id: offset,
141            location: Location::Storage,
142            visibility: Visibility::Read,
143            ty: ir::Type::scalar(ir::ElemType::UInt(ir::UIntKind::U32)),
144            size: None,
145            has_extended_meta: false,
146        };
147        if self.metadata.static_len() > 0 {
148            self.state.info = target.generate_binding(self, info_binding, "info".to_string());
149            offset += 1;
150        }
151
152        self.state.scalar_bindings = kernel
153            .scalars
154            .into_iter()
155            .enumerate()
156            .map(|(i, binding)| {
157                let elem = binding.ty;
158                let binding = Binding {
159                    id: i as u32 + offset,
160                    location: Location::Storage,
161                    visibility: Visibility::Read,
162                    ty: ir::Type::new(elem),
163                    size: Some(binding.count),
164                    has_extended_meta: false,
165                };
166                let name = format!("scalars({elem})");
167                (elem, target.generate_binding(self, binding, name))
168            })
169            .collect();
170
171        let cube_dims = [kernel.cube_dim.x, kernel.cube_dim.y, kernel.cube_dim.z];
172        self.state.cube_dims = cube_dims.iter().map(|dim| self.const_u32(*dim)).collect();
173        self.state.cube_size = self.const_u32(cube_dims.iter().product());
174
175        let shared_liveness = self.shared_liveness.clone();
176        let shared_memories = shared_liveness.allocations.values().map(|alloc| {
177            let smem_id = self.id();
178            let smem = SharedMemory {
179                id: smem_id,
180                item: self.compile_type(alloc.smem.ty),
181                len: alloc.smem.length,
182                align: alloc.smem.align,
183                offset: alloc.offset,
184            };
185            (alloc.smem.id, smem)
186        });
187        self.state.shared_memories = shared_memories.collect();
188    }
189
190    fn dedup_const(&mut self, inst: &dr::Instruction) -> Option<Word> {
191        self.module_ref()
192            .types_global_values
193            .iter()
194            .find(|it| {
195                it.class == inst.class
196                    && it.result_type == inst.result_type
197                    && it.operands == inst.operands
198            })
199            .and_then(|it| it.result_id)
200    }
201
202    pub fn dedup_constant_bit32(&mut self, ty: Word, val: u32) -> Word {
203        let inst = dr::Instruction::new(
204            spirv::Op::Constant,
205            Some(ty),
206            None,
207            vec![dr::Operand::LiteralBit32(val)],
208        );
209        if let Some(id) = self.dedup_const(&inst) {
210            id
211        } else {
212            self.constant_bit32(ty, val)
213        }
214    }
215
216    pub fn dedup_constant_bit64(&mut self, ty: Word, val: u64) -> Word {
217        let inst = dr::Instruction::new(
218            spirv::Op::Constant,
219            Some(ty),
220            None,
221            vec![dr::Operand::LiteralBit64(val)],
222        );
223        if let Some(id) = self.dedup_const(&inst) {
224            id
225        } else {
226            self.constant_bit64(ty, val)
227        }
228    }
229
230    pub fn const_u32(&mut self, value: u32) -> Word {
231        let ty = Item::Scalar(Elem::Int(32, false));
232        let ty_id = ty.id(self);
233        self.dedup_constant_bit32(ty_id, value)
234    }
235
236    pub fn insert_global(&mut self, insert: impl FnOnce(&mut Self) -> Word) -> Word {
237        let current_block = self.selected_block();
238        let setup = self.setup_block;
239        self.select_block(Some(setup)).unwrap();
240        let id = insert(self);
241        self.select_block(current_block).unwrap();
242        id
243    }
244
245    pub fn get_local(&mut self, id: Id, item: &Item, var: ir::Variable) -> Word {
246        if let Some(existing) = self.state.variables.get(&id) {
247            *existing
248        } else {
249            let ty = Item::Pointer(StorageClass::Function, Box::new(item.clone())).id(self);
250            let word = self.declare_function_variable(ty);
251            self.state.variables.insert(id, word);
252            self.debug_var_name(word, var);
253            word
254        }
255    }
256
257    pub fn get_binding(&mut self, id: Id, var: &ir::Variable) -> Word {
258        if let Some(existing) = self.state.bindings.get(&id) {
259            *existing
260        } else {
261            let word = self.id();
262            self.state.bindings.insert(id, word);
263            self.debug_var_name(word, *var);
264            word
265        }
266    }
267
268    pub fn merge_binding(&mut self, id: Id, word: Word) {
269        self.state.bindings.insert(id, word);
270    }
271
272    pub fn get_versioned(&mut self, id: (Id, u16), var: &ir::Variable) -> Word {
273        if let Some(existing) = self.state.versioned.get(&id) {
274            *existing
275        } else {
276            let word = self.id();
277            self.state.versioned.insert(id, word);
278            let mut debug_var = *var;
279            debug_var.kind = VariableKind::LocalMut { id: id.0 };
280            let name = self.name_of_var(debug_var);
281            self.debug_name(word, format!("{name}.v{}", id.1));
282            word
283        }
284    }
285
286    pub fn label(&mut self, block: NodeIndex) -> Word {
287        if let Some(existing) = self.state.labels.get(&block) {
288            *existing
289        } else {
290            let word = self.id();
291            self.debug_name(word, format!("bb{}", block.index()));
292            self.state.labels.insert(block, word);
293            word
294        }
295    }
296
297    pub fn end_label(&mut self, block: NodeIndex) -> Word {
298        if let Some(existing) = self.state.end_labels.get(&block) {
299            *existing
300        } else {
301            let word = self.label(block);
302            self.state.end_labels.insert(block, word);
303            word
304        }
305    }
306
307    pub fn global_scalar(&mut self, id: Id, ty: ir::StorageType) -> Variable {
308        if let Some(existing) = self.state.scalars.get(&(id, ty)).copied() {
309            let item = self.compile_type(ir::Type::new(ty));
310            Variable::GlobalScalar(existing, item.elem())
311        } else {
312            let ir_var = ir::Variable::new(VariableKind::GlobalScalar(id), Type::new(ty));
313            let current_block = self.selected_block();
314            let setup = self.setup_block;
315            self.select_block(Some(setup)).unwrap();
316            let arr_id = self.state.scalar_bindings[&ty];
317            let item = self.compile_type(ir::Type::new(ty));
318            let arr = Variable::GlobalInputArray(arr_id, item.clone(), 0);
319            let const_id = self.const_u32(id);
320            let index = Variable::ConstantScalar(const_id, id.into(), Elem::Int(32, false));
321            let read_id = self.id();
322            let var = Variable::GlobalScalar(read_id, item.elem());
323            self.debug_var_name(read_id, ir_var);
324            self.read_indexed_unchecked(&var, &arr, &index);
325            self.select_block(current_block).unwrap();
326            self.state.scalars.insert((id, ty), read_id);
327            var
328        }
329    }
330
331    pub fn register_const_array(&mut self, arr: ConstArray) {
332        let var = ir::Variable::new(
333            VariableKind::ConstantArray {
334                id: arr.id,
335                length: arr.length,
336                unroll_factor: 1,
337            },
338            arr.item,
339        );
340        let item = self.compile_type(arr.item);
341        let array_ty = Item::Array(Box::new(item.clone()), arr.length);
342        let pointer_ty = Item::Pointer(StorageClass::Function, Box::new(array_ty.clone())).id(self);
343        let array_ty = array_ty.id(self);
344        let values = arr
345            .values
346            .into_iter()
347            .map(|it| self.compile_variable(it))
348            .collect::<Vec<_>>()
349            .into_iter()
350            .map(|it| self.read_as(&it, &item))
351            .collect::<Vec<_>>();
352        let constant = self.constant_composite(array_ty, values);
353        let id = self.variable(pointer_ty, None, StorageClass::Function, Some(constant));
354        self.debug_var_name(id, var);
355        self.state.const_arrays.insert(
356            arr.id as usize,
357            Array {
358                id,
359                item,
360                len: arr.length,
361                var,
362                alignment: None,
363            },
364        );
365    }
366}