1use cubecl_core::ir::{self as core, FloatKind, IntKind, UIntKind};
2use rspirv::spirv::{Capability, CooperativeMatrixUse, FPEncoding, Scope, StorageClass, Word};
3
4use crate::{compiler::SpirvCompiler, target::SpirvTarget, variable::ConstVal};
5
6#[derive(Debug, Clone, PartialEq, Eq, Hash)]
7pub enum Item {
8 Scalar(Elem),
9 Vector(Elem, u32),
11 Array(Box<Item>, u32),
12 RuntimeArray(Box<Item>),
13 Struct(Vec<Item>),
14 Pointer(StorageClass, Box<Item>),
15 CoopMatrix {
16 ty: Elem,
17 rows: u32,
18 columns: u32,
19 ident: CooperativeMatrixUse,
20 },
21}
22
23impl Item {
24 pub fn id<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>) -> Word {
25 let id = match self {
26 Item::Scalar(elem) => elem.id(b),
27 Item::Vector(elem, vec) => {
28 let elem = elem.id(b);
29 b.type_vector(elem, *vec)
30 }
31 Item::Array(item, len) => {
32 let item = item.id(b);
33 let len = b.const_u32(*len);
34 b.type_array(item, len)
35 }
36 Item::RuntimeArray(item) => {
37 let item = item.id(b);
38 b.type_runtime_array(item)
39 }
40 Item::Struct(vec) => {
41 let items: Vec<_> = vec.iter().map(|item| item.id(b)).collect();
42 let id = b.id(); b.type_struct_id(Some(id), items)
44 }
45 Item::Pointer(storage_class, item) => {
46 let item = item.id(b);
47 b.type_pointer(None, *storage_class, item)
48 }
49 Item::CoopMatrix {
50 ty,
51 rows,
52 columns,
53 ident,
54 } => {
55 let ty = ty.id(b);
56 let scope = b.const_u32(Scope::Subgroup as u32);
57 let usage = b.const_u32(*ident as u32);
58 b.type_cooperative_matrix_khr(ty, scope, *rows, *columns, usage)
59 }
60 };
61 if b.debug_symbols && !b.state.debug_types.contains(&id) {
62 b.debug_name(id, format!("{self}"));
63 b.state.debug_types.insert(id);
64 }
65 id
66 }
67
68 pub fn builtin_u32() -> Self {
69 Item::Scalar(Elem::Int(32, false))
70 }
71
72 pub fn size(&self) -> u32 {
73 match self {
74 Item::Scalar(elem) => elem.size(),
75 Item::Vector(elem, factor) => elem.size() * *factor,
76 Item::Array(item, len) => item.size() * *len,
77 Item::RuntimeArray(item) => item.size(),
78 Item::Struct(vec) => vec.iter().map(|it| it.size()).sum(),
79 Item::Pointer(_, item) => item.size(),
80 Item::CoopMatrix { ty, .. } => ty.size(),
81 }
82 }
83
84 pub fn elem(&self) -> Elem {
85 match self {
86 Item::Scalar(elem) => *elem,
87 Item::Vector(elem, _) => *elem,
88 Item::Array(item, _) => item.elem(),
89 Item::RuntimeArray(item) => item.elem(),
90 Item::Struct(_) => Elem::Void,
91 Item::Pointer(_, item) => item.elem(),
92 Item::CoopMatrix { ty, .. } => *ty,
93 }
94 }
95
96 pub fn same_vectorization(&self, elem: Elem) -> Item {
97 match self {
98 Item::Scalar(_) => Item::Scalar(elem),
99 Item::Vector(_, factor) => Item::Vector(elem, *factor),
100 _ => unreachable!(),
101 }
102 }
103
104 pub fn constant<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>, value: ConstVal) -> Word {
105 let scalar = self.elem().constant(b, value);
106 let ty = self.id(b);
107 match self {
108 Item::Scalar(_) => scalar,
109 Item::Vector(_, vec) => b.constant_composite(ty, (0..*vec).map(|_| scalar)),
110 Item::Array(item, len) => {
111 let elem = item.constant(b, value);
112 b.constant_composite(ty, (0..*len).map(|_| elem))
113 }
114 Item::RuntimeArray(_) => unimplemented!("Can't create constant runtime array"),
115 Item::Struct(elems) => {
116 let items = elems
117 .iter()
118 .map(|item| item.constant(b, value))
119 .collect::<Vec<_>>();
120 b.constant_composite(ty, items)
121 }
122 Item::Pointer(_, _) => unimplemented!("Can't create constant pointer"),
123 Item::CoopMatrix { .. } => unimplemented!("Can't create constant cmma matrix"),
124 }
125 }
126
127 pub fn const_u32<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>, value: u32) -> Word {
128 b.static_cast(ConstVal::Bit32(value), &Elem::Int(32, false), self)
129 .0
130 }
131
132 pub fn broadcast<T: SpirvTarget>(
134 &self,
135 b: &mut SpirvCompiler<T>,
136 obj: Word,
137 out_id: Option<Word>,
138 other: &Item,
139 ) -> Word {
140 match (self, other) {
141 (Item::Scalar(elem), Item::Vector(_, factor)) => {
142 let item = Item::Vector(*elem, *factor);
143 let ty = item.id(b);
144 b.composite_construct(ty, out_id, (0..*factor).map(|_| obj).collect::<Vec<_>>())
145 .unwrap()
146 }
147 _ => obj,
148 }
149 }
150
151 pub fn cast_to<T: SpirvTarget>(
152 &self,
153 b: &mut SpirvCompiler<T>,
154 out_id: Option<Word>,
155 obj: Word,
156 other: &Item,
157 ) -> Word {
158 let ty = other.id(b);
159
160 let matching_vec = match (self, other) {
161 (Item::Scalar(_), Item::Scalar(_)) => true,
162 (Item::Vector(_, factor_from), Item::Vector(_, factor_to)) => factor_from == factor_to,
163 _ => false,
164 };
165 let matching_elem = self.elem() == other.elem();
166
167 let convert_i_width =
168 |b: &mut SpirvCompiler<T>, obj: Word, out_id: Option<Word>, signed: bool| {
169 if signed {
170 b.s_convert(ty, out_id, obj).unwrap()
171 } else {
172 b.u_convert(ty, out_id, obj).unwrap()
173 }
174 };
175
176 let convert_int = |b: &mut SpirvCompiler<T>,
177 obj: Word,
178 out_id: Option<Word>,
179 (width_self, signed_self),
180 (width_other, signed_other)| {
181 let width_differs = width_self != width_other;
182 let sign_extend = signed_self && signed_other;
183 match width_differs {
184 true => convert_i_width(b, obj, out_id, sign_extend),
185 false => b.copy_object(ty, out_id, obj).unwrap(),
186 }
187 };
188
189 let cast_elem = |b: &mut SpirvCompiler<T>, obj: Word, out_id: Option<Word>| -> Word {
190 match (self.elem(), other.elem()) {
191 (Elem::Bool, Elem::Int(_, _)) => {
192 let one = other.const_u32(b, 1);
193 let zero = other.const_u32(b, 0);
194 b.select(ty, out_id, obj, one, zero).unwrap()
195 }
196 (Elem::Bool, Elem::Float(_, _)) | (Elem::Bool, Elem::Relaxed) => {
197 let one = other.const_u32(b, 1);
198 let zero = other.const_u32(b, 0);
199 b.select(ty, out_id, obj, one, zero).unwrap()
200 }
201 (Elem::Int(_, _), Elem::Bool) => {
202 let zero = self.const_u32(b, 0);
203 b.i_not_equal(ty, out_id, obj, zero).unwrap()
204 }
205 (Elem::Int(width_self, signed_self), Elem::Int(width_other, signed_other)) => {
206 convert_int(
207 b,
208 obj,
209 out_id,
210 (width_self, signed_self),
211 (width_other, signed_other),
212 )
213 }
214 (Elem::Int(_, false), Elem::Float(_, _)) | (Elem::Int(_, false), Elem::Relaxed) => {
215 b.convert_u_to_f(ty, out_id, obj).unwrap()
216 }
217 (Elem::Int(_, true), Elem::Float(_, _)) | (Elem::Int(_, true), Elem::Relaxed) => {
218 b.convert_s_to_f(ty, out_id, obj).unwrap()
219 }
220 (Elem::Float(_, _), Elem::Bool) | (Elem::Relaxed, Elem::Bool) => {
221 let zero = self.const_u32(b, 0);
222 b.f_unord_not_equal(ty, out_id, obj, zero).unwrap()
223 }
224 (Elem::Float(_, _), Elem::Int(_, false)) | (Elem::Relaxed, Elem::Int(_, false)) => {
225 b.convert_f_to_u(ty, out_id, obj).unwrap()
226 }
227 (Elem::Float(_, _), Elem::Int(_, true)) | (Elem::Relaxed, Elem::Int(_, true)) => {
228 b.convert_f_to_s(ty, out_id, obj).unwrap()
229 }
230 (Elem::Float(_, _), Elem::Float(_, _))
231 | (Elem::Float(_, _), Elem::Relaxed)
232 | (Elem::Relaxed, Elem::Float(_, _)) => b.f_convert(ty, out_id, obj).unwrap(),
233 (Elem::Bool, Elem::Bool) => b.copy_object(ty, out_id, obj).unwrap(),
234 (Elem::Relaxed, Elem::Relaxed) => b.copy_object(ty, out_id, obj).unwrap(),
235 (from, to) => panic!("Invalid cast from {from:?} to {to:?}"),
236 }
237 };
238
239 match (matching_vec, matching_elem) {
240 (true, true) if out_id.is_some() => b.copy_object(ty, out_id, obj).unwrap(),
241 (true, true) => obj,
242 (true, false) => cast_elem(b, obj, out_id),
243 (false, true) => self.broadcast(b, obj, out_id, other),
244 (false, false) => {
245 let broadcast = self.broadcast(b, obj, None, other);
246 cast_elem(b, broadcast, out_id)
247 }
248 }
249 }
250}
251
252#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
253pub enum Elem {
254 Void,
255 Bool,
256 Int(u32, bool),
257 Float(u32, Option<FPEncoding>),
258 Relaxed,
259}
260
261impl Elem {
262 pub fn id<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>) -> Word {
263 let id = match self {
264 Elem::Void => b.type_void(),
265 Elem::Bool => b.type_bool(),
266 Elem::Int(width, _) => b.type_int(*width, 0),
267 Elem::Float(width, encoding) => b.type_float(*width, *encoding),
268 Elem::Relaxed => b.type_float(32, None),
269 };
270 if b.debug_symbols && !b.state.debug_types.contains(&id) {
271 b.debug_name(id, format!("{self}"));
272 b.state.debug_types.insert(id);
273 }
274 id
275 }
276
277 pub fn size(&self) -> u32 {
278 match self {
279 Elem::Void => 0,
280 Elem::Bool => 1,
281 Elem::Int(size, _) => *size / 8,
282 Elem::Float(size, _) => *size / 8,
283 Elem::Relaxed => 4,
284 }
285 }
286
287 pub fn constant<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>, value: ConstVal) -> Word {
288 let ty = self.id(b);
289 match self {
290 Elem::Void => unreachable!(),
291 Elem::Bool if value.as_u64() != 0 => b.constant_true(ty),
292 Elem::Bool => b.constant_false(ty),
293 _ => match value {
294 ConstVal::Bit32(val) => b.dedup_constant_bit32(ty, val),
295 ConstVal::Bit64(val) => b.dedup_constant_bit64(ty, val),
296 },
297 }
298 }
299
300 pub fn float_encoding(&self) -> Option<FPEncoding> {
301 match self {
302 Elem::Float(_, encoding) => *encoding,
303 _ => None,
304 }
305 }
306}
307
308impl<T: SpirvTarget> SpirvCompiler<T> {
309 pub fn compile_type(&mut self, item: core::Type) -> Item {
310 match item {
311 core::Type::Scalar(storage) => Item::Scalar(self.compile_storage_type(storage)),
312 core::Type::Line(storage, size) => {
313 Item::Vector(self.compile_storage_type(storage), size as u32)
314 }
315 core::Type::Semantic(_) => unimplemented!("Can't compile semantic type"),
316 }
317 }
318
319 pub fn compile_storage_type(&mut self, ty: core::StorageType) -> Elem {
320 match ty {
321 core::StorageType::Scalar(ty) | core::StorageType::Atomic(ty) => self.compile_elem(ty),
322 core::StorageType::Opaque(ty) => match ty {
323 core::OpaqueType::Barrier(_) => {
324 unimplemented!("Barrier type not supported in SPIR-V")
325 }
326 },
327 core::StorageType::Packed(_, _) => {
328 unimplemented!("Packed types not yet supported in SPIR-V")
329 }
330 }
331 }
332
333 pub fn compile_elem(&mut self, elem: core::ElemType) -> Elem {
334 match elem {
335 core::ElemType::Float(
336 core::FloatKind::E2M1
337 | core::FloatKind::E2M3
338 | core::FloatKind::E3M2
339 | core::FloatKind::UE8M0,
340 ) => panic!("Minifloat not supported in SPIR-V"),
341 core::ElemType::Float(core::FloatKind::E4M3) => {
342 self.capabilities.insert(Capability::Float8EXT);
343 Elem::Float(8, Some(FPEncoding::Float8E4M3EXT))
344 }
345 core::ElemType::Float(core::FloatKind::E5M2) => {
346 self.capabilities.insert(Capability::Float8EXT);
347 Elem::Float(8, Some(FPEncoding::Float8E5M2EXT))
348 }
349 core::ElemType::Float(core::FloatKind::BF16) => {
350 self.capabilities.insert(Capability::BFloat16TypeKHR);
351 Elem::Float(16, Some(FPEncoding::BFloat16KHR))
352 }
353 core::ElemType::Float(FloatKind::F16) => {
354 self.capabilities.insert(Capability::Float16);
355 Elem::Float(16, None)
356 }
357 core::ElemType::Float(FloatKind::TF32) => panic!("TF32 not supported in SPIR-V"),
358 core::ElemType::Float(FloatKind::Flex32) => Elem::Relaxed,
359 core::ElemType::Float(FloatKind::F32) => Elem::Float(32, None),
360 core::ElemType::Float(FloatKind::F64) => {
361 self.capabilities.insert(Capability::Float64);
362 Elem::Float(64, None)
363 }
364 core::ElemType::Int(IntKind::I8) => {
365 self.capabilities.insert(Capability::Int8);
366 Elem::Int(8, true)
367 }
368 core::ElemType::Int(IntKind::I16) => {
369 self.capabilities.insert(Capability::Int16);
370 Elem::Int(16, true)
371 }
372 core::ElemType::Int(IntKind::I32) => Elem::Int(32, true),
373 core::ElemType::Int(IntKind::I64) => {
374 self.capabilities.insert(Capability::Int64);
375 Elem::Int(64, true)
376 }
377 core::ElemType::UInt(UIntKind::U64) => {
378 self.capabilities.insert(Capability::Int64);
379 Elem::Int(64, false)
380 }
381 core::ElemType::UInt(UIntKind::U32) => Elem::Int(32, false),
382 core::ElemType::UInt(UIntKind::U16) => {
383 self.capabilities.insert(Capability::Int16);
384 Elem::Int(16, false)
385 }
386 core::ElemType::UInt(UIntKind::U8) => {
387 self.capabilities.insert(Capability::Int8);
388 Elem::Int(8, false)
389 }
390 core::ElemType::Bool => Elem::Bool,
391 }
392 }
393
394 pub fn static_cast(&mut self, val: ConstVal, from: &Elem, item: &Item) -> (Word, ConstVal) {
395 let elem_cast = match (*from, item.elem()) {
396 (Elem::Bool, Elem::Int(width, _)) => ConstVal::from_uint(val.as_u32() as u64, width),
397 (Elem::Bool, Elem::Float(width, encoding)) => {
398 ConstVal::from_float(val.as_u32() as f64, width, encoding)
399 }
400 (Elem::Bool, Elem::Relaxed) => ConstVal::from_float(val.as_u32() as f64, 32, None),
401 (Elem::Int(_, _), Elem::Bool) => ConstVal::from_bool(val.as_u64() != 0),
402 (Elem::Int(_, false), Elem::Int(width, _)) => ConstVal::from_uint(val.as_u64(), width),
403 (Elem::Int(w_in, true), Elem::Int(width, _)) => {
404 ConstVal::from_uint(val.as_int(w_in) as u64, width)
405 }
406 (Elem::Int(_, false), Elem::Float(width, encoding)) => {
407 ConstVal::from_float(val.as_u64() as f64, width, encoding)
408 }
409 (Elem::Int(_, false), Elem::Relaxed) => {
410 ConstVal::from_float(val.as_u64() as f64, 32, None)
411 }
412 (Elem::Int(in_w, true), Elem::Float(width, encoding)) => {
413 ConstVal::from_float(val.as_int(in_w) as f64, width, encoding)
414 }
415 (Elem::Int(in_w, true), Elem::Relaxed) => {
416 ConstVal::from_float(val.as_int(in_w) as f64, 32, None)
417 }
418 (Elem::Float(in_w, encoding), Elem::Bool) => {
419 ConstVal::from_bool(val.as_float(in_w, encoding) != 0.0)
420 }
421 (Elem::Relaxed, Elem::Bool) => ConstVal::from_bool(val.as_float(32, None) != 0.0),
422 (Elem::Float(in_w, encoding), Elem::Int(out_w, false)) => {
423 ConstVal::from_uint(val.as_float(in_w, encoding) as u64, out_w)
424 }
425 (Elem::Relaxed, Elem::Int(out_w, false)) => {
426 ConstVal::from_uint(val.as_float(32, None) as u64, out_w)
427 }
428 (Elem::Float(in_w, encoding), Elem::Int(out_w, true)) => {
429 ConstVal::from_int(val.as_float(in_w, encoding) as i64, out_w)
430 }
431 (Elem::Relaxed, Elem::Int(out_w, true)) => {
432 ConstVal::from_int(val.as_float(32, None) as i64, out_w)
433 }
434 (Elem::Float(in_w, encoding), Elem::Float(out_w, encoding_out)) => {
435 ConstVal::from_float(val.as_float(in_w, encoding), out_w, encoding_out)
436 }
437 (Elem::Relaxed, Elem::Float(out_w, encoding)) => {
438 ConstVal::from_float(val.as_float(32, None), out_w, encoding)
439 }
440 (Elem::Float(in_w, encoding), Elem::Relaxed) => {
441 ConstVal::from_float(val.as_float(in_w, encoding), 32, None)
442 }
443 (Elem::Bool, Elem::Bool) => val,
444 (Elem::Relaxed, Elem::Relaxed) => val,
445 (_, Elem::Void) | (Elem::Void, _) => unreachable!(),
446 };
447 let id = item.constant(self, elem_cast);
448 (id, elem_cast)
449 }
450}
451
452impl std::fmt::Display for Item {
453 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
454 match self {
455 Item::Scalar(elem) => write!(f, "{elem}"),
456 Item::Vector(elem, factor) => write!(f, "vec{factor}<{elem}>"),
457 Item::Array(item, len) => write!(f, "array<{item}, {len}>"),
458 Item::RuntimeArray(item) => write!(f, "array<{item}>"),
459 Item::Struct(members) => {
460 write!(f, "struct<")?;
461 for item in members {
462 write!(f, "{item}")?;
463 }
464 f.write_str(">")
465 }
466 Item::Pointer(class, item) => write!(f, "ptr<{class:?}, {item}>"),
467 Item::CoopMatrix { ty, ident, .. } => write!(f, "matrix<{ty}, {ident:?}>"),
468 }
469 }
470}
471
472impl std::fmt::Display for Elem {
473 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
474 match self {
475 Elem::Void => write!(f, "void"),
476 Elem::Bool => write!(f, "bool"),
477 Elem::Int(width, false) => write!(f, "u{width}"),
478 Elem::Int(width, true) => write!(f, "i{width}"),
479 Elem::Float(width, None) => write!(f, "f{width}"),
480 Elem::Float(_, Some(FPEncoding::BFloat16KHR)) => write!(f, "bf16"),
481 Elem::Float(_, Some(FPEncoding::Float8E4M3EXT)) => write!(f, "e4m3"),
482 Elem::Float(_, Some(FPEncoding::Float8E5M2EXT)) => write!(f, "e5m2"),
483 Elem::Relaxed => write!(f, "flex32"),
484 }
485 }
486}