1use std::collections::VecDeque;
2
3use cubecl_core::{
4 ir::{self, Id, Type, VariableKind},
5 prelude::{Binding, KernelDefinition, Location, Visibility},
6};
7use cubecl_opt::{ConstArray, NodeIndex};
8use hashbrown::{HashMap, HashSet};
9use rspirv::{
10 dr,
11 spirv::{self, BuiltIn, CooperativeMatrixLayout, CooperativeMatrixUse, StorageClass, Word},
12};
13
14use crate::{
15 MAX_VECTORIZATION, SpirvCompiler, SpirvTarget,
16 item::{Elem, Item},
17 variable::{ConstVal, Variable},
18};
19
20#[derive(Clone, Debug, Default)]
21pub struct LookupTables {
22 pub buffers: Vec<Word>,
23 pub scalar_bindings: HashMap<ir::StorageType, Word>,
24 pub info: Word,
25 pub cube_dims: Vec<Word>,
26 pub cube_size: Word,
27
28 pub const_arrays: Vec<Array>,
29 pub shared_memories: HashMap<Id, SharedMemory>,
30 pub local_arrays: HashMap<Id, Array>,
31 pub matrices: HashMap<Id, Matrix>,
32
33 pub used_builtins: HashMap<BuiltIn, (Word, Item)>,
34
35 pub scalars: HashMap<(Id, ir::StorageType), Word>,
36 pub array_types: HashSet<Word>,
37 pub constants: HashMap<(ConstVal, Item), Word>,
38 pub bindings: HashMap<Id, Word>,
39 pub variables: HashMap<Id, Word>,
40 pub versioned: HashMap<(Id, u16), Word>,
41 pub labels: HashMap<NodeIndex, Word>,
42 pub end_labels: HashMap<NodeIndex, Word>,
43
44 pub slices: HashMap<Id, Slice>,
45
46 pub loops: VecDeque<Loop>,
48
49 pub decorated_types: HashSet<Word>,
51 pub debug_types: HashSet<Word>,
52}
53
54#[derive(Clone, Debug)]
55pub struct Slice {
56 pub ptr: Variable,
57 pub offset: Word,
58 pub end: Word,
59 pub const_len: Option<u32>,
60 pub item: Item,
61}
62
63impl From<&Slice> for Variable {
64 fn from(value: &Slice) -> Self {
65 Variable::Slice {
66 ptr: Box::new(value.ptr.clone()),
67 offset: value.offset,
68 end: value.end,
69 const_len: value.const_len,
70 item: value.item.clone(),
71 }
72 }
73}
74
75#[derive(Clone, Debug)]
76pub struct Array {
77 pub id: Word,
78 pub item: Item,
79 pub len: u32,
80 pub var: ir::Variable,
81 pub alignment: Option<u32>,
82}
83
84#[derive(Clone, Debug)]
85pub struct SharedMemory {
86 pub id: Word,
87 pub item: Item,
88 pub len: u32,
89 pub align: u32,
90 pub offset: u32,
91}
92
93impl SharedMemory {
94 pub fn end(&self) -> u32 {
95 self.offset + self.len * self.item.size()
96 }
97}
98
99#[derive(Debug, Clone, Copy, PartialEq)]
100#[allow(missing_docs)]
101pub struct Matrix {
102 pub id: Word,
103 pub ident: CooperativeMatrixUse,
104 pub m: u32,
105 pub n: u32,
106 pub k: u32,
107 pub elem: Elem,
108 pub layout: Option<CooperativeMatrixLayout>,
109}
110
111#[derive(Clone, Debug)]
112pub struct Loop {
113 pub header: Word,
114 pub continue_target: Word,
115 pub post: Word,
116}
117
118impl<T: SpirvTarget> SpirvCompiler<T> {
119 pub fn init_state(&mut self, kernel: KernelDefinition) {
120 let mut target = self.target.clone();
121
122 self.state.buffers = kernel
123 .buffers
124 .into_iter()
125 .map(|mut binding| {
126 if binding.ty.line_size() > MAX_VECTORIZATION {
129 binding.ty = binding.ty.line(MAX_VECTORIZATION);
130 }
131 let var = ir::Variable::new(VariableKind::GlobalInputArray(binding.id), binding.ty);
132 let name = self.name_of_var(var);
133 target.generate_binding(self, binding, name.into())
134 })
135 .collect();
136
137 let mut offset = self.state.buffers.len() as u32;
138 let info_binding = Binding {
139 id: offset,
140 location: Location::Storage,
141 visibility: Visibility::Read,
142 ty: ir::Type::scalar(ir::ElemType::UInt(ir::UIntKind::U32)),
143 size: None,
144 has_extended_meta: false,
145 };
146 if self.metadata.static_len() > 0 {
147 self.state.info = target.generate_binding(self, info_binding, "info".to_string());
148 offset += 1;
149 }
150
151 self.state.scalar_bindings = kernel
152 .scalars
153 .into_iter()
154 .enumerate()
155 .map(|(i, binding)| {
156 let elem = binding.ty;
157 let binding = Binding {
158 id: i as u32 + offset,
159 location: Location::Storage,
160 visibility: Visibility::Read,
161 ty: ir::Type::new(elem),
162 size: Some(binding.count),
163 has_extended_meta: false,
164 };
165 let name = format!("scalars({elem})");
166 (elem, target.generate_binding(self, binding, name))
167 })
168 .collect();
169
170 let cube_dims = [kernel.cube_dim.x, kernel.cube_dim.y, kernel.cube_dim.z];
171 self.state.cube_dims = cube_dims.iter().map(|dim| self.const_u32(*dim)).collect();
172 self.state.cube_size = self.const_u32(cube_dims.iter().product());
173
174 let shared_liveness = self.shared_liveness.clone();
175 let shared_memories = shared_liveness.allocations.values().map(|alloc| {
176 let smem_id = self.id();
177 let smem = SharedMemory {
178 id: smem_id,
179 item: self.compile_type(alloc.smem.ty),
180 len: alloc.smem.length,
181 align: alloc.smem.align,
182 offset: alloc.offset,
183 };
184 (alloc.smem.id, smem)
185 });
186 self.state.shared_memories = shared_memories.collect();
187 }
188
189 fn dedup_const(&mut self, inst: &dr::Instruction) -> Option<Word> {
190 self.module_ref()
191 .types_global_values
192 .iter()
193 .find(|it| {
194 it.class == inst.class
195 && it.result_type == inst.result_type
196 && it.operands == inst.operands
197 })
198 .and_then(|it| it.result_id)
199 }
200
201 pub fn dedup_constant_bit32(&mut self, ty: Word, val: u32) -> Word {
202 let inst = dr::Instruction::new(
203 spirv::Op::Constant,
204 Some(ty),
205 None,
206 vec![dr::Operand::LiteralBit32(val)],
207 );
208 if let Some(id) = self.dedup_const(&inst) {
209 id
210 } else {
211 self.constant_bit32(ty, val)
212 }
213 }
214
215 pub fn dedup_constant_bit64(&mut self, ty: Word, val: u64) -> Word {
216 let inst = dr::Instruction::new(
217 spirv::Op::Constant,
218 Some(ty),
219 None,
220 vec![dr::Operand::LiteralBit64(val)],
221 );
222 if let Some(id) = self.dedup_const(&inst) {
223 id
224 } else {
225 self.constant_bit64(ty, val)
226 }
227 }
228
229 pub fn const_u32(&mut self, value: u32) -> Word {
230 let ty = Item::Scalar(Elem::Int(32, false));
231 let ty_id = ty.id(self);
232 self.dedup_constant_bit32(ty_id, value)
233 }
234
235 pub fn insert_global(&mut self, insert: impl FnOnce(&mut Self) -> Word) -> Word {
236 let current_block = self.selected_block();
237 let setup = self.setup_block;
238 self.select_block(Some(setup)).unwrap();
239 let id = insert(self);
240 self.select_block(current_block).unwrap();
241 id
242 }
243
244 pub fn get_local(&mut self, id: Id, item: &Item, var: ir::Variable) -> Word {
245 if let Some(existing) = self.state.variables.get(&id) {
246 *existing
247 } else {
248 let ty = Item::Pointer(StorageClass::Function, Box::new(item.clone())).id(self);
249 let word = self.declare_function_variable(ty);
250 self.state.variables.insert(id, word);
251 self.debug_var_name(word, var);
252 word
253 }
254 }
255
256 pub fn get_binding(&mut self, id: Id, var: &ir::Variable) -> Word {
257 if let Some(existing) = self.state.bindings.get(&id) {
258 *existing
259 } else {
260 let word = self.id();
261 self.state.bindings.insert(id, word);
262 self.debug_var_name(word, *var);
263 word
264 }
265 }
266
267 pub fn merge_binding(&mut self, id: Id, word: Word) {
268 self.state.bindings.insert(id, word);
269 }
270
271 pub fn get_versioned(&mut self, id: (Id, u16), var: &ir::Variable) -> Word {
272 if let Some(existing) = self.state.versioned.get(&id) {
273 *existing
274 } else {
275 let word = self.id();
276 self.state.versioned.insert(id, word);
277 let mut debug_var = *var;
278 debug_var.kind = VariableKind::LocalMut { id: id.0 };
279 let name = self.name_of_var(debug_var);
280 self.debug_name(word, format!("{name}.v{}", id.1));
281 word
282 }
283 }
284
285 pub fn label(&mut self, block: NodeIndex) -> Word {
286 if let Some(existing) = self.state.labels.get(&block) {
287 *existing
288 } else {
289 let word = self.id();
290 self.debug_name(word, format!("bb{}", block.index()));
291 self.state.labels.insert(block, word);
292 word
293 }
294 }
295
296 pub fn end_label(&mut self, block: NodeIndex) -> Word {
297 if let Some(existing) = self.state.end_labels.get(&block) {
298 *existing
299 } else {
300 let word = self.label(block);
301 self.state.end_labels.insert(block, word);
302 word
303 }
304 }
305
306 pub fn global_scalar(&mut self, id: Id, ty: ir::StorageType) -> Variable {
307 if let Some(existing) = self.state.scalars.get(&(id, ty)).copied() {
308 let item = self.compile_type(ir::Type::new(ty));
309 Variable::GlobalScalar(existing, item.elem())
310 } else {
311 let ir_var = ir::Variable::new(VariableKind::GlobalScalar(id), Type::new(ty));
312 let current_block = self.selected_block();
313 let setup = self.setup_block;
314 self.select_block(Some(setup)).unwrap();
315 let arr_id = self.state.scalar_bindings[&ty];
316 let item = self.compile_type(ir::Type::new(ty));
317 let arr = Variable::GlobalInputArray(arr_id, item.clone(), 0);
318 let const_id = self.const_u32(id);
319 let index = Variable::ConstantScalar(const_id, id.into(), Elem::Int(32, false));
320 let read_id = self.id();
321 let var = Variable::GlobalScalar(read_id, item.elem());
322 self.debug_var_name(read_id, ir_var);
323 self.read_indexed_unchecked(&var, &arr, &index);
324 self.select_block(current_block).unwrap();
325 self.state.scalars.insert((id, ty), read_id);
326 var
327 }
328 }
329
330 pub fn register_const_array(&mut self, arr: ConstArray) {
331 let var = ir::Variable::new(
332 VariableKind::ConstantArray {
333 id: arr.id,
334 length: arr.length,
335 unroll_factor: 1,
336 },
337 arr.item,
338 );
339 let item = self.compile_type(arr.item);
340 let array_ty = Item::Array(Box::new(item.clone()), arr.length);
341 let pointer_ty = Item::Pointer(StorageClass::Function, Box::new(array_ty.clone())).id(self);
342 let array_ty = array_ty.id(self);
343 let values = arr
344 .values
345 .into_iter()
346 .map(|it| self.compile_variable(it))
347 .collect::<Vec<_>>()
348 .into_iter()
349 .map(|it| self.read_as(&it, &item))
350 .collect::<Vec<_>>();
351 let constant = self.constant_composite(array_ty, values);
352 let id = self.variable(pointer_ty, None, StorageClass::Function, Some(constant));
353 self.debug_var_name(id, var);
354 self.state.const_arrays.insert(
355 arr.id as usize,
356 Array {
357 id,
358 item,
359 len: arr.length,
360 var,
361 alignment: None,
362 },
363 );
364 }
365}