1use std::collections::VecDeque;
2
3use cubecl_core::{
4 ir::{self, Id, Type, VariableKind},
5 prelude::{Binding, KernelDefinition, Location, Visibility},
6};
7use cubecl_opt::{ConstArray, NodeIndex, SharedMemory};
8use hashbrown::{HashMap, HashSet};
9use rspirv::{
10 dr,
11 spirv::{self, BuiltIn, CooperativeMatrixLayout, CooperativeMatrixUse, StorageClass, Word},
12};
13
14use crate::{
15 MAX_VECTORIZATION, SpirvCompiler, SpirvTarget,
16 item::{Elem, Item},
17 variable::{ConstVal, Variable},
18};
19
20#[derive(Clone, Debug, Default)]
21pub struct LookupTables {
22 pub buffers: Vec<Word>,
23 pub scalar_bindings: HashMap<ir::StorageType, Word>,
24 pub info: Word,
25 pub cube_dims: Vec<Word>,
26 pub cube_size: Word,
27
28 pub const_arrays: Vec<Array>,
29 pub shared_arrays: HashMap<Id, SharedArray>,
30 pub shared: HashMap<Id, SharedVar>,
31 pub local_arrays: HashMap<Id, Array>,
32 pub matrices: HashMap<Id, Matrix>,
33
34 pub used_builtins: HashMap<BuiltIn, (Word, Item)>,
35
36 pub scalars: HashMap<(Id, ir::StorageType), Word>,
37 pub array_types: HashSet<Word>,
38 pub constants: HashMap<(ConstVal, Item), Word>,
39 pub bindings: HashMap<Id, Word>,
40 pub variables: HashMap<Id, Word>,
41 pub versioned: HashMap<(Id, u16), Word>,
42 pub labels: HashMap<NodeIndex, Word>,
43 pub end_labels: HashMap<NodeIndex, Word>,
44
45 pub slices: HashMap<Id, Slice>,
46
47 pub loops: VecDeque<Loop>,
49
50 pub decorated_types: HashSet<Word>,
52 pub debug_types: HashSet<Word>,
53}
54
55#[derive(Clone, Debug)]
56pub struct Slice {
57 pub ptr: Variable,
58 pub offset: Word,
59 pub end: Word,
60 pub const_len: Option<u32>,
61 pub item: Item,
62}
63
64impl From<&Slice> for Variable {
65 fn from(value: &Slice) -> Self {
66 Variable::Slice {
67 ptr: Box::new(value.ptr.clone()),
68 offset: value.offset,
69 end: value.end,
70 const_len: value.const_len,
71 item: value.item.clone(),
72 }
73 }
74}
75
76#[derive(Clone, Debug)]
77pub struct Array {
78 pub id: Word,
79 pub item: Item,
80 pub len: u32,
81 pub var: ir::Variable,
82 pub alignment: Option<u32>,
83}
84
85#[derive(Clone, Debug)]
86pub struct SharedArray {
87 pub id: Word,
88 pub item: Item,
89 pub len: u32,
90 pub align: u32,
91 pub offset: u32,
92}
93
94#[derive(Clone, Debug)]
95pub struct SharedVar {
96 pub id: Word,
97 pub item: Item,
98 pub offset: u32,
99 pub align: u32,
100}
101
102impl SharedArray {
103 pub fn end(&self) -> u32 {
104 self.offset + self.len * self.item.size()
105 }
106}
107
108#[derive(Debug, Clone, Copy, PartialEq)]
109#[allow(missing_docs)]
110pub struct Matrix {
111 pub id: Word,
112 pub ident: CooperativeMatrixUse,
113 pub m: u32,
114 pub n: u32,
115 pub k: u32,
116 pub elem: Elem,
117 pub layout: Option<CooperativeMatrixLayout>,
118}
119
120#[derive(Clone, Debug)]
121pub struct Loop {
122 pub header: Word,
123 pub continue_target: Word,
124 pub post: Word,
125}
126
127impl<T: SpirvTarget> SpirvCompiler<T> {
128 pub fn init_state(&mut self, kernel: KernelDefinition) {
129 let mut target = self.target.clone();
130
131 self.state.buffers = kernel
132 .buffers
133 .into_iter()
134 .map(|mut binding| {
135 if binding.ty.line_size() > MAX_VECTORIZATION {
138 binding.ty = binding.ty.line(MAX_VECTORIZATION);
139 }
140 let var = ir::Variable::new(VariableKind::GlobalInputArray(binding.id), binding.ty);
141 let name = self.name_of_var(var);
142 target.generate_binding(self, binding, name.into())
143 })
144 .collect();
145
146 let mut offset = self.state.buffers.len() as u32;
147 let info_binding = Binding {
148 id: offset,
149 location: Location::Storage,
150 visibility: Visibility::Read,
151 ty: ir::Type::scalar(ir::ElemType::UInt(ir::UIntKind::U32)),
152 size: None,
153 has_extended_meta: false,
154 };
155 if self.metadata.static_len() > 0 {
156 self.state.info = target.generate_binding(self, info_binding, "info".to_string());
157 offset += 1;
158 }
159
160 self.state.scalar_bindings = kernel
161 .scalars
162 .into_iter()
163 .enumerate()
164 .map(|(i, binding)| {
165 let elem = binding.ty;
166 let binding = Binding {
167 id: i as u32 + offset,
168 location: Location::Storage,
169 visibility: Visibility::Read,
170 ty: ir::Type::new(elem),
171 size: Some(binding.count),
172 has_extended_meta: false,
173 };
174 let name = format!("scalars({elem})");
175 (elem, target.generate_binding(self, binding, name))
176 })
177 .collect();
178
179 let cube_dims = [kernel.cube_dim.x, kernel.cube_dim.y, kernel.cube_dim.z];
180 self.state.cube_dims = cube_dims.iter().map(|dim| self.const_u32(*dim)).collect();
181 self.state.cube_size = self.const_u32(cube_dims.iter().product());
182
183 let shared_liveness = self.shared_liveness.clone();
184 for alloc in shared_liveness.allocations.values() {
185 let smem_id = self.id();
186
187 match alloc.smem {
188 SharedMemory::Array {
189 id,
190 length,
191 ty,
192 align,
193 } => {
194 let item = self.compile_type(ty);
195 self.state.shared_arrays.insert(
196 id,
197 SharedArray {
198 id: smem_id,
199 item,
200 len: length,
201 align,
202 offset,
203 },
204 );
205 }
206 SharedMemory::Value { id, ty, align } => {
207 let item = self.compile_type(ty);
208 self.state.shared.insert(
209 id,
210 SharedVar {
211 id: smem_id,
212 item,
213 offset,
214 align,
215 },
216 );
217 }
218 }
219 }
220 }
221
222 fn dedup_const(&mut self, inst: &dr::Instruction) -> Option<Word> {
223 self.module_ref()
224 .types_global_values
225 .iter()
226 .find(|it| {
227 it.class == inst.class
228 && it.result_type == inst.result_type
229 && it.operands == inst.operands
230 })
231 .and_then(|it| it.result_id)
232 }
233
234 pub fn dedup_constant_bit32(&mut self, ty: Word, val: u32) -> Word {
235 let inst = dr::Instruction::new(
236 spirv::Op::Constant,
237 Some(ty),
238 None,
239 vec![dr::Operand::LiteralBit32(val)],
240 );
241 if let Some(id) = self.dedup_const(&inst) {
242 id
243 } else {
244 self.constant_bit32(ty, val)
245 }
246 }
247
248 pub fn dedup_constant_bit64(&mut self, ty: Word, val: u64) -> Word {
249 let inst = dr::Instruction::new(
250 spirv::Op::Constant,
251 Some(ty),
252 None,
253 vec![dr::Operand::LiteralBit64(val)],
254 );
255 if let Some(id) = self.dedup_const(&inst) {
256 id
257 } else {
258 self.constant_bit64(ty, val)
259 }
260 }
261
262 pub fn const_u32(&mut self, value: u32) -> Word {
263 let ty = Item::Scalar(Elem::Int(32, false));
264 let ty_id = ty.id(self);
265 self.dedup_constant_bit32(ty_id, value)
266 }
267
268 pub fn insert_global(&mut self, insert: impl FnOnce(&mut Self) -> Word) -> Word {
269 let current_block = self.selected_block();
270 let setup = self.setup_block;
271 self.select_block(Some(setup)).unwrap();
272 let id = insert(self);
273 self.select_block(current_block).unwrap();
274 id
275 }
276
277 pub fn get_local(&mut self, id: Id, item: &Item, var: ir::Variable) -> Word {
278 if let Some(existing) = self.state.variables.get(&id) {
279 *existing
280 } else {
281 let ty = Item::Pointer(StorageClass::Function, Box::new(item.clone())).id(self);
282 let word = self.declare_function_variable(ty);
283 self.state.variables.insert(id, word);
284 self.debug_var_name(word, var);
285 word
286 }
287 }
288
289 pub fn get_binding(&mut self, id: Id, var: &ir::Variable) -> Word {
290 if let Some(existing) = self.state.bindings.get(&id) {
291 *existing
292 } else {
293 let word = self.id();
294 self.state.bindings.insert(id, word);
295 self.debug_var_name(word, *var);
296 word
297 }
298 }
299
300 pub fn merge_binding(&mut self, id: Id, word: Word) {
301 self.state.bindings.insert(id, word);
302 }
303
304 pub fn get_versioned(&mut self, id: (Id, u16), var: &ir::Variable) -> Word {
305 if let Some(existing) = self.state.versioned.get(&id) {
306 *existing
307 } else {
308 let word = self.id();
309 self.state.versioned.insert(id, word);
310 let mut debug_var = *var;
311 debug_var.kind = VariableKind::LocalMut { id: id.0 };
312 let name = self.name_of_var(debug_var);
313 self.debug_name(word, format!("{name}.v{}", id.1));
314 word
315 }
316 }
317
318 pub fn label(&mut self, block: NodeIndex) -> Word {
319 if let Some(existing) = self.state.labels.get(&block) {
320 *existing
321 } else {
322 let word = self.id();
323 self.debug_name(word, format!("bb{}", block.index()));
324 self.state.labels.insert(block, word);
325 word
326 }
327 }
328
329 pub fn end_label(&mut self, block: NodeIndex) -> Word {
330 if let Some(existing) = self.state.end_labels.get(&block) {
331 *existing
332 } else {
333 let word = self.label(block);
334 self.state.end_labels.insert(block, word);
335 word
336 }
337 }
338
339 pub fn global_scalar(&mut self, id: Id, ty: ir::StorageType) -> Variable {
340 if let Some(existing) = self.state.scalars.get(&(id, ty)).copied() {
341 let item = self.compile_type(ir::Type::new(ty));
342 Variable::GlobalScalar(existing, item.elem())
343 } else {
344 let ir_var = ir::Variable::new(VariableKind::GlobalScalar(id), Type::new(ty));
345 let current_block = self.selected_block();
346 let setup = self.setup_block;
347 self.select_block(Some(setup)).unwrap();
348 let arr_id = self.state.scalar_bindings[&ty];
349 let item = self.compile_type(ir::Type::new(ty));
350 let arr = Variable::GlobalInputArray(arr_id, item.clone(), 0);
351 let const_id = self.const_u32(id);
352 let index = Variable::ConstantScalar(const_id, id.into(), Elem::Int(32, false));
353 let read_id = self.id();
354 let var = Variable::GlobalScalar(read_id, item.elem());
355 self.debug_var_name(read_id, ir_var);
356 self.read_indexed_unchecked(&var, &arr, &index);
357 self.select_block(current_block).unwrap();
358 self.state.scalars.insert((id, ty), read_id);
359 var
360 }
361 }
362
363 pub fn register_const_array(&mut self, arr: ConstArray) {
364 let var = ir::Variable::new(
365 VariableKind::ConstantArray {
366 id: arr.id,
367 length: arr.length,
368 unroll_factor: 1,
369 },
370 arr.item,
371 );
372 let item = self.compile_type(arr.item);
373 let array_ty = Item::Array(Box::new(item.clone()), arr.length);
374 let pointer_ty = Item::Pointer(StorageClass::Function, Box::new(array_ty.clone())).id(self);
375 let array_ty = array_ty.id(self);
376 let values = arr
377 .values
378 .into_iter()
379 .map(|it| self.compile_variable(it))
380 .collect::<Vec<_>>()
381 .into_iter()
382 .map(|it| self.read_as(&it, &item))
383 .collect::<Vec<_>>();
384 let constant = self.constant_composite(array_ty, values);
385 let id = self.variable(pointer_ty, None, StorageClass::Function, Some(constant));
386 self.debug_var_name(id, var);
387 self.state.const_arrays.insert(
388 arr.id as usize,
389 Array {
390 id,
391 item,
392 len: arr.length,
393 var,
394 alignment: None,
395 },
396 );
397 }
398}