1use std::collections::VecDeque;
2
3use cubecl_core::{
4 ir::{self, Builtin, Id, Type, VariableKind},
5 prelude::{Binding, KernelDefinition, Location, Visibility},
6};
7use cubecl_opt::{ConstArray, NodeIndex, SharedMemory};
8use hashbrown::{HashMap, HashSet};
9use rspirv::{
10 dr,
11 spirv::{
12 self, BuiltIn, CooperativeMatrixLayout, CooperativeMatrixUse, Scope, StorageClass, Word,
13 },
14};
15
16use crate::{
17 MAX_VECTORIZATION, SpirvCompiler, SpirvTarget,
18 item::{Elem, Item},
19 variable::{ConstVal, Variable},
20};
21
22#[derive(Clone, Debug, Default)]
23pub struct LookupTables {
24 pub buffers: Vec<Word>,
25 pub scalar_bindings: HashMap<ir::StorageType, Word>,
26 pub info: Word,
27 pub cube_dims: Vec<Word>,
28 pub cube_size: Word,
29
30 pub const_arrays: Vec<Array>,
31 pub shared_arrays: HashMap<Id, SharedArray>,
32 pub shared: HashMap<Id, SharedVar>,
33 pub local_arrays: HashMap<Id, Array>,
34 pub matrices: HashMap<Id, Matrix>,
35 pub globals: HashMap<Builtin, Word>,
36 pub loaded_builtins: HashMap<BuiltIn, Word>,
37
38 pub used_builtins: HashMap<BuiltIn, (Word, Item)>,
39
40 pub scalars: HashMap<(Id, ir::StorageType), Word>,
41 pub array_types: HashSet<Word>,
42 pub constants: HashMap<(ConstVal, Item), Word>,
43 pub bindings: HashMap<Id, Word>,
44 pub variables: HashMap<Id, Word>,
45 pub versioned: HashMap<(Id, u16), Word>,
46 pub labels: HashMap<NodeIndex, Word>,
47 pub end_labels: HashMap<NodeIndex, Word>,
48
49 pub atomic_scopes: HashMap<Word, Scope>,
50
51 pub slices: HashMap<Id, Slice>,
52
53 pub loops: VecDeque<Loop>,
55
56 pub decorated_types: HashSet<Word>,
58 pub debug_types: HashSet<Word>,
59}
60
61#[derive(Clone, Debug)]
62pub struct Slice {
63 pub ptr: Variable,
64 pub offset: Word,
65 pub end: Word,
66 pub const_len: Option<u32>,
67 pub item: Item,
68}
69
70impl From<&Slice> for Variable {
71 fn from(value: &Slice) -> Self {
72 Variable::Slice {
73 ptr: Box::new(value.ptr.clone()),
74 offset: value.offset,
75 end: value.end,
76 const_len: value.const_len,
77 item: value.item.clone(),
78 }
79 }
80}
81
82#[derive(Clone, Debug)]
83pub struct Array {
84 pub id: Word,
85 pub item: Item,
86 pub len: u32,
87 pub var: ir::Variable,
88 pub alignment: Option<u32>,
89}
90
91#[derive(Clone, Debug)]
92pub struct SharedArray {
93 pub id: Word,
94 pub item: Item,
95 pub len: u32,
96 pub align: u32,
97 pub offset: u32,
98}
99
100#[derive(Clone, Debug)]
101pub struct SharedVar {
102 pub id: Word,
103 pub item: Item,
104 pub offset: u32,
105 pub align: u32,
106}
107
108impl SharedArray {
109 pub fn end(&self) -> u32 {
110 self.offset + self.len * self.item.size()
111 }
112}
113
114#[derive(Debug, Clone, Copy, PartialEq)]
115#[allow(missing_docs)]
116pub struct Matrix {
117 pub id: Word,
118 pub ident: CooperativeMatrixUse,
119 pub m: u32,
120 pub n: u32,
121 pub k: u32,
122 pub elem: Elem,
123 pub layout: Option<CooperativeMatrixLayout>,
124}
125
126#[derive(Clone, Debug)]
127pub struct Loop {
128 pub header: Word,
129 pub continue_target: Word,
130 pub post: Word,
131}
132
133impl<T: SpirvTarget> SpirvCompiler<T> {
134 pub fn init_state(&mut self, kernel: KernelDefinition) {
135 let mut target = self.target.clone();
136
137 self.state.buffers = kernel
138 .buffers
139 .into_iter()
140 .map(|mut binding| {
141 if binding.ty.line_size() > MAX_VECTORIZATION {
144 binding.ty = binding.ty.line(MAX_VECTORIZATION);
145 }
146 let var = ir::Variable::new(VariableKind::GlobalInputArray(binding.id), binding.ty);
147 let name = self.name_of_var(var);
148 target.generate_binding(self, binding, name.into())
149 })
150 .collect();
151
152 let mut offset = self.state.buffers.len() as u32;
153 let info_binding = Binding {
154 id: offset,
155 location: Location::Storage,
156 visibility: Visibility::Read,
157 ty: self.addr_type.into(),
158 size: None,
159 has_extended_meta: false,
160 };
161 if self.metadata.static_len() > 0 {
162 self.state.info = target.generate_binding(self, info_binding, "info".to_string());
163 offset += 1;
164 }
165
166 self.state.scalar_bindings = kernel
167 .scalars
168 .into_iter()
169 .enumerate()
170 .map(|(i, binding)| {
171 let elem = binding.ty;
172 let binding = Binding {
173 id: i as u32 + offset,
174 location: Location::Storage,
175 visibility: Visibility::Read,
176 ty: ir::Type::new(elem),
177 size: Some(binding.count),
178 has_extended_meta: false,
179 };
180 let name = format!("scalars({elem})");
181 (elem, target.generate_binding(self, binding, name))
182 })
183 .collect();
184
185 let cube_dims = [kernel.cube_dim.x, kernel.cube_dim.y, kernel.cube_dim.z];
186 self.state.cube_dims = cube_dims.iter().map(|dim| self.const_u32(*dim)).collect();
187 self.state.cube_size = self.const_u32(cube_dims.iter().product());
188
189 let shared_liveness = self.shared_liveness.clone();
190 for alloc in shared_liveness.allocations.values() {
191 let smem_id = self.id();
192
193 match alloc.smem {
194 SharedMemory::Array {
195 id,
196 length,
197 ty,
198 align,
199 } => {
200 let item = self.compile_type(ty);
201 self.state.shared_arrays.insert(
202 id,
203 SharedArray {
204 id: smem_id,
205 item,
206 len: length as u32,
207 align: align as u32,
208 offset: alloc.offset as u32,
209 },
210 );
211 }
212 SharedMemory::Value { id, ty, align } => {
213 let item = self.compile_type(ty);
214 self.state.shared.insert(
215 id,
216 SharedVar {
217 id: smem_id,
218 item,
219 offset: alloc.offset as u32,
220 align: align as u32,
221 },
222 );
223 }
224 }
225 }
226 }
227
228 fn dedup_const(&mut self, inst: &dr::Instruction) -> Option<Word> {
229 self.module_ref()
230 .types_global_values
231 .iter()
232 .find(|it| {
233 it.class == inst.class
234 && it.result_type == inst.result_type
235 && it.operands == inst.operands
236 })
237 .and_then(|it| it.result_id)
238 }
239
240 pub fn dedup_constant_bit32(&mut self, ty: Word, val: u32) -> Word {
241 let inst = dr::Instruction::new(
242 spirv::Op::Constant,
243 Some(ty),
244 None,
245 vec![dr::Operand::LiteralBit32(val)],
246 );
247 if let Some(id) = self.dedup_const(&inst) {
248 id
249 } else {
250 self.constant_bit32(ty, val)
251 }
252 }
253
254 pub fn dedup_constant_bit64(&mut self, ty: Word, val: u64) -> Word {
255 let inst = dr::Instruction::new(
256 spirv::Op::Constant,
257 Some(ty),
258 None,
259 vec![dr::Operand::LiteralBit64(val)],
260 );
261 if let Some(id) = self.dedup_const(&inst) {
262 id
263 } else {
264 self.constant_bit64(ty, val)
265 }
266 }
267
268 pub fn const_u32(&mut self, value: u32) -> Word {
269 let ty = Item::Scalar(Elem::Int(32, false));
270 let ty_id = ty.id(self);
271 self.dedup_constant_bit32(ty_id, value)
272 }
273
274 pub fn insert_builtin(
275 &mut self,
276 builtin: BuiltIn,
277 insert: impl FnOnce(&mut Self) -> Word,
278 ) -> Word {
279 if let Some(id) = self.state.loaded_builtins.get(&builtin) {
280 *id
281 } else {
282 let id = self.insert_in_setup(insert);
283 self.state.loaded_builtins.insert(builtin, id);
284 id
285 }
286 }
287
288 pub fn insert_global(
289 &mut self,
290 builtin: Builtin,
291 insert: impl FnOnce(&mut Self) -> Word,
292 ) -> Word {
293 if let Some(id) = self.state.globals.get(&builtin) {
294 *id
295 } else {
296 let id = self.insert_in_setup(insert);
297 self.state.globals.insert(builtin, id);
298 id
299 }
300 }
301
302 pub fn insert_in_setup(&mut self, insert: impl FnOnce(&mut Self) -> Word) -> Word {
303 let current_block = self.selected_block();
304 let setup = self.setup_block;
305 self.select_block(Some(setup)).unwrap();
306 let id = insert(self);
307 self.select_block(current_block).unwrap();
308 id
309 }
310
311 pub fn get_local(&mut self, id: Id, item: &Item, var: ir::Variable) -> Word {
312 if let Some(existing) = self.state.variables.get(&id) {
313 *existing
314 } else {
315 let ty = Item::Pointer(StorageClass::Function, Box::new(item.clone())).id(self);
316 let word = self.declare_function_variable(ty);
317 self.state.variables.insert(id, word);
318 self.debug_var_name(word, var);
319 word
320 }
321 }
322
323 pub fn get_binding(&mut self, id: Id, var: &ir::Variable) -> Word {
324 if let Some(existing) = self.state.bindings.get(&id) {
325 *existing
326 } else {
327 let word = self.id();
328 self.state.bindings.insert(id, word);
329 self.debug_var_name(word, *var);
330 word
331 }
332 }
333
334 pub fn merge_binding(&mut self, id: Id, word: Word) {
335 self.state.bindings.insert(id, word);
336 }
337
338 pub fn get_versioned(&mut self, id: (Id, u16), var: &ir::Variable) -> Word {
339 if let Some(existing) = self.state.versioned.get(&id) {
340 *existing
341 } else {
342 let word = self.id();
343 self.state.versioned.insert(id, word);
344 let mut debug_var = *var;
345 debug_var.kind = VariableKind::LocalMut { id: id.0 };
346 let name = self.name_of_var(debug_var);
347 self.debug_name(word, format!("{name}.v{}", id.1));
348 word
349 }
350 }
351
352 pub fn label(&mut self, block: NodeIndex) -> Word {
353 if let Some(existing) = self.state.labels.get(&block) {
354 *existing
355 } else {
356 let word = self.id();
357 self.debug_name(word, format!("bb{}", block.index()));
358 self.state.labels.insert(block, word);
359 word
360 }
361 }
362
363 pub fn end_label(&mut self, block: NodeIndex) -> Word {
364 if let Some(existing) = self.state.end_labels.get(&block) {
365 *existing
366 } else {
367 let word = self.label(block);
368 self.state.end_labels.insert(block, word);
369 word
370 }
371 }
372
373 pub fn global_scalar(&mut self, id: Id, ty: ir::StorageType) -> Variable {
374 if let Some(existing) = self.state.scalars.get(&(id, ty)).copied() {
375 let item = self.compile_type(ir::Type::new(ty));
376 Variable::GlobalScalar(existing, item.elem())
377 } else {
378 let ir_var = ir::Variable::new(VariableKind::GlobalScalar(id), Type::new(ty));
379 let current_block = self.selected_block();
380 let setup = self.setup_block;
381 self.select_block(Some(setup)).unwrap();
382 let arr_id = self.state.scalar_bindings[&ty];
383 let item = self.compile_type(ir::Type::new(ty));
384 let arr = Variable::GlobalInputArray(arr_id, item.clone(), 0);
385 let const_id = self.const_u32(id);
386 let index = Variable::Constant(const_id, id.into(), Item::Scalar(Elem::Int(32, false)));
387 let read_id = self.id();
388 let var = Variable::GlobalScalar(read_id, item.elem());
389 self.debug_var_name(read_id, ir_var);
390 self.read_indexed_unchecked(&var, &arr, &index);
391 self.select_block(current_block).unwrap();
392 self.state.scalars.insert((id, ty), read_id);
393 var
394 }
395 }
396
397 pub fn register_const_array(&mut self, arr: ConstArray) {
398 let var = ir::Variable::new(
399 VariableKind::ConstantArray {
400 id: arr.id,
401 length: arr.length,
402 unroll_factor: 1,
403 },
404 arr.item,
405 );
406 let item = self.compile_type(arr.item);
407 let array_ty = Item::Array(Box::new(item.clone()), arr.length as u32);
408 let pointer_ty = Item::Pointer(StorageClass::Function, Box::new(array_ty.clone())).id(self);
409 let array_ty = array_ty.id(self);
410 let values = arr
411 .values
412 .into_iter()
413 .map(|it| self.compile_variable(it))
414 .collect::<Vec<_>>()
415 .into_iter()
416 .map(|it| self.read_as(&it, &item))
417 .collect::<Vec<_>>();
418 let constant = self.constant_composite(array_ty, values);
419 let id = self.variable(pointer_ty, None, StorageClass::Function, Some(constant));
420 self.debug_var_name(id, var);
421 self.state.const_arrays.insert(
422 arr.id as usize,
423 Array {
424 id,
425 item,
426 len: arr.length as u32,
427 var,
428 alignment: None,
429 },
430 );
431 }
432}