1use std::collections::VecDeque;
2
3use cubecl_core::{
4 compute::{Binding, Location, Visibility},
5 ir::{self, Id, Type, VariableKind},
6 prelude::KernelDefinition,
7};
8use cubecl_opt::{ConstArray, NodeIndex};
9use hashbrown::{HashMap, HashSet};
10use rspirv::{
11 dr,
12 spirv::{self, BuiltIn, CooperativeMatrixLayout, CooperativeMatrixUse, StorageClass, Word},
13};
14
15use crate::{
16 MAX_VECTORIZATION, SpirvCompiler, SpirvTarget,
17 item::{Elem, Item},
18 variable::{ConstVal, Variable},
19};
20
21#[derive(Clone, Debug, Default)]
22pub struct LookupTables {
23 pub buffers: Vec<Word>,
24 pub scalar_bindings: HashMap<ir::StorageType, Word>,
25 pub info: Word,
26 pub cube_dims: Vec<Word>,
27 pub cube_size: Word,
28
29 pub const_arrays: Vec<Array>,
30 pub shared_memories: HashMap<Id, SharedMemory>,
31 pub local_arrays: HashMap<Id, Array>,
32 pub matrices: HashMap<Id, Matrix>,
33
34 pub used_builtins: HashMap<BuiltIn, (Word, Item)>,
35
36 pub scalars: HashMap<(Id, ir::StorageType), Word>,
37 pub array_types: HashSet<Word>,
38 pub constants: HashMap<(ConstVal, Item), Word>,
39 pub bindings: HashMap<Id, Word>,
40 pub variables: HashMap<Id, Word>,
41 pub versioned: HashMap<(Id, u16), Word>,
42 pub labels: HashMap<NodeIndex, Word>,
43 pub end_labels: HashMap<NodeIndex, Word>,
44
45 pub slices: HashMap<Id, Slice>,
46
47 pub loops: VecDeque<Loop>,
49
50 pub decorated_types: HashSet<Word>,
52 pub debug_types: HashSet<Word>,
53}
54
55#[derive(Clone, Debug)]
56pub struct Slice {
57 pub ptr: Variable,
58 pub offset: Word,
59 pub end: Word,
60 pub const_len: Option<u32>,
61 pub item: Item,
62}
63
64impl From<&Slice> for Variable {
65 fn from(value: &Slice) -> Self {
66 Variable::Slice {
67 ptr: Box::new(value.ptr.clone()),
68 offset: value.offset,
69 end: value.end,
70 const_len: value.const_len,
71 item: value.item.clone(),
72 }
73 }
74}
75
76#[derive(Clone, Debug)]
77pub struct Array {
78 pub id: Word,
79 pub item: Item,
80 pub len: u32,
81 pub var: ir::Variable,
82 pub alignment: Option<u32>,
83}
84
85#[derive(Clone, Debug)]
86pub struct SharedMemory {
87 pub id: Word,
88 pub item: Item,
89 pub len: u32,
90 pub align: u32,
91 pub offset: u32,
92}
93
94impl SharedMemory {
95 pub fn end(&self) -> u32 {
96 self.offset + self.len * self.item.size()
97 }
98}
99
100#[derive(Debug, Clone, Copy, PartialEq)]
101#[allow(missing_docs)]
102pub struct Matrix {
103 pub id: Word,
104 pub ident: CooperativeMatrixUse,
105 pub m: u32,
106 pub n: u32,
107 pub k: u32,
108 pub elem: Elem,
109 pub layout: Option<CooperativeMatrixLayout>,
110}
111
112#[derive(Clone, Debug)]
113pub struct Loop {
114 pub header: Word,
115 pub continue_target: Word,
116 pub post: Word,
117}
118
119impl<T: SpirvTarget> SpirvCompiler<T> {
120 pub fn init_state(&mut self, kernel: KernelDefinition) {
121 let mut target = self.target.clone();
122
123 self.state.buffers = kernel
124 .buffers
125 .into_iter()
126 .map(|mut binding| {
127 if binding.ty.line_size() > MAX_VECTORIZATION {
130 binding.ty = binding.ty.line(MAX_VECTORIZATION);
131 }
132 let var = ir::Variable::new(VariableKind::GlobalInputArray(binding.id), binding.ty);
133 let name = self.name_of_var(var);
134 target.generate_binding(self, binding, name.into())
135 })
136 .collect();
137
138 let mut offset = self.state.buffers.len() as u32;
139 let info_binding = Binding {
140 id: offset,
141 location: Location::Storage,
142 visibility: Visibility::Read,
143 ty: ir::Type::scalar(ir::ElemType::UInt(ir::UIntKind::U32)),
144 size: None,
145 has_extended_meta: false,
146 };
147 if self.metadata.static_len() > 0 {
148 self.state.info = target.generate_binding(self, info_binding, "info".to_string());
149 offset += 1;
150 }
151
152 self.state.scalar_bindings = kernel
153 .scalars
154 .into_iter()
155 .enumerate()
156 .map(|(i, binding)| {
157 let elem = binding.ty;
158 let binding = Binding {
159 id: i as u32 + offset,
160 location: Location::Storage,
161 visibility: Visibility::Read,
162 ty: ir::Type::new(elem),
163 size: Some(binding.count),
164 has_extended_meta: false,
165 };
166 let name = format!("scalars({elem})");
167 (elem, target.generate_binding(self, binding, name))
168 })
169 .collect();
170
171 let cube_dims = [kernel.cube_dim.x, kernel.cube_dim.y, kernel.cube_dim.z];
172 self.state.cube_dims = cube_dims.iter().map(|dim| self.const_u32(*dim)).collect();
173 self.state.cube_size = self.const_u32(cube_dims.iter().product());
174
175 let shared_liveness = self.shared_liveness.clone();
176 let shared_memories = shared_liveness.allocations.values().map(|alloc| {
177 let smem_id = self.id();
178 let smem = SharedMemory {
179 id: smem_id,
180 item: self.compile_type(alloc.smem.ty),
181 len: alloc.smem.length,
182 align: alloc.smem.align,
183 offset: alloc.offset,
184 };
185 (alloc.smem.id, smem)
186 });
187 self.state.shared_memories = shared_memories.collect();
188 }
189
190 fn dedup_const(&mut self, inst: &dr::Instruction) -> Option<Word> {
191 self.module_ref()
192 .types_global_values
193 .iter()
194 .find(|it| {
195 it.class == inst.class
196 && it.result_type == inst.result_type
197 && it.operands == inst.operands
198 })
199 .and_then(|it| it.result_id)
200 }
201
202 pub fn dedup_constant_bit32(&mut self, ty: Word, val: u32) -> Word {
203 let inst = dr::Instruction::new(
204 spirv::Op::Constant,
205 Some(ty),
206 None,
207 vec![dr::Operand::LiteralBit32(val)],
208 );
209 if let Some(id) = self.dedup_const(&inst) {
210 id
211 } else {
212 self.constant_bit32(ty, val)
213 }
214 }
215
216 pub fn dedup_constant_bit64(&mut self, ty: Word, val: u64) -> Word {
217 let inst = dr::Instruction::new(
218 spirv::Op::Constant,
219 Some(ty),
220 None,
221 vec![dr::Operand::LiteralBit64(val)],
222 );
223 if let Some(id) = self.dedup_const(&inst) {
224 id
225 } else {
226 self.constant_bit64(ty, val)
227 }
228 }
229
230 pub fn const_u32(&mut self, value: u32) -> Word {
231 let ty = Item::Scalar(Elem::Int(32, false));
232 let ty_id = ty.id(self);
233 self.dedup_constant_bit32(ty_id, value)
234 }
235
236 pub fn insert_global(&mut self, insert: impl FnOnce(&mut Self) -> Word) -> Word {
237 let current_block = self.selected_block();
238 let setup = self.setup_block;
239 self.select_block(Some(setup)).unwrap();
240 let id = insert(self);
241 self.select_block(current_block).unwrap();
242 id
243 }
244
245 pub fn get_local(&mut self, id: Id, item: &Item, var: ir::Variable) -> Word {
246 if let Some(existing) = self.state.variables.get(&id) {
247 *existing
248 } else {
249 let ty = Item::Pointer(StorageClass::Function, Box::new(item.clone())).id(self);
250 let word = self.declare_function_variable(ty);
251 self.state.variables.insert(id, word);
252 self.debug_var_name(word, var);
253 word
254 }
255 }
256
257 pub fn get_binding(&mut self, id: Id, var: &ir::Variable) -> Word {
258 if let Some(existing) = self.state.bindings.get(&id) {
259 *existing
260 } else {
261 let word = self.id();
262 self.state.bindings.insert(id, word);
263 self.debug_var_name(word, *var);
264 word
265 }
266 }
267
268 pub fn merge_binding(&mut self, id: Id, word: Word) {
269 self.state.bindings.insert(id, word);
270 }
271
272 pub fn get_versioned(&mut self, id: (Id, u16), var: &ir::Variable) -> Word {
273 if let Some(existing) = self.state.versioned.get(&id) {
274 *existing
275 } else {
276 let word = self.id();
277 self.state.versioned.insert(id, word);
278 let mut debug_var = *var;
279 debug_var.kind = VariableKind::LocalMut { id: id.0 };
280 let name = self.name_of_var(debug_var);
281 self.debug_name(word, format!("{name}.v{}", id.1));
282 word
283 }
284 }
285
286 pub fn label(&mut self, block: NodeIndex) -> Word {
287 if let Some(existing) = self.state.labels.get(&block) {
288 *existing
289 } else {
290 let word = self.id();
291 self.debug_name(word, format!("bb{}", block.index()));
292 self.state.labels.insert(block, word);
293 word
294 }
295 }
296
297 pub fn end_label(&mut self, block: NodeIndex) -> Word {
298 if let Some(existing) = self.state.end_labels.get(&block) {
299 *existing
300 } else {
301 let word = self.label(block);
302 self.state.end_labels.insert(block, word);
303 word
304 }
305 }
306
307 pub fn global_scalar(&mut self, id: Id, ty: ir::StorageType) -> Variable {
308 if let Some(existing) = self.state.scalars.get(&(id, ty)).copied() {
309 let item = self.compile_type(ir::Type::new(ty));
310 Variable::GlobalScalar(existing, item.elem())
311 } else {
312 let ir_var = ir::Variable::new(VariableKind::GlobalScalar(id), Type::new(ty));
313 let current_block = self.selected_block();
314 let setup = self.setup_block;
315 self.select_block(Some(setup)).unwrap();
316 let arr_id = self.state.scalar_bindings[&ty];
317 let item = self.compile_type(ir::Type::new(ty));
318 let arr = Variable::GlobalInputArray(arr_id, item.clone(), 0);
319 let const_id = self.const_u32(id);
320 let index = Variable::ConstantScalar(const_id, id.into(), Elem::Int(32, false));
321 let read_id = self.id();
322 let var = Variable::GlobalScalar(read_id, item.elem());
323 self.debug_var_name(read_id, ir_var);
324 self.read_indexed_unchecked(&var, &arr, &index);
325 self.select_block(current_block).unwrap();
326 self.state.scalars.insert((id, ty), read_id);
327 var
328 }
329 }
330
331 pub fn register_const_array(&mut self, arr: ConstArray) {
332 let var = ir::Variable::new(
333 VariableKind::ConstantArray {
334 id: arr.id,
335 length: arr.length,
336 unroll_factor: 1,
337 },
338 arr.item,
339 );
340 let item = self.compile_type(arr.item);
341 let array_ty = Item::Array(Box::new(item.clone()), arr.length);
342 let pointer_ty = Item::Pointer(StorageClass::Function, Box::new(array_ty.clone())).id(self);
343 let array_ty = array_ty.id(self);
344 let values = arr
345 .values
346 .into_iter()
347 .map(|it| self.compile_variable(it))
348 .collect::<Vec<_>>()
349 .into_iter()
350 .map(|it| self.read_as(&it, &item))
351 .collect::<Vec<_>>();
352 let constant = self.constant_composite(array_ty, values);
353 let id = self.variable(pointer_ty, None, StorageClass::Function, Some(constant));
354 self.debug_var_name(id, var);
355 self.state.const_arrays.insert(
356 arr.id as usize,
357 Array {
358 id,
359 item,
360 len: arr.length,
361 var,
362 alignment: None,
363 },
364 );
365 }
366}