1use cubecl_core::ir::{self as core, FloatKind, IntKind, UIntKind};
2use rspirv::spirv::{Capability, CooperativeMatrixUse, FPEncoding, Scope, StorageClass, Word};
3
4use crate::{compiler::SpirvCompiler, target::SpirvTarget, variable::ConstVal};
5
6#[derive(Debug, Clone, PartialEq, Eq, Hash)]
7pub enum Item {
8 Scalar(Elem),
9 Vector(Elem, u32),
11 Array(Box<Item>, u32),
12 RuntimeArray(Box<Item>),
13 Struct(Vec<Item>),
14 Pointer(StorageClass, Box<Item>),
15 CoopMatrix {
16 ty: Elem,
17 rows: u32,
18 columns: u32,
19 ident: CooperativeMatrixUse,
20 },
21}
22
23impl Item {
24 pub fn id<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>) -> Word {
25 let id = match self {
26 Item::Scalar(elem) => elem.id(b),
27 Item::Vector(elem, vec) => {
28 let elem = elem.id(b);
29 b.type_vector(elem, *vec)
30 }
31 Item::Array(item, len) => {
32 let item = item.id(b);
33 let len = b.const_u32(*len);
34 b.type_array(item, len)
35 }
36 Item::RuntimeArray(item) => {
37 let item = item.id(b);
38 b.type_runtime_array(item)
39 }
40 Item::Struct(vec) => {
41 let items: Vec<_> = vec.iter().map(|item| item.id(b)).collect();
42 let id = b.id(); b.type_struct_id(Some(id), items)
44 }
45 Item::Pointer(storage_class, item) => {
46 let item = item.id(b);
47 b.type_pointer(None, *storage_class, item)
48 }
49 Item::CoopMatrix {
50 ty,
51 rows,
52 columns,
53 ident,
54 } => {
55 let ty = ty.id(b);
56 let scope = b.const_u32(Scope::Subgroup as u32);
57 let usage = b.const_u32(*ident as u32);
58 b.type_cooperative_matrix_khr(ty, scope, *rows, *columns, usage)
59 }
60 };
61 if b.debug_symbols && !b.state.debug_types.contains(&id) {
62 b.debug_name(id, format!("{self}"));
63 b.state.debug_types.insert(id);
64 }
65 id
66 }
67
68 pub fn size(&self) -> u32 {
69 match self {
70 Item::Scalar(elem) => elem.size(),
71 Item::Vector(elem, factor) => elem.size() * *factor,
72 Item::Array(item, len) => item.size() * *len,
73 Item::RuntimeArray(item) => item.size(),
74 Item::Struct(vec) => vec.iter().map(|it| it.size()).sum(),
75 Item::Pointer(_, item) => item.size(),
76 Item::CoopMatrix { ty, .. } => ty.size(),
77 }
78 }
79
80 pub fn elem(&self) -> Elem {
81 match self {
82 Item::Scalar(elem) => *elem,
83 Item::Vector(elem, _) => *elem,
84 Item::Array(item, _) => item.elem(),
85 Item::RuntimeArray(item) => item.elem(),
86 Item::Struct(_) => Elem::Void,
87 Item::Pointer(_, item) => item.elem(),
88 Item::CoopMatrix { ty, .. } => *ty,
89 }
90 }
91
92 pub fn same_vectorization(&self, elem: Elem) -> Item {
93 match self {
94 Item::Scalar(_) => Item::Scalar(elem),
95 Item::Vector(_, factor) => Item::Vector(elem, *factor),
96 _ => unreachable!(),
97 }
98 }
99
100 pub fn constant<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>, value: ConstVal) -> Word {
101 let scalar = self.elem().constant(b, value);
102 let ty = self.id(b);
103 match self {
104 Item::Scalar(_) => scalar,
105 Item::Vector(_, vec) => b.constant_composite(ty, (0..*vec).map(|_| scalar)),
106 Item::Array(item, len) => {
107 let elem = item.constant(b, value);
108 b.constant_composite(ty, (0..*len).map(|_| elem))
109 }
110 Item::RuntimeArray(_) => unimplemented!("Can't create constant runtime array"),
111 Item::Struct(elems) => {
112 let items = elems
113 .iter()
114 .map(|item| item.constant(b, value))
115 .collect::<Vec<_>>();
116 b.constant_composite(ty, items)
117 }
118 Item::Pointer(_, _) => unimplemented!("Can't create constant pointer"),
119 Item::CoopMatrix { .. } => unimplemented!("Can't create constant cmma matrix"),
120 }
121 }
122
123 pub fn const_u32<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>, value: u32) -> Word {
124 b.static_cast(ConstVal::Bit32(value), &Elem::Int(32, false), self)
125 }
126
127 pub fn broadcast<T: SpirvTarget>(
129 &self,
130 b: &mut SpirvCompiler<T>,
131 obj: Word,
132 out_id: Option<Word>,
133 other: &Item,
134 ) -> Word {
135 match (self, other) {
136 (Item::Scalar(elem), Item::Vector(_, factor)) => {
137 let item = Item::Vector(*elem, *factor);
138 let ty = item.id(b);
139 b.composite_construct(ty, out_id, (0..*factor).map(|_| obj).collect::<Vec<_>>())
140 .unwrap()
141 }
142 _ => obj,
143 }
144 }
145
146 pub fn cast_to<T: SpirvTarget>(
147 &self,
148 b: &mut SpirvCompiler<T>,
149 out_id: Option<Word>,
150 obj: Word,
151 other: &Item,
152 ) -> Word {
153 let ty = other.id(b);
154
155 let matching_vec = match (self, other) {
156 (Item::Scalar(_), Item::Scalar(_)) => true,
157 (Item::Vector(_, factor_from), Item::Vector(_, factor_to)) => factor_from == factor_to,
158 _ => false,
159 };
160 let matching_elem = self.elem() == other.elem();
161
162 let convert_i_width =
163 |b: &mut SpirvCompiler<T>, obj: Word, out_id: Option<Word>, signed: bool| {
164 if signed {
165 b.s_convert(ty, out_id, obj).unwrap()
166 } else {
167 b.u_convert(ty, out_id, obj).unwrap()
168 }
169 };
170
171 let convert_int = |b: &mut SpirvCompiler<T>,
172 obj: Word,
173 out_id: Option<Word>,
174 (width_self, signed_self),
175 (width_other, signed_other)| {
176 let width_differs = width_self != width_other;
177 let sign_extend = signed_self && signed_other;
178 match width_differs {
179 true => convert_i_width(b, obj, out_id, sign_extend),
180 false => b.copy_object(ty, out_id, obj).unwrap(),
181 }
182 };
183
184 let cast_elem = |b: &mut SpirvCompiler<T>, obj: Word, out_id: Option<Word>| -> Word {
185 match (self.elem(), other.elem()) {
186 (Elem::Bool, Elem::Int(_, _)) => {
187 let one = other.const_u32(b, 1);
188 let zero = other.const_u32(b, 0);
189 b.select(ty, out_id, obj, one, zero).unwrap()
190 }
191 (Elem::Bool, Elem::Float(_, _)) | (Elem::Bool, Elem::Relaxed) => {
192 let one = other.const_u32(b, 1);
193 let zero = other.const_u32(b, 0);
194 b.select(ty, out_id, obj, one, zero).unwrap()
195 }
196 (Elem::Int(_, _), Elem::Bool) => {
197 let zero = self.const_u32(b, 0);
198 b.i_not_equal(ty, out_id, obj, zero).unwrap()
199 }
200 (Elem::Int(width_self, signed_self), Elem::Int(width_other, signed_other)) => {
201 convert_int(
202 b,
203 obj,
204 out_id,
205 (width_self, signed_self),
206 (width_other, signed_other),
207 )
208 }
209 (Elem::Int(_, false), Elem::Float(_, _)) | (Elem::Int(_, false), Elem::Relaxed) => {
210 b.convert_u_to_f(ty, out_id, obj).unwrap()
211 }
212 (Elem::Int(_, true), Elem::Float(_, _)) | (Elem::Int(_, true), Elem::Relaxed) => {
213 b.convert_s_to_f(ty, out_id, obj).unwrap()
214 }
215 (Elem::Float(_, _), Elem::Bool) | (Elem::Relaxed, Elem::Bool) => {
216 let zero = self.const_u32(b, 0);
217 b.f_unord_not_equal(ty, out_id, obj, zero).unwrap()
218 }
219 (Elem::Float(_, _), Elem::Int(_, false)) | (Elem::Relaxed, Elem::Int(_, false)) => {
220 b.convert_f_to_u(ty, out_id, obj).unwrap()
221 }
222 (Elem::Float(_, _), Elem::Int(_, true)) | (Elem::Relaxed, Elem::Int(_, true)) => {
223 b.convert_f_to_s(ty, out_id, obj).unwrap()
224 }
225 (Elem::Float(_, _), Elem::Float(_, _))
226 | (Elem::Float(_, _), Elem::Relaxed)
227 | (Elem::Relaxed, Elem::Float(_, _)) => b.f_convert(ty, out_id, obj).unwrap(),
228 (Elem::Bool, Elem::Bool) => b.copy_object(ty, out_id, obj).unwrap(),
229 (Elem::Relaxed, Elem::Relaxed) => b.copy_object(ty, out_id, obj).unwrap(),
230 (from, to) => panic!("Invalid cast from {from:?} to {to:?}"),
231 }
232 };
233
234 match (matching_vec, matching_elem) {
235 (true, true) if out_id.is_some() => b.copy_object(ty, out_id, obj).unwrap(),
236 (true, true) => obj,
237 (true, false) => cast_elem(b, obj, out_id),
238 (false, true) => self.broadcast(b, obj, out_id, other),
239 (false, false) => {
240 let broadcast = self.broadcast(b, obj, None, other);
241 cast_elem(b, broadcast, out_id)
242 }
243 }
244 }
245}
246
247#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
248pub enum Elem {
249 Void,
250 Bool,
251 Int(u32, bool),
252 Float(u32, Option<FPEncoding>),
253 Relaxed,
254}
255
256impl Elem {
257 pub fn id<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>) -> Word {
258 let id = match self {
259 Elem::Void => b.type_void(),
260 Elem::Bool => b.type_bool(),
261 Elem::Int(width, _) => b.type_int(*width, 0),
262 Elem::Float(width, encoding) => b.type_float(*width, *encoding),
263 Elem::Relaxed => b.type_float(32, None),
264 };
265 if b.debug_symbols && !b.state.debug_types.contains(&id) {
266 b.debug_name(id, format!("{self}"));
267 b.state.debug_types.insert(id);
268 }
269 id
270 }
271
272 pub fn size(&self) -> u32 {
273 match self {
274 Elem::Void => 0,
275 Elem::Bool => 1,
276 Elem::Int(size, _) => *size / 8,
277 Elem::Float(size, _) => *size / 8,
278 Elem::Relaxed => 4,
279 }
280 }
281
282 pub fn constant<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>, value: ConstVal) -> Word {
283 let ty = self.id(b);
284 match self {
285 Elem::Void => unreachable!(),
286 Elem::Bool if value.as_u64() != 0 => b.constant_true(ty),
287 Elem::Bool => b.constant_false(ty),
288 _ => match value {
289 ConstVal::Bit32(val) => b.dedup_constant_bit32(ty, val),
290 ConstVal::Bit64(val) => b.dedup_constant_bit64(ty, val),
291 },
292 }
293 }
294}
295
296impl<T: SpirvTarget> SpirvCompiler<T> {
297 pub fn compile_type(&mut self, item: core::Type) -> Item {
298 match item {
299 core::Type::Scalar(storage) => Item::Scalar(self.compile_storage_type(storage)),
300 core::Type::Line(storage, size) => {
301 Item::Vector(self.compile_storage_type(storage), size)
302 }
303 core::Type::Semantic(_) => unimplemented!("Can't compile semantic type"),
304 }
305 }
306
307 pub fn compile_storage_type(&mut self, ty: core::StorageType) -> Elem {
308 match ty {
309 core::StorageType::Scalar(ty) | core::StorageType::Atomic(ty) => self.compile_elem(ty),
310 core::StorageType::Packed(_, _) => {
311 unimplemented!("Packed types not yet supported in SPIR-V")
312 }
313 }
314 }
315
316 pub fn compile_elem(&mut self, elem: core::ElemType) -> Elem {
317 match elem {
318 core::ElemType::Float(
319 core::FloatKind::E2M1
320 | core::FloatKind::E2M3
321 | core::FloatKind::E3M2
322 | core::FloatKind::UE8M0,
323 ) => panic!("Minifloat not supported in SPIR-V"),
324 core::ElemType::Float(core::FloatKind::E4M3) => {
325 self.capabilities.insert(Capability::Float8EXT);
326 Elem::Float(8, Some(FPEncoding::Float8E4M3EXT))
327 }
328 core::ElemType::Float(core::FloatKind::E5M2) => {
329 self.capabilities.insert(Capability::Float8EXT);
330 Elem::Float(8, Some(FPEncoding::Float8E5M2EXT))
331 }
332 core::ElemType::Float(core::FloatKind::BF16) => {
333 self.capabilities.insert(Capability::BFloat16TypeKHR);
334 Elem::Float(16, Some(FPEncoding::BFloat16KHR))
335 }
336 core::ElemType::Float(FloatKind::F16) => {
337 self.capabilities.insert(Capability::Float16);
338 Elem::Float(16, None)
339 }
340 core::ElemType::Float(FloatKind::TF32) => panic!("TF32 not supported in SPIR-V"),
341 core::ElemType::Float(FloatKind::Flex32) => Elem::Relaxed,
342 core::ElemType::Float(FloatKind::F32) => Elem::Float(32, None),
343 core::ElemType::Float(FloatKind::F64) => {
344 self.capabilities.insert(Capability::Float64);
345 Elem::Float(64, None)
346 }
347 core::ElemType::Int(IntKind::I8) => {
348 self.capabilities.insert(Capability::Int8);
349 Elem::Int(8, true)
350 }
351 core::ElemType::Int(IntKind::I16) => {
352 self.capabilities.insert(Capability::Int16);
353 Elem::Int(16, true)
354 }
355 core::ElemType::Int(IntKind::I32) => Elem::Int(32, true),
356 core::ElemType::Int(IntKind::I64) => {
357 self.capabilities.insert(Capability::Int64);
358 Elem::Int(64, true)
359 }
360 core::ElemType::UInt(UIntKind::U64) => {
361 self.capabilities.insert(Capability::Int64);
362 Elem::Int(64, false)
363 }
364 core::ElemType::UInt(UIntKind::U32) => Elem::Int(32, false),
365 core::ElemType::UInt(UIntKind::U16) => {
366 self.capabilities.insert(Capability::Int16);
367 Elem::Int(16, false)
368 }
369 core::ElemType::UInt(UIntKind::U8) => {
370 self.capabilities.insert(Capability::Int8);
371 Elem::Int(8, false)
372 }
373 core::ElemType::Bool => Elem::Bool,
374 }
375 }
376
377 pub fn static_core(&mut self, val: core::Variable, item: &Item) -> Word {
378 let val = val.as_const().unwrap();
379
380 let value = match (val, item.elem()) {
381 (core::ConstantScalarValue::Int(val, _), Elem::Bool) => ConstVal::from_bool(val != 0),
382 (core::ConstantScalarValue::Int(val, _), Elem::Int(width, false)) => {
383 ConstVal::from_uint(val as u64, width)
384 }
385 (core::ConstantScalarValue::Int(val, _), Elem::Int(width, true)) => {
386 ConstVal::from_int(val, width)
387 }
388 (core::ConstantScalarValue::Int(val, _), Elem::Float(width, encoding)) => {
389 ConstVal::from_float(val as f64, width, encoding)
390 }
391 (core::ConstantScalarValue::Int(val, _), Elem::Relaxed) => {
392 ConstVal::from_float(val as f64, 32, None)
393 }
394 (core::ConstantScalarValue::Float(val, _), Elem::Bool) => {
395 ConstVal::from_bool(val != 0.0)
396 }
397 (core::ConstantScalarValue::Float(val, _), Elem::Int(width, false)) => {
398 ConstVal::from_uint(val as u64, width)
399 }
400 (core::ConstantScalarValue::Float(val, _), Elem::Int(width, true)) => {
401 ConstVal::from_int(val as i64, width)
402 }
403 (core::ConstantScalarValue::Float(val, _), Elem::Float(width, encoding)) => {
404 ConstVal::from_float(val, width, encoding)
405 }
406 (core::ConstantScalarValue::Float(val, _), Elem::Relaxed) => {
407 ConstVal::from_float(val, 32, None)
408 }
409 (core::ConstantScalarValue::UInt(val, _), Elem::Bool) => ConstVal::from_bool(val != 0),
410 (core::ConstantScalarValue::UInt(val, _), Elem::Int(width, false)) => {
411 ConstVal::from_uint(val, width)
412 }
413 (core::ConstantScalarValue::UInt(val, _), Elem::Int(width, true)) => {
414 ConstVal::from_int(val as i64, width)
415 }
416 (core::ConstantScalarValue::UInt(val, _), Elem::Float(width, encoding)) => {
417 ConstVal::from_float(val as f64, width, encoding)
418 }
419 (core::ConstantScalarValue::UInt(val, _), Elem::Relaxed) => {
420 ConstVal::from_float(val as f64, 32, None)
421 }
422 (core::ConstantScalarValue::Bool(val), Elem::Bool) => ConstVal::from_bool(val),
423 (core::ConstantScalarValue::Bool(val), Elem::Int(width, _)) => {
424 ConstVal::from_uint(val as u64, width)
425 }
426 (core::ConstantScalarValue::Bool(val), Elem::Float(width, encoding)) => {
427 ConstVal::from_float(val as u32 as f64, width, encoding)
428 }
429 (core::ConstantScalarValue::Bool(val), Elem::Relaxed) => {
430 ConstVal::from_float(val as u32 as f64, 32, None)
431 }
432 (_, Elem::Void) => unreachable!(),
433 };
434 item.constant(self, value)
435 }
436
437 pub fn static_cast(&mut self, val: ConstVal, from: &Elem, item: &Item) -> Word {
438 let elem_cast = match (*from, item.elem()) {
439 (Elem::Bool, Elem::Int(width, _)) => ConstVal::from_uint(val.as_u32() as u64, width),
440 (Elem::Bool, Elem::Float(width, encoding)) => {
441 ConstVal::from_float(val.as_u32() as f64, width, encoding)
442 }
443 (Elem::Bool, Elem::Relaxed) => ConstVal::from_float(val.as_u32() as f64, 32, None),
444 (Elem::Int(_, _), Elem::Bool) => ConstVal::from_bool(val.as_u64() != 0),
445 (Elem::Int(_, false), Elem::Int(width, _)) => ConstVal::from_uint(val.as_u64(), width),
446 (Elem::Int(w_in, true), Elem::Int(width, _)) => {
447 ConstVal::from_uint(val.as_int(w_in) as u64, width)
448 }
449 (Elem::Int(_, false), Elem::Float(width, encoding)) => {
450 ConstVal::from_float(val.as_u64() as f64, width, encoding)
451 }
452 (Elem::Int(_, false), Elem::Relaxed) => {
453 ConstVal::from_float(val.as_u64() as f64, 32, None)
454 }
455 (Elem::Int(in_w, true), Elem::Float(width, encoding)) => {
456 ConstVal::from_float(val.as_int(in_w) as f64, width, encoding)
457 }
458 (Elem::Int(in_w, true), Elem::Relaxed) => {
459 ConstVal::from_float(val.as_int(in_w) as f64, 32, None)
460 }
461 (Elem::Float(in_w, encoding), Elem::Bool) => {
462 ConstVal::from_bool(val.as_float(in_w, encoding) != 0.0)
463 }
464 (Elem::Relaxed, Elem::Bool) => ConstVal::from_bool(val.as_float(32, None) != 0.0),
465 (Elem::Float(in_w, encoding), Elem::Int(out_w, false)) => {
466 ConstVal::from_uint(val.as_float(in_w, encoding) as u64, out_w)
467 }
468 (Elem::Relaxed, Elem::Int(out_w, false)) => {
469 ConstVal::from_uint(val.as_float(32, None) as u64, out_w)
470 }
471 (Elem::Float(in_w, encoding), Elem::Int(out_w, true)) => {
472 ConstVal::from_int(val.as_float(in_w, encoding) as i64, out_w)
473 }
474 (Elem::Relaxed, Elem::Int(out_w, true)) => {
475 ConstVal::from_int(val.as_float(32, None) as i64, out_w)
476 }
477 (Elem::Float(in_w, encoding), Elem::Float(out_w, encoding_out)) => {
478 ConstVal::from_float(val.as_float(in_w, encoding), out_w, encoding_out)
479 }
480 (Elem::Relaxed, Elem::Float(out_w, encoding)) => {
481 ConstVal::from_float(val.as_float(32, None), out_w, encoding)
482 }
483 (Elem::Float(in_w, encoding), Elem::Relaxed) => {
484 ConstVal::from_float(val.as_float(in_w, encoding), 32, None)
485 }
486 (Elem::Bool, Elem::Bool) => val,
487 (Elem::Relaxed, Elem::Relaxed) => val,
488 (_, Elem::Void) | (Elem::Void, _) => unreachable!(),
489 };
490 item.constant(self, elem_cast)
491 }
492}
493
494impl std::fmt::Display for Item {
495 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
496 match self {
497 Item::Scalar(elem) => write!(f, "{elem}"),
498 Item::Vector(elem, factor) => write!(f, "vec{factor}<{elem}>"),
499 Item::Array(item, len) => write!(f, "array<{item}, {len}>"),
500 Item::RuntimeArray(item) => write!(f, "array<{item}>"),
501 Item::Struct(members) => {
502 write!(f, "struct<")?;
503 for item in members {
504 write!(f, "{item}")?;
505 }
506 f.write_str(">")
507 }
508 Item::Pointer(class, item) => write!(f, "ptr<{class:?}, {item}>"),
509 Item::CoopMatrix { ty, ident, .. } => write!(f, "matrix<{ty}, {ident:?}>"),
510 }
511 }
512}
513
514impl std::fmt::Display for Elem {
515 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
516 match self {
517 Elem::Void => write!(f, "void"),
518 Elem::Bool => write!(f, "bool"),
519 Elem::Int(width, false) => write!(f, "u{width}"),
520 Elem::Int(width, true) => write!(f, "i{width}"),
521 Elem::Float(width, None) => write!(f, "f{width}"),
522 Elem::Float(_, Some(FPEncoding::BFloat16KHR)) => write!(f, "bf16"),
523 Elem::Float(_, Some(FPEncoding::Float8E4M3EXT)) => write!(f, "e4m3"),
524 Elem::Float(_, Some(FPEncoding::Float8E5M2EXT)) => write!(f, "e5m2"),
525 Elem::Relaxed => write!(f, "flex32"),
526 }
527 }
528}