cubecl_ir/
type.rs

1use super::{ConstantScalarValue, Variable, VariableKind};
2use crate::{BarrierLevel, TypeHash};
3use core::fmt::Display;
4use cubecl_common::{
5    e2m1, e2m1x2, e2m3, e3m2, e4m3, e5m2, flex32,
6    quant::scheme::{QuantParam, QuantValue},
7    tf32, ue8m0,
8};
9
10#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
12#[allow(missing_docs)]
13pub enum FloatKind {
14    /// FP4, 2 bit exponent, 1 bit mantissa
15    E2M1,
16    /// FP6, 2 bit exponent, 3 bit mantissa
17    /// Note: represented by an 8-bit value, with the upper two bits being insignificant
18    E2M3,
19    /// FP6, 3 bit exponent, 2 bit mantissa
20    /// Note: represented by an 8-bit value, with the upper two bits being insignificant
21    E3M2,
22    /// FP8, 4 bit exponent, 3 bit mantissa
23    E4M3,
24    /// FP8, 5 bit exponent, 2 bit mantissa
25    E5M2,
26    /// FP8, unsigned, 8 bit exponent, 0 bit mantissa
27    UE8M0,
28    F16,
29    BF16,
30    Flex32,
31    F32,
32    TF32,
33    F64,
34}
35
36#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
37#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
38#[allow(missing_docs)]
39pub enum IntKind {
40    I8,
41    I16,
42    I32,
43    I64,
44}
45
46#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
47#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
48#[allow(missing_docs)]
49pub enum UIntKind {
50    U8,
51    U16,
52    U32,
53    U64,
54}
55
56/// Conceptual element type, not necessarily the physical type used in the code
57#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
58#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
59#[allow(missing_docs)]
60pub enum ElemType {
61    Float(FloatKind),
62    Int(IntKind),
63    UInt(UIntKind),
64    Bool,
65}
66
67#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
68#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
69pub enum OpaqueType {
70    Barrier(BarrierLevel),
71}
72
73#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
74#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
75pub enum SemanticType {
76    BarrierToken,
77    Pipeline,
78    TensorMap,
79}
80
81/// Physical type containing one or more elements
82#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
83#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
84pub enum StorageType {
85    /// `ElemType` is the same as the physical type
86    Scalar(ElemType),
87    /// Packed values of type `ElemType`
88    Packed(ElemType, u32),
89    /// Atomically accessed version of `ElemType`
90    Atomic(ElemType),
91    /// Opaque types that can be stored but not interacted with normally. Currently only barrier,
92    /// but may be used for arrival tokens and tensor map descriptors, for example.
93    Opaque(OpaqueType),
94}
95
96impl ElemType {
97    /// Creates an elem type that correspond to the given [QuantParam].
98    pub fn from_quant_param(quant_param: QuantParam) -> Self {
99        match quant_param {
100            QuantParam::F32 => Self::Float(FloatKind::F32),
101            QuantParam::F16 => Self::Float(FloatKind::F16),
102            QuantParam::BF16 => Self::Float(FloatKind::BF16),
103            QuantParam::UE8M0 => Self::Float(FloatKind::UE8M0),
104            QuantParam::UE4M3 => Self::Float(FloatKind::UE8M0),
105        }
106    }
107
108    /// Creates an elem type that correspond to the given [QuantValue].
109    pub fn from_quant_value(quant_value: QuantValue) -> Self {
110        match quant_value {
111            QuantValue::E5M2 => Self::Float(FloatKind::E5M2),
112            QuantValue::E4M3 => Self::Float(FloatKind::E4M3),
113            QuantValue::E2M1 => Self::Float(FloatKind::E2M1),
114            QuantValue::Q8F | QuantValue::Q8S => Self::Int(IntKind::I8),
115            other => panic!("Unsupported quant value {other:?}"),
116        }
117    }
118    /// Create a constant scalar from a float.
119    ///
120    /// The output will have the same type as the element.
121    pub fn constant_from_f64(&self, val: f64) -> Variable {
122        Variable::constant(match self {
123            ElemType::Float(kind) => ConstantScalarValue::Float(val, *kind),
124            ElemType::Int(kind) => ConstantScalarValue::Int(val as i64, *kind),
125            ElemType::UInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
126            ElemType::Bool => ConstantScalarValue::Bool(val > 0.0),
127        })
128    }
129    /// Create a constant scalar from a signed integer.
130    ///
131    /// The output will have the same type as the element.
132    pub fn constant_from_i64(&self, val: i64) -> Variable {
133        Variable::constant(match self {
134            ElemType::Float(kind) => ConstantScalarValue::Float(val as f64, *kind),
135            ElemType::Int(kind) => ConstantScalarValue::Int(val, *kind),
136            ElemType::UInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
137            ElemType::Bool => ConstantScalarValue::Bool(val > 0),
138        })
139    }
140    /// Create a constant scalar from a unsigned integer.
141    ///
142    /// The output will have the same type as the element.
143    pub fn constant_from_u64(&self, val: u64) -> Variable {
144        Variable::constant(match self {
145            ElemType::Float(kind) => ConstantScalarValue::Float(val as f64, *kind),
146            ElemType::Int(kind) => ConstantScalarValue::Int(val as i64, *kind),
147            ElemType::UInt(kind) => ConstantScalarValue::UInt(val, *kind),
148            ElemType::Bool => ConstantScalarValue::Bool(val > 0),
149        })
150    }
151    /// Create a constant scalar from a boolean.
152    ///
153    /// The output will have the same type as the element.
154    pub fn constant_from_bool(&self, val: bool) -> Variable {
155        Variable::constant(match self {
156            ElemType::Float(kind) => ConstantScalarValue::Float(val as u32 as f64, *kind),
157            ElemType::Int(kind) => ConstantScalarValue::Int(val as i64, *kind),
158            ElemType::UInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
159            ElemType::Bool => ConstantScalarValue::Bool(val),
160        })
161    }
162
163    /// Ensure that the variable provided, when a constant, is the same type as elem.
164    pub fn from_constant(&self, constant: Variable) -> Variable {
165        let value = match constant.kind {
166            VariableKind::ConstantScalar(value) => value,
167            _ => return constant,
168        };
169
170        match value {
171            ConstantScalarValue::Int(val, _) => self.constant_from_i64(val),
172            ConstantScalarValue::Float(val, _) => self.constant_from_f64(val),
173            ConstantScalarValue::UInt(val, _) => self.constant_from_u64(val),
174            ConstantScalarValue::Bool(val) => self.constant_from_bool(val),
175        }
176    }
177    /// Get the size in bytes.
178    pub const fn size(&self) -> usize {
179        match self {
180            ElemType::Float(kind) => match kind {
181                FloatKind::E2M1
182                | FloatKind::E2M3
183                | FloatKind::E3M2
184                | FloatKind::E4M3
185                | FloatKind::E5M2
186                | FloatKind::UE8M0 => core::mem::size_of::<u8>(),
187                FloatKind::F16 => core::mem::size_of::<half::f16>(),
188                FloatKind::BF16 => core::mem::size_of::<half::bf16>(),
189                FloatKind::F32 => core::mem::size_of::<f32>(),
190                FloatKind::F64 => core::mem::size_of::<f64>(),
191                FloatKind::Flex32 => core::mem::size_of::<f32>(),
192                FloatKind::TF32 => core::mem::size_of::<f32>(),
193            },
194            ElemType::Int(kind) => match kind {
195                IntKind::I8 => core::mem::size_of::<i8>(),
196                IntKind::I16 => core::mem::size_of::<i16>(),
197                IntKind::I32 => core::mem::size_of::<i32>(),
198                IntKind::I64 => core::mem::size_of::<i64>(),
199            },
200            ElemType::UInt(kind) => match kind {
201                UIntKind::U8 => core::mem::size_of::<u8>(),
202                UIntKind::U16 => core::mem::size_of::<u16>(),
203                UIntKind::U32 => core::mem::size_of::<u32>(),
204                UIntKind::U64 => core::mem::size_of::<u64>(),
205            },
206            ElemType::Bool => core::mem::size_of::<bool>(),
207        }
208    }
209
210    /// Get the size in bits.
211    pub const fn size_bits(&self) -> usize {
212        match self {
213            ElemType::Float(kind) => match kind {
214                FloatKind::E2M3
215                | FloatKind::E3M2
216                | FloatKind::E4M3
217                | FloatKind::E5M2
218                | FloatKind::UE8M0
219                | FloatKind::F16
220                | FloatKind::BF16
221                | FloatKind::F32
222                | FloatKind::F64
223                | FloatKind::Flex32
224                | FloatKind::TF32 => self.size() * 8,
225                FloatKind::E2M1 => 4,
226            },
227            ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool => self.size() * 8,
228        }
229    }
230
231    pub const fn min_line_size(&self) -> u8 {
232        match self {
233            ElemType::Float(FloatKind::E2M1) => 2,
234            _ => 1,
235        }
236    }
237
238    pub fn is_int(&self) -> bool {
239        matches!(self, ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool)
240    }
241
242    pub fn is_signed_int(&self) -> bool {
243        matches!(self, ElemType::Int(_))
244    }
245
246    pub fn is_unsigned_int(&self) -> bool {
247        matches!(self, ElemType::UInt(_) | ElemType::Bool)
248    }
249
250    pub fn is_float(&self) -> bool {
251        matches!(self, ElemType::Float(_))
252    }
253
254    pub fn max_variable(&self) -> Variable {
255        let value = match self {
256            ElemType::Float(kind) => match kind {
257                FloatKind::E2M1 => ConstantScalarValue::Float(e2m1::MAX, FloatKind::E2M1),
258                FloatKind::E2M3 => ConstantScalarValue::Float(e2m3::MAX, FloatKind::E2M3),
259                FloatKind::E3M2 => ConstantScalarValue::Float(e3m2::MAX, FloatKind::E3M2),
260                FloatKind::E4M3 => ConstantScalarValue::Float(e4m3::MAX, FloatKind::E4M3),
261                FloatKind::E5M2 => ConstantScalarValue::Float(e5m2::MAX, FloatKind::E5M2),
262                FloatKind::UE8M0 => ConstantScalarValue::Float(ue8m0::MAX, FloatKind::UE8M0),
263                FloatKind::F16 => {
264                    ConstantScalarValue::Float(half::f16::MAX.to_f64(), FloatKind::F16)
265                }
266                FloatKind::BF16 => {
267                    ConstantScalarValue::Float(half::bf16::MAX.to_f64(), FloatKind::BF16)
268                }
269                FloatKind::Flex32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::Flex32),
270                FloatKind::F32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::F32),
271                FloatKind::TF32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::TF32),
272                FloatKind::F64 => ConstantScalarValue::Float(f64::MAX, FloatKind::F64),
273            },
274            ElemType::Int(kind) => match kind {
275                IntKind::I8 => ConstantScalarValue::Int(i8::MAX.into(), IntKind::I8),
276                IntKind::I16 => ConstantScalarValue::Int(i16::MAX.into(), IntKind::I16),
277                IntKind::I32 => ConstantScalarValue::Int(i32::MAX.into(), IntKind::I32),
278                IntKind::I64 => ConstantScalarValue::Int(i64::MAX, IntKind::I64),
279            },
280            ElemType::UInt(kind) => match kind {
281                UIntKind::U8 => ConstantScalarValue::UInt(u8::MAX.into(), UIntKind::U8),
282                UIntKind::U16 => ConstantScalarValue::UInt(u16::MAX.into(), UIntKind::U16),
283                UIntKind::U32 => ConstantScalarValue::UInt(u32::MAX.into(), UIntKind::U32),
284                UIntKind::U64 => ConstantScalarValue::UInt(u64::MAX, UIntKind::U64),
285            },
286            ElemType::Bool => ConstantScalarValue::Bool(true),
287        };
288
289        Variable::new(VariableKind::ConstantScalar(value), Type::scalar(*self))
290    }
291
292    pub fn min_variable(&self) -> Variable {
293        let value = match self {
294            ElemType::Float(kind) => match kind {
295                FloatKind::E2M1 => ConstantScalarValue::Float(e2m1::MIN, FloatKind::E2M1),
296                FloatKind::E2M3 => ConstantScalarValue::Float(e2m3::MIN, FloatKind::E2M3),
297                FloatKind::E3M2 => ConstantScalarValue::Float(e3m2::MIN, FloatKind::E3M2),
298                FloatKind::E4M3 => ConstantScalarValue::Float(e4m3::MIN, FloatKind::E4M3),
299                FloatKind::E5M2 => ConstantScalarValue::Float(e5m2::MIN, FloatKind::E5M2),
300                FloatKind::UE8M0 => ConstantScalarValue::Float(ue8m0::MIN, FloatKind::UE8M0),
301                FloatKind::F16 => {
302                    ConstantScalarValue::Float(half::f16::MIN.to_f64(), FloatKind::F16)
303                }
304                FloatKind::BF16 => {
305                    ConstantScalarValue::Float(half::bf16::MIN.to_f64(), FloatKind::BF16)
306                }
307                FloatKind::Flex32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::Flex32),
308                FloatKind::F32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::F32),
309                FloatKind::TF32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::TF32),
310                FloatKind::F64 => ConstantScalarValue::Float(f64::MIN, FloatKind::F64),
311            },
312            ElemType::Int(kind) => match kind {
313                IntKind::I8 => ConstantScalarValue::Int(i8::MIN.into(), IntKind::I8),
314                IntKind::I16 => ConstantScalarValue::Int(i16::MIN.into(), IntKind::I16),
315                IntKind::I32 => ConstantScalarValue::Int(i32::MIN.into(), IntKind::I32),
316                IntKind::I64 => ConstantScalarValue::Int(i64::MIN, IntKind::I64),
317            },
318            ElemType::UInt(kind) => match kind {
319                UIntKind::U8 => ConstantScalarValue::UInt(u8::MIN.into(), UIntKind::U8),
320                UIntKind::U16 => ConstantScalarValue::UInt(u16::MIN.into(), UIntKind::U16),
321                UIntKind::U32 => ConstantScalarValue::UInt(u32::MIN.into(), UIntKind::U32),
322                UIntKind::U64 => ConstantScalarValue::UInt(u64::MIN, UIntKind::U64),
323            },
324            ElemType::Bool => ConstantScalarValue::Bool(false),
325        };
326
327        Variable::new(VariableKind::ConstantScalar(value), Type::scalar(*self))
328    }
329}
330
331impl OpaqueType {
332    /// Get the size in bytes.
333    pub const fn size(&self) -> usize {
334        match self {
335            OpaqueType::Barrier(_) => 8,
336        }
337    }
338
339    /// Get the size in bits.
340    pub const fn size_bits(&self) -> usize {
341        match self {
342            OpaqueType::Barrier(_) => 64,
343        }
344    }
345}
346
347impl StorageType {
348    pub fn elem_type(&self) -> ElemType {
349        match self {
350            StorageType::Scalar(ty) | StorageType::Packed(ty, _) | StorageType::Atomic(ty) => *ty,
351            StorageType::Opaque(_) => unimplemented!("Can't get elem type for opaque type"),
352        }
353    }
354
355    pub fn packing_factor(&self) -> u32 {
356        match self {
357            StorageType::Packed(_, factor) => *factor,
358            _ => 1,
359        }
360    }
361
362    pub fn is_atomic(&self) -> bool {
363        matches!(self, StorageType::Atomic(_))
364    }
365
366    pub fn size(&self) -> usize {
367        self.size_bits().div_ceil(8)
368    }
369
370    pub fn size_bits(&self) -> usize {
371        match self {
372            StorageType::Packed(ty, factor) => ty.size_bits() * *factor as usize,
373            StorageType::Scalar(ty) | StorageType::Atomic(ty) => ty.size_bits(),
374            StorageType::Opaque(ty) => ty.size_bits(),
375        }
376    }
377
378    /// Ensure that the variable provided, when a constant, is the same type as elem.
379    pub fn from_constant(&self, constant: Variable) -> Variable {
380        self.elem_type().from_constant(constant)
381    }
382
383    pub fn is_int(&self) -> bool {
384        self.elem_type().is_int()
385    }
386
387    pub fn is_signed_int(&self) -> bool {
388        self.elem_type().is_signed_int()
389    }
390
391    pub fn is_unsigned_int(&self) -> bool {
392        self.elem_type().is_unsigned_int()
393    }
394
395    pub fn is_float(&self) -> bool {
396        self.elem_type().is_float()
397    }
398}
399
400impl From<ElemType> for StorageType {
401    fn from(val: ElemType) -> Self {
402        StorageType::Scalar(val)
403    }
404}
405
406impl From<OpaqueType> for StorageType {
407    fn from(val: OpaqueType) -> Self {
408        StorageType::Opaque(val)
409    }
410}
411
412impl<T: Into<StorageType>> From<T> for Type {
413    fn from(val: T) -> Self {
414        Type::new(val.into())
415    }
416}
417
418impl From<SemanticType> for Type {
419    fn from(val: SemanticType) -> Self {
420        Type::semantic(val)
421    }
422}
423
424#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
425#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
426pub enum Type {
427    /// Scalar type containing a single storage element
428    Scalar(StorageType),
429    /// Line wrapping `n` storage elements
430    Line(StorageType, u32),
431    /// No defined physical representation, purely semantic. i.e. barrier, pipeline
432    Semantic(SemanticType),
433}
434
435pub type LineSize = u32;
436
437impl Type {
438    /// Fetch the elem of the item.
439    pub fn elem_type(&self) -> ElemType {
440        self.storage_type().elem_type()
441    }
442
443    /// Create a new item
444    pub fn new(storage: StorageType) -> Self {
445        Type::Scalar(storage)
446    }
447
448    pub fn scalar(elem: ElemType) -> Self {
449        Self::new(StorageType::Scalar(elem))
450    }
451
452    pub fn semantic(ty: SemanticType) -> Self {
453        Self::Semantic(ty)
454    }
455
456    pub fn line(self, line_size: LineSize) -> Type {
457        match line_size > 1 {
458            true => Type::Line(self.storage_type(), line_size),
459            false => Type::Scalar(self.storage_type()),
460        }
461    }
462
463    pub fn line_size(&self) -> u32 {
464        match self {
465            Type::Scalar(_) => 1,
466            Type::Line(_, line_size) => *line_size,
467            Type::Semantic(_) => 0,
468        }
469    }
470
471    pub fn size(&self) -> usize {
472        match self {
473            Type::Scalar(ty) => ty.size(),
474            Type::Line(ty, line_size) => ty.size() * *line_size as usize,
475            Type::Semantic(_) => 0,
476        }
477    }
478
479    pub fn size_bits(&self) -> usize {
480        match self {
481            Type::Scalar(ty) => ty.size_bits(),
482            Type::Line(ty, line_size) => ty.size_bits() * *line_size as usize,
483            Type::Semantic(_) => 0,
484        }
485    }
486
487    pub fn is_atomic(&self) -> bool {
488        !self.is_semantic() && self.storage_type().is_atomic()
489    }
490
491    pub fn is_int(&self) -> bool {
492        !self.is_semantic() && self.storage_type().is_int()
493    }
494
495    pub fn is_signed_int(&self) -> bool {
496        !self.is_semantic() && self.storage_type().is_signed_int()
497    }
498
499    pub fn is_unsigned_int(&self) -> bool {
500        !self.is_semantic() && self.storage_type().is_unsigned_int()
501    }
502
503    pub fn is_float(&self) -> bool {
504        !self.is_semantic() && self.storage_type().is_float()
505    }
506
507    pub fn storage_type(&self) -> StorageType {
508        match self {
509            Type::Scalar(ty) | Type::Line(ty, _) => *ty,
510            Type::Semantic(_) => unimplemented!("Can't get storage for semantic type"),
511        }
512    }
513
514    pub fn is_semantic(&self) -> bool {
515        match self {
516            Type::Scalar(_) | Type::Line(_, _) => false,
517            Type::Semantic(_) => true,
518        }
519    }
520}
521
522impl Display for Type {
523    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
524        match self {
525            Type::Scalar(ty) => write!(f, "{ty}"),
526            Type::Line(ty, line_size) => write!(f, "line<{ty}, {line_size}>"),
527            Type::Semantic(ty) => write!(f, "{ty}"),
528        }
529    }
530}
531
532impl Display for StorageType {
533    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
534        match self {
535            StorageType::Scalar(ty) => write!(f, "{ty}"),
536            StorageType::Packed(ty, factor) => write!(f, "packed<{ty}, {factor}>"),
537            StorageType::Atomic(ty) => write!(f, "atomic<{ty}>"),
538            StorageType::Opaque(ty) => write!(f, "{ty}"),
539        }
540    }
541}
542
543impl Display for ElemType {
544    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
545        match self {
546            Self::Float(kind) => match kind {
547                FloatKind::E2M1 => f.write_str("e2m1"),
548                FloatKind::E2M3 => f.write_str("e2m3"),
549                FloatKind::E3M2 => f.write_str("e3m2"),
550                FloatKind::E4M3 => f.write_str("e4m3"),
551                FloatKind::E5M2 => f.write_str("e5m2"),
552                FloatKind::UE8M0 => f.write_str("ue8m0"),
553                FloatKind::F16 => f.write_str("f16"),
554                FloatKind::BF16 => f.write_str("bf16"),
555                FloatKind::Flex32 => f.write_str("flex32"),
556                FloatKind::TF32 => f.write_str("tf32"),
557                FloatKind::F32 => f.write_str("f32"),
558                FloatKind::F64 => f.write_str("f64"),
559            },
560            Self::Int(kind) => match kind {
561                IntKind::I8 => f.write_str("i8"),
562                IntKind::I16 => f.write_str("i16"),
563                IntKind::I32 => f.write_str("i32"),
564                IntKind::I64 => f.write_str("i64"),
565            },
566            Self::UInt(kind) => match kind {
567                UIntKind::U8 => f.write_str("u8"),
568                UIntKind::U16 => f.write_str("u16"),
569                UIntKind::U32 => f.write_str("u32"),
570                UIntKind::U64 => f.write_str("u64"),
571            },
572            Self::Bool => f.write_str("bool"),
573        }
574    }
575}
576
577impl Display for SemanticType {
578    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
579        match self {
580            SemanticType::BarrierToken => f.write_str("barrier_token"),
581            SemanticType::Pipeline => f.write_str("pipeline"),
582            SemanticType::TensorMap => f.write_str("tensor_map"),
583        }
584    }
585}
586
587impl Display for OpaqueType {
588    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
589        match self {
590            OpaqueType::Barrier(level) => write!(f, "barrier<{level}>"),
591        }
592    }
593}
594
595impl From<bool> for Variable {
596    fn from(value: bool) -> Self {
597        Variable::constant(ConstantScalarValue::Bool(value))
598    }
599}
600
601impl From<i8> for Variable {
602    fn from(value: i8) -> Self {
603        Variable::constant(ConstantScalarValue::Int(value as i64, IntKind::I8))
604    }
605}
606
607impl From<i16> for Variable {
608    fn from(value: i16) -> Self {
609        Variable::constant(ConstantScalarValue::Int(value as i64, IntKind::I16))
610    }
611}
612
613impl From<i32> for Variable {
614    fn from(value: i32) -> Self {
615        Variable::constant(ConstantScalarValue::Int(value as i64, IntKind::I32))
616    }
617}
618
619impl From<i64> for Variable {
620    fn from(value: i64) -> Self {
621        Variable::constant(ConstantScalarValue::Int(value, IntKind::I64))
622    }
623}
624
625impl From<e2m1> for Variable {
626    fn from(_value: e2m1) -> Self {
627        unimplemented!("Can't currently construct minifloats")
628    }
629}
630
631impl From<e2m1x2> for Variable {
632    fn from(_value: e2m1x2) -> Self {
633        unimplemented!("Can't currently construct minifloats")
634    }
635}
636
637impl From<e2m3> for Variable {
638    fn from(_value: e2m3) -> Self {
639        unimplemented!("Can't currently construct minifloats")
640    }
641}
642
643impl From<e3m2> for Variable {
644    fn from(_value: e3m2) -> Self {
645        unimplemented!("Can't currently construct minifloats")
646    }
647}
648
649impl From<e4m3> for Variable {
650    fn from(_value: e4m3) -> Self {
651        unimplemented!("Can't currently construct minifloats")
652    }
653}
654
655impl From<e5m2> for Variable {
656    fn from(_value: e5m2) -> Self {
657        unimplemented!("Can't currently construct minifloats")
658    }
659}
660
661impl From<ue8m0> for Variable {
662    fn from(_value: ue8m0) -> Self {
663        unimplemented!("Can't currently construct minifloats")
664    }
665}
666
667impl From<half::f16> for Variable {
668    fn from(value: half::f16) -> Self {
669        Variable::constant(ConstantScalarValue::Float(value.to_f64(), FloatKind::F16))
670    }
671}
672
673impl From<half::bf16> for Variable {
674    fn from(value: half::bf16) -> Self {
675        Variable::constant(ConstantScalarValue::Float(value.to_f64(), FloatKind::BF16))
676    }
677}
678
679impl From<flex32> for Variable {
680    fn from(value: flex32) -> Self {
681        Variable::constant(ConstantScalarValue::Float(
682            value.to_f64(),
683            FloatKind::Flex32,
684        ))
685    }
686}
687
688impl From<tf32> for Variable {
689    fn from(value: tf32) -> Self {
690        Variable::constant(ConstantScalarValue::Float(value.to_f64(), FloatKind::TF32))
691    }
692}
693
694impl From<f32> for Variable {
695    fn from(value: f32) -> Self {
696        Variable::constant(ConstantScalarValue::Float(value as f64, FloatKind::F32))
697    }
698}
699
700impl From<f64> for Variable {
701    fn from(value: f64) -> Self {
702        Variable::constant(ConstantScalarValue::Float(value, FloatKind::F64))
703    }
704}
705
706impl From<u8> for Variable {
707    fn from(value: u8) -> Self {
708        Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U8))
709    }
710}
711
712impl From<u16> for Variable {
713    fn from(value: u16) -> Self {
714        Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U16))
715    }
716}
717
718impl From<u32> for Variable {
719    fn from(value: u32) -> Self {
720        Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U32))
721    }
722}
723
724impl From<u64> for Variable {
725    fn from(value: u64) -> Self {
726        Variable::constant(ConstantScalarValue::UInt(value, UIntKind::U64))
727    }
728}
729
730impl From<usize> for Variable {
731    fn from(value: usize) -> Self {
732        Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U32))
733    }
734}