1use cubecl_core::ir::{self as core, FloatKind, IntKind, UIntKind};
2use rspirv::spirv::{Capability, CooperativeMatrixUse, FPEncoding, Scope, StorageClass, Word};
3use serde::{Deserialize, Serialize};
4
5use crate::{compiler::SpirvCompiler, target::SpirvTarget, variable::ConstVal};
6
7#[derive(Debug, Clone, PartialEq, Eq, Hash)]
8pub enum Item {
9 Scalar(Elem),
10 Vector(Elem, u32),
12 Array(Box<Item>, u32),
13 RuntimeArray(Box<Item>),
14 Struct(Vec<Item>),
15 Pointer(StorageClass, Box<Item>),
16 CoopMatrix {
17 ty: Elem,
18 rows: u32,
19 columns: u32,
20 ident: CooperativeMatrixUse,
21 },
22}
23
24impl Item {
25 pub fn id<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>) -> Word {
26 let id = match self {
27 Item::Scalar(elem) => elem.id(b),
28 Item::Vector(elem, vec) => {
29 let elem = elem.id(b);
30 b.type_vector(elem, *vec)
31 }
32 Item::Array(item, len) => {
33 let item = item.id(b);
34 let len = b.const_u32(*len);
35 b.type_array(item, len)
36 }
37 Item::RuntimeArray(item) => {
38 let item = item.id(b);
39 b.type_runtime_array(item)
40 }
41 Item::Struct(vec) => {
42 let items: Vec<_> = vec.iter().map(|item| item.id(b)).collect();
43 let id = b.id(); b.type_struct_id(Some(id), items)
45 }
46 Item::Pointer(storage_class, item) => {
47 let item = item.id(b);
48 b.type_pointer(None, *storage_class, item)
49 }
50 Item::CoopMatrix {
51 ty,
52 rows,
53 columns,
54 ident,
55 } => {
56 let ty = ty.id(b);
57 let scope = b.const_u32(Scope::Subgroup as u32);
58 let usage = b.const_u32(*ident as u32);
59 b.type_cooperative_matrix_khr(ty, scope, *rows, *columns, usage)
60 }
61 };
62 if b.debug_symbols && !b.state.debug_types.contains(&id) {
63 b.debug_name(id, format!("{self}"));
64 b.state.debug_types.insert(id);
65 }
66 id
67 }
68
69 pub fn builtin_u32() -> Self {
70 Item::Scalar(Elem::Int(32, false))
71 }
72
73 pub fn size(&self) -> u32 {
74 match self {
75 Item::Scalar(elem) => elem.size(),
76 Item::Vector(elem, factor) => elem.size() * *factor,
77 Item::Array(item, len) => item.size() * *len,
78 Item::RuntimeArray(item) => item.size(),
79 Item::Struct(vec) => vec.iter().map(|it| it.size()).sum(),
80 Item::Pointer(_, item) => item.size(),
81 Item::CoopMatrix { ty, .. } => ty.size(),
82 }
83 }
84
85 pub fn elem(&self) -> Elem {
86 match self {
87 Item::Scalar(elem) => *elem,
88 Item::Vector(elem, _) => *elem,
89 Item::Array(item, _) => item.elem(),
90 Item::RuntimeArray(item) => item.elem(),
91 Item::Struct(_) => Elem::Void,
92 Item::Pointer(_, item) => item.elem(),
93 Item::CoopMatrix { ty, .. } => *ty,
94 }
95 }
96
97 pub fn same_vectorization(&self, elem: Elem) -> Item {
98 match self {
99 Item::Scalar(_) => Item::Scalar(elem),
100 Item::Vector(_, factor) => Item::Vector(elem, *factor),
101 _ => unreachable!(),
102 }
103 }
104
105 pub fn constant<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>, value: ConstVal) -> Word {
106 let scalar = self.elem().constant(b, value);
107 let ty = self.id(b);
108 match self {
109 Item::Scalar(_) => scalar,
110 Item::Vector(_, vec) => b.constant_composite(ty, (0..*vec).map(|_| scalar)),
111 Item::Array(item, len) => {
112 let elem = item.constant(b, value);
113 b.constant_composite(ty, (0..*len).map(|_| elem))
114 }
115 Item::RuntimeArray(_) => unimplemented!("Can't create constant runtime array"),
116 Item::Struct(elems) => {
117 let items = elems
118 .iter()
119 .map(|item| item.constant(b, value))
120 .collect::<Vec<_>>();
121 b.constant_composite(ty, items)
122 }
123 Item::Pointer(_, _) => unimplemented!("Can't create constant pointer"),
124 Item::CoopMatrix { .. } => unimplemented!("Can't create constant cmma matrix"),
125 }
126 }
127
128 pub fn const_u32<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>, value: u32) -> Word {
129 b.static_cast(ConstVal::Bit32(value), &Elem::Int(32, false), self)
130 .0
131 }
132
133 pub fn broadcast<T: SpirvTarget>(
135 &self,
136 b: &mut SpirvCompiler<T>,
137 obj: Word,
138 out_id: Option<Word>,
139 other: &Item,
140 ) -> Word {
141 match (self, other) {
142 (Item::Scalar(elem), Item::Vector(_, factor)) => {
143 let item = Item::Vector(*elem, *factor);
144 let ty = item.id(b);
145 b.composite_construct(ty, out_id, (0..*factor).map(|_| obj).collect::<Vec<_>>())
146 .unwrap()
147 }
148 _ => obj,
149 }
150 }
151
152 pub fn cast_to<T: SpirvTarget>(
153 &self,
154 b: &mut SpirvCompiler<T>,
155 out_id: Option<Word>,
156 obj: Word,
157 other: &Item,
158 ) -> Word {
159 let ty = other.id(b);
160
161 let matching_vec = match (self, other) {
162 (Item::Scalar(_), Item::Scalar(_)) => true,
163 (Item::Vector(_, factor_from), Item::Vector(_, factor_to)) => factor_from == factor_to,
164 _ => false,
165 };
166 let matching_elem = self.elem() == other.elem();
167
168 let convert_i_width =
169 |b: &mut SpirvCompiler<T>, obj: Word, out_id: Option<Word>, signed: bool| {
170 if signed {
171 b.s_convert(ty, out_id, obj).unwrap()
172 } else {
173 b.u_convert(ty, out_id, obj).unwrap()
174 }
175 };
176
177 let convert_int = |b: &mut SpirvCompiler<T>,
178 obj: Word,
179 out_id: Option<Word>,
180 (width_self, signed_self),
181 (width_other, signed_other)| {
182 let width_differs = width_self != width_other;
183 let sign_extend = signed_self && signed_other;
184 match width_differs {
185 true => convert_i_width(b, obj, out_id, sign_extend),
186 false => b.copy_object(ty, out_id, obj).unwrap(),
187 }
188 };
189
190 let cast_elem = |b: &mut SpirvCompiler<T>, obj: Word, out_id: Option<Word>| -> Word {
191 match (self.elem(), other.elem()) {
192 (Elem::Bool, Elem::Int(_, _)) => {
193 let one = other.const_u32(b, 1);
194 let zero = other.const_u32(b, 0);
195 b.select(ty, out_id, obj, one, zero).unwrap()
196 }
197 (Elem::Bool, Elem::Float(_, _)) | (Elem::Bool, Elem::Relaxed) => {
198 let one = other.const_u32(b, 1);
199 let zero = other.const_u32(b, 0);
200 b.select(ty, out_id, obj, one, zero).unwrap()
201 }
202 (Elem::Int(_, _), Elem::Bool) => {
203 let zero = self.const_u32(b, 0);
204 b.i_not_equal(ty, out_id, obj, zero).unwrap()
205 }
206 (Elem::Int(width_self, signed_self), Elem::Int(width_other, signed_other)) => {
207 convert_int(
208 b,
209 obj,
210 out_id,
211 (width_self, signed_self),
212 (width_other, signed_other),
213 )
214 }
215 (Elem::Int(_, false), Elem::Float(_, _)) | (Elem::Int(_, false), Elem::Relaxed) => {
216 b.convert_u_to_f(ty, out_id, obj).unwrap()
217 }
218 (Elem::Int(_, true), Elem::Float(_, _)) | (Elem::Int(_, true), Elem::Relaxed) => {
219 b.convert_s_to_f(ty, out_id, obj).unwrap()
220 }
221 (Elem::Float(_, _), Elem::Bool) | (Elem::Relaxed, Elem::Bool) => {
222 let zero = self.const_u32(b, 0);
223 b.f_unord_not_equal(ty, out_id, obj, zero).unwrap()
224 }
225 (Elem::Float(_, _), Elem::Int(_, false)) | (Elem::Relaxed, Elem::Int(_, false)) => {
226 b.convert_f_to_u(ty, out_id, obj).unwrap()
227 }
228 (Elem::Float(_, _), Elem::Int(_, true)) | (Elem::Relaxed, Elem::Int(_, true)) => {
229 b.convert_f_to_s(ty, out_id, obj).unwrap()
230 }
231 (Elem::Float(_, _), Elem::Float(_, _))
232 | (Elem::Float(_, _), Elem::Relaxed)
233 | (Elem::Relaxed, Elem::Float(_, _)) => b.f_convert(ty, out_id, obj).unwrap(),
234 (Elem::Bool, Elem::Bool) => b.copy_object(ty, out_id, obj).unwrap(),
235 (Elem::Relaxed, Elem::Relaxed) => b.copy_object(ty, out_id, obj).unwrap(),
236 (from, to) => panic!("Invalid cast from {from:?} to {to:?}"),
237 }
238 };
239
240 match (matching_vec, matching_elem) {
241 (true, true) if out_id.is_some() => b.copy_object(ty, out_id, obj).unwrap(),
242 (true, true) => obj,
243 (true, false) => cast_elem(b, obj, out_id),
244 (false, true) => self.broadcast(b, obj, out_id, other),
245 (false, false) => {
246 let broadcast = self.broadcast(b, obj, None, other);
247 cast_elem(b, broadcast, out_id)
248 }
249 }
250 }
251}
252
253#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
254pub enum Elem {
255 Void,
256 Bool,
257 Int(u32, bool),
258 Float(u32, Option<FPEncoding>),
259 Relaxed,
260}
261
262impl Elem {
263 pub fn id<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>) -> Word {
264 let id = match self {
265 Elem::Void => b.type_void(),
266 Elem::Bool => b.type_bool(),
267 Elem::Int(width, _) => b.type_int(*width, 0),
268 Elem::Float(width, encoding) => b.type_float(*width, *encoding),
269 Elem::Relaxed => b.type_float(32, None),
270 };
271 if b.debug_symbols && !b.state.debug_types.contains(&id) {
272 b.debug_name(id, format!("{self}"));
273 b.state.debug_types.insert(id);
274 }
275 id
276 }
277
278 pub fn size(&self) -> u32 {
279 match self {
280 Elem::Void => 0,
281 Elem::Bool => 1,
282 Elem::Int(size, _) => *size / 8,
283 Elem::Float(size, _) => *size / 8,
284 Elem::Relaxed => 4,
285 }
286 }
287
288 pub fn constant<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>, value: ConstVal) -> Word {
289 let ty = self.id(b);
290 match self {
291 Elem::Void => unreachable!(),
292 Elem::Bool if value.as_u64() != 0 => b.constant_true(ty),
293 Elem::Bool => b.constant_false(ty),
294 _ => match value {
295 ConstVal::Bit32(val) => b.dedup_constant_bit32(ty, val),
296 ConstVal::Bit64(val) => b.dedup_constant_bit64(ty, val),
297 },
298 }
299 }
300
301 pub fn float_encoding(&self) -> Option<FPEncoding> {
302 match self {
303 Elem::Float(_, encoding) => *encoding,
304 _ => None,
305 }
306 }
307}
308
309impl<T: SpirvTarget> SpirvCompiler<T> {
310 pub fn compile_type(&mut self, item: core::Type) -> Item {
311 match item {
312 core::Type::Scalar(storage) => Item::Scalar(self.compile_storage_type(storage)),
313 core::Type::Line(storage, size) => {
314 Item::Vector(self.compile_storage_type(storage), size as u32)
315 }
316 core::Type::Semantic(_) => unimplemented!("Can't compile semantic type"),
317 }
318 }
319
320 pub fn compile_storage_type(&mut self, ty: core::StorageType) -> Elem {
321 match ty {
322 core::StorageType::Scalar(ty) | core::StorageType::Atomic(ty) => self.compile_elem(ty),
323 core::StorageType::Opaque(ty) => match ty {
324 core::OpaqueType::Barrier(_) => {
325 unimplemented!("Barrier type not supported in SPIR-V")
326 }
327 },
328 core::StorageType::Packed(_, _) => {
329 unimplemented!("Packed types not yet supported in SPIR-V")
330 }
331 }
332 }
333
334 pub fn compile_elem(&mut self, elem: core::ElemType) -> Elem {
335 match elem {
336 core::ElemType::Float(
337 core::FloatKind::E2M1
338 | core::FloatKind::E2M3
339 | core::FloatKind::E3M2
340 | core::FloatKind::UE8M0,
341 ) => panic!("Minifloat not supported in SPIR-V"),
342 core::ElemType::Float(core::FloatKind::E4M3) => {
343 self.capabilities.insert(Capability::Float8EXT);
344 Elem::Float(8, Some(FPEncoding::Float8E4M3EXT))
345 }
346 core::ElemType::Float(core::FloatKind::E5M2) => {
347 self.capabilities.insert(Capability::Float8EXT);
348 Elem::Float(8, Some(FPEncoding::Float8E5M2EXT))
349 }
350 core::ElemType::Float(core::FloatKind::BF16) => {
351 self.capabilities.insert(Capability::BFloat16TypeKHR);
352 Elem::Float(16, Some(FPEncoding::BFloat16KHR))
353 }
354 core::ElemType::Float(FloatKind::F16) => {
355 self.capabilities.insert(Capability::Float16);
356 Elem::Float(16, None)
357 }
358 core::ElemType::Float(FloatKind::TF32) => panic!("TF32 not supported in SPIR-V"),
359 core::ElemType::Float(FloatKind::Flex32) => Elem::Relaxed,
360 core::ElemType::Float(FloatKind::F32) => Elem::Float(32, None),
361 core::ElemType::Float(FloatKind::F64) => {
362 self.capabilities.insert(Capability::Float64);
363 Elem::Float(64, None)
364 }
365 core::ElemType::Int(IntKind::I8) => {
366 self.capabilities.insert(Capability::Int8);
367 Elem::Int(8, true)
368 }
369 core::ElemType::Int(IntKind::I16) => {
370 self.capabilities.insert(Capability::Int16);
371 Elem::Int(16, true)
372 }
373 core::ElemType::Int(IntKind::I32) => Elem::Int(32, true),
374 core::ElemType::Int(IntKind::I64) => {
375 self.capabilities.insert(Capability::Int64);
376 Elem::Int(64, true)
377 }
378 core::ElemType::UInt(UIntKind::U64) => {
379 self.capabilities.insert(Capability::Int64);
380 Elem::Int(64, false)
381 }
382 core::ElemType::UInt(UIntKind::U32) => Elem::Int(32, false),
383 core::ElemType::UInt(UIntKind::U16) => {
384 self.capabilities.insert(Capability::Int16);
385 Elem::Int(16, false)
386 }
387 core::ElemType::UInt(UIntKind::U8) => {
388 self.capabilities.insert(Capability::Int8);
389 Elem::Int(8, false)
390 }
391 core::ElemType::Bool => Elem::Bool,
392 }
393 }
394
395 pub fn static_cast(&mut self, val: ConstVal, from: &Elem, item: &Item) -> (Word, ConstVal) {
396 let elem_cast = match (*from, item.elem()) {
397 (Elem::Bool, Elem::Int(width, _)) => ConstVal::from_uint(val.as_u32() as u64, width),
398 (Elem::Bool, Elem::Float(width, encoding)) => {
399 ConstVal::from_float(val.as_u32() as f64, width, encoding)
400 }
401 (Elem::Bool, Elem::Relaxed) => ConstVal::from_float(val.as_u32() as f64, 32, None),
402 (Elem::Int(_, _), Elem::Bool) => ConstVal::from_bool(val.as_u64() != 0),
403 (Elem::Int(_, false), Elem::Int(width, _)) => ConstVal::from_uint(val.as_u64(), width),
404 (Elem::Int(w_in, true), Elem::Int(width, _)) => {
405 ConstVal::from_uint(val.as_int(w_in) as u64, width)
406 }
407 (Elem::Int(_, false), Elem::Float(width, encoding)) => {
408 ConstVal::from_float(val.as_u64() as f64, width, encoding)
409 }
410 (Elem::Int(_, false), Elem::Relaxed) => {
411 ConstVal::from_float(val.as_u64() as f64, 32, None)
412 }
413 (Elem::Int(in_w, true), Elem::Float(width, encoding)) => {
414 ConstVal::from_float(val.as_int(in_w) as f64, width, encoding)
415 }
416 (Elem::Int(in_w, true), Elem::Relaxed) => {
417 ConstVal::from_float(val.as_int(in_w) as f64, 32, None)
418 }
419 (Elem::Float(in_w, encoding), Elem::Bool) => {
420 ConstVal::from_bool(val.as_float(in_w, encoding) != 0.0)
421 }
422 (Elem::Relaxed, Elem::Bool) => ConstVal::from_bool(val.as_float(32, None) != 0.0),
423 (Elem::Float(in_w, encoding), Elem::Int(out_w, false)) => {
424 ConstVal::from_uint(val.as_float(in_w, encoding) as u64, out_w)
425 }
426 (Elem::Relaxed, Elem::Int(out_w, false)) => {
427 ConstVal::from_uint(val.as_float(32, None) as u64, out_w)
428 }
429 (Elem::Float(in_w, encoding), Elem::Int(out_w, true)) => {
430 ConstVal::from_int(val.as_float(in_w, encoding) as i64, out_w)
431 }
432 (Elem::Relaxed, Elem::Int(out_w, true)) => {
433 ConstVal::from_int(val.as_float(32, None) as i64, out_w)
434 }
435 (Elem::Float(in_w, encoding), Elem::Float(out_w, encoding_out)) => {
436 ConstVal::from_float(val.as_float(in_w, encoding), out_w, encoding_out)
437 }
438 (Elem::Relaxed, Elem::Float(out_w, encoding)) => {
439 ConstVal::from_float(val.as_float(32, None), out_w, encoding)
440 }
441 (Elem::Float(in_w, encoding), Elem::Relaxed) => {
442 ConstVal::from_float(val.as_float(in_w, encoding), 32, None)
443 }
444 (Elem::Bool, Elem::Bool) => val,
445 (Elem::Relaxed, Elem::Relaxed) => val,
446 (_, Elem::Void) | (Elem::Void, _) => unreachable!(),
447 };
448 let id = item.constant(self, elem_cast);
449 (id, elem_cast)
450 }
451}
452
453impl std::fmt::Display for Item {
454 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
455 match self {
456 Item::Scalar(elem) => write!(f, "{elem}"),
457 Item::Vector(elem, factor) => write!(f, "vec{factor}<{elem}>"),
458 Item::Array(item, len) => write!(f, "array<{item}, {len}>"),
459 Item::RuntimeArray(item) => write!(f, "array<{item}>"),
460 Item::Struct(members) => {
461 write!(f, "struct<")?;
462 for item in members {
463 write!(f, "{item}")?;
464 }
465 f.write_str(">")
466 }
467 Item::Pointer(class, item) => write!(f, "ptr<{class:?}, {item}>"),
468 Item::CoopMatrix { ty, ident, .. } => write!(f, "matrix<{ty}, {ident:?}>"),
469 }
470 }
471}
472
473impl std::fmt::Display for Elem {
474 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
475 match self {
476 Elem::Void => write!(f, "void"),
477 Elem::Bool => write!(f, "bool"),
478 Elem::Int(width, false) => write!(f, "u{width}"),
479 Elem::Int(width, true) => write!(f, "i{width}"),
480 Elem::Float(width, None) => write!(f, "f{width}"),
481 Elem::Float(_, Some(FPEncoding::BFloat16KHR)) => write!(f, "bf16"),
482 Elem::Float(_, Some(FPEncoding::Float8E4M3EXT)) => write!(f, "e4m3"),
483 Elem::Float(_, Some(FPEncoding::Float8E5M2EXT)) => write!(f, "e5m2"),
484 Elem::Relaxed => write!(f, "flex32"),
485 }
486 }
487}