Skip to main content

cubecl_spirv/
item.rs

1use cubecl_core::ir::{self as core, FloatKind, IntKind, UIntKind};
2use rspirv::spirv::{Capability, CooperativeMatrixUse, FPEncoding, Scope, StorageClass, Word};
3use serde::{Deserialize, Serialize};
4
5use crate::{compiler::SpirvCompiler, target::SpirvTarget, variable::ConstVal};
6
7#[derive(Debug, Clone, PartialEq, Eq, Hash)]
8pub enum Item {
9    Scalar(Elem),
10    // Vector of scalars. Must be 2, 3, or 4, or 8/16 for OpenCL only
11    Vector(Elem, u32),
12    Array(Box<Item>, u32),
13    RuntimeArray(Box<Item>),
14    Struct(Vec<Item>),
15    Pointer(StorageClass, Box<Item>),
16    CoopMatrix {
17        ty: Elem,
18        rows: u32,
19        columns: u32,
20        ident: CooperativeMatrixUse,
21    },
22}
23
24impl Item {
25    pub fn id<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>) -> Word {
26        let id = match self {
27            Item::Scalar(elem) => elem.id(b),
28            Item::Vector(elem, vec) => {
29                let elem = elem.id(b);
30                b.type_vector(elem, *vec)
31            }
32            Item::Array(item, len) => {
33                let item = item.id(b);
34                let len = b.const_u32(*len);
35                b.type_array(item, len)
36            }
37            Item::RuntimeArray(item) => {
38                let item = item.id(b);
39                b.type_runtime_array(item)
40            }
41            Item::Struct(vec) => {
42                let items: Vec<_> = vec.iter().map(|item| item.id(b)).collect();
43                let id = b.id(); // Avoid deduplicating this struct, because of decorations
44                b.type_struct_id(Some(id), items)
45            }
46            Item::Pointer(storage_class, item) => {
47                let item = item.id(b);
48                b.type_pointer(None, *storage_class, item)
49            }
50            Item::CoopMatrix {
51                ty,
52                rows,
53                columns,
54                ident,
55            } => {
56                let ty = ty.id(b);
57                let scope = b.const_u32(Scope::Subgroup as u32);
58                let usage = b.const_u32(*ident as u32);
59                b.type_cooperative_matrix_khr(ty, scope, *rows, *columns, usage)
60            }
61        };
62        if b.debug_symbols && !b.state.debug_types.contains(&id) {
63            b.debug_name(id, format!("{self}"));
64            b.state.debug_types.insert(id);
65        }
66        id
67    }
68
69    pub fn builtin_u32() -> Self {
70        Item::Scalar(Elem::Int(32, false))
71    }
72
73    pub fn size(&self) -> u32 {
74        match self {
75            Item::Scalar(elem) => elem.size(),
76            Item::Vector(elem, factor) => elem.size() * *factor,
77            Item::Array(item, len) => item.size() * *len,
78            Item::RuntimeArray(item) => item.size(),
79            Item::Struct(vec) => vec.iter().map(|it| it.size()).sum(),
80            Item::Pointer(_, item) => item.size(),
81            Item::CoopMatrix { ty, .. } => ty.size(),
82        }
83    }
84
85    pub fn elem(&self) -> Elem {
86        match self {
87            Item::Scalar(elem) => *elem,
88            Item::Vector(elem, _) => *elem,
89            Item::Array(item, _) => item.elem(),
90            Item::RuntimeArray(item) => item.elem(),
91            Item::Struct(_) => Elem::Void,
92            Item::Pointer(_, item) => item.elem(),
93            Item::CoopMatrix { ty, .. } => *ty,
94        }
95    }
96
97    pub fn same_vectorization(&self, elem: Elem) -> Item {
98        match self {
99            Item::Scalar(_) => Item::Scalar(elem),
100            Item::Vector(_, factor) => Item::Vector(elem, *factor),
101            _ => unreachable!(),
102        }
103    }
104
105    pub fn constant<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>, value: ConstVal) -> Word {
106        let scalar = self.elem().constant(b, value);
107        let ty = self.id(b);
108        match self {
109            Item::Scalar(_) => scalar,
110            Item::Vector(_, vec) => b.constant_composite(ty, (0..*vec).map(|_| scalar)),
111            Item::Array(item, len) => {
112                let elem = item.constant(b, value);
113                b.constant_composite(ty, (0..*len).map(|_| elem))
114            }
115            Item::RuntimeArray(_) => unimplemented!("Can't create constant runtime array"),
116            Item::Struct(elems) => {
117                let items = elems
118                    .iter()
119                    .map(|item| item.constant(b, value))
120                    .collect::<Vec<_>>();
121                b.constant_composite(ty, items)
122            }
123            Item::Pointer(_, _) => unimplemented!("Can't create constant pointer"),
124            Item::CoopMatrix { .. } => unimplemented!("Can't create constant cmma matrix"),
125        }
126    }
127
128    pub fn const_u32<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>, value: u32) -> Word {
129        b.static_cast(ConstVal::Bit32(value), &Elem::Int(32, false), self)
130            .0
131    }
132
133    /// Broadcast a scalar to a vector if needed, ex: f32 -> vec2<f32>, vec2<f32> -> vec2<f32>
134    pub fn broadcast<T: SpirvTarget>(
135        &self,
136        b: &mut SpirvCompiler<T>,
137        obj: Word,
138        out_id: Option<Word>,
139        other: &Item,
140    ) -> Word {
141        match (self, other) {
142            (Item::Scalar(elem), Item::Vector(_, factor)) => {
143                let item = Item::Vector(*elem, *factor);
144                let ty = item.id(b);
145                b.composite_construct(ty, out_id, (0..*factor).map(|_| obj).collect::<Vec<_>>())
146                    .unwrap()
147            }
148            _ => obj,
149        }
150    }
151
152    pub fn cast_to<T: SpirvTarget>(
153        &self,
154        b: &mut SpirvCompiler<T>,
155        out_id: Option<Word>,
156        obj: Word,
157        other: &Item,
158    ) -> Word {
159        let ty = other.id(b);
160
161        let matching_vec = match (self, other) {
162            (Item::Scalar(_), Item::Scalar(_)) => true,
163            (Item::Vector(_, factor_from), Item::Vector(_, factor_to)) => factor_from == factor_to,
164            _ => false,
165        };
166        let matching_elem = self.elem() == other.elem();
167
168        let convert_i_width =
169            |b: &mut SpirvCompiler<T>, obj: Word, out_id: Option<Word>, signed: bool| {
170                if signed {
171                    b.s_convert(ty, out_id, obj).unwrap()
172                } else {
173                    b.u_convert(ty, out_id, obj).unwrap()
174                }
175            };
176
177        let convert_int = |b: &mut SpirvCompiler<T>,
178                           obj: Word,
179                           out_id: Option<Word>,
180                           (width_self, signed_self),
181                           (width_other, signed_other)| {
182            let width_differs = width_self != width_other;
183            let sign_extend = signed_self && signed_other;
184            match width_differs {
185                true => convert_i_width(b, obj, out_id, sign_extend),
186                false => b.copy_object(ty, out_id, obj).unwrap(),
187            }
188        };
189
190        let cast_elem = |b: &mut SpirvCompiler<T>, obj: Word, out_id: Option<Word>| -> Word {
191            match (self.elem(), other.elem()) {
192                (Elem::Bool, Elem::Int(_, _)) => {
193                    let one = other.const_u32(b, 1);
194                    let zero = other.const_u32(b, 0);
195                    b.select(ty, out_id, obj, one, zero).unwrap()
196                }
197                (Elem::Bool, Elem::Float(_, _)) | (Elem::Bool, Elem::Relaxed) => {
198                    let one = other.const_u32(b, 1);
199                    let zero = other.const_u32(b, 0);
200                    b.select(ty, out_id, obj, one, zero).unwrap()
201                }
202                (Elem::Int(_, _), Elem::Bool) => {
203                    let zero = self.const_u32(b, 0);
204                    b.i_not_equal(ty, out_id, obj, zero).unwrap()
205                }
206                (Elem::Int(width_self, signed_self), Elem::Int(width_other, signed_other)) => {
207                    convert_int(
208                        b,
209                        obj,
210                        out_id,
211                        (width_self, signed_self),
212                        (width_other, signed_other),
213                    )
214                }
215                (Elem::Int(_, false), Elem::Float(_, _)) | (Elem::Int(_, false), Elem::Relaxed) => {
216                    b.convert_u_to_f(ty, out_id, obj).unwrap()
217                }
218                (Elem::Int(_, true), Elem::Float(_, _)) | (Elem::Int(_, true), Elem::Relaxed) => {
219                    b.convert_s_to_f(ty, out_id, obj).unwrap()
220                }
221                (Elem::Float(_, _), Elem::Bool) | (Elem::Relaxed, Elem::Bool) => {
222                    let zero = self.const_u32(b, 0);
223                    b.f_unord_not_equal(ty, out_id, obj, zero).unwrap()
224                }
225                (Elem::Float(_, _), Elem::Int(_, false)) | (Elem::Relaxed, Elem::Int(_, false)) => {
226                    b.convert_f_to_u(ty, out_id, obj).unwrap()
227                }
228                (Elem::Float(_, _), Elem::Int(_, true)) | (Elem::Relaxed, Elem::Int(_, true)) => {
229                    b.convert_f_to_s(ty, out_id, obj).unwrap()
230                }
231                (Elem::Float(_, _), Elem::Float(_, _))
232                | (Elem::Float(_, _), Elem::Relaxed)
233                | (Elem::Relaxed, Elem::Float(_, _)) => b.f_convert(ty, out_id, obj).unwrap(),
234                (Elem::Bool, Elem::Bool) => b.copy_object(ty, out_id, obj).unwrap(),
235                (Elem::Relaxed, Elem::Relaxed) => b.copy_object(ty, out_id, obj).unwrap(),
236                (from, to) => panic!("Invalid cast from {from:?} to {to:?}"),
237            }
238        };
239
240        match (matching_vec, matching_elem) {
241            (true, true) if out_id.is_some() => b.copy_object(ty, out_id, obj).unwrap(),
242            (true, true) => obj,
243            (true, false) => cast_elem(b, obj, out_id),
244            (false, true) => self.broadcast(b, obj, out_id, other),
245            (false, false) => {
246                let broadcast = self.broadcast(b, obj, None, other);
247                cast_elem(b, broadcast, out_id)
248            }
249        }
250    }
251}
252
253#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
254pub enum Elem {
255    Void,
256    Bool,
257    Int(u32, bool),
258    Float(u32, Option<FPEncoding>),
259    Relaxed,
260}
261
262impl Elem {
263    pub fn id<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>) -> Word {
264        let id = match self {
265            Elem::Void => b.type_void(),
266            Elem::Bool => b.type_bool(),
267            Elem::Int(width, _) => b.type_int(*width, 0),
268            Elem::Float(width, encoding) => b.type_float(*width, *encoding),
269            Elem::Relaxed => b.type_float(32, None),
270        };
271        if b.debug_symbols && !b.state.debug_types.contains(&id) {
272            b.debug_name(id, format!("{self}"));
273            b.state.debug_types.insert(id);
274        }
275        id
276    }
277
278    pub fn size(&self) -> u32 {
279        match self {
280            Elem::Void => 0,
281            Elem::Bool => 1,
282            Elem::Int(size, _) => *size / 8,
283            Elem::Float(size, _) => *size / 8,
284            Elem::Relaxed => 4,
285        }
286    }
287
288    pub fn constant<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>, value: ConstVal) -> Word {
289        let ty = self.id(b);
290        match self {
291            Elem::Void => unreachable!(),
292            Elem::Bool if value.as_u64() != 0 => b.constant_true(ty),
293            Elem::Bool => b.constant_false(ty),
294            _ => match value {
295                ConstVal::Bit32(val) => b.dedup_constant_bit32(ty, val),
296                ConstVal::Bit64(val) => b.dedup_constant_bit64(ty, val),
297            },
298        }
299    }
300
301    pub fn float_encoding(&self) -> Option<FPEncoding> {
302        match self {
303            Elem::Float(_, encoding) => *encoding,
304            _ => None,
305        }
306    }
307}
308
309impl<T: SpirvTarget> SpirvCompiler<T> {
310    pub fn compile_type(&mut self, item: core::Type) -> Item {
311        match item {
312            core::Type::Scalar(storage) => Item::Scalar(self.compile_storage_type(storage)),
313            core::Type::Line(storage, size) => {
314                Item::Vector(self.compile_storage_type(storage), size as u32)
315            }
316            core::Type::Semantic(_) => unimplemented!("Can't compile semantic type"),
317        }
318    }
319
320    pub fn compile_storage_type(&mut self, ty: core::StorageType) -> Elem {
321        match ty {
322            core::StorageType::Scalar(ty) | core::StorageType::Atomic(ty) => self.compile_elem(ty),
323            core::StorageType::Opaque(ty) => match ty {
324                core::OpaqueType::Barrier(_) => {
325                    unimplemented!("Barrier type not supported in SPIR-V")
326                }
327            },
328            core::StorageType::Packed(_, _) => {
329                unimplemented!("Packed types not yet supported in SPIR-V")
330            }
331        }
332    }
333
334    pub fn compile_elem(&mut self, elem: core::ElemType) -> Elem {
335        match elem {
336            core::ElemType::Float(
337                core::FloatKind::E2M1
338                | core::FloatKind::E2M3
339                | core::FloatKind::E3M2
340                | core::FloatKind::UE8M0,
341            ) => panic!("Minifloat not supported in SPIR-V"),
342            core::ElemType::Float(core::FloatKind::E4M3) => {
343                self.capabilities.insert(Capability::Float8EXT);
344                Elem::Float(8, Some(FPEncoding::Float8E4M3EXT))
345            }
346            core::ElemType::Float(core::FloatKind::E5M2) => {
347                self.capabilities.insert(Capability::Float8EXT);
348                Elem::Float(8, Some(FPEncoding::Float8E5M2EXT))
349            }
350            core::ElemType::Float(core::FloatKind::BF16) => {
351                self.capabilities.insert(Capability::BFloat16TypeKHR);
352                Elem::Float(16, Some(FPEncoding::BFloat16KHR))
353            }
354            core::ElemType::Float(FloatKind::F16) => {
355                self.capabilities.insert(Capability::Float16);
356                Elem::Float(16, None)
357            }
358            core::ElemType::Float(FloatKind::TF32) => panic!("TF32 not supported in SPIR-V"),
359            core::ElemType::Float(FloatKind::Flex32) => Elem::Relaxed,
360            core::ElemType::Float(FloatKind::F32) => Elem::Float(32, None),
361            core::ElemType::Float(FloatKind::F64) => {
362                self.capabilities.insert(Capability::Float64);
363                Elem::Float(64, None)
364            }
365            core::ElemType::Int(IntKind::I8) => {
366                self.capabilities.insert(Capability::Int8);
367                Elem::Int(8, true)
368            }
369            core::ElemType::Int(IntKind::I16) => {
370                self.capabilities.insert(Capability::Int16);
371                Elem::Int(16, true)
372            }
373            core::ElemType::Int(IntKind::I32) => Elem::Int(32, true),
374            core::ElemType::Int(IntKind::I64) => {
375                self.capabilities.insert(Capability::Int64);
376                Elem::Int(64, true)
377            }
378            core::ElemType::UInt(UIntKind::U64) => {
379                self.capabilities.insert(Capability::Int64);
380                Elem::Int(64, false)
381            }
382            core::ElemType::UInt(UIntKind::U32) => Elem::Int(32, false),
383            core::ElemType::UInt(UIntKind::U16) => {
384                self.capabilities.insert(Capability::Int16);
385                Elem::Int(16, false)
386            }
387            core::ElemType::UInt(UIntKind::U8) => {
388                self.capabilities.insert(Capability::Int8);
389                Elem::Int(8, false)
390            }
391            core::ElemType::Bool => Elem::Bool,
392        }
393    }
394
395    pub fn static_cast(&mut self, val: ConstVal, from: &Elem, item: &Item) -> (Word, ConstVal) {
396        let elem_cast = match (*from, item.elem()) {
397            (Elem::Bool, Elem::Int(width, _)) => ConstVal::from_uint(val.as_u32() as u64, width),
398            (Elem::Bool, Elem::Float(width, encoding)) => {
399                ConstVal::from_float(val.as_u32() as f64, width, encoding)
400            }
401            (Elem::Bool, Elem::Relaxed) => ConstVal::from_float(val.as_u32() as f64, 32, None),
402            (Elem::Int(_, _), Elem::Bool) => ConstVal::from_bool(val.as_u64() != 0),
403            (Elem::Int(_, false), Elem::Int(width, _)) => ConstVal::from_uint(val.as_u64(), width),
404            (Elem::Int(w_in, true), Elem::Int(width, _)) => {
405                ConstVal::from_uint(val.as_int(w_in) as u64, width)
406            }
407            (Elem::Int(_, false), Elem::Float(width, encoding)) => {
408                ConstVal::from_float(val.as_u64() as f64, width, encoding)
409            }
410            (Elem::Int(_, false), Elem::Relaxed) => {
411                ConstVal::from_float(val.as_u64() as f64, 32, None)
412            }
413            (Elem::Int(in_w, true), Elem::Float(width, encoding)) => {
414                ConstVal::from_float(val.as_int(in_w) as f64, width, encoding)
415            }
416            (Elem::Int(in_w, true), Elem::Relaxed) => {
417                ConstVal::from_float(val.as_int(in_w) as f64, 32, None)
418            }
419            (Elem::Float(in_w, encoding), Elem::Bool) => {
420                ConstVal::from_bool(val.as_float(in_w, encoding) != 0.0)
421            }
422            (Elem::Relaxed, Elem::Bool) => ConstVal::from_bool(val.as_float(32, None) != 0.0),
423            (Elem::Float(in_w, encoding), Elem::Int(out_w, false)) => {
424                ConstVal::from_uint(val.as_float(in_w, encoding) as u64, out_w)
425            }
426            (Elem::Relaxed, Elem::Int(out_w, false)) => {
427                ConstVal::from_uint(val.as_float(32, None) as u64, out_w)
428            }
429            (Elem::Float(in_w, encoding), Elem::Int(out_w, true)) => {
430                ConstVal::from_int(val.as_float(in_w, encoding) as i64, out_w)
431            }
432            (Elem::Relaxed, Elem::Int(out_w, true)) => {
433                ConstVal::from_int(val.as_float(32, None) as i64, out_w)
434            }
435            (Elem::Float(in_w, encoding), Elem::Float(out_w, encoding_out)) => {
436                ConstVal::from_float(val.as_float(in_w, encoding), out_w, encoding_out)
437            }
438            (Elem::Relaxed, Elem::Float(out_w, encoding)) => {
439                ConstVal::from_float(val.as_float(32, None), out_w, encoding)
440            }
441            (Elem::Float(in_w, encoding), Elem::Relaxed) => {
442                ConstVal::from_float(val.as_float(in_w, encoding), 32, None)
443            }
444            (Elem::Bool, Elem::Bool) => val,
445            (Elem::Relaxed, Elem::Relaxed) => val,
446            (_, Elem::Void) | (Elem::Void, _) => unreachable!(),
447        };
448        let id = item.constant(self, elem_cast);
449        (id, elem_cast)
450    }
451}
452
453impl std::fmt::Display for Item {
454    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
455        match self {
456            Item::Scalar(elem) => write!(f, "{elem}"),
457            Item::Vector(elem, factor) => write!(f, "vec{factor}<{elem}>"),
458            Item::Array(item, len) => write!(f, "array<{item}, {len}>"),
459            Item::RuntimeArray(item) => write!(f, "array<{item}>"),
460            Item::Struct(members) => {
461                write!(f, "struct<")?;
462                for item in members {
463                    write!(f, "{item}")?;
464                }
465                f.write_str(">")
466            }
467            Item::Pointer(class, item) => write!(f, "ptr<{class:?}, {item}>"),
468            Item::CoopMatrix { ty, ident, .. } => write!(f, "matrix<{ty}, {ident:?}>"),
469        }
470    }
471}
472
473impl std::fmt::Display for Elem {
474    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
475        match self {
476            Elem::Void => write!(f, "void"),
477            Elem::Bool => write!(f, "bool"),
478            Elem::Int(width, false) => write!(f, "u{width}"),
479            Elem::Int(width, true) => write!(f, "i{width}"),
480            Elem::Float(width, None) => write!(f, "f{width}"),
481            Elem::Float(_, Some(FPEncoding::BFloat16KHR)) => write!(f, "bf16"),
482            Elem::Float(_, Some(FPEncoding::Float8E4M3EXT)) => write!(f, "e4m3"),
483            Elem::Float(_, Some(FPEncoding::Float8E5M2EXT)) => write!(f, "e5m2"),
484            Elem::Relaxed => write!(f, "flex32"),
485        }
486    }
487}