cubecl_spirv/
item.rs

1use cubecl_core::ir::{self as core, FloatKind, IntKind, UIntKind};
2use rspirv::spirv::{Capability, CooperativeMatrixUse, FPEncoding, Scope, StorageClass, Word};
3
4use crate::{compiler::SpirvCompiler, target::SpirvTarget, variable::ConstVal};
5
6#[derive(Debug, Clone, PartialEq, Eq, Hash)]
7pub enum Item {
8    Scalar(Elem),
9    // Vector of scalars. Must be 2, 3, or 4, or 8/16 for OpenCL only
10    Vector(Elem, u32),
11    Array(Box<Item>, u32),
12    RuntimeArray(Box<Item>),
13    Struct(Vec<Item>),
14    Pointer(StorageClass, Box<Item>),
15    CoopMatrix {
16        ty: Elem,
17        rows: u32,
18        columns: u32,
19        ident: CooperativeMatrixUse,
20    },
21}
22
23impl Item {
24    pub fn id<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>) -> Word {
25        let id = match self {
26            Item::Scalar(elem) => elem.id(b),
27            Item::Vector(elem, vec) => {
28                let elem = elem.id(b);
29                b.type_vector(elem, *vec)
30            }
31            Item::Array(item, len) => {
32                let item = item.id(b);
33                let len = b.const_u32(*len);
34                b.type_array(item, len)
35            }
36            Item::RuntimeArray(item) => {
37                let item = item.id(b);
38                b.type_runtime_array(item)
39            }
40            Item::Struct(vec) => {
41                let items: Vec<_> = vec.iter().map(|item| item.id(b)).collect();
42                let id = b.id(); // Avoid deduplicating this struct, because of decorations
43                b.type_struct_id(Some(id), items)
44            }
45            Item::Pointer(storage_class, item) => {
46                let item = item.id(b);
47                b.type_pointer(None, *storage_class, item)
48            }
49            Item::CoopMatrix {
50                ty,
51                rows,
52                columns,
53                ident,
54            } => {
55                let ty = ty.id(b);
56                let scope = b.const_u32(Scope::Subgroup as u32);
57                let usage = b.const_u32(*ident as u32);
58                b.type_cooperative_matrix_khr(ty, scope, *rows, *columns, usage)
59            }
60        };
61        if b.debug_symbols && !b.state.debug_types.contains(&id) {
62            b.debug_name(id, format!("{self}"));
63            b.state.debug_types.insert(id);
64        }
65        id
66    }
67
68    pub fn builtin_u32() -> Self {
69        Item::Scalar(Elem::Int(32, false))
70    }
71
72    pub fn size(&self) -> u32 {
73        match self {
74            Item::Scalar(elem) => elem.size(),
75            Item::Vector(elem, factor) => elem.size() * *factor,
76            Item::Array(item, len) => item.size() * *len,
77            Item::RuntimeArray(item) => item.size(),
78            Item::Struct(vec) => vec.iter().map(|it| it.size()).sum(),
79            Item::Pointer(_, item) => item.size(),
80            Item::CoopMatrix { ty, .. } => ty.size(),
81        }
82    }
83
84    pub fn elem(&self) -> Elem {
85        match self {
86            Item::Scalar(elem) => *elem,
87            Item::Vector(elem, _) => *elem,
88            Item::Array(item, _) => item.elem(),
89            Item::RuntimeArray(item) => item.elem(),
90            Item::Struct(_) => Elem::Void,
91            Item::Pointer(_, item) => item.elem(),
92            Item::CoopMatrix { ty, .. } => *ty,
93        }
94    }
95
96    pub fn same_vectorization(&self, elem: Elem) -> Item {
97        match self {
98            Item::Scalar(_) => Item::Scalar(elem),
99            Item::Vector(_, factor) => Item::Vector(elem, *factor),
100            _ => unreachable!(),
101        }
102    }
103
104    pub fn constant<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>, value: ConstVal) -> Word {
105        let scalar = self.elem().constant(b, value);
106        let ty = self.id(b);
107        match self {
108            Item::Scalar(_) => scalar,
109            Item::Vector(_, vec) => b.constant_composite(ty, (0..*vec).map(|_| scalar)),
110            Item::Array(item, len) => {
111                let elem = item.constant(b, value);
112                b.constant_composite(ty, (0..*len).map(|_| elem))
113            }
114            Item::RuntimeArray(_) => unimplemented!("Can't create constant runtime array"),
115            Item::Struct(elems) => {
116                let items = elems
117                    .iter()
118                    .map(|item| item.constant(b, value))
119                    .collect::<Vec<_>>();
120                b.constant_composite(ty, items)
121            }
122            Item::Pointer(_, _) => unimplemented!("Can't create constant pointer"),
123            Item::CoopMatrix { .. } => unimplemented!("Can't create constant cmma matrix"),
124        }
125    }
126
127    pub fn const_u32<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>, value: u32) -> Word {
128        b.static_cast(ConstVal::Bit32(value), &Elem::Int(32, false), self)
129            .0
130    }
131
132    /// Broadcast a scalar to a vector if needed, ex: f32 -> vec2<f32>, vec2<f32> -> vec2<f32>
133    pub fn broadcast<T: SpirvTarget>(
134        &self,
135        b: &mut SpirvCompiler<T>,
136        obj: Word,
137        out_id: Option<Word>,
138        other: &Item,
139    ) -> Word {
140        match (self, other) {
141            (Item::Scalar(elem), Item::Vector(_, factor)) => {
142                let item = Item::Vector(*elem, *factor);
143                let ty = item.id(b);
144                b.composite_construct(ty, out_id, (0..*factor).map(|_| obj).collect::<Vec<_>>())
145                    .unwrap()
146            }
147            _ => obj,
148        }
149    }
150
151    pub fn cast_to<T: SpirvTarget>(
152        &self,
153        b: &mut SpirvCompiler<T>,
154        out_id: Option<Word>,
155        obj: Word,
156        other: &Item,
157    ) -> Word {
158        let ty = other.id(b);
159
160        let matching_vec = match (self, other) {
161            (Item::Scalar(_), Item::Scalar(_)) => true,
162            (Item::Vector(_, factor_from), Item::Vector(_, factor_to)) => factor_from == factor_to,
163            _ => false,
164        };
165        let matching_elem = self.elem() == other.elem();
166
167        let convert_i_width =
168            |b: &mut SpirvCompiler<T>, obj: Word, out_id: Option<Word>, signed: bool| {
169                if signed {
170                    b.s_convert(ty, out_id, obj).unwrap()
171                } else {
172                    b.u_convert(ty, out_id, obj).unwrap()
173                }
174            };
175
176        let convert_int = |b: &mut SpirvCompiler<T>,
177                           obj: Word,
178                           out_id: Option<Word>,
179                           (width_self, signed_self),
180                           (width_other, signed_other)| {
181            let width_differs = width_self != width_other;
182            let sign_extend = signed_self && signed_other;
183            match width_differs {
184                true => convert_i_width(b, obj, out_id, sign_extend),
185                false => b.copy_object(ty, out_id, obj).unwrap(),
186            }
187        };
188
189        let cast_elem = |b: &mut SpirvCompiler<T>, obj: Word, out_id: Option<Word>| -> Word {
190            match (self.elem(), other.elem()) {
191                (Elem::Bool, Elem::Int(_, _)) => {
192                    let one = other.const_u32(b, 1);
193                    let zero = other.const_u32(b, 0);
194                    b.select(ty, out_id, obj, one, zero).unwrap()
195                }
196                (Elem::Bool, Elem::Float(_, _)) | (Elem::Bool, Elem::Relaxed) => {
197                    let one = other.const_u32(b, 1);
198                    let zero = other.const_u32(b, 0);
199                    b.select(ty, out_id, obj, one, zero).unwrap()
200                }
201                (Elem::Int(_, _), Elem::Bool) => {
202                    let zero = self.const_u32(b, 0);
203                    b.i_not_equal(ty, out_id, obj, zero).unwrap()
204                }
205                (Elem::Int(width_self, signed_self), Elem::Int(width_other, signed_other)) => {
206                    convert_int(
207                        b,
208                        obj,
209                        out_id,
210                        (width_self, signed_self),
211                        (width_other, signed_other),
212                    )
213                }
214                (Elem::Int(_, false), Elem::Float(_, _)) | (Elem::Int(_, false), Elem::Relaxed) => {
215                    b.convert_u_to_f(ty, out_id, obj).unwrap()
216                }
217                (Elem::Int(_, true), Elem::Float(_, _)) | (Elem::Int(_, true), Elem::Relaxed) => {
218                    b.convert_s_to_f(ty, out_id, obj).unwrap()
219                }
220                (Elem::Float(_, _), Elem::Bool) | (Elem::Relaxed, Elem::Bool) => {
221                    let zero = self.const_u32(b, 0);
222                    b.f_unord_not_equal(ty, out_id, obj, zero).unwrap()
223                }
224                (Elem::Float(_, _), Elem::Int(_, false)) | (Elem::Relaxed, Elem::Int(_, false)) => {
225                    b.convert_f_to_u(ty, out_id, obj).unwrap()
226                }
227                (Elem::Float(_, _), Elem::Int(_, true)) | (Elem::Relaxed, Elem::Int(_, true)) => {
228                    b.convert_f_to_s(ty, out_id, obj).unwrap()
229                }
230                (Elem::Float(_, _), Elem::Float(_, _))
231                | (Elem::Float(_, _), Elem::Relaxed)
232                | (Elem::Relaxed, Elem::Float(_, _)) => b.f_convert(ty, out_id, obj).unwrap(),
233                (Elem::Bool, Elem::Bool) => b.copy_object(ty, out_id, obj).unwrap(),
234                (Elem::Relaxed, Elem::Relaxed) => b.copy_object(ty, out_id, obj).unwrap(),
235                (from, to) => panic!("Invalid cast from {from:?} to {to:?}"),
236            }
237        };
238
239        match (matching_vec, matching_elem) {
240            (true, true) if out_id.is_some() => b.copy_object(ty, out_id, obj).unwrap(),
241            (true, true) => obj,
242            (true, false) => cast_elem(b, obj, out_id),
243            (false, true) => self.broadcast(b, obj, out_id, other),
244            (false, false) => {
245                let broadcast = self.broadcast(b, obj, None, other);
246                cast_elem(b, broadcast, out_id)
247            }
248        }
249    }
250}
251
252#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
253pub enum Elem {
254    Void,
255    Bool,
256    Int(u32, bool),
257    Float(u32, Option<FPEncoding>),
258    Relaxed,
259}
260
261impl Elem {
262    pub fn id<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>) -> Word {
263        let id = match self {
264            Elem::Void => b.type_void(),
265            Elem::Bool => b.type_bool(),
266            Elem::Int(width, _) => b.type_int(*width, 0),
267            Elem::Float(width, encoding) => b.type_float(*width, *encoding),
268            Elem::Relaxed => b.type_float(32, None),
269        };
270        if b.debug_symbols && !b.state.debug_types.contains(&id) {
271            b.debug_name(id, format!("{self}"));
272            b.state.debug_types.insert(id);
273        }
274        id
275    }
276
277    pub fn size(&self) -> u32 {
278        match self {
279            Elem::Void => 0,
280            Elem::Bool => 1,
281            Elem::Int(size, _) => *size / 8,
282            Elem::Float(size, _) => *size / 8,
283            Elem::Relaxed => 4,
284        }
285    }
286
287    pub fn constant<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>, value: ConstVal) -> Word {
288        let ty = self.id(b);
289        match self {
290            Elem::Void => unreachable!(),
291            Elem::Bool if value.as_u64() != 0 => b.constant_true(ty),
292            Elem::Bool => b.constant_false(ty),
293            _ => match value {
294                ConstVal::Bit32(val) => b.dedup_constant_bit32(ty, val),
295                ConstVal::Bit64(val) => b.dedup_constant_bit64(ty, val),
296            },
297        }
298    }
299
300    pub fn float_encoding(&self) -> Option<FPEncoding> {
301        match self {
302            Elem::Float(_, encoding) => *encoding,
303            _ => None,
304        }
305    }
306}
307
308impl<T: SpirvTarget> SpirvCompiler<T> {
309    pub fn compile_type(&mut self, item: core::Type) -> Item {
310        match item {
311            core::Type::Scalar(storage) => Item::Scalar(self.compile_storage_type(storage)),
312            core::Type::Line(storage, size) => {
313                Item::Vector(self.compile_storage_type(storage), size as u32)
314            }
315            core::Type::Semantic(_) => unimplemented!("Can't compile semantic type"),
316        }
317    }
318
319    pub fn compile_storage_type(&mut self, ty: core::StorageType) -> Elem {
320        match ty {
321            core::StorageType::Scalar(ty) | core::StorageType::Atomic(ty) => self.compile_elem(ty),
322            core::StorageType::Opaque(ty) => match ty {
323                core::OpaqueType::Barrier(_) => {
324                    unimplemented!("Barrier type not supported in SPIR-V")
325                }
326            },
327            core::StorageType::Packed(_, _) => {
328                unimplemented!("Packed types not yet supported in SPIR-V")
329            }
330        }
331    }
332
333    pub fn compile_elem(&mut self, elem: core::ElemType) -> Elem {
334        match elem {
335            core::ElemType::Float(
336                core::FloatKind::E2M1
337                | core::FloatKind::E2M3
338                | core::FloatKind::E3M2
339                | core::FloatKind::UE8M0,
340            ) => panic!("Minifloat not supported in SPIR-V"),
341            core::ElemType::Float(core::FloatKind::E4M3) => {
342                self.capabilities.insert(Capability::Float8EXT);
343                Elem::Float(8, Some(FPEncoding::Float8E4M3EXT))
344            }
345            core::ElemType::Float(core::FloatKind::E5M2) => {
346                self.capabilities.insert(Capability::Float8EXT);
347                Elem::Float(8, Some(FPEncoding::Float8E5M2EXT))
348            }
349            core::ElemType::Float(core::FloatKind::BF16) => {
350                self.capabilities.insert(Capability::BFloat16TypeKHR);
351                Elem::Float(16, Some(FPEncoding::BFloat16KHR))
352            }
353            core::ElemType::Float(FloatKind::F16) => {
354                self.capabilities.insert(Capability::Float16);
355                Elem::Float(16, None)
356            }
357            core::ElemType::Float(FloatKind::TF32) => panic!("TF32 not supported in SPIR-V"),
358            core::ElemType::Float(FloatKind::Flex32) => Elem::Relaxed,
359            core::ElemType::Float(FloatKind::F32) => Elem::Float(32, None),
360            core::ElemType::Float(FloatKind::F64) => {
361                self.capabilities.insert(Capability::Float64);
362                Elem::Float(64, None)
363            }
364            core::ElemType::Int(IntKind::I8) => {
365                self.capabilities.insert(Capability::Int8);
366                Elem::Int(8, true)
367            }
368            core::ElemType::Int(IntKind::I16) => {
369                self.capabilities.insert(Capability::Int16);
370                Elem::Int(16, true)
371            }
372            core::ElemType::Int(IntKind::I32) => Elem::Int(32, true),
373            core::ElemType::Int(IntKind::I64) => {
374                self.capabilities.insert(Capability::Int64);
375                Elem::Int(64, true)
376            }
377            core::ElemType::UInt(UIntKind::U64) => {
378                self.capabilities.insert(Capability::Int64);
379                Elem::Int(64, false)
380            }
381            core::ElemType::UInt(UIntKind::U32) => Elem::Int(32, false),
382            core::ElemType::UInt(UIntKind::U16) => {
383                self.capabilities.insert(Capability::Int16);
384                Elem::Int(16, false)
385            }
386            core::ElemType::UInt(UIntKind::U8) => {
387                self.capabilities.insert(Capability::Int8);
388                Elem::Int(8, false)
389            }
390            core::ElemType::Bool => Elem::Bool,
391        }
392    }
393
394    pub fn static_cast(&mut self, val: ConstVal, from: &Elem, item: &Item) -> (Word, ConstVal) {
395        let elem_cast = match (*from, item.elem()) {
396            (Elem::Bool, Elem::Int(width, _)) => ConstVal::from_uint(val.as_u32() as u64, width),
397            (Elem::Bool, Elem::Float(width, encoding)) => {
398                ConstVal::from_float(val.as_u32() as f64, width, encoding)
399            }
400            (Elem::Bool, Elem::Relaxed) => ConstVal::from_float(val.as_u32() as f64, 32, None),
401            (Elem::Int(_, _), Elem::Bool) => ConstVal::from_bool(val.as_u64() != 0),
402            (Elem::Int(_, false), Elem::Int(width, _)) => ConstVal::from_uint(val.as_u64(), width),
403            (Elem::Int(w_in, true), Elem::Int(width, _)) => {
404                ConstVal::from_uint(val.as_int(w_in) as u64, width)
405            }
406            (Elem::Int(_, false), Elem::Float(width, encoding)) => {
407                ConstVal::from_float(val.as_u64() as f64, width, encoding)
408            }
409            (Elem::Int(_, false), Elem::Relaxed) => {
410                ConstVal::from_float(val.as_u64() as f64, 32, None)
411            }
412            (Elem::Int(in_w, true), Elem::Float(width, encoding)) => {
413                ConstVal::from_float(val.as_int(in_w) as f64, width, encoding)
414            }
415            (Elem::Int(in_w, true), Elem::Relaxed) => {
416                ConstVal::from_float(val.as_int(in_w) as f64, 32, None)
417            }
418            (Elem::Float(in_w, encoding), Elem::Bool) => {
419                ConstVal::from_bool(val.as_float(in_w, encoding) != 0.0)
420            }
421            (Elem::Relaxed, Elem::Bool) => ConstVal::from_bool(val.as_float(32, None) != 0.0),
422            (Elem::Float(in_w, encoding), Elem::Int(out_w, false)) => {
423                ConstVal::from_uint(val.as_float(in_w, encoding) as u64, out_w)
424            }
425            (Elem::Relaxed, Elem::Int(out_w, false)) => {
426                ConstVal::from_uint(val.as_float(32, None) as u64, out_w)
427            }
428            (Elem::Float(in_w, encoding), Elem::Int(out_w, true)) => {
429                ConstVal::from_int(val.as_float(in_w, encoding) as i64, out_w)
430            }
431            (Elem::Relaxed, Elem::Int(out_w, true)) => {
432                ConstVal::from_int(val.as_float(32, None) as i64, out_w)
433            }
434            (Elem::Float(in_w, encoding), Elem::Float(out_w, encoding_out)) => {
435                ConstVal::from_float(val.as_float(in_w, encoding), out_w, encoding_out)
436            }
437            (Elem::Relaxed, Elem::Float(out_w, encoding)) => {
438                ConstVal::from_float(val.as_float(32, None), out_w, encoding)
439            }
440            (Elem::Float(in_w, encoding), Elem::Relaxed) => {
441                ConstVal::from_float(val.as_float(in_w, encoding), 32, None)
442            }
443            (Elem::Bool, Elem::Bool) => val,
444            (Elem::Relaxed, Elem::Relaxed) => val,
445            (_, Elem::Void) | (Elem::Void, _) => unreachable!(),
446        };
447        let id = item.constant(self, elem_cast);
448        (id, elem_cast)
449    }
450}
451
452impl std::fmt::Display for Item {
453    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
454        match self {
455            Item::Scalar(elem) => write!(f, "{elem}"),
456            Item::Vector(elem, factor) => write!(f, "vec{factor}<{elem}>"),
457            Item::Array(item, len) => write!(f, "array<{item}, {len}>"),
458            Item::RuntimeArray(item) => write!(f, "array<{item}>"),
459            Item::Struct(members) => {
460                write!(f, "struct<")?;
461                for item in members {
462                    write!(f, "{item}")?;
463                }
464                f.write_str(">")
465            }
466            Item::Pointer(class, item) => write!(f, "ptr<{class:?}, {item}>"),
467            Item::CoopMatrix { ty, ident, .. } => write!(f, "matrix<{ty}, {ident:?}>"),
468        }
469    }
470}
471
472impl std::fmt::Display for Elem {
473    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
474        match self {
475            Elem::Void => write!(f, "void"),
476            Elem::Bool => write!(f, "bool"),
477            Elem::Int(width, false) => write!(f, "u{width}"),
478            Elem::Int(width, true) => write!(f, "i{width}"),
479            Elem::Float(width, None) => write!(f, "f{width}"),
480            Elem::Float(_, Some(FPEncoding::BFloat16KHR)) => write!(f, "bf16"),
481            Elem::Float(_, Some(FPEncoding::Float8E4M3EXT)) => write!(f, "e4m3"),
482            Elem::Float(_, Some(FPEncoding::Float8E5M2EXT)) => write!(f, "e5m2"),
483            Elem::Relaxed => write!(f, "flex32"),
484        }
485    }
486}