cubecl_spirv/
item.rs

1use cubecl_core::ir::{self as core, FloatKind, IntKind, UIntKind};
2use rspirv::spirv::{Capability, CooperativeMatrixUse, FPEncoding, Scope, StorageClass, Word};
3
4use crate::{compiler::SpirvCompiler, target::SpirvTarget, variable::ConstVal};
5
6#[derive(Debug, Clone, PartialEq, Eq, Hash)]
7pub enum Item {
8    Scalar(Elem),
9    // Vector of scalars. Must be 2, 3, or 4, or 8/16 for OpenCL only
10    Vector(Elem, u32),
11    Array(Box<Item>, u32),
12    RuntimeArray(Box<Item>),
13    Struct(Vec<Item>),
14    Pointer(StorageClass, Box<Item>),
15    CoopMatrix {
16        ty: Elem,
17        rows: u32,
18        columns: u32,
19        ident: CooperativeMatrixUse,
20    },
21}
22
23impl Item {
24    pub fn id<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>) -> Word {
25        let id = match self {
26            Item::Scalar(elem) => elem.id(b),
27            Item::Vector(elem, vec) => {
28                let elem = elem.id(b);
29                b.type_vector(elem, *vec)
30            }
31            Item::Array(item, len) => {
32                let item = item.id(b);
33                let len = b.const_u32(*len);
34                b.type_array(item, len)
35            }
36            Item::RuntimeArray(item) => {
37                let item = item.id(b);
38                b.type_runtime_array(item)
39            }
40            Item::Struct(vec) => {
41                let items: Vec<_> = vec.iter().map(|item| item.id(b)).collect();
42                let id = b.id(); // Avoid deduplicating this struct, because of decorations
43                b.type_struct_id(Some(id), items)
44            }
45            Item::Pointer(storage_class, item) => {
46                let item = item.id(b);
47                b.type_pointer(None, *storage_class, item)
48            }
49            Item::CoopMatrix {
50                ty,
51                rows,
52                columns,
53                ident,
54            } => {
55                let ty = ty.id(b);
56                let scope = b.const_u32(Scope::Subgroup as u32);
57                let usage = b.const_u32(*ident as u32);
58                b.type_cooperative_matrix_khr(ty, scope, *rows, *columns, usage)
59            }
60        };
61        if b.debug_symbols && !b.state.debug_types.contains(&id) {
62            b.debug_name(id, format!("{self}"));
63            b.state.debug_types.insert(id);
64        }
65        id
66    }
67
68    pub fn size(&self) -> u32 {
69        match self {
70            Item::Scalar(elem) => elem.size(),
71            Item::Vector(elem, factor) => elem.size() * *factor,
72            Item::Array(item, len) => item.size() * *len,
73            Item::RuntimeArray(item) => item.size(),
74            Item::Struct(vec) => vec.iter().map(|it| it.size()).sum(),
75            Item::Pointer(_, item) => item.size(),
76            Item::CoopMatrix { ty, .. } => ty.size(),
77        }
78    }
79
80    pub fn elem(&self) -> Elem {
81        match self {
82            Item::Scalar(elem) => *elem,
83            Item::Vector(elem, _) => *elem,
84            Item::Array(item, _) => item.elem(),
85            Item::RuntimeArray(item) => item.elem(),
86            Item::Struct(_) => Elem::Void,
87            Item::Pointer(_, item) => item.elem(),
88            Item::CoopMatrix { ty, .. } => *ty,
89        }
90    }
91
92    pub fn same_vectorization(&self, elem: Elem) -> Item {
93        match self {
94            Item::Scalar(_) => Item::Scalar(elem),
95            Item::Vector(_, factor) => Item::Vector(elem, *factor),
96            _ => unreachable!(),
97        }
98    }
99
100    pub fn constant<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>, value: ConstVal) -> Word {
101        let scalar = self.elem().constant(b, value);
102        let ty = self.id(b);
103        match self {
104            Item::Scalar(_) => scalar,
105            Item::Vector(_, vec) => b.constant_composite(ty, (0..*vec).map(|_| scalar)),
106            Item::Array(item, len) => {
107                let elem = item.constant(b, value);
108                b.constant_composite(ty, (0..*len).map(|_| elem))
109            }
110            Item::RuntimeArray(_) => unimplemented!("Can't create constant runtime array"),
111            Item::Struct(elems) => {
112                let items = elems
113                    .iter()
114                    .map(|item| item.constant(b, value))
115                    .collect::<Vec<_>>();
116                b.constant_composite(ty, items)
117            }
118            Item::Pointer(_, _) => unimplemented!("Can't create constant pointer"),
119            Item::CoopMatrix { .. } => unimplemented!("Can't create constant cmma matrix"),
120        }
121    }
122
123    pub fn const_u32<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>, value: u32) -> Word {
124        b.static_cast(ConstVal::Bit32(value), &Elem::Int(32, false), self)
125    }
126
127    /// Broadcast a scalar to a vector if needed, ex: f32 -> vec2<f32>, vec2<f32> -> vec2<f32>
128    pub fn broadcast<T: SpirvTarget>(
129        &self,
130        b: &mut SpirvCompiler<T>,
131        obj: Word,
132        out_id: Option<Word>,
133        other: &Item,
134    ) -> Word {
135        match (self, other) {
136            (Item::Scalar(elem), Item::Vector(_, factor)) => {
137                let item = Item::Vector(*elem, *factor);
138                let ty = item.id(b);
139                b.composite_construct(ty, out_id, (0..*factor).map(|_| obj).collect::<Vec<_>>())
140                    .unwrap()
141            }
142            _ => obj,
143        }
144    }
145
146    pub fn cast_to<T: SpirvTarget>(
147        &self,
148        b: &mut SpirvCompiler<T>,
149        out_id: Option<Word>,
150        obj: Word,
151        other: &Item,
152    ) -> Word {
153        let ty = other.id(b);
154
155        let matching_vec = match (self, other) {
156            (Item::Scalar(_), Item::Scalar(_)) => true,
157            (Item::Vector(_, factor_from), Item::Vector(_, factor_to)) => factor_from == factor_to,
158            _ => false,
159        };
160        let matching_elem = self.elem() == other.elem();
161
162        let convert_i_width =
163            |b: &mut SpirvCompiler<T>, obj: Word, out_id: Option<Word>, signed: bool| {
164                if signed {
165                    b.s_convert(ty, out_id, obj).unwrap()
166                } else {
167                    b.u_convert(ty, out_id, obj).unwrap()
168                }
169            };
170
171        let convert_int = |b: &mut SpirvCompiler<T>,
172                           obj: Word,
173                           out_id: Option<Word>,
174                           (width_self, signed_self),
175                           (width_other, signed_other)| {
176            let width_differs = width_self != width_other;
177            let sign_extend = signed_self && signed_other;
178            match width_differs {
179                true => convert_i_width(b, obj, out_id, sign_extend),
180                false => b.copy_object(ty, out_id, obj).unwrap(),
181            }
182        };
183
184        let cast_elem = |b: &mut SpirvCompiler<T>, obj: Word, out_id: Option<Word>| -> Word {
185            match (self.elem(), other.elem()) {
186                (Elem::Bool, Elem::Int(_, _)) => {
187                    let one = other.const_u32(b, 1);
188                    let zero = other.const_u32(b, 0);
189                    b.select(ty, out_id, obj, one, zero).unwrap()
190                }
191                (Elem::Bool, Elem::Float(_, _)) | (Elem::Bool, Elem::Relaxed) => {
192                    let one = other.const_u32(b, 1);
193                    let zero = other.const_u32(b, 0);
194                    b.select(ty, out_id, obj, one, zero).unwrap()
195                }
196                (Elem::Int(_, _), Elem::Bool) => {
197                    let zero = self.const_u32(b, 0);
198                    b.i_not_equal(ty, out_id, obj, zero).unwrap()
199                }
200                (Elem::Int(width_self, signed_self), Elem::Int(width_other, signed_other)) => {
201                    convert_int(
202                        b,
203                        obj,
204                        out_id,
205                        (width_self, signed_self),
206                        (width_other, signed_other),
207                    )
208                }
209                (Elem::Int(_, false), Elem::Float(_, _)) | (Elem::Int(_, false), Elem::Relaxed) => {
210                    b.convert_u_to_f(ty, out_id, obj).unwrap()
211                }
212                (Elem::Int(_, true), Elem::Float(_, _)) | (Elem::Int(_, true), Elem::Relaxed) => {
213                    b.convert_s_to_f(ty, out_id, obj).unwrap()
214                }
215                (Elem::Float(_, _), Elem::Bool) | (Elem::Relaxed, Elem::Bool) => {
216                    let zero = self.const_u32(b, 0);
217                    b.f_unord_not_equal(ty, out_id, obj, zero).unwrap()
218                }
219                (Elem::Float(_, _), Elem::Int(_, false)) | (Elem::Relaxed, Elem::Int(_, false)) => {
220                    b.convert_f_to_u(ty, out_id, obj).unwrap()
221                }
222                (Elem::Float(_, _), Elem::Int(_, true)) | (Elem::Relaxed, Elem::Int(_, true)) => {
223                    b.convert_f_to_s(ty, out_id, obj).unwrap()
224                }
225                (Elem::Float(_, _), Elem::Float(_, _))
226                | (Elem::Float(_, _), Elem::Relaxed)
227                | (Elem::Relaxed, Elem::Float(_, _)) => b.f_convert(ty, out_id, obj).unwrap(),
228                (Elem::Bool, Elem::Bool) => b.copy_object(ty, out_id, obj).unwrap(),
229                (Elem::Relaxed, Elem::Relaxed) => b.copy_object(ty, out_id, obj).unwrap(),
230                (from, to) => panic!("Invalid cast from {from:?} to {to:?}"),
231            }
232        };
233
234        match (matching_vec, matching_elem) {
235            (true, true) if out_id.is_some() => b.copy_object(ty, out_id, obj).unwrap(),
236            (true, true) => obj,
237            (true, false) => cast_elem(b, obj, out_id),
238            (false, true) => self.broadcast(b, obj, out_id, other),
239            (false, false) => {
240                let broadcast = self.broadcast(b, obj, None, other);
241                cast_elem(b, broadcast, out_id)
242            }
243        }
244    }
245}
246
247#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
248pub enum Elem {
249    Void,
250    Bool,
251    Int(u32, bool),
252    Float(u32, Option<FPEncoding>),
253    Relaxed,
254}
255
256impl Elem {
257    pub fn id<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>) -> Word {
258        let id = match self {
259            Elem::Void => b.type_void(),
260            Elem::Bool => b.type_bool(),
261            Elem::Int(width, _) => b.type_int(*width, 0),
262            Elem::Float(width, encoding) => b.type_float(*width, *encoding),
263            Elem::Relaxed => b.type_float(32, None),
264        };
265        if b.debug_symbols && !b.state.debug_types.contains(&id) {
266            b.debug_name(id, format!("{self}"));
267            b.state.debug_types.insert(id);
268        }
269        id
270    }
271
272    pub fn size(&self) -> u32 {
273        match self {
274            Elem::Void => 0,
275            Elem::Bool => 1,
276            Elem::Int(size, _) => *size / 8,
277            Elem::Float(size, _) => *size / 8,
278            Elem::Relaxed => 4,
279        }
280    }
281
282    pub fn constant<T: SpirvTarget>(&self, b: &mut SpirvCompiler<T>, value: ConstVal) -> Word {
283        let ty = self.id(b);
284        match self {
285            Elem::Void => unreachable!(),
286            Elem::Bool if value.as_u64() != 0 => b.constant_true(ty),
287            Elem::Bool => b.constant_false(ty),
288            _ => match value {
289                ConstVal::Bit32(val) => b.dedup_constant_bit32(ty, val),
290                ConstVal::Bit64(val) => b.dedup_constant_bit64(ty, val),
291            },
292        }
293    }
294}
295
296impl<T: SpirvTarget> SpirvCompiler<T> {
297    pub fn compile_type(&mut self, item: core::Type) -> Item {
298        match item {
299            core::Type::Scalar(storage) => Item::Scalar(self.compile_storage_type(storage)),
300            core::Type::Line(storage, size) => {
301                Item::Vector(self.compile_storage_type(storage), size)
302            }
303            core::Type::Semantic(_) => unimplemented!("Can't compile semantic type"),
304        }
305    }
306
307    pub fn compile_storage_type(&mut self, ty: core::StorageType) -> Elem {
308        match ty {
309            core::StorageType::Scalar(ty) | core::StorageType::Atomic(ty) => self.compile_elem(ty),
310            core::StorageType::Packed(_, _) => {
311                unimplemented!("Packed types not yet supported in SPIR-V")
312            }
313        }
314    }
315
316    pub fn compile_elem(&mut self, elem: core::ElemType) -> Elem {
317        match elem {
318            core::ElemType::Float(
319                core::FloatKind::E2M1
320                | core::FloatKind::E2M3
321                | core::FloatKind::E3M2
322                | core::FloatKind::UE8M0,
323            ) => panic!("Minifloat not supported in SPIR-V"),
324            core::ElemType::Float(core::FloatKind::E4M3) => {
325                self.capabilities.insert(Capability::Float8EXT);
326                Elem::Float(8, Some(FPEncoding::Float8E4M3EXT))
327            }
328            core::ElemType::Float(core::FloatKind::E5M2) => {
329                self.capabilities.insert(Capability::Float8EXT);
330                Elem::Float(8, Some(FPEncoding::Float8E5M2EXT))
331            }
332            core::ElemType::Float(core::FloatKind::BF16) => {
333                self.capabilities.insert(Capability::BFloat16TypeKHR);
334                Elem::Float(16, Some(FPEncoding::BFloat16KHR))
335            }
336            core::ElemType::Float(FloatKind::F16) => {
337                self.capabilities.insert(Capability::Float16);
338                Elem::Float(16, None)
339            }
340            core::ElemType::Float(FloatKind::TF32) => panic!("TF32 not supported in SPIR-V"),
341            core::ElemType::Float(FloatKind::Flex32) => Elem::Relaxed,
342            core::ElemType::Float(FloatKind::F32) => Elem::Float(32, None),
343            core::ElemType::Float(FloatKind::F64) => {
344                self.capabilities.insert(Capability::Float64);
345                Elem::Float(64, None)
346            }
347            core::ElemType::Int(IntKind::I8) => {
348                self.capabilities.insert(Capability::Int8);
349                Elem::Int(8, true)
350            }
351            core::ElemType::Int(IntKind::I16) => {
352                self.capabilities.insert(Capability::Int16);
353                Elem::Int(16, true)
354            }
355            core::ElemType::Int(IntKind::I32) => Elem::Int(32, true),
356            core::ElemType::Int(IntKind::I64) => {
357                self.capabilities.insert(Capability::Int64);
358                Elem::Int(64, true)
359            }
360            core::ElemType::UInt(UIntKind::U64) => {
361                self.capabilities.insert(Capability::Int64);
362                Elem::Int(64, false)
363            }
364            core::ElemType::UInt(UIntKind::U32) => Elem::Int(32, false),
365            core::ElemType::UInt(UIntKind::U16) => {
366                self.capabilities.insert(Capability::Int16);
367                Elem::Int(16, false)
368            }
369            core::ElemType::UInt(UIntKind::U8) => {
370                self.capabilities.insert(Capability::Int8);
371                Elem::Int(8, false)
372            }
373            core::ElemType::Bool => Elem::Bool,
374        }
375    }
376
377    pub fn static_core(&mut self, val: core::Variable, item: &Item) -> Word {
378        let val = val.as_const().unwrap();
379
380        let value = match (val, item.elem()) {
381            (core::ConstantScalarValue::Int(val, _), Elem::Bool) => ConstVal::from_bool(val != 0),
382            (core::ConstantScalarValue::Int(val, _), Elem::Int(width, false)) => {
383                ConstVal::from_uint(val as u64, width)
384            }
385            (core::ConstantScalarValue::Int(val, _), Elem::Int(width, true)) => {
386                ConstVal::from_int(val, width)
387            }
388            (core::ConstantScalarValue::Int(val, _), Elem::Float(width, encoding)) => {
389                ConstVal::from_float(val as f64, width, encoding)
390            }
391            (core::ConstantScalarValue::Int(val, _), Elem::Relaxed) => {
392                ConstVal::from_float(val as f64, 32, None)
393            }
394            (core::ConstantScalarValue::Float(val, _), Elem::Bool) => {
395                ConstVal::from_bool(val != 0.0)
396            }
397            (core::ConstantScalarValue::Float(val, _), Elem::Int(width, false)) => {
398                ConstVal::from_uint(val as u64, width)
399            }
400            (core::ConstantScalarValue::Float(val, _), Elem::Int(width, true)) => {
401                ConstVal::from_int(val as i64, width)
402            }
403            (core::ConstantScalarValue::Float(val, _), Elem::Float(width, encoding)) => {
404                ConstVal::from_float(val, width, encoding)
405            }
406            (core::ConstantScalarValue::Float(val, _), Elem::Relaxed) => {
407                ConstVal::from_float(val, 32, None)
408            }
409            (core::ConstantScalarValue::UInt(val, _), Elem::Bool) => ConstVal::from_bool(val != 0),
410            (core::ConstantScalarValue::UInt(val, _), Elem::Int(width, false)) => {
411                ConstVal::from_uint(val, width)
412            }
413            (core::ConstantScalarValue::UInt(val, _), Elem::Int(width, true)) => {
414                ConstVal::from_int(val as i64, width)
415            }
416            (core::ConstantScalarValue::UInt(val, _), Elem::Float(width, encoding)) => {
417                ConstVal::from_float(val as f64, width, encoding)
418            }
419            (core::ConstantScalarValue::UInt(val, _), Elem::Relaxed) => {
420                ConstVal::from_float(val as f64, 32, None)
421            }
422            (core::ConstantScalarValue::Bool(val), Elem::Bool) => ConstVal::from_bool(val),
423            (core::ConstantScalarValue::Bool(val), Elem::Int(width, _)) => {
424                ConstVal::from_uint(val as u64, width)
425            }
426            (core::ConstantScalarValue::Bool(val), Elem::Float(width, encoding)) => {
427                ConstVal::from_float(val as u32 as f64, width, encoding)
428            }
429            (core::ConstantScalarValue::Bool(val), Elem::Relaxed) => {
430                ConstVal::from_float(val as u32 as f64, 32, None)
431            }
432            (_, Elem::Void) => unreachable!(),
433        };
434        item.constant(self, value)
435    }
436
437    pub fn static_cast(&mut self, val: ConstVal, from: &Elem, item: &Item) -> Word {
438        let elem_cast = match (*from, item.elem()) {
439            (Elem::Bool, Elem::Int(width, _)) => ConstVal::from_uint(val.as_u32() as u64, width),
440            (Elem::Bool, Elem::Float(width, encoding)) => {
441                ConstVal::from_float(val.as_u32() as f64, width, encoding)
442            }
443            (Elem::Bool, Elem::Relaxed) => ConstVal::from_float(val.as_u32() as f64, 32, None),
444            (Elem::Int(_, _), Elem::Bool) => ConstVal::from_bool(val.as_u64() != 0),
445            (Elem::Int(_, false), Elem::Int(width, _)) => ConstVal::from_uint(val.as_u64(), width),
446            (Elem::Int(w_in, true), Elem::Int(width, _)) => {
447                ConstVal::from_uint(val.as_int(w_in) as u64, width)
448            }
449            (Elem::Int(_, false), Elem::Float(width, encoding)) => {
450                ConstVal::from_float(val.as_u64() as f64, width, encoding)
451            }
452            (Elem::Int(_, false), Elem::Relaxed) => {
453                ConstVal::from_float(val.as_u64() as f64, 32, None)
454            }
455            (Elem::Int(in_w, true), Elem::Float(width, encoding)) => {
456                ConstVal::from_float(val.as_int(in_w) as f64, width, encoding)
457            }
458            (Elem::Int(in_w, true), Elem::Relaxed) => {
459                ConstVal::from_float(val.as_int(in_w) as f64, 32, None)
460            }
461            (Elem::Float(in_w, encoding), Elem::Bool) => {
462                ConstVal::from_bool(val.as_float(in_w, encoding) != 0.0)
463            }
464            (Elem::Relaxed, Elem::Bool) => ConstVal::from_bool(val.as_float(32, None) != 0.0),
465            (Elem::Float(in_w, encoding), Elem::Int(out_w, false)) => {
466                ConstVal::from_uint(val.as_float(in_w, encoding) as u64, out_w)
467            }
468            (Elem::Relaxed, Elem::Int(out_w, false)) => {
469                ConstVal::from_uint(val.as_float(32, None) as u64, out_w)
470            }
471            (Elem::Float(in_w, encoding), Elem::Int(out_w, true)) => {
472                ConstVal::from_int(val.as_float(in_w, encoding) as i64, out_w)
473            }
474            (Elem::Relaxed, Elem::Int(out_w, true)) => {
475                ConstVal::from_int(val.as_float(32, None) as i64, out_w)
476            }
477            (Elem::Float(in_w, encoding), Elem::Float(out_w, encoding_out)) => {
478                ConstVal::from_float(val.as_float(in_w, encoding), out_w, encoding_out)
479            }
480            (Elem::Relaxed, Elem::Float(out_w, encoding)) => {
481                ConstVal::from_float(val.as_float(32, None), out_w, encoding)
482            }
483            (Elem::Float(in_w, encoding), Elem::Relaxed) => {
484                ConstVal::from_float(val.as_float(in_w, encoding), 32, None)
485            }
486            (Elem::Bool, Elem::Bool) => val,
487            (Elem::Relaxed, Elem::Relaxed) => val,
488            (_, Elem::Void) | (Elem::Void, _) => unreachable!(),
489        };
490        item.constant(self, elem_cast)
491    }
492}
493
494impl std::fmt::Display for Item {
495    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
496        match self {
497            Item::Scalar(elem) => write!(f, "{elem}"),
498            Item::Vector(elem, factor) => write!(f, "vec{factor}<{elem}>"),
499            Item::Array(item, len) => write!(f, "array<{item}, {len}>"),
500            Item::RuntimeArray(item) => write!(f, "array<{item}>"),
501            Item::Struct(members) => {
502                write!(f, "struct<")?;
503                for item in members {
504                    write!(f, "{item}")?;
505                }
506                f.write_str(">")
507            }
508            Item::Pointer(class, item) => write!(f, "ptr<{class:?}, {item}>"),
509            Item::CoopMatrix { ty, ident, .. } => write!(f, "matrix<{ty}, {ident:?}>"),
510        }
511    }
512}
513
514impl std::fmt::Display for Elem {
515    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
516        match self {
517            Elem::Void => write!(f, "void"),
518            Elem::Bool => write!(f, "bool"),
519            Elem::Int(width, false) => write!(f, "u{width}"),
520            Elem::Int(width, true) => write!(f, "i{width}"),
521            Elem::Float(width, None) => write!(f, "f{width}"),
522            Elem::Float(_, Some(FPEncoding::BFloat16KHR)) => write!(f, "bf16"),
523            Elem::Float(_, Some(FPEncoding::Float8E4M3EXT)) => write!(f, "e4m3"),
524            Elem::Float(_, Some(FPEncoding::Float8E5M2EXT)) => write!(f, "e5m2"),
525            Elem::Relaxed => write!(f, "flex32"),
526        }
527    }
528}