Skip to main content

cubecl_spirv/
item.rs

1use cubecl_core::ir::{self as core, FloatKind, IntKind, UIntKind};
2use rspirv::spirv::{Capability, CooperativeMatrixUse, FPEncoding, Scope, StorageClass, Word};
3use serde::{Deserialize, Serialize};
4
5use crate::{compiler::SpirvCompiler, target::SpirvTarget, variable::ConstVal};
6
7#[derive(Debug, Clone, PartialEq, Eq, Hash)]
8pub enum Item {
9    Scalar(Elem),
10    // Vector of scalars. Must be 2, 3, or 4, or 8/16 for OpenCL only
11    Vector(Elem, u32),
12    Array(Box<Item>, u32),
13    RuntimeArray(Box<Item>),
14    Struct(Vec<Item>),
15    Pointer(StorageClass, Box<Item>),
16    CoopMatrix {
17        ty: Elem,
18        rows: u32,
19        columns: u32,
20        ident: CooperativeMatrixUse,
21    },
22}
23
24impl Item {
25    pub fn id<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>) -> Word {
26        let id = match self {
27            Item::Scalar(elem) => elem.id(b),
28            Item::Vector(elem, vec) => {
29                let elem = elem.id(b);
30                b.type_vector(elem, *vec)
31            }
32            Item::Array(item, len) => {
33                let item = item.id(b);
34                let len = b.const_u32(*len);
35                b.type_array(item, len)
36            }
37            Item::RuntimeArray(item) => {
38                let item = item.id(b);
39                b.type_runtime_array(item)
40            }
41            Item::Struct(vec) => {
42                let items: Vec<_> = vec.iter().map(|item| item.id(b)).collect();
43                let id = b.id(); // Avoid deduplicating this struct, because of decorations
44                b.type_struct_id(Some(id), items)
45            }
46            Item::Pointer(storage_class, item) => {
47                let item = item.id(b);
48                b.type_pointer(None, *storage_class, item)
49            }
50            Item::CoopMatrix {
51                ty,
52                rows,
53                columns,
54                ident,
55            } => {
56                let ty = ty.id(b);
57                let scope = b.const_u32(Scope::Subgroup as u32);
58                let usage = b.const_u32(*ident as u32);
59                b.type_cooperative_matrix_khr(ty, scope, *rows, *columns, usage)
60            }
61        };
62        if b.debug_symbols && !b.state.debug_types.contains(&id) {
63            b.debug_name(id, format!("{self}"));
64            b.state.debug_types.insert(id);
65        }
66        id
67    }
68
69    pub fn builtin_u32() -> Self {
70        Item::Scalar(Elem::Int(32, false))
71    }
72
73    pub fn size(&self) -> u32 {
74        match self {
75            Item::Scalar(elem) => elem.size(),
76            Item::Vector(elem, factor) => elem.size() * *factor,
77            Item::Array(item, len) => item.size() * *len,
78            Item::RuntimeArray(item) => item.size(),
79            Item::Struct(vec) => vec.iter().map(|it| it.size()).sum(),
80            Item::Pointer(_, item) => item.size(),
81            Item::CoopMatrix { ty, .. } => ty.size(),
82        }
83    }
84
85    pub fn elem(&self) -> Elem {
86        match self {
87            Item::Scalar(elem) => *elem,
88            Item::Vector(elem, _) => *elem,
89            Item::Array(item, _) => item.elem(),
90            Item::RuntimeArray(item) => item.elem(),
91            Item::Struct(_) => Elem::Void,
92            Item::Pointer(_, item) => item.elem(),
93            Item::CoopMatrix { ty, .. } => *ty,
94        }
95    }
96
97    pub fn same_vectorization(&self, elem: Elem) -> Item {
98        match self {
99            Item::Scalar(_) => Item::Scalar(elem),
100            Item::Vector(_, factor) => Item::Vector(elem, *factor),
101            _ => unreachable!(),
102        }
103    }
104
105    pub fn vectorization(&self) -> u32 {
106        match self {
107            Item::Vector(_, factor) => *factor,
108            _ => 1,
109        }
110    }
111
112    pub fn constant<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>, value: ConstVal) -> Word {
113        let scalar = self.elem().constant(b, value);
114        let ty = self.id(b);
115        match self {
116            Item::Scalar(_) => scalar,
117            Item::Vector(_, vec) => b.constant_composite(ty, (0..*vec).map(|_| scalar)),
118            Item::Array(item, len) => {
119                let elem = item.constant(b, value);
120                b.constant_composite(ty, (0..*len).map(|_| elem))
121            }
122            Item::RuntimeArray(_) => unimplemented!("Can't create constant runtime array"),
123            Item::Struct(elems) => {
124                let items = elems
125                    .iter()
126                    .map(|item| item.constant(b, value))
127                    .collect::<Vec<_>>();
128                b.constant_composite(ty, items)
129            }
130            Item::Pointer(_, _) => unimplemented!("Can't create constant pointer"),
131            Item::CoopMatrix { .. } => unimplemented!("Can't create constant cmma matrix"),
132        }
133    }
134
135    pub fn const_u32<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>, value: u32) -> Word {
136        b.static_cast(ConstVal::Bit32(value), &Elem::Int(32, false), self)
137            .0
138    }
139
140    /// Broadcast a scalar to a vector if needed, ex: f32 -> vec2<f32>, vec2<f32> -> vec2<f32>
141    pub fn broadcast<T: SpirvTarget>(
142        &self,
143        b: &mut SpirvCompiler<T>,
144        obj: Word,
145        out_id: Option<Word>,
146        other: &Item,
147    ) -> Word {
148        match (self, other) {
149            (Item::Scalar(elem), Item::Vector(_, factor)) => {
150                let item = Item::Vector(*elem, *factor);
151                let ty = item.id(b);
152                b.composite_construct(ty, out_id, (0..*factor).map(|_| obj).collect::<Vec<_>>())
153                    .unwrap()
154            }
155            _ => obj,
156        }
157    }
158
159    pub fn cast_to<T: SpirvTarget>(
160        &self,
161        b: &mut SpirvCompiler<T>,
162        out_id: Option<Word>,
163        obj: Word,
164        other: &Item,
165    ) -> Word {
166        let ty = other.id(b);
167
168        let matching_vec = match (self, other) {
169            (Item::Scalar(_), Item::Scalar(_)) => true,
170            (Item::Vector(_, factor_from), Item::Vector(_, factor_to)) => factor_from == factor_to,
171            _ => false,
172        };
173        let matching_elem = self.elem() == other.elem();
174
175        let convert_i_width =
176            |b: &mut SpirvCompiler<T>, obj: Word, out_id: Option<Word>, signed: bool| {
177                if signed {
178                    b.s_convert(ty, out_id, obj).unwrap()
179                } else {
180                    b.u_convert(ty, out_id, obj).unwrap()
181                }
182            };
183
184        let convert_int = |b: &mut SpirvCompiler<T>,
185                           obj: Word,
186                           out_id: Option<Word>,
187                           (width_self, signed_self),
188                           (width_other, signed_other)| {
189            let width_differs = width_self != width_other;
190            let sign_extend = signed_self && signed_other;
191            match width_differs {
192                true => convert_i_width(b, obj, out_id, sign_extend),
193                false => b.copy_object(ty, out_id, obj).unwrap(),
194            }
195        };
196
197        let cast_elem = |b: &mut SpirvCompiler<T>, obj: Word, out_id: Option<Word>| -> Word {
198            match (self.elem(), other.elem()) {
199                (Elem::Bool, Elem::Int(_, _)) => {
200                    let one = other.const_u32(b, 1);
201                    let zero = other.const_u32(b, 0);
202                    b.select(ty, out_id, obj, one, zero).unwrap()
203                }
204                (Elem::Bool, Elem::Float(_, _)) | (Elem::Bool, Elem::Relaxed) => {
205                    let one = other.const_u32(b, 1);
206                    let zero = other.const_u32(b, 0);
207                    b.select(ty, out_id, obj, one, zero).unwrap()
208                }
209                (Elem::Int(_, _), Elem::Bool) => {
210                    let zero = self.const_u32(b, 0);
211                    b.i_not_equal(ty, out_id, obj, zero).unwrap()
212                }
213                (Elem::Int(width_self, signed_self), Elem::Int(width_other, signed_other)) => {
214                    convert_int(
215                        b,
216                        obj,
217                        out_id,
218                        (width_self, signed_self),
219                        (width_other, signed_other),
220                    )
221                }
222                (Elem::Int(_, false), Elem::Float(_, _)) | (Elem::Int(_, false), Elem::Relaxed) => {
223                    b.convert_u_to_f(ty, out_id, obj).unwrap()
224                }
225                (Elem::Int(_, true), Elem::Float(_, _)) | (Elem::Int(_, true), Elem::Relaxed) => {
226                    b.convert_s_to_f(ty, out_id, obj).unwrap()
227                }
228                (Elem::Float(_, _), Elem::Bool) | (Elem::Relaxed, Elem::Bool) => {
229                    let zero = self.const_u32(b, 0);
230                    b.f_unord_not_equal(ty, out_id, obj, zero).unwrap()
231                }
232                (Elem::Float(_, _), Elem::Int(_, false)) | (Elem::Relaxed, Elem::Int(_, false)) => {
233                    b.convert_f_to_u(ty, out_id, obj).unwrap()
234                }
235                (Elem::Float(_, _), Elem::Int(_, true)) | (Elem::Relaxed, Elem::Int(_, true)) => {
236                    b.convert_f_to_s(ty, out_id, obj).unwrap()
237                }
238                (Elem::Float(_, _), Elem::Float(_, _))
239                | (Elem::Float(_, _), Elem::Relaxed)
240                | (Elem::Relaxed, Elem::Float(_, _)) => b.f_convert(ty, out_id, obj).unwrap(),
241                (Elem::Bool, Elem::Bool) => b.copy_object(ty, out_id, obj).unwrap(),
242                (Elem::Relaxed, Elem::Relaxed) => b.copy_object(ty, out_id, obj).unwrap(),
243                (from, to) => panic!("Invalid cast from {from:?} to {to:?}"),
244            }
245        };
246
247        match (matching_vec, matching_elem) {
248            (true, true) if out_id.is_some() => b.copy_object(ty, out_id, obj).unwrap(),
249            (true, true) => obj,
250            (true, false) => cast_elem(b, obj, out_id),
251            (false, true) => self.broadcast(b, obj, out_id, other),
252            (false, false) => {
253                let broadcast = self.broadcast(b, obj, None, other);
254                cast_elem(b, broadcast, out_id)
255            }
256        }
257    }
258}
259
260#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
261pub enum Elem {
262    Void,
263    Bool,
264    Int(u32, bool),
265    Float(u32, Option<FPEncoding>),
266    Relaxed,
267}
268
269impl Elem {
270    pub fn id<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>) -> Word {
271        let id = match self {
272            Elem::Void => b.type_void(),
273            Elem::Bool => b.type_bool(),
274            Elem::Int(width, _) => b.type_int(*width, 0),
275            Elem::Float(width, encoding) => b.type_float(*width, *encoding),
276            Elem::Relaxed => b.type_float(32, None),
277        };
278        if b.debug_symbols && !b.state.debug_types.contains(&id) {
279            b.debug_name(id, format!("{self}"));
280            b.state.debug_types.insert(id);
281        }
282        id
283    }
284
285    pub fn size(&self) -> u32 {
286        match self {
287            Elem::Void => 0,
288            Elem::Bool => 1,
289            Elem::Int(size, _) => *size / 8,
290            Elem::Float(size, _) => *size / 8,
291            Elem::Relaxed => 4,
292        }
293    }
294
295    pub fn constant<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>, value: ConstVal) -> Word {
296        let ty = self.id(b);
297        match self {
298            Elem::Void => unreachable!(),
299            Elem::Bool if value.as_u64() != 0 => b.constant_true(ty),
300            Elem::Bool => b.constant_false(ty),
301            _ => match value {
302                ConstVal::Bit32(val) => b.dedup_constant_bit32(ty, val),
303                ConstVal::Bit64(val) => b.dedup_constant_bit64(ty, val),
304            },
305        }
306    }
307
308    pub fn float_encoding(&self) -> Option<FPEncoding> {
309        match self {
310            Elem::Float(_, encoding) => *encoding,
311            _ => None,
312        }
313    }
314}
315
316impl<T: SpirvTarget> SpirvCompiler<T> {
317    pub fn compile_type(&mut self, item: core::Type) -> Item {
318        match item {
319            core::Type::Scalar(storage) => Item::Scalar(self.compile_storage_type(storage)),
320            core::Type::Vector(storage, size) => {
321                Item::Vector(self.compile_storage_type(storage), size as u32)
322            }
323            core::Type::Semantic(_) => unimplemented!("Can't compile semantic type"),
324        }
325    }
326
327    pub fn compile_storage_type(&mut self, ty: core::StorageType) -> Elem {
328        match ty {
329            core::StorageType::Scalar(ty) | core::StorageType::Atomic(ty) => self.compile_elem(ty),
330            core::StorageType::Opaque(ty) => match ty {
331                core::OpaqueType::Barrier(_) => {
332                    unimplemented!("Barrier type not supported in SPIR-V")
333                }
334            },
335            core::StorageType::Packed(_, _) => {
336                unimplemented!("Packed types not yet supported in SPIR-V")
337            }
338        }
339    }
340
341    pub fn compile_elem(&mut self, elem: core::ElemType) -> Elem {
342        match elem {
343            core::ElemType::Float(
344                core::FloatKind::E2M1
345                | core::FloatKind::E2M3
346                | core::FloatKind::E3M2
347                | core::FloatKind::UE8M0,
348            ) => panic!("Minifloat not supported in SPIR-V"),
349            core::ElemType::Float(core::FloatKind::E4M3) => {
350                self.capabilities.insert(Capability::Float8EXT);
351                Elem::Float(8, Some(FPEncoding::Float8E4M3EXT))
352            }
353            core::ElemType::Float(core::FloatKind::E5M2) => {
354                self.capabilities.insert(Capability::Float8EXT);
355                Elem::Float(8, Some(FPEncoding::Float8E5M2EXT))
356            }
357            core::ElemType::Float(core::FloatKind::BF16) => {
358                self.capabilities.insert(Capability::BFloat16TypeKHR);
359                Elem::Float(16, Some(FPEncoding::BFloat16KHR))
360            }
361            core::ElemType::Float(FloatKind::F16) => {
362                self.capabilities.insert(Capability::Float16);
363                Elem::Float(16, None)
364            }
365            core::ElemType::Float(FloatKind::TF32) => panic!("TF32 not supported in SPIR-V"),
366            core::ElemType::Float(FloatKind::Flex32) => Elem::Relaxed,
367            core::ElemType::Float(FloatKind::F32) => Elem::Float(32, None),
368            core::ElemType::Float(FloatKind::F64) => {
369                self.capabilities.insert(Capability::Float64);
370                Elem::Float(64, None)
371            }
372            core::ElemType::Int(IntKind::I8) => {
373                self.capabilities.insert(Capability::Int8);
374                Elem::Int(8, true)
375            }
376            core::ElemType::Int(IntKind::I16) => {
377                self.capabilities.insert(Capability::Int16);
378                Elem::Int(16, true)
379            }
380            core::ElemType::Int(IntKind::I32) => Elem::Int(32, true),
381            core::ElemType::Int(IntKind::I64) => {
382                self.capabilities.insert(Capability::Int64);
383                Elem::Int(64, true)
384            }
385            core::ElemType::UInt(UIntKind::U64) => {
386                self.capabilities.insert(Capability::Int64);
387                Elem::Int(64, false)
388            }
389            core::ElemType::UInt(UIntKind::U32) => Elem::Int(32, false),
390            core::ElemType::UInt(UIntKind::U16) => {
391                self.capabilities.insert(Capability::Int16);
392                Elem::Int(16, false)
393            }
394            core::ElemType::UInt(UIntKind::U8) => {
395                self.capabilities.insert(Capability::Int8);
396                Elem::Int(8, false)
397            }
398            core::ElemType::Bool => Elem::Bool,
399        }
400    }
401
402    pub fn static_cast(&mut self, val: ConstVal, from: &Elem, item: &Item) -> (Word, ConstVal) {
403        let elem_cast = match (*from, item.elem()) {
404            (Elem::Bool, Elem::Int(width, _)) => ConstVal::from_uint(val.as_u32() as u64, width),
405            (Elem::Bool, Elem::Float(width, encoding)) => {
406                ConstVal::from_float(val.as_u32() as f64, width, encoding)
407            }
408            (Elem::Bool, Elem::Relaxed) => ConstVal::from_float(val.as_u32() as f64, 32, None),
409            (Elem::Int(_, _), Elem::Bool) => ConstVal::from_bool(val.as_u64() != 0),
410            (Elem::Int(_, false), Elem::Int(width, _)) => ConstVal::from_uint(val.as_u64(), width),
411            (Elem::Int(w_in, true), Elem::Int(width, _)) => {
412                ConstVal::from_uint(val.as_int(w_in) as u64, width)
413            }
414            (Elem::Int(_, false), Elem::Float(width, encoding)) => {
415                ConstVal::from_float(val.as_u64() as f64, width, encoding)
416            }
417            (Elem::Int(_, false), Elem::Relaxed) => {
418                ConstVal::from_float(val.as_u64() as f64, 32, None)
419            }
420            (Elem::Int(in_w, true), Elem::Float(width, encoding)) => {
421                ConstVal::from_float(val.as_int(in_w) as f64, width, encoding)
422            }
423            (Elem::Int(in_w, true), Elem::Relaxed) => {
424                ConstVal::from_float(val.as_int(in_w) as f64, 32, None)
425            }
426            (Elem::Float(in_w, encoding), Elem::Bool) => {
427                ConstVal::from_bool(val.as_float(in_w, encoding) != 0.0)
428            }
429            (Elem::Relaxed, Elem::Bool) => ConstVal::from_bool(val.as_float(32, None) != 0.0),
430            (Elem::Float(in_w, encoding), Elem::Int(out_w, false)) => {
431                ConstVal::from_uint(val.as_float(in_w, encoding) as u64, out_w)
432            }
433            (Elem::Relaxed, Elem::Int(out_w, false)) => {
434                ConstVal::from_uint(val.as_float(32, None) as u64, out_w)
435            }
436            (Elem::Float(in_w, encoding), Elem::Int(out_w, true)) => {
437                ConstVal::from_int(val.as_float(in_w, encoding) as i64, out_w)
438            }
439            (Elem::Relaxed, Elem::Int(out_w, true)) => {
440                ConstVal::from_int(val.as_float(32, None) as i64, out_w)
441            }
442            (Elem::Float(in_w, encoding), Elem::Float(out_w, encoding_out)) => {
443                ConstVal::from_float(val.as_float(in_w, encoding), out_w, encoding_out)
444            }
445            (Elem::Relaxed, Elem::Float(out_w, encoding)) => {
446                ConstVal::from_float(val.as_float(32, None), out_w, encoding)
447            }
448            (Elem::Float(in_w, encoding), Elem::Relaxed) => {
449                ConstVal::from_float(val.as_float(in_w, encoding), 32, None)
450            }
451            (Elem::Bool, Elem::Bool) => val,
452            (Elem::Relaxed, Elem::Relaxed) => val,
453            (_, Elem::Void) | (Elem::Void, _) => unreachable!(),
454        };
455        let id = item.constant(self, elem_cast);
456        (id, elem_cast)
457    }
458}
459
460impl std::fmt::Display for Item {
461    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
462        match self {
463            Item::Scalar(elem) => write!(f, "{elem}"),
464            Item::Vector(elem, factor) => write!(f, "vec{factor}<{elem}>"),
465            Item::Array(item, len) => write!(f, "array<{item}, {len}>"),
466            Item::RuntimeArray(item) => write!(f, "array<{item}>"),
467            Item::Struct(members) => {
468                write!(f, "struct<")?;
469                for item in members {
470                    write!(f, "{item}")?;
471                }
472                f.write_str(">")
473            }
474            Item::Pointer(class, item) => write!(f, "ptr<{class:?}, {item}>"),
475            Item::CoopMatrix { ty, ident, .. } => write!(f, "matrix<{ty}, {ident:?}>"),
476        }
477    }
478}
479
480impl std::fmt::Display for Elem {
481    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
482        match self {
483            Elem::Void => write!(f, "void"),
484            Elem::Bool => write!(f, "bool"),
485            Elem::Int(width, false) => write!(f, "u{width}"),
486            Elem::Int(width, true) => write!(f, "i{width}"),
487            Elem::Float(width, None) => write!(f, "f{width}"),
488            Elem::Float(_, Some(FPEncoding::BFloat16KHR)) => write!(f, "bf16"),
489            Elem::Float(_, Some(FPEncoding::Float8E4M3EXT)) => write!(f, "e4m3"),
490            Elem::Float(_, Some(FPEncoding::Float8E5M2EXT)) => write!(f, "e5m2"),
491            Elem::Relaxed => write!(f, "flex32"),
492        }
493    }
494}