1use cubecl_core::ir::{self as core, FloatKind, IntKind, UIntKind};
2use rspirv::spirv::{Capability, CooperativeMatrixUse, FPEncoding, Scope, StorageClass, Word};
3use serde::{Deserialize, Serialize};
4
5use crate::{compiler::SpirvCompiler, target::SpirvTarget, variable::ConstVal};
6
7#[derive(Debug, Clone, PartialEq, Eq, Hash)]
8pub enum Item {
9 Scalar(Elem),
10 Vector(Elem, u32),
12 Array(Box<Item>, u32),
13 RuntimeArray(Box<Item>),
14 Struct(Vec<Item>),
15 Pointer(StorageClass, Box<Item>),
16 CoopMatrix {
17 ty: Elem,
18 rows: u32,
19 columns: u32,
20 ident: CooperativeMatrixUse,
21 },
22}
23
24impl Item {
25 pub fn id<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>) -> Word {
26 let id = match self {
27 Item::Scalar(elem) => elem.id(b),
28 Item::Vector(elem, vec) => {
29 let elem = elem.id(b);
30 b.type_vector(elem, *vec)
31 }
32 Item::Array(item, len) => {
33 let item = item.id(b);
34 let len = b.const_u32(*len);
35 b.type_array(item, len)
36 }
37 Item::RuntimeArray(item) => {
38 let item = item.id(b);
39 b.type_runtime_array(item)
40 }
41 Item::Struct(vec) => {
42 let items: Vec<_> = vec.iter().map(|item| item.id(b)).collect();
43 let id = b.id(); b.type_struct_id(Some(id), items)
45 }
46 Item::Pointer(storage_class, item) => {
47 let item = item.id(b);
48 b.type_pointer(None, *storage_class, item)
49 }
50 Item::CoopMatrix {
51 ty,
52 rows,
53 columns,
54 ident,
55 } => {
56 let ty = ty.id(b);
57 let scope = b.const_u32(Scope::Subgroup as u32);
58 let usage = b.const_u32(*ident as u32);
59 b.type_cooperative_matrix_khr(ty, scope, *rows, *columns, usage)
60 }
61 };
62 if b.debug_symbols && !b.state.debug_types.contains(&id) {
63 b.debug_name(id, format!("{self}"));
64 b.state.debug_types.insert(id);
65 }
66 id
67 }
68
69 pub fn builtin_u32() -> Self {
70 Item::Scalar(Elem::Int(32, false))
71 }
72
73 pub fn size(&self) -> u32 {
74 match self {
75 Item::Scalar(elem) => elem.size(),
76 Item::Vector(elem, factor) => elem.size() * *factor,
77 Item::Array(item, len) => item.size() * *len,
78 Item::RuntimeArray(item) => item.size(),
79 Item::Struct(vec) => vec.iter().map(|it| it.size()).sum(),
80 Item::Pointer(_, item) => item.size(),
81 Item::CoopMatrix { ty, .. } => ty.size(),
82 }
83 }
84
85 pub fn elem(&self) -> Elem {
86 match self {
87 Item::Scalar(elem) => *elem,
88 Item::Vector(elem, _) => *elem,
89 Item::Array(item, _) => item.elem(),
90 Item::RuntimeArray(item) => item.elem(),
91 Item::Struct(_) => Elem::Void,
92 Item::Pointer(_, item) => item.elem(),
93 Item::CoopMatrix { ty, .. } => *ty,
94 }
95 }
96
97 pub fn same_vectorization(&self, elem: Elem) -> Item {
98 match self {
99 Item::Scalar(_) => Item::Scalar(elem),
100 Item::Vector(_, factor) => Item::Vector(elem, *factor),
101 _ => unreachable!(),
102 }
103 }
104
105 pub fn vectorization(&self) -> u32 {
106 match self {
107 Item::Vector(_, factor) => *factor,
108 _ => 1,
109 }
110 }
111
112 pub fn constant<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>, value: ConstVal) -> Word {
113 let scalar = self.elem().constant(b, value);
114 let ty = self.id(b);
115 match self {
116 Item::Scalar(_) => scalar,
117 Item::Vector(_, vec) => b.constant_composite(ty, (0..*vec).map(|_| scalar)),
118 Item::Array(item, len) => {
119 let elem = item.constant(b, value);
120 b.constant_composite(ty, (0..*len).map(|_| elem))
121 }
122 Item::RuntimeArray(_) => unimplemented!("Can't create constant runtime array"),
123 Item::Struct(elems) => {
124 let items = elems
125 .iter()
126 .map(|item| item.constant(b, value))
127 .collect::<Vec<_>>();
128 b.constant_composite(ty, items)
129 }
130 Item::Pointer(_, _) => unimplemented!("Can't create constant pointer"),
131 Item::CoopMatrix { .. } => unimplemented!("Can't create constant cmma matrix"),
132 }
133 }
134
135 pub fn const_u32<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>, value: u32) -> Word {
136 b.static_cast(ConstVal::Bit32(value), &Elem::Int(32, false), self)
137 .0
138 }
139
140 pub fn broadcast<T: SpirvTarget>(
142 &self,
143 b: &mut SpirvCompiler<T>,
144 obj: Word,
145 out_id: Option<Word>,
146 other: &Item,
147 ) -> Word {
148 match (self, other) {
149 (Item::Scalar(elem), Item::Vector(_, factor)) => {
150 let item = Item::Vector(*elem, *factor);
151 let ty = item.id(b);
152 b.composite_construct(ty, out_id, (0..*factor).map(|_| obj).collect::<Vec<_>>())
153 .unwrap()
154 }
155 _ => obj,
156 }
157 }
158
159 pub fn cast_to<T: SpirvTarget>(
160 &self,
161 b: &mut SpirvCompiler<T>,
162 out_id: Option<Word>,
163 obj: Word,
164 other: &Item,
165 ) -> Word {
166 let ty = other.id(b);
167
168 let matching_vec = match (self, other) {
169 (Item::Scalar(_), Item::Scalar(_)) => true,
170 (Item::Vector(_, factor_from), Item::Vector(_, factor_to)) => factor_from == factor_to,
171 _ => false,
172 };
173 let matching_elem = self.elem() == other.elem();
174
175 let convert_i_width =
176 |b: &mut SpirvCompiler<T>, obj: Word, out_id: Option<Word>, signed: bool| {
177 if signed {
178 b.s_convert(ty, out_id, obj).unwrap()
179 } else {
180 b.u_convert(ty, out_id, obj).unwrap()
181 }
182 };
183
184 let convert_int = |b: &mut SpirvCompiler<T>,
185 obj: Word,
186 out_id: Option<Word>,
187 (width_self, signed_self),
188 (width_other, signed_other)| {
189 let width_differs = width_self != width_other;
190 let sign_extend = signed_self && signed_other;
191 match width_differs {
192 true => convert_i_width(b, obj, out_id, sign_extend),
193 false => b.copy_object(ty, out_id, obj).unwrap(),
194 }
195 };
196
197 let cast_elem = |b: &mut SpirvCompiler<T>, obj: Word, out_id: Option<Word>| -> Word {
198 match (self.elem(), other.elem()) {
199 (Elem::Bool, Elem::Int(_, _)) => {
200 let one = other.const_u32(b, 1);
201 let zero = other.const_u32(b, 0);
202 b.select(ty, out_id, obj, one, zero).unwrap()
203 }
204 (Elem::Bool, Elem::Float(_, _)) | (Elem::Bool, Elem::Relaxed) => {
205 let one = other.const_u32(b, 1);
206 let zero = other.const_u32(b, 0);
207 b.select(ty, out_id, obj, one, zero).unwrap()
208 }
209 (Elem::Int(_, _), Elem::Bool) => {
210 let zero = self.const_u32(b, 0);
211 b.i_not_equal(ty, out_id, obj, zero).unwrap()
212 }
213 (Elem::Int(width_self, signed_self), Elem::Int(width_other, signed_other)) => {
214 convert_int(
215 b,
216 obj,
217 out_id,
218 (width_self, signed_self),
219 (width_other, signed_other),
220 )
221 }
222 (Elem::Int(_, false), Elem::Float(_, _)) | (Elem::Int(_, false), Elem::Relaxed) => {
223 b.convert_u_to_f(ty, out_id, obj).unwrap()
224 }
225 (Elem::Int(_, true), Elem::Float(_, _)) | (Elem::Int(_, true), Elem::Relaxed) => {
226 b.convert_s_to_f(ty, out_id, obj).unwrap()
227 }
228 (Elem::Float(_, _), Elem::Bool) | (Elem::Relaxed, Elem::Bool) => {
229 let zero = self.const_u32(b, 0);
230 b.f_unord_not_equal(ty, out_id, obj, zero).unwrap()
231 }
232 (Elem::Float(_, _), Elem::Int(_, false)) | (Elem::Relaxed, Elem::Int(_, false)) => {
233 b.convert_f_to_u(ty, out_id, obj).unwrap()
234 }
235 (Elem::Float(_, _), Elem::Int(_, true)) | (Elem::Relaxed, Elem::Int(_, true)) => {
236 b.convert_f_to_s(ty, out_id, obj).unwrap()
237 }
238 (Elem::Float(_, _), Elem::Float(_, _))
239 | (Elem::Float(_, _), Elem::Relaxed)
240 | (Elem::Relaxed, Elem::Float(_, _)) => b.f_convert(ty, out_id, obj).unwrap(),
241 (Elem::Bool, Elem::Bool) => b.copy_object(ty, out_id, obj).unwrap(),
242 (Elem::Relaxed, Elem::Relaxed) => b.copy_object(ty, out_id, obj).unwrap(),
243 (from, to) => panic!("Invalid cast from {from:?} to {to:?}"),
244 }
245 };
246
247 match (matching_vec, matching_elem) {
248 (true, true) if out_id.is_some() => b.copy_object(ty, out_id, obj).unwrap(),
249 (true, true) => obj,
250 (true, false) => cast_elem(b, obj, out_id),
251 (false, true) => self.broadcast(b, obj, out_id, other),
252 (false, false) => {
253 let broadcast = self.broadcast(b, obj, None, other);
254 cast_elem(b, broadcast, out_id)
255 }
256 }
257 }
258}
259
260#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
261pub enum Elem {
262 Void,
263 Bool,
264 Int(u32, bool),
265 Float(u32, Option<FPEncoding>),
266 Relaxed,
267}
268
269impl Elem {
270 pub fn id<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>) -> Word {
271 let id = match self {
272 Elem::Void => b.type_void(),
273 Elem::Bool => b.type_bool(),
274 Elem::Int(width, _) => b.type_int(*width, 0),
275 Elem::Float(width, encoding) => b.type_float(*width, *encoding),
276 Elem::Relaxed => b.type_float(32, None),
277 };
278 if b.debug_symbols && !b.state.debug_types.contains(&id) {
279 b.debug_name(id, format!("{self}"));
280 b.state.debug_types.insert(id);
281 }
282 id
283 }
284
285 pub fn size(&self) -> u32 {
286 match self {
287 Elem::Void => 0,
288 Elem::Bool => 1,
289 Elem::Int(size, _) => *size / 8,
290 Elem::Float(size, _) => *size / 8,
291 Elem::Relaxed => 4,
292 }
293 }
294
295 pub fn constant<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>, value: ConstVal) -> Word {
296 let ty = self.id(b);
297 match self {
298 Elem::Void => unreachable!(),
299 Elem::Bool if value.as_u64() != 0 => b.constant_true(ty),
300 Elem::Bool => b.constant_false(ty),
301 _ => match value {
302 ConstVal::Bit32(val) => b.dedup_constant_bit32(ty, val),
303 ConstVal::Bit64(val) => b.dedup_constant_bit64(ty, val),
304 },
305 }
306 }
307
308 pub fn float_encoding(&self) -> Option<FPEncoding> {
309 match self {
310 Elem::Float(_, encoding) => *encoding,
311 _ => None,
312 }
313 }
314}
315
316impl<T: SpirvTarget> SpirvCompiler<T> {
317 pub fn compile_type(&mut self, item: core::Type) -> Item {
318 match item {
319 core::Type::Scalar(storage) => Item::Scalar(self.compile_storage_type(storage)),
320 core::Type::Vector(storage, size) => {
321 Item::Vector(self.compile_storage_type(storage), size as u32)
322 }
323 core::Type::Semantic(_) => unimplemented!("Can't compile semantic type"),
324 }
325 }
326
327 pub fn compile_storage_type(&mut self, ty: core::StorageType) -> Elem {
328 match ty {
329 core::StorageType::Scalar(ty) | core::StorageType::Atomic(ty) => self.compile_elem(ty),
330 core::StorageType::Opaque(ty) => match ty {
331 core::OpaqueType::Barrier(_) => {
332 unimplemented!("Barrier type not supported in SPIR-V")
333 }
334 },
335 core::StorageType::Packed(_, _) => {
336 unimplemented!("Packed types not yet supported in SPIR-V")
337 }
338 }
339 }
340
341 pub fn compile_elem(&mut self, elem: core::ElemType) -> Elem {
342 match elem {
343 core::ElemType::Float(
344 core::FloatKind::E2M1
345 | core::FloatKind::E2M3
346 | core::FloatKind::E3M2
347 | core::FloatKind::UE8M0,
348 ) => panic!("Minifloat not supported in SPIR-V"),
349 core::ElemType::Float(core::FloatKind::E4M3) => {
350 self.capabilities.insert(Capability::Float8EXT);
351 Elem::Float(8, Some(FPEncoding::Float8E4M3EXT))
352 }
353 core::ElemType::Float(core::FloatKind::E5M2) => {
354 self.capabilities.insert(Capability::Float8EXT);
355 Elem::Float(8, Some(FPEncoding::Float8E5M2EXT))
356 }
357 core::ElemType::Float(core::FloatKind::BF16) => {
358 self.capabilities.insert(Capability::BFloat16TypeKHR);
359 Elem::Float(16, Some(FPEncoding::BFloat16KHR))
360 }
361 core::ElemType::Float(FloatKind::F16) => {
362 self.capabilities.insert(Capability::Float16);
363 Elem::Float(16, None)
364 }
365 core::ElemType::Float(FloatKind::TF32) => panic!("TF32 not supported in SPIR-V"),
366 core::ElemType::Float(FloatKind::Flex32) => Elem::Relaxed,
367 core::ElemType::Float(FloatKind::F32) => Elem::Float(32, None),
368 core::ElemType::Float(FloatKind::F64) => {
369 self.capabilities.insert(Capability::Float64);
370 Elem::Float(64, None)
371 }
372 core::ElemType::Int(IntKind::I8) => {
373 self.capabilities.insert(Capability::Int8);
374 Elem::Int(8, true)
375 }
376 core::ElemType::Int(IntKind::I16) => {
377 self.capabilities.insert(Capability::Int16);
378 Elem::Int(16, true)
379 }
380 core::ElemType::Int(IntKind::I32) => Elem::Int(32, true),
381 core::ElemType::Int(IntKind::I64) => {
382 self.capabilities.insert(Capability::Int64);
383 Elem::Int(64, true)
384 }
385 core::ElemType::UInt(UIntKind::U64) => {
386 self.capabilities.insert(Capability::Int64);
387 Elem::Int(64, false)
388 }
389 core::ElemType::UInt(UIntKind::U32) => Elem::Int(32, false),
390 core::ElemType::UInt(UIntKind::U16) => {
391 self.capabilities.insert(Capability::Int16);
392 Elem::Int(16, false)
393 }
394 core::ElemType::UInt(UIntKind::U8) => {
395 self.capabilities.insert(Capability::Int8);
396 Elem::Int(8, false)
397 }
398 core::ElemType::Bool => Elem::Bool,
399 }
400 }
401
402 pub fn static_cast(&mut self, val: ConstVal, from: &Elem, item: &Item) -> (Word, ConstVal) {
403 let elem_cast = match (*from, item.elem()) {
404 (Elem::Bool, Elem::Int(width, _)) => ConstVal::from_uint(val.as_u32() as u64, width),
405 (Elem::Bool, Elem::Float(width, encoding)) => {
406 ConstVal::from_float(val.as_u32() as f64, width, encoding)
407 }
408 (Elem::Bool, Elem::Relaxed) => ConstVal::from_float(val.as_u32() as f64, 32, None),
409 (Elem::Int(_, _), Elem::Bool) => ConstVal::from_bool(val.as_u64() != 0),
410 (Elem::Int(_, false), Elem::Int(width, _)) => ConstVal::from_uint(val.as_u64(), width),
411 (Elem::Int(w_in, true), Elem::Int(width, _)) => {
412 ConstVal::from_uint(val.as_int(w_in) as u64, width)
413 }
414 (Elem::Int(_, false), Elem::Float(width, encoding)) => {
415 ConstVal::from_float(val.as_u64() as f64, width, encoding)
416 }
417 (Elem::Int(_, false), Elem::Relaxed) => {
418 ConstVal::from_float(val.as_u64() as f64, 32, None)
419 }
420 (Elem::Int(in_w, true), Elem::Float(width, encoding)) => {
421 ConstVal::from_float(val.as_int(in_w) as f64, width, encoding)
422 }
423 (Elem::Int(in_w, true), Elem::Relaxed) => {
424 ConstVal::from_float(val.as_int(in_w) as f64, 32, None)
425 }
426 (Elem::Float(in_w, encoding), Elem::Bool) => {
427 ConstVal::from_bool(val.as_float(in_w, encoding) != 0.0)
428 }
429 (Elem::Relaxed, Elem::Bool) => ConstVal::from_bool(val.as_float(32, None) != 0.0),
430 (Elem::Float(in_w, encoding), Elem::Int(out_w, false)) => {
431 ConstVal::from_uint(val.as_float(in_w, encoding) as u64, out_w)
432 }
433 (Elem::Relaxed, Elem::Int(out_w, false)) => {
434 ConstVal::from_uint(val.as_float(32, None) as u64, out_w)
435 }
436 (Elem::Float(in_w, encoding), Elem::Int(out_w, true)) => {
437 ConstVal::from_int(val.as_float(in_w, encoding) as i64, out_w)
438 }
439 (Elem::Relaxed, Elem::Int(out_w, true)) => {
440 ConstVal::from_int(val.as_float(32, None) as i64, out_w)
441 }
442 (Elem::Float(in_w, encoding), Elem::Float(out_w, encoding_out)) => {
443 ConstVal::from_float(val.as_float(in_w, encoding), out_w, encoding_out)
444 }
445 (Elem::Relaxed, Elem::Float(out_w, encoding)) => {
446 ConstVal::from_float(val.as_float(32, None), out_w, encoding)
447 }
448 (Elem::Float(in_w, encoding), Elem::Relaxed) => {
449 ConstVal::from_float(val.as_float(in_w, encoding), 32, None)
450 }
451 (Elem::Bool, Elem::Bool) => val,
452 (Elem::Relaxed, Elem::Relaxed) => val,
453 (_, Elem::Void) | (Elem::Void, _) => unreachable!(),
454 };
455 let id = item.constant(self, elem_cast);
456 (id, elem_cast)
457 }
458}
459
460impl std::fmt::Display for Item {
461 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
462 match self {
463 Item::Scalar(elem) => write!(f, "{elem}"),
464 Item::Vector(elem, factor) => write!(f, "vec{factor}<{elem}>"),
465 Item::Array(item, len) => write!(f, "array<{item}, {len}>"),
466 Item::RuntimeArray(item) => write!(f, "array<{item}>"),
467 Item::Struct(members) => {
468 write!(f, "struct<")?;
469 for item in members {
470 write!(f, "{item}")?;
471 }
472 f.write_str(">")
473 }
474 Item::Pointer(class, item) => write!(f, "ptr<{class:?}, {item}>"),
475 Item::CoopMatrix { ty, ident, .. } => write!(f, "matrix<{ty}, {ident:?}>"),
476 }
477 }
478}
479
480impl std::fmt::Display for Elem {
481 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
482 match self {
483 Elem::Void => write!(f, "void"),
484 Elem::Bool => write!(f, "bool"),
485 Elem::Int(width, false) => write!(f, "u{width}"),
486 Elem::Int(width, true) => write!(f, "i{width}"),
487 Elem::Float(width, None) => write!(f, "f{width}"),
488 Elem::Float(_, Some(FPEncoding::BFloat16KHR)) => write!(f, "bf16"),
489 Elem::Float(_, Some(FPEncoding::Float8E4M3EXT)) => write!(f, "e4m3"),
490 Elem::Float(_, Some(FPEncoding::Float8E5M2EXT)) => write!(f, "e5m2"),
491 Elem::Relaxed => write!(f, "flex32"),
492 }
493 }
494}