cubecl_ir/
type.rs

1use super::{ConstantScalarValue, Variable, VariableKind};
2use crate::{BarrierLevel, TypeHash};
3use core::fmt::Display;
4use cubecl_common::{
5    e2m1, e2m1x2, e2m3, e3m2, e4m3, e5m2, flex32,
6    quant::scheme::{QuantParam, QuantValue},
7    tf32, ue8m0,
8};
9
10#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
12#[allow(missing_docs)]
13pub enum FloatKind {
14    /// FP4, 2 bit exponent, 1 bit mantissa
15    E2M1,
16    /// FP6, 2 bit exponent, 3 bit mantissa
17    /// Note: represented by an 8-bit value, with the upper two bits being insignificant
18    E2M3,
19    /// FP6, 3 bit exponent, 2 bit mantissa
20    /// Note: represented by an 8-bit value, with the upper two bits being insignificant
21    E3M2,
22    /// FP8, 4 bit exponent, 3 bit mantissa
23    E4M3,
24    /// FP8, 5 bit exponent, 2 bit mantissa
25    E5M2,
26    /// FP8, unsigned, 8 bit exponent, 0 bit mantissa
27    UE8M0,
28    F16,
29    BF16,
30    Flex32,
31    F32,
32    TF32,
33    F64,
34}
35
36#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
37#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
38#[allow(missing_docs)]
39pub enum IntKind {
40    I8,
41    I16,
42    I32,
43    I64,
44}
45
46#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
47#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
48#[allow(missing_docs)]
49pub enum UIntKind {
50    U8,
51    U16,
52    U32,
53    U64,
54}
55
56/// Conceptual element type, not necessarily the physical type used in the code
57#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
58#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
59#[allow(missing_docs)]
60pub enum ElemType {
61    Float(FloatKind),
62    Int(IntKind),
63    UInt(UIntKind),
64    Bool,
65}
66
67#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
68#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
69pub enum OpaqueType {
70    Barrier(BarrierLevel),
71}
72
73#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
74#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
75pub enum SemanticType {
76    BarrierToken,
77    Pipeline,
78    TensorMap,
79}
80
81/// Physical type containing one or more elements
82#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
83#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
84pub enum StorageType {
85    /// `ElemType` is the same as the physical type
86    Scalar(ElemType),
87    /// Packed values of type `ElemType`
88    Packed(ElemType, u32),
89    /// Atomically accessed version of `ElemType`
90    Atomic(ElemType),
91    /// Opaque types that can be stored but not interacted with normally. Currently only barrier,
92    /// but may be used for arrival tokens and tensor map descriptors, for example.
93    Opaque(OpaqueType),
94}
95
96impl ElemType {
97    /// Creates an elem type that correspond to the given [QuantParam].
98    pub fn from_quant_param(quant_param: QuantParam) -> Self {
99        match quant_param {
100            QuantParam::F32 => Self::Float(FloatKind::F32),
101            QuantParam::F16 => Self::Float(FloatKind::F16),
102            QuantParam::BF16 => Self::Float(FloatKind::BF16),
103            QuantParam::UE8M0 => Self::Float(FloatKind::UE8M0),
104            QuantParam::UE4M3 => Self::Float(FloatKind::UE8M0),
105        }
106    }
107
108    /// Creates an elem type that correspond to the given [QuantValue].
109    pub fn from_quant_value(quant_value: QuantValue) -> Self {
110        match quant_value {
111            QuantValue::E5M2 => Self::Float(FloatKind::E5M2),
112            QuantValue::E4M3 => Self::Float(FloatKind::E4M3),
113            QuantValue::E2M1 => Self::Float(FloatKind::E2M1),
114            QuantValue::Q8F | QuantValue::Q8S => Self::Int(IntKind::I8),
115            other => panic!("Unsupported quant value {other:?}"),
116        }
117    }
118    /// Create a constant scalar from a float.
119    ///
120    /// The output will have the same type as the element.
121    pub fn constant_from_f64(&self, val: f64) -> Variable {
122        Variable::constant(match self {
123            ElemType::Float(kind) => ConstantScalarValue::Float(val, *kind),
124            ElemType::Int(kind) => ConstantScalarValue::Int(val as i64, *kind),
125            ElemType::UInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
126            ElemType::Bool => ConstantScalarValue::Bool(val > 0.0),
127        })
128    }
129    /// Create a constant scalar from a signed integer.
130    ///
131    /// The output will have the same type as the element.
132    pub fn constant_from_i64(&self, val: i64) -> Variable {
133        Variable::constant(match self {
134            ElemType::Float(kind) => ConstantScalarValue::Float(val as f64, *kind),
135            ElemType::Int(kind) => ConstantScalarValue::Int(val, *kind),
136            ElemType::UInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
137            ElemType::Bool => ConstantScalarValue::Bool(val > 0),
138        })
139    }
140    /// Create a constant scalar from a unsigned integer.
141    ///
142    /// The output will have the same type as the element.
143    pub fn constant_from_u64(&self, val: u64) -> Variable {
144        Variable::constant(match self {
145            ElemType::Float(kind) => ConstantScalarValue::Float(val as f64, *kind),
146            ElemType::Int(kind) => ConstantScalarValue::Int(val as i64, *kind),
147            ElemType::UInt(kind) => ConstantScalarValue::UInt(val, *kind),
148            ElemType::Bool => ConstantScalarValue::Bool(val > 0),
149        })
150    }
151    /// Create a constant scalar from a boolean.
152    ///
153    /// The output will have the same type as the element.
154    pub fn constant_from_bool(&self, val: bool) -> Variable {
155        Variable::constant(match self {
156            ElemType::Float(kind) => ConstantScalarValue::Float(val as u32 as f64, *kind),
157            ElemType::Int(kind) => ConstantScalarValue::Int(val as i64, *kind),
158            ElemType::UInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
159            ElemType::Bool => ConstantScalarValue::Bool(val),
160        })
161    }
162
163    /// Ensure that the variable provided, when a constant, is the same type as elem.
164    pub fn from_constant(&self, constant: Variable) -> Variable {
165        let value = match constant.kind {
166            VariableKind::ConstantScalar(value) => value,
167            _ => return constant,
168        };
169
170        match value {
171            ConstantScalarValue::Int(val, _) => self.constant_from_i64(val),
172            ConstantScalarValue::Float(val, _) => self.constant_from_f64(val),
173            ConstantScalarValue::UInt(val, _) => self.constant_from_u64(val),
174            ConstantScalarValue::Bool(val) => self.constant_from_bool(val),
175        }
176    }
177    /// Get the size in bytes.
178    pub const fn size(&self) -> usize {
179        match self {
180            ElemType::Float(kind) => match kind {
181                FloatKind::E2M1
182                | FloatKind::E2M3
183                | FloatKind::E3M2
184                | FloatKind::E4M3
185                | FloatKind::E5M2
186                | FloatKind::UE8M0 => core::mem::size_of::<u8>(),
187                FloatKind::F16 => core::mem::size_of::<half::f16>(),
188                FloatKind::BF16 => core::mem::size_of::<half::bf16>(),
189                FloatKind::F32 => core::mem::size_of::<f32>(),
190                FloatKind::F64 => core::mem::size_of::<f64>(),
191                FloatKind::Flex32 => core::mem::size_of::<f32>(),
192                FloatKind::TF32 => core::mem::size_of::<f32>(),
193            },
194            ElemType::Int(kind) => match kind {
195                IntKind::I8 => core::mem::size_of::<i8>(),
196                IntKind::I16 => core::mem::size_of::<i16>(),
197                IntKind::I32 => core::mem::size_of::<i32>(),
198                IntKind::I64 => core::mem::size_of::<i64>(),
199            },
200            ElemType::UInt(kind) => match kind {
201                UIntKind::U8 => core::mem::size_of::<u8>(),
202                UIntKind::U16 => core::mem::size_of::<u16>(),
203                UIntKind::U32 => core::mem::size_of::<u32>(),
204                UIntKind::U64 => core::mem::size_of::<u64>(),
205            },
206            ElemType::Bool => core::mem::size_of::<bool>(),
207        }
208    }
209
210    /// Get the size in bits.
211    pub const fn size_bits(&self) -> usize {
212        match self {
213            ElemType::Float(kind) => match kind {
214                FloatKind::E2M3
215                | FloatKind::E3M2
216                | FloatKind::E4M3
217                | FloatKind::E5M2
218                | FloatKind::UE8M0
219                | FloatKind::F16
220                | FloatKind::BF16
221                | FloatKind::F32
222                | FloatKind::F64
223                | FloatKind::Flex32
224                | FloatKind::TF32 => self.size() * 8,
225                FloatKind::E2M1 => 4,
226            },
227            ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool => self.size() * 8,
228        }
229    }
230
231    pub const fn min_line_size(&self) -> u8 {
232        match self {
233            ElemType::Float(FloatKind::E2M1) => 2,
234            _ => 1,
235        }
236    }
237
238    pub fn is_int(&self) -> bool {
239        matches!(self, ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool)
240    }
241
242    pub fn is_signed_int(&self) -> bool {
243        matches!(self, ElemType::Int(_))
244    }
245
246    pub fn is_unsigned_int(&self) -> bool {
247        matches!(self, ElemType::UInt(_) | ElemType::Bool)
248    }
249
250    pub fn is_float(&self) -> bool {
251        matches!(self, ElemType::Float(_))
252    }
253
254    pub fn max_variable(&self) -> Variable {
255        let value = match self {
256            ElemType::Float(kind) => match kind {
257                FloatKind::E2M1 => ConstantScalarValue::Float(e2m1::MAX, FloatKind::E2M1),
258                FloatKind::E2M3 => ConstantScalarValue::Float(e2m3::MAX, FloatKind::E2M3),
259                FloatKind::E3M2 => ConstantScalarValue::Float(e3m2::MAX, FloatKind::E3M2),
260                FloatKind::E4M3 => ConstantScalarValue::Float(e4m3::MAX, FloatKind::E4M3),
261                FloatKind::E5M2 => ConstantScalarValue::Float(e5m2::MAX, FloatKind::E5M2),
262                FloatKind::UE8M0 => ConstantScalarValue::Float(ue8m0::MAX, FloatKind::UE8M0),
263                FloatKind::F16 => {
264                    ConstantScalarValue::Float(half::f16::MAX.to_f64(), FloatKind::F16)
265                }
266                FloatKind::BF16 => {
267                    ConstantScalarValue::Float(half::bf16::MAX.to_f64(), FloatKind::BF16)
268                }
269                FloatKind::Flex32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::Flex32),
270                FloatKind::F32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::F32),
271                FloatKind::TF32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::TF32),
272                FloatKind::F64 => ConstantScalarValue::Float(f64::MAX, FloatKind::F64),
273            },
274            ElemType::Int(kind) => match kind {
275                IntKind::I8 => ConstantScalarValue::Int(i8::MAX.into(), IntKind::I8),
276                IntKind::I16 => ConstantScalarValue::Int(i16::MAX.into(), IntKind::I16),
277                IntKind::I32 => ConstantScalarValue::Int(i32::MAX.into(), IntKind::I32),
278                IntKind::I64 => ConstantScalarValue::Int(i64::MAX, IntKind::I64),
279            },
280            ElemType::UInt(kind) => match kind {
281                UIntKind::U8 => ConstantScalarValue::UInt(u8::MAX.into(), UIntKind::U8),
282                UIntKind::U16 => ConstantScalarValue::UInt(u16::MAX.into(), UIntKind::U16),
283                UIntKind::U32 => ConstantScalarValue::UInt(u32::MAX.into(), UIntKind::U32),
284                UIntKind::U64 => ConstantScalarValue::UInt(u64::MAX, UIntKind::U64),
285            },
286            ElemType::Bool => ConstantScalarValue::Bool(true),
287        };
288
289        Variable::new(VariableKind::ConstantScalar(value), Type::scalar(*self))
290    }
291
292    pub fn min_variable(&self) -> Variable {
293        let value = match self {
294            ElemType::Float(kind) => match kind {
295                FloatKind::E2M1 => ConstantScalarValue::Float(e2m1::MIN, FloatKind::E2M1),
296                FloatKind::E2M3 => ConstantScalarValue::Float(e2m3::MIN, FloatKind::E2M3),
297                FloatKind::E3M2 => ConstantScalarValue::Float(e3m2::MIN, FloatKind::E3M2),
298                FloatKind::E4M3 => ConstantScalarValue::Float(e4m3::MIN, FloatKind::E4M3),
299                FloatKind::E5M2 => ConstantScalarValue::Float(e5m2::MIN, FloatKind::E5M2),
300                FloatKind::UE8M0 => ConstantScalarValue::Float(ue8m0::MIN, FloatKind::UE8M0),
301                FloatKind::F16 => {
302                    ConstantScalarValue::Float(half::f16::MIN.to_f64(), FloatKind::F16)
303                }
304                FloatKind::BF16 => {
305                    ConstantScalarValue::Float(half::bf16::MIN.to_f64(), FloatKind::BF16)
306                }
307                FloatKind::Flex32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::Flex32),
308                FloatKind::F32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::F32),
309                FloatKind::TF32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::TF32),
310                FloatKind::F64 => ConstantScalarValue::Float(f64::MIN, FloatKind::F64),
311            },
312            ElemType::Int(kind) => match kind {
313                IntKind::I8 => ConstantScalarValue::Int(i8::MIN.into(), IntKind::I8),
314                IntKind::I16 => ConstantScalarValue::Int(i16::MIN.into(), IntKind::I16),
315                IntKind::I32 => ConstantScalarValue::Int(i32::MIN.into(), IntKind::I32),
316                IntKind::I64 => ConstantScalarValue::Int(i64::MIN, IntKind::I64),
317            },
318            ElemType::UInt(kind) => match kind {
319                UIntKind::U8 => ConstantScalarValue::UInt(u8::MIN.into(), UIntKind::U8),
320                UIntKind::U16 => ConstantScalarValue::UInt(u16::MIN.into(), UIntKind::U16),
321                UIntKind::U32 => ConstantScalarValue::UInt(u32::MIN.into(), UIntKind::U32),
322                UIntKind::U64 => ConstantScalarValue::UInt(u64::MIN, UIntKind::U64),
323            },
324            ElemType::Bool => ConstantScalarValue::Bool(false),
325        };
326
327        Variable::new(VariableKind::ConstantScalar(value), Type::scalar(*self))
328    }
329
330    pub fn epsilon(&self) -> f64 {
331        match self {
332            ElemType::Float(kind) => match kind {
333                FloatKind::E2M1 => 0.5 * (e2m1::MAX - e2m1::MIN),
334                FloatKind::E2M3 => 0.5 * (e2m3::MAX - e2m3::MIN),
335                FloatKind::E3M2 => 0.5 * (e3m2::MAX - e3m2::MIN),
336                FloatKind::E4M3 => 0.5 * (e4m3::MAX - e4m3::MIN),
337                FloatKind::E5M2 => 0.5 * (e5m2::MAX - e5m2::MIN),
338                FloatKind::UE8M0 => 0.5 * (ue8m0::MAX - ue8m0::MIN),
339                FloatKind::F16 => half::f16::EPSILON.to_f64(),
340                FloatKind::BF16 => 0.0078125, // bf16 epsilon ≈ 2^-7
341                FloatKind::Flex32 | FloatKind::F32 | FloatKind::TF32 => f32::EPSILON.into(),
342                FloatKind::F64 => f64::EPSILON,
343            },
344            ElemType::Int(_) | ElemType::UInt(_) => 1.0, // step of 1
345            ElemType::Bool => 1.0,
346        }
347    }
348}
349
350impl OpaqueType {
351    /// Get the size in bytes.
352    pub const fn size(&self) -> usize {
353        match self {
354            OpaqueType::Barrier(_) => 8,
355        }
356    }
357
358    /// Get the size in bits.
359    pub const fn size_bits(&self) -> usize {
360        match self {
361            OpaqueType::Barrier(_) => 64,
362        }
363    }
364}
365
366impl StorageType {
367    pub fn elem_type(&self) -> ElemType {
368        match self {
369            StorageType::Scalar(ty) | StorageType::Packed(ty, _) | StorageType::Atomic(ty) => *ty,
370            StorageType::Opaque(_) => unimplemented!("Can't get elem type for opaque type"),
371        }
372    }
373
374    pub fn packing_factor(&self) -> u32 {
375        match self {
376            StorageType::Packed(_, factor) => *factor,
377            _ => 1,
378        }
379    }
380
381    pub fn is_atomic(&self) -> bool {
382        matches!(self, StorageType::Atomic(_))
383    }
384
385    pub fn size(&self) -> usize {
386        self.size_bits().div_ceil(8)
387    }
388
389    pub fn size_bits(&self) -> usize {
390        match self {
391            StorageType::Packed(ty, factor) => ty.size_bits() * *factor as usize,
392            StorageType::Scalar(ty) | StorageType::Atomic(ty) => ty.size_bits(),
393            StorageType::Opaque(ty) => ty.size_bits(),
394        }
395    }
396
397    /// Ensure that the variable provided, when a constant, is the same type as elem.
398    pub fn from_constant(&self, constant: Variable) -> Variable {
399        self.elem_type().from_constant(constant)
400    }
401
402    pub fn is_int(&self) -> bool {
403        self.elem_type().is_int()
404    }
405
406    pub fn is_signed_int(&self) -> bool {
407        self.elem_type().is_signed_int()
408    }
409
410    pub fn is_unsigned_int(&self) -> bool {
411        self.elem_type().is_unsigned_int()
412    }
413
414    pub fn is_float(&self) -> bool {
415        self.elem_type().is_float()
416    }
417
418    /// Returns an empirical epsilon for this storage type, taking quantization into account.
419    pub fn epsilon(&self) -> f64 {
420        match self {
421            StorageType::Scalar(ty) | StorageType::Atomic(ty) => ty.epsilon(),
422            StorageType::Packed(ty, factor) => {
423                // For packed types, we can conservatively scale epsilon by the number of packed elements
424                ty.epsilon() * (*factor as f64)
425            }
426            StorageType::Opaque(_) => panic!("Opaque type does not have an epsilon"),
427        }
428    }
429}
430
431impl From<ElemType> for StorageType {
432    fn from(val: ElemType) -> Self {
433        StorageType::Scalar(val)
434    }
435}
436
437impl From<OpaqueType> for StorageType {
438    fn from(val: OpaqueType) -> Self {
439        StorageType::Opaque(val)
440    }
441}
442
443impl<T: Into<StorageType>> From<T> for Type {
444    fn from(val: T) -> Self {
445        Type::new(val.into())
446    }
447}
448
449impl From<SemanticType> for Type {
450    fn from(val: SemanticType) -> Self {
451        Type::semantic(val)
452    }
453}
454
455#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
456#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
457pub enum Type {
458    /// Scalar type containing a single storage element
459    Scalar(StorageType),
460    /// Line wrapping `n` storage elements
461    Line(StorageType, u32),
462    /// No defined physical representation, purely semantic. i.e. barrier, pipeline
463    Semantic(SemanticType),
464}
465
466pub type LineSize = u32;
467
468impl Type {
469    /// Fetch the elem of the item.
470    pub fn elem_type(&self) -> ElemType {
471        self.storage_type().elem_type()
472    }
473
474    /// Create a new item
475    pub fn new(storage: StorageType) -> Self {
476        Type::Scalar(storage)
477    }
478
479    pub fn scalar(elem: ElemType) -> Self {
480        Self::new(StorageType::Scalar(elem))
481    }
482
483    pub fn semantic(ty: SemanticType) -> Self {
484        Self::Semantic(ty)
485    }
486
487    pub fn line(self, line_size: LineSize) -> Type {
488        match line_size > 1 {
489            true => Type::Line(self.storage_type(), line_size),
490            false => Type::Scalar(self.storage_type()),
491        }
492    }
493
494    pub fn line_size(&self) -> u32 {
495        match self {
496            Type::Scalar(_) => 1,
497            Type::Line(_, line_size) => *line_size,
498            Type::Semantic(_) => 0,
499        }
500    }
501
502    pub fn size(&self) -> usize {
503        match self {
504            Type::Scalar(ty) => ty.size(),
505            Type::Line(ty, line_size) => ty.size() * *line_size as usize,
506            Type::Semantic(_) => 0,
507        }
508    }
509
510    pub fn size_bits(&self) -> usize {
511        match self {
512            Type::Scalar(ty) => ty.size_bits(),
513            Type::Line(ty, line_size) => ty.size_bits() * *line_size as usize,
514            Type::Semantic(_) => 0,
515        }
516    }
517
518    pub fn is_atomic(&self) -> bool {
519        !self.is_semantic() && self.storage_type().is_atomic()
520    }
521
522    pub fn is_int(&self) -> bool {
523        !self.is_semantic() && self.storage_type().is_int()
524    }
525
526    pub fn is_signed_int(&self) -> bool {
527        !self.is_semantic() && self.storage_type().is_signed_int()
528    }
529
530    pub fn is_unsigned_int(&self) -> bool {
531        !self.is_semantic() && self.storage_type().is_unsigned_int()
532    }
533
534    pub fn is_float(&self) -> bool {
535        !self.is_semantic() && self.storage_type().is_float()
536    }
537
538    pub fn storage_type(&self) -> StorageType {
539        match self {
540            Type::Scalar(ty) | Type::Line(ty, _) => *ty,
541            Type::Semantic(_) => unimplemented!("Can't get storage for semantic type"),
542        }
543    }
544
545    pub fn is_semantic(&self) -> bool {
546        match self {
547            Type::Scalar(_) | Type::Line(_, _) => false,
548            Type::Semantic(_) => true,
549        }
550    }
551}
552
553impl Display for Type {
554    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
555        match self {
556            Type::Scalar(ty) => write!(f, "{ty}"),
557            Type::Line(ty, line_size) => write!(f, "line<{ty}, {line_size}>"),
558            Type::Semantic(ty) => write!(f, "{ty}"),
559        }
560    }
561}
562
563impl Display for StorageType {
564    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
565        match self {
566            StorageType::Scalar(ty) => write!(f, "{ty}"),
567            StorageType::Packed(ty, factor) => write!(f, "packed<{ty}, {factor}>"),
568            StorageType::Atomic(ty) => write!(f, "atomic<{ty}>"),
569            StorageType::Opaque(ty) => write!(f, "{ty}"),
570        }
571    }
572}
573
574impl Display for ElemType {
575    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
576        match self {
577            Self::Float(kind) => match kind {
578                FloatKind::E2M1 => f.write_str("e2m1"),
579                FloatKind::E2M3 => f.write_str("e2m3"),
580                FloatKind::E3M2 => f.write_str("e3m2"),
581                FloatKind::E4M3 => f.write_str("e4m3"),
582                FloatKind::E5M2 => f.write_str("e5m2"),
583                FloatKind::UE8M0 => f.write_str("ue8m0"),
584                FloatKind::F16 => f.write_str("f16"),
585                FloatKind::BF16 => f.write_str("bf16"),
586                FloatKind::Flex32 => f.write_str("flex32"),
587                FloatKind::TF32 => f.write_str("tf32"),
588                FloatKind::F32 => f.write_str("f32"),
589                FloatKind::F64 => f.write_str("f64"),
590            },
591            Self::Int(kind) => match kind {
592                IntKind::I8 => f.write_str("i8"),
593                IntKind::I16 => f.write_str("i16"),
594                IntKind::I32 => f.write_str("i32"),
595                IntKind::I64 => f.write_str("i64"),
596            },
597            Self::UInt(kind) => match kind {
598                UIntKind::U8 => f.write_str("u8"),
599                UIntKind::U16 => f.write_str("u16"),
600                UIntKind::U32 => f.write_str("u32"),
601                UIntKind::U64 => f.write_str("u64"),
602            },
603            Self::Bool => f.write_str("bool"),
604        }
605    }
606}
607
608impl Display for SemanticType {
609    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
610        match self {
611            SemanticType::BarrierToken => f.write_str("barrier_token"),
612            SemanticType::Pipeline => f.write_str("pipeline"),
613            SemanticType::TensorMap => f.write_str("tensor_map"),
614        }
615    }
616}
617
618impl Display for OpaqueType {
619    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
620        match self {
621            OpaqueType::Barrier(level) => write!(f, "barrier<{level}>"),
622        }
623    }
624}
625
626impl From<bool> for Variable {
627    fn from(value: bool) -> Self {
628        Variable::constant(ConstantScalarValue::Bool(value))
629    }
630}
631
632impl From<i8> for Variable {
633    fn from(value: i8) -> Self {
634        Variable::constant(ConstantScalarValue::Int(value as i64, IntKind::I8))
635    }
636}
637
638impl From<i16> for Variable {
639    fn from(value: i16) -> Self {
640        Variable::constant(ConstantScalarValue::Int(value as i64, IntKind::I16))
641    }
642}
643
644impl From<i32> for Variable {
645    fn from(value: i32) -> Self {
646        Variable::constant(ConstantScalarValue::Int(value as i64, IntKind::I32))
647    }
648}
649
650impl From<i64> for Variable {
651    fn from(value: i64) -> Self {
652        Variable::constant(ConstantScalarValue::Int(value, IntKind::I64))
653    }
654}
655
656impl From<e2m1> for Variable {
657    fn from(_value: e2m1) -> Self {
658        unimplemented!("Can't currently construct minifloats")
659    }
660}
661
662impl From<e2m1x2> for Variable {
663    fn from(_value: e2m1x2) -> Self {
664        unimplemented!("Can't currently construct minifloats")
665    }
666}
667
668impl From<e2m3> for Variable {
669    fn from(_value: e2m3) -> Self {
670        unimplemented!("Can't currently construct minifloats")
671    }
672}
673
674impl From<e3m2> for Variable {
675    fn from(_value: e3m2) -> Self {
676        unimplemented!("Can't currently construct minifloats")
677    }
678}
679
680impl From<e4m3> for Variable {
681    fn from(_value: e4m3) -> Self {
682        unimplemented!("Can't currently construct minifloats")
683    }
684}
685
686impl From<e5m2> for Variable {
687    fn from(_value: e5m2) -> Self {
688        unimplemented!("Can't currently construct minifloats")
689    }
690}
691
692impl From<ue8m0> for Variable {
693    fn from(_value: ue8m0) -> Self {
694        unimplemented!("Can't currently construct minifloats")
695    }
696}
697
698impl From<half::f16> for Variable {
699    fn from(value: half::f16) -> Self {
700        Variable::constant(ConstantScalarValue::Float(value.to_f64(), FloatKind::F16))
701    }
702}
703
704impl From<half::bf16> for Variable {
705    fn from(value: half::bf16) -> Self {
706        Variable::constant(ConstantScalarValue::Float(value.to_f64(), FloatKind::BF16))
707    }
708}
709
710impl From<flex32> for Variable {
711    fn from(value: flex32) -> Self {
712        Variable::constant(ConstantScalarValue::Float(
713            value.to_f64(),
714            FloatKind::Flex32,
715        ))
716    }
717}
718
719impl From<tf32> for Variable {
720    fn from(value: tf32) -> Self {
721        Variable::constant(ConstantScalarValue::Float(value.to_f64(), FloatKind::TF32))
722    }
723}
724
725impl From<f32> for Variable {
726    fn from(value: f32) -> Self {
727        Variable::constant(ConstantScalarValue::Float(value as f64, FloatKind::F32))
728    }
729}
730
731impl From<f64> for Variable {
732    fn from(value: f64) -> Self {
733        Variable::constant(ConstantScalarValue::Float(value, FloatKind::F64))
734    }
735}
736
737impl From<u8> for Variable {
738    fn from(value: u8) -> Self {
739        Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U8))
740    }
741}
742
743impl From<u16> for Variable {
744    fn from(value: u16) -> Self {
745        Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U16))
746    }
747}
748
749impl From<u32> for Variable {
750    fn from(value: u32) -> Self {
751        Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U32))
752    }
753}
754
755impl From<u64> for Variable {
756    fn from(value: u64) -> Self {
757        Variable::constant(ConstantScalarValue::UInt(value, UIntKind::U64))
758    }
759}
760
761impl From<usize> for Variable {
762    fn from(value: usize) -> Self {
763        Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U32))
764    }
765}