cubecl_spirv/
item.rs

1use cubecl_core::ir::{self as core, FloatKind, IntKind, UIntKind};
2use rspirv::spirv::{Capability, CooperativeMatrixUse, FPEncoding, Scope, StorageClass, Word};
3
4use crate::{compiler::SpirvCompiler, target::SpirvTarget, variable::ConstVal};
5
6#[derive(Debug, Clone, PartialEq, Eq, Hash)]
7pub enum Item {
8    Scalar(Elem),
9    // Vector of scalars. Must be 2, 3, or 4, or 8/16 for OpenCL only
10    Vector(Elem, u32),
11    Array(Box<Item>, u32),
12    RuntimeArray(Box<Item>),
13    Struct(Vec<Item>),
14    Pointer(StorageClass, Box<Item>),
15    CoopMatrix {
16        ty: Elem,
17        rows: u32,
18        columns: u32,
19        ident: CooperativeMatrixUse,
20    },
21}
22
23impl Item {
24    pub fn id<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>) -> Word {
25        let id = match self {
26            Item::Scalar(elem) => elem.id(b),
27            Item::Vector(elem, vec) => {
28                let elem = elem.id(b);
29                b.type_vector(elem, *vec)
30            }
31            Item::Array(item, len) => {
32                let item = item.id(b);
33                let len = b.const_u32(*len);
34                b.type_array(item, len)
35            }
36            Item::RuntimeArray(item) => {
37                let item = item.id(b);
38                b.type_runtime_array(item)
39            }
40            Item::Struct(vec) => {
41                let items: Vec<_> = vec.iter().map(|item| item.id(b)).collect();
42                let id = b.id(); // Avoid deduplicating this struct, because of decorations
43                b.type_struct_id(Some(id), items)
44            }
45            Item::Pointer(storage_class, item) => {
46                let item = item.id(b);
47                b.type_pointer(None, *storage_class, item)
48            }
49            Item::CoopMatrix {
50                ty,
51                rows,
52                columns,
53                ident,
54            } => {
55                let ty = ty.id(b);
56                let scope = b.const_u32(Scope::Subgroup as u32);
57                let usage = b.const_u32(*ident as u32);
58                b.type_cooperative_matrix_khr(ty, scope, *rows, *columns, usage)
59            }
60        };
61        if b.debug_symbols && !b.state.debug_types.contains(&id) {
62            b.debug_name(id, format!("{self}"));
63            b.state.debug_types.insert(id);
64        }
65        id
66    }
67
68    pub fn size(&self) -> u32 {
69        match self {
70            Item::Scalar(elem) => elem.size(),
71            Item::Vector(elem, factor) => elem.size() * *factor,
72            Item::Array(item, len) => item.size() * *len,
73            Item::RuntimeArray(item) => item.size(),
74            Item::Struct(vec) => vec.iter().map(|it| it.size()).sum(),
75            Item::Pointer(_, item) => item.size(),
76            Item::CoopMatrix { ty, .. } => ty.size(),
77        }
78    }
79
80    pub fn elem(&self) -> Elem {
81        match self {
82            Item::Scalar(elem) => *elem,
83            Item::Vector(elem, _) => *elem,
84            Item::Array(item, _) => item.elem(),
85            Item::RuntimeArray(item) => item.elem(),
86            Item::Struct(_) => Elem::Void,
87            Item::Pointer(_, item) => item.elem(),
88            Item::CoopMatrix { ty, .. } => *ty,
89        }
90    }
91
92    pub fn same_vectorization(&self, elem: Elem) -> Item {
93        match self {
94            Item::Scalar(_) => Item::Scalar(elem),
95            Item::Vector(_, factor) => Item::Vector(elem, *factor),
96            _ => unreachable!(),
97        }
98    }
99
100    pub fn constant<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>, value: ConstVal) -> Word {
101        let scalar = self.elem().constant(b, value);
102        let ty = self.id(b);
103        match self {
104            Item::Scalar(_) => scalar,
105            Item::Vector(_, vec) => b.constant_composite(ty, (0..*vec).map(|_| scalar)),
106            Item::Array(item, len) => {
107                let elem = item.constant(b, value);
108                b.constant_composite(ty, (0..*len).map(|_| elem))
109            }
110            Item::RuntimeArray(_) => unimplemented!("Can't create constant runtime array"),
111            Item::Struct(elems) => {
112                let items = elems
113                    .iter()
114                    .map(|item| item.constant(b, value))
115                    .collect::<Vec<_>>();
116                b.constant_composite(ty, items)
117            }
118            Item::Pointer(_, _) => unimplemented!("Can't create constant pointer"),
119            Item::CoopMatrix { .. } => unimplemented!("Can't create constant cmma matrix"),
120        }
121    }
122
123    pub fn const_u32<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>, value: u32) -> Word {
124        b.static_cast(ConstVal::Bit32(value), &Elem::Int(32, false), self)
125    }
126
127    /// Broadcast a scalar to a vector if needed, ex: f32 -> vec2<f32>, vec2<f32> -> vec2<f32>
128    pub fn broadcast<T: SpirvTarget>(
129        &self,
130        b: &mut SpirvCompiler<T>,
131        obj: Word,
132        out_id: Option<Word>,
133        other: &Item,
134    ) -> Word {
135        match (self, other) {
136            (Item::Scalar(elem), Item::Vector(_, factor)) => {
137                let item = Item::Vector(*elem, *factor);
138                let ty = item.id(b);
139                b.composite_construct(ty, out_id, (0..*factor).map(|_| obj).collect::<Vec<_>>())
140                    .unwrap()
141            }
142            _ => obj,
143        }
144    }
145
146    pub fn cast_to<T: SpirvTarget>(
147        &self,
148        b: &mut SpirvCompiler<T>,
149        out_id: Option<Word>,
150        obj: Word,
151        other: &Item,
152    ) -> Word {
153        let ty = other.id(b);
154
155        let matching_vec = match (self, other) {
156            (Item::Scalar(_), Item::Scalar(_)) => true,
157            (Item::Vector(_, factor_from), Item::Vector(_, factor_to)) => factor_from == factor_to,
158            _ => false,
159        };
160        let matching_elem = self.elem() == other.elem();
161
162        let convert_i_width =
163            |b: &mut SpirvCompiler<T>, obj: Word, out_id: Option<Word>, signed: bool| {
164                if signed {
165                    b.s_convert(ty, out_id, obj).unwrap()
166                } else {
167                    b.u_convert(ty, out_id, obj).unwrap()
168                }
169            };
170
171        let convert_int = |b: &mut SpirvCompiler<T>,
172                           obj: Word,
173                           out_id: Option<Word>,
174                           (width_self, signed_self),
175                           (width_other, signed_other)| {
176            let width_differs = width_self != width_other;
177            let sign_extend = signed_self && signed_other;
178            match width_differs {
179                true => convert_i_width(b, obj, out_id, sign_extend),
180                false => b.copy_object(ty, out_id, obj).unwrap(),
181            }
182        };
183
184        let cast_elem = |b: &mut SpirvCompiler<T>, obj: Word, out_id: Option<Word>| -> Word {
185            match (self.elem(), other.elem()) {
186                (Elem::Bool, Elem::Int(_, _)) => {
187                    let one = other.const_u32(b, 1);
188                    let zero = other.const_u32(b, 0);
189                    b.select(ty, out_id, obj, one, zero).unwrap()
190                }
191                (Elem::Bool, Elem::Float(_, _)) | (Elem::Bool, Elem::Relaxed) => {
192                    let one = other.const_u32(b, 1);
193                    let zero = other.const_u32(b, 0);
194                    b.select(ty, out_id, obj, one, zero).unwrap()
195                }
196                (Elem::Int(_, _), Elem::Bool) => {
197                    let zero = self.const_u32(b, 0);
198                    b.i_not_equal(ty, out_id, obj, zero).unwrap()
199                }
200                (Elem::Int(width_self, signed_self), Elem::Int(width_other, signed_other)) => {
201                    convert_int(
202                        b,
203                        obj,
204                        out_id,
205                        (width_self, signed_self),
206                        (width_other, signed_other),
207                    )
208                }
209                (Elem::Int(_, false), Elem::Float(_, _)) | (Elem::Int(_, false), Elem::Relaxed) => {
210                    b.convert_u_to_f(ty, out_id, obj).unwrap()
211                }
212                (Elem::Int(_, true), Elem::Float(_, _)) | (Elem::Int(_, true), Elem::Relaxed) => {
213                    b.convert_s_to_f(ty, out_id, obj).unwrap()
214                }
215                (Elem::Float(_, _), Elem::Bool) | (Elem::Relaxed, Elem::Bool) => {
216                    let zero = self.const_u32(b, 0);
217                    b.f_unord_not_equal(ty, out_id, obj, zero).unwrap()
218                }
219                (Elem::Float(_, _), Elem::Int(_, false)) | (Elem::Relaxed, Elem::Int(_, false)) => {
220                    b.convert_f_to_u(ty, out_id, obj).unwrap()
221                }
222                (Elem::Float(_, _), Elem::Int(_, true)) | (Elem::Relaxed, Elem::Int(_, true)) => {
223                    b.convert_f_to_s(ty, out_id, obj).unwrap()
224                }
225                (Elem::Float(_, _), Elem::Float(_, _))
226                | (Elem::Float(_, _), Elem::Relaxed)
227                | (Elem::Relaxed, Elem::Float(_, _)) => b.f_convert(ty, out_id, obj).unwrap(),
228                (Elem::Bool, Elem::Bool) => b.copy_object(ty, out_id, obj).unwrap(),
229                (Elem::Relaxed, Elem::Relaxed) => b.copy_object(ty, out_id, obj).unwrap(),
230                (from, to) => panic!("Invalid cast from {from:?} to {to:?}"),
231            }
232        };
233
234        match (matching_vec, matching_elem) {
235            (true, true) if out_id.is_some() => b.copy_object(ty, out_id, obj).unwrap(),
236            (true, true) => obj,
237            (true, false) => cast_elem(b, obj, out_id),
238            (false, true) => self.broadcast(b, obj, out_id, other),
239            (false, false) => {
240                let broadcast = self.broadcast(b, obj, None, other);
241                cast_elem(b, broadcast, out_id)
242            }
243        }
244    }
245}
246
247#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
248pub enum Elem {
249    Void,
250    Bool,
251    Int(u32, bool),
252    Float(u32, Option<FPEncoding>),
253    Relaxed,
254}
255
256impl Elem {
257    pub fn id<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>) -> Word {
258        let id = match self {
259            Elem::Void => b.type_void(),
260            Elem::Bool => b.type_bool(),
261            Elem::Int(width, _) => b.type_int(*width, 0),
262            Elem::Float(width, encoding) => b.type_float(*width, *encoding),
263            Elem::Relaxed => b.type_float(32, None),
264        };
265        if b.debug_symbols && !b.state.debug_types.contains(&id) {
266            b.debug_name(id, format!("{self}"));
267            b.state.debug_types.insert(id);
268        }
269        id
270    }
271
272    pub fn size(&self) -> u32 {
273        match self {
274            Elem::Void => 0,
275            Elem::Bool => 1,
276            Elem::Int(size, _) => *size / 8,
277            Elem::Float(size, _) => *size / 8,
278            Elem::Relaxed => 4,
279        }
280    }
281
282    pub fn constant<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>, value: ConstVal) -> Word {
283        let ty = self.id(b);
284        match self {
285            Elem::Void => unreachable!(),
286            Elem::Bool if value.as_u64() != 0 => b.constant_true(ty),
287            Elem::Bool => b.constant_false(ty),
288            _ => match value {
289                ConstVal::Bit32(val) => b.dedup_constant_bit32(ty, val),
290                ConstVal::Bit64(val) => b.dedup_constant_bit64(ty, val),
291            },
292        }
293    }
294}
295
296impl<T: SpirvTarget> SpirvCompiler<T> {
297    pub fn compile_type(&mut self, item: core::Type) -> Item {
298        match item {
299            core::Type::Scalar(storage) => Item::Scalar(self.compile_storage_type(storage)),
300            core::Type::Line(storage, size) => {
301                Item::Vector(self.compile_storage_type(storage), size)
302            }
303            core::Type::Semantic(_) => unimplemented!("Can't compile semantic type"),
304        }
305    }
306
307    pub fn compile_storage_type(&mut self, ty: core::StorageType) -> Elem {
308        match ty {
309            core::StorageType::Scalar(ty) | core::StorageType::Atomic(ty) => self.compile_elem(ty),
310            core::StorageType::Opaque(ty) => match ty {
311                core::OpaqueType::Barrier(_) => {
312                    unimplemented!("Barrier type not supported in SPIR-V")
313                }
314            },
315            core::StorageType::Packed(_, _) => {
316                unimplemented!("Packed types not yet supported in SPIR-V")
317            }
318        }
319    }
320
321    pub fn compile_elem(&mut self, elem: core::ElemType) -> Elem {
322        match elem {
323            core::ElemType::Float(
324                core::FloatKind::E2M1
325                | core::FloatKind::E2M3
326                | core::FloatKind::E3M2
327                | core::FloatKind::UE8M0,
328            ) => panic!("Minifloat not supported in SPIR-V"),
329            core::ElemType::Float(core::FloatKind::E4M3) => {
330                self.capabilities.insert(Capability::Float8EXT);
331                Elem::Float(8, Some(FPEncoding::Float8E4M3EXT))
332            }
333            core::ElemType::Float(core::FloatKind::E5M2) => {
334                self.capabilities.insert(Capability::Float8EXT);
335                Elem::Float(8, Some(FPEncoding::Float8E5M2EXT))
336            }
337            core::ElemType::Float(core::FloatKind::BF16) => {
338                self.capabilities.insert(Capability::BFloat16TypeKHR);
339                Elem::Float(16, Some(FPEncoding::BFloat16KHR))
340            }
341            core::ElemType::Float(FloatKind::F16) => {
342                self.capabilities.insert(Capability::Float16);
343                Elem::Float(16, None)
344            }
345            core::ElemType::Float(FloatKind::TF32) => panic!("TF32 not supported in SPIR-V"),
346            core::ElemType::Float(FloatKind::Flex32) => Elem::Relaxed,
347            core::ElemType::Float(FloatKind::F32) => Elem::Float(32, None),
348            core::ElemType::Float(FloatKind::F64) => {
349                self.capabilities.insert(Capability::Float64);
350                Elem::Float(64, None)
351            }
352            core::ElemType::Int(IntKind::I8) => {
353                self.capabilities.insert(Capability::Int8);
354                Elem::Int(8, true)
355            }
356            core::ElemType::Int(IntKind::I16) => {
357                self.capabilities.insert(Capability::Int16);
358                Elem::Int(16, true)
359            }
360            core::ElemType::Int(IntKind::I32) => Elem::Int(32, true),
361            core::ElemType::Int(IntKind::I64) => {
362                self.capabilities.insert(Capability::Int64);
363                Elem::Int(64, true)
364            }
365            core::ElemType::UInt(UIntKind::U64) => {
366                self.capabilities.insert(Capability::Int64);
367                Elem::Int(64, false)
368            }
369            core::ElemType::UInt(UIntKind::U32) => Elem::Int(32, false),
370            core::ElemType::UInt(UIntKind::U16) => {
371                self.capabilities.insert(Capability::Int16);
372                Elem::Int(16, false)
373            }
374            core::ElemType::UInt(UIntKind::U8) => {
375                self.capabilities.insert(Capability::Int8);
376                Elem::Int(8, false)
377            }
378            core::ElemType::Bool => Elem::Bool,
379        }
380    }
381
382    pub fn static_core(&mut self, val: core::Variable, item: &Item) -> Word {
383        let val = val.as_const().unwrap();
384
385        let value = match (val, item.elem()) {
386            (core::ConstantScalarValue::Int(val, _), Elem::Bool) => ConstVal::from_bool(val != 0),
387            (core::ConstantScalarValue::Int(val, _), Elem::Int(width, false)) => {
388                ConstVal::from_uint(val as u64, width)
389            }
390            (core::ConstantScalarValue::Int(val, _), Elem::Int(width, true)) => {
391                ConstVal::from_int(val, width)
392            }
393            (core::ConstantScalarValue::Int(val, _), Elem::Float(width, encoding)) => {
394                ConstVal::from_float(val as f64, width, encoding)
395            }
396            (core::ConstantScalarValue::Int(val, _), Elem::Relaxed) => {
397                ConstVal::from_float(val as f64, 32, None)
398            }
399            (core::ConstantScalarValue::Float(val, _), Elem::Bool) => {
400                ConstVal::from_bool(val != 0.0)
401            }
402            (core::ConstantScalarValue::Float(val, _), Elem::Int(width, false)) => {
403                ConstVal::from_uint(val as u64, width)
404            }
405            (core::ConstantScalarValue::Float(val, _), Elem::Int(width, true)) => {
406                ConstVal::from_int(val as i64, width)
407            }
408            (core::ConstantScalarValue::Float(val, _), Elem::Float(width, encoding)) => {
409                ConstVal::from_float(val, width, encoding)
410            }
411            (core::ConstantScalarValue::Float(val, _), Elem::Relaxed) => {
412                ConstVal::from_float(val, 32, None)
413            }
414            (core::ConstantScalarValue::UInt(val, _), Elem::Bool) => ConstVal::from_bool(val != 0),
415            (core::ConstantScalarValue::UInt(val, _), Elem::Int(width, false)) => {
416                ConstVal::from_uint(val, width)
417            }
418            (core::ConstantScalarValue::UInt(val, _), Elem::Int(width, true)) => {
419                ConstVal::from_int(val as i64, width)
420            }
421            (core::ConstantScalarValue::UInt(val, _), Elem::Float(width, encoding)) => {
422                ConstVal::from_float(val as f64, width, encoding)
423            }
424            (core::ConstantScalarValue::UInt(val, _), Elem::Relaxed) => {
425                ConstVal::from_float(val as f64, 32, None)
426            }
427            (core::ConstantScalarValue::Bool(val), Elem::Bool) => ConstVal::from_bool(val),
428            (core::ConstantScalarValue::Bool(val), Elem::Int(width, _)) => {
429                ConstVal::from_uint(val as u64, width)
430            }
431            (core::ConstantScalarValue::Bool(val), Elem::Float(width, encoding)) => {
432                ConstVal::from_float(val as u32 as f64, width, encoding)
433            }
434            (core::ConstantScalarValue::Bool(val), Elem::Relaxed) => {
435                ConstVal::from_float(val as u32 as f64, 32, None)
436            }
437            (_, Elem::Void) => unreachable!(),
438        };
439        item.constant(self, value)
440    }
441
442    pub fn static_cast(&mut self, val: ConstVal, from: &Elem, item: &Item) -> Word {
443        let elem_cast = match (*from, item.elem()) {
444            (Elem::Bool, Elem::Int(width, _)) => ConstVal::from_uint(val.as_u32() as u64, width),
445            (Elem::Bool, Elem::Float(width, encoding)) => {
446                ConstVal::from_float(val.as_u32() as f64, width, encoding)
447            }
448            (Elem::Bool, Elem::Relaxed) => ConstVal::from_float(val.as_u32() as f64, 32, None),
449            (Elem::Int(_, _), Elem::Bool) => ConstVal::from_bool(val.as_u64() != 0),
450            (Elem::Int(_, false), Elem::Int(width, _)) => ConstVal::from_uint(val.as_u64(), width),
451            (Elem::Int(w_in, true), Elem::Int(width, _)) => {
452                ConstVal::from_uint(val.as_int(w_in) as u64, width)
453            }
454            (Elem::Int(_, false), Elem::Float(width, encoding)) => {
455                ConstVal::from_float(val.as_u64() as f64, width, encoding)
456            }
457            (Elem::Int(_, false), Elem::Relaxed) => {
458                ConstVal::from_float(val.as_u64() as f64, 32, None)
459            }
460            (Elem::Int(in_w, true), Elem::Float(width, encoding)) => {
461                ConstVal::from_float(val.as_int(in_w) as f64, width, encoding)
462            }
463            (Elem::Int(in_w, true), Elem::Relaxed) => {
464                ConstVal::from_float(val.as_int(in_w) as f64, 32, None)
465            }
466            (Elem::Float(in_w, encoding), Elem::Bool) => {
467                ConstVal::from_bool(val.as_float(in_w, encoding) != 0.0)
468            }
469            (Elem::Relaxed, Elem::Bool) => ConstVal::from_bool(val.as_float(32, None) != 0.0),
470            (Elem::Float(in_w, encoding), Elem::Int(out_w, false)) => {
471                ConstVal::from_uint(val.as_float(in_w, encoding) as u64, out_w)
472            }
473            (Elem::Relaxed, Elem::Int(out_w, false)) => {
474                ConstVal::from_uint(val.as_float(32, None) as u64, out_w)
475            }
476            (Elem::Float(in_w, encoding), Elem::Int(out_w, true)) => {
477                ConstVal::from_int(val.as_float(in_w, encoding) as i64, out_w)
478            }
479            (Elem::Relaxed, Elem::Int(out_w, true)) => {
480                ConstVal::from_int(val.as_float(32, None) as i64, out_w)
481            }
482            (Elem::Float(in_w, encoding), Elem::Float(out_w, encoding_out)) => {
483                ConstVal::from_float(val.as_float(in_w, encoding), out_w, encoding_out)
484            }
485            (Elem::Relaxed, Elem::Float(out_w, encoding)) => {
486                ConstVal::from_float(val.as_float(32, None), out_w, encoding)
487            }
488            (Elem::Float(in_w, encoding), Elem::Relaxed) => {
489                ConstVal::from_float(val.as_float(in_w, encoding), 32, None)
490            }
491            (Elem::Bool, Elem::Bool) => val,
492            (Elem::Relaxed, Elem::Relaxed) => val,
493            (_, Elem::Void) | (Elem::Void, _) => unreachable!(),
494        };
495        item.constant(self, elem_cast)
496    }
497}
498
499impl std::fmt::Display for Item {
500    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
501        match self {
502            Item::Scalar(elem) => write!(f, "{elem}"),
503            Item::Vector(elem, factor) => write!(f, "vec{factor}<{elem}>"),
504            Item::Array(item, len) => write!(f, "array<{item}, {len}>"),
505            Item::RuntimeArray(item) => write!(f, "array<{item}>"),
506            Item::Struct(members) => {
507                write!(f, "struct<")?;
508                for item in members {
509                    write!(f, "{item}")?;
510                }
511                f.write_str(">")
512            }
513            Item::Pointer(class, item) => write!(f, "ptr<{class:?}, {item}>"),
514            Item::CoopMatrix { ty, ident, .. } => write!(f, "matrix<{ty}, {ident:?}>"),
515        }
516    }
517}
518
519impl std::fmt::Display for Elem {
520    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
521        match self {
522            Elem::Void => write!(f, "void"),
523            Elem::Bool => write!(f, "bool"),
524            Elem::Int(width, false) => write!(f, "u{width}"),
525            Elem::Int(width, true) => write!(f, "i{width}"),
526            Elem::Float(width, None) => write!(f, "f{width}"),
527            Elem::Float(_, Some(FPEncoding::BFloat16KHR)) => write!(f, "bf16"),
528            Elem::Float(_, Some(FPEncoding::Float8E4M3EXT)) => write!(f, "e4m3"),
529            Elem::Float(_, Some(FPEncoding::Float8E5M2EXT)) => write!(f, "e5m2"),
530            Elem::Relaxed => write!(f, "flex32"),
531        }
532    }
533}