cubecl_ir/
type.rs

1use super::{ConstantValue, Variable, VariableKind};
2use crate::{BarrierLevel, TypeHash};
3use core::fmt::Display;
4use cubecl_common::{
5    e2m1, e2m1x2, e2m3, e3m2, e4m3, e5m2, flex32,
6    quant::scheme::{QuantParam, QuantValue},
7    tf32, ue8m0,
8};
9use derive_more::From;
10use half::{bf16, f16};
11
12#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
14#[allow(missing_docs)]
15pub enum FloatKind {
16    /// FP4, 2 bit exponent, 1 bit mantissa
17    E2M1,
18    /// FP6, 2 bit exponent, 3 bit mantissa
19    /// Note: represented by an 8-bit value, with the upper two bits being insignificant
20    E2M3,
21    /// FP6, 3 bit exponent, 2 bit mantissa
22    /// Note: represented by an 8-bit value, with the upper two bits being insignificant
23    E3M2,
24    /// FP8, 4 bit exponent, 3 bit mantissa
25    E4M3,
26    /// FP8, 5 bit exponent, 2 bit mantissa
27    E5M2,
28    /// FP8, unsigned, 8 bit exponent, 0 bit mantissa
29    UE8M0,
30    F16,
31    BF16,
32    Flex32,
33    F32,
34    TF32,
35    F64,
36}
37
38#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
39#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
40#[allow(missing_docs)]
41pub enum IntKind {
42    I8,
43    I16,
44    I32,
45    I64,
46}
47
48#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
49#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
50#[allow(missing_docs)]
51pub enum UIntKind {
52    U8,
53    U16,
54    U32,
55    U64,
56}
57
58/// Conceptual element type, not necessarily the physical type used in the code
59#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
60#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord, From)]
61#[allow(missing_docs)]
62pub enum ElemType {
63    Float(FloatKind),
64    Int(IntKind),
65    UInt(UIntKind),
66    Bool,
67}
68
69#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
70#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
71pub enum OpaqueType {
72    Barrier(BarrierLevel),
73}
74
75#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
76#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
77pub enum SemanticType {
78    BarrierToken,
79    Pipeline,
80    TensorMap,
81}
82
83/// Physical type containing one or more elements
84#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
85#[derive(Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
86pub enum StorageType {
87    /// `ElemType` is the same as the physical type
88    Scalar(ElemType),
89    /// Packed values of type `ElemType`
90    Packed(ElemType, usize),
91    /// Atomically accessed version of `ElemType`
92    Atomic(ElemType),
93    /// Opaque types that can be stored but not interacted with normally. Currently only barrier,
94    /// but may be used for arrival tokens and tensor map descriptors, for example.
95    Opaque(OpaqueType),
96}
97
98impl core::fmt::Debug for StorageType {
99    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
100        // Ensure debug is not spread into multiple lines because it makes kernel ids very hard
101        // to read.
102        struct Dummy<'a>(&'a StorageType);
103
104        impl<'a> core::fmt::Debug for Dummy<'a> {
105            fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
106                match self.0 {
107                    StorageType::Scalar(f0) => f.debug_tuple("Scalar").field(&f0).finish(),
108                    StorageType::Packed(f0, f1) => {
109                        f.debug_tuple("Packed").field(&f0).field(&f1).finish()
110                    }
111                    StorageType::Atomic(f0) => f.debug_tuple("Atomic").field(&f0).finish(),
112                    StorageType::Opaque(f0) => f.debug_tuple("Opaque").field(&f0).finish(),
113                }
114            }
115        }
116
117        write!(f, "{:?}", Dummy(self))
118    }
119}
120
121impl ElemType {
122    /// Creates an elem type that correspond to the given [QuantParam].
123    pub fn from_quant_param(quant_param: QuantParam) -> Self {
124        match quant_param {
125            QuantParam::F32 => Self::Float(FloatKind::F32),
126            QuantParam::F16 => Self::Float(FloatKind::F16),
127            QuantParam::BF16 => Self::Float(FloatKind::BF16),
128            QuantParam::UE8M0 => Self::Float(FloatKind::UE8M0),
129            QuantParam::UE4M3 => Self::Float(FloatKind::UE8M0),
130        }
131    }
132
133    /// Creates an elem type that correspond to the given [QuantValue].
134    pub fn from_quant_value(quant_value: QuantValue) -> Self {
135        match quant_value {
136            QuantValue::E5M2 => Self::Float(FloatKind::E5M2),
137            QuantValue::E4M3 => Self::Float(FloatKind::E4M3),
138            QuantValue::E2M1 => Self::Float(FloatKind::E2M1),
139            QuantValue::Q8F | QuantValue::Q8S => Self::Int(IntKind::I8),
140            other => panic!("Unsupported quant value {other:?}"),
141        }
142    }
143
144    /// Create a constant from a constant value.
145    ///
146    /// The output will have the same type as the element.
147    pub fn constant(&self, val: ConstantValue) -> Variable {
148        Variable::constant(val, Type::scalar(*self))
149    }
150
151    /// Get the size in bytes.
152    pub const fn size(&self) -> usize {
153        match self {
154            ElemType::Float(kind) => match kind {
155                FloatKind::E2M1
156                | FloatKind::E2M3
157                | FloatKind::E3M2
158                | FloatKind::E4M3
159                | FloatKind::E5M2
160                | FloatKind::UE8M0 => core::mem::size_of::<u8>(),
161                FloatKind::F16 => core::mem::size_of::<half::f16>(),
162                FloatKind::BF16 => core::mem::size_of::<half::bf16>(),
163                FloatKind::F32 => core::mem::size_of::<f32>(),
164                FloatKind::F64 => core::mem::size_of::<f64>(),
165                FloatKind::Flex32 => core::mem::size_of::<f32>(),
166                FloatKind::TF32 => core::mem::size_of::<f32>(),
167            },
168            ElemType::Int(kind) => match kind {
169                IntKind::I8 => core::mem::size_of::<i8>(),
170                IntKind::I16 => core::mem::size_of::<i16>(),
171                IntKind::I32 => core::mem::size_of::<i32>(),
172                IntKind::I64 => core::mem::size_of::<i64>(),
173            },
174            ElemType::UInt(kind) => match kind {
175                UIntKind::U8 => core::mem::size_of::<u8>(),
176                UIntKind::U16 => core::mem::size_of::<u16>(),
177                UIntKind::U32 => core::mem::size_of::<u32>(),
178                UIntKind::U64 => core::mem::size_of::<u64>(),
179            },
180            ElemType::Bool => core::mem::size_of::<bool>(),
181        }
182    }
183
184    /// Get the size in bits.
185    pub const fn size_bits(&self) -> usize {
186        match self {
187            ElemType::Float(kind) => match kind {
188                FloatKind::E2M3
189                | FloatKind::E3M2
190                | FloatKind::E4M3
191                | FloatKind::E5M2
192                | FloatKind::UE8M0
193                | FloatKind::F16
194                | FloatKind::BF16
195                | FloatKind::F32
196                | FloatKind::F64
197                | FloatKind::Flex32
198                | FloatKind::TF32 => self.size() * 8,
199                FloatKind::E2M1 => 4,
200            },
201            ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool => self.size() * 8,
202        }
203    }
204
205    pub const fn min_line_size(&self) -> u8 {
206        match self {
207            ElemType::Float(FloatKind::E2M1) => 2,
208            _ => 1,
209        }
210    }
211
212    pub fn is_int(&self) -> bool {
213        matches!(self, ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool)
214    }
215
216    pub fn is_signed_int(&self) -> bool {
217        matches!(self, ElemType::Int(_))
218    }
219
220    pub fn is_unsigned_int(&self) -> bool {
221        matches!(self, ElemType::UInt(_) | ElemType::Bool)
222    }
223
224    pub fn is_float(&self) -> bool {
225        matches!(self, ElemType::Float(_))
226    }
227
228    pub fn is_bool(&self) -> bool {
229        matches!(self, ElemType::Bool)
230    }
231
232    pub fn as_float(&self) -> Option<FloatKind> {
233        match self {
234            ElemType::Float(kind) => Some(*kind),
235            _ => None,
236        }
237    }
238
239    pub fn max_variable(&self) -> Variable {
240        let value = match self {
241            ElemType::Float(kind) => match kind {
242                FloatKind::E2M1 => e2m1::MAX,
243                FloatKind::E2M3 => e2m3::MAX,
244                FloatKind::E3M2 => e3m2::MAX,
245                FloatKind::E4M3 => e4m3::MAX,
246                FloatKind::E5M2 => e5m2::MAX,
247                FloatKind::UE8M0 => ue8m0::MAX,
248                FloatKind::F16 => half::f16::MAX.to_f64(),
249                FloatKind::BF16 => half::bf16::MAX.to_f64(),
250                FloatKind::Flex32 | FloatKind::TF32 | FloatKind::F32 => f32::MAX as f64,
251                FloatKind::F64 => f64::MAX,
252            }
253            .into(),
254            ElemType::Int(kind) => match kind {
255                IntKind::I8 => i8::MAX as i64,
256                IntKind::I16 => i16::MAX as i64,
257                IntKind::I32 => i32::MAX as i64,
258                IntKind::I64 => i64::MAX,
259            }
260            .into(),
261            ElemType::UInt(kind) => match kind {
262                UIntKind::U8 => u8::MAX as u64,
263                UIntKind::U16 => u16::MAX as u64,
264                UIntKind::U32 => u32::MAX as u64,
265                UIntKind::U64 => u64::MAX,
266            }
267            .into(),
268            ElemType::Bool => true.into(),
269        };
270
271        Variable::new(VariableKind::Constant(value), Type::scalar(*self))
272    }
273
274    pub fn min_variable(&self) -> Variable {
275        let value = match self {
276            ElemType::Float(kind) => match kind {
277                FloatKind::E2M1 => e2m1::MIN,
278                FloatKind::E2M3 => e2m3::MIN,
279                FloatKind::E3M2 => e3m2::MIN,
280                FloatKind::E4M3 => e4m3::MIN,
281                FloatKind::E5M2 => e5m2::MIN,
282                FloatKind::UE8M0 => ue8m0::MIN,
283                FloatKind::F16 => half::f16::MIN.to_f64(),
284                FloatKind::BF16 => half::bf16::MIN.to_f64(),
285                FloatKind::Flex32 | FloatKind::TF32 | FloatKind::F32 => f32::MIN as f64,
286                FloatKind::F64 => f64::MIN,
287            }
288            .into(),
289            ElemType::Int(kind) => match kind {
290                IntKind::I8 => i8::MIN as i64,
291                IntKind::I16 => i16::MIN as i64,
292                IntKind::I32 => i32::MIN as i64,
293                IntKind::I64 => i64::MIN,
294            }
295            .into(),
296            ElemType::UInt(kind) => match kind {
297                UIntKind::U8 => u8::MIN as u64,
298                UIntKind::U16 => u16::MIN as u64,
299                UIntKind::U32 => u32::MIN as u64,
300                UIntKind::U64 => u64::MIN,
301            }
302            .into(),
303            ElemType::Bool => false.into(),
304        };
305
306        Variable::new(VariableKind::Constant(value), Type::scalar(*self))
307    }
308
309    pub fn epsilon(&self) -> f64 {
310        match self {
311            ElemType::Float(kind) => match kind {
312                FloatKind::E2M1 => 0.5 * (e2m1::MAX - e2m1::MIN),
313                FloatKind::E2M3 => 0.5 * (e2m3::MAX - e2m3::MIN),
314                FloatKind::E3M2 => 0.5 * (e3m2::MAX - e3m2::MIN),
315                FloatKind::E4M3 => 0.5 * (e4m3::MAX - e4m3::MIN),
316                FloatKind::E5M2 => 0.5 * (e5m2::MAX - e5m2::MIN),
317                FloatKind::UE8M0 => 0.5 * (ue8m0::MAX - ue8m0::MIN),
318                FloatKind::F16 => half::f16::EPSILON.to_f64(),
319                FloatKind::BF16 => 0.0078125, // bf16 epsilon ≈ 2^-7
320                FloatKind::Flex32 | FloatKind::F32 | FloatKind::TF32 => f32::EPSILON.into(),
321                FloatKind::F64 => f64::EPSILON,
322            },
323            ElemType::Int(_) | ElemType::UInt(_) => 1.0, // step of 1
324            ElemType::Bool => 1.0,
325        }
326    }
327}
328
329impl OpaqueType {
330    /// Get the size in bytes.
331    pub const fn size(&self) -> usize {
332        match self {
333            OpaqueType::Barrier(_) => 8,
334        }
335    }
336
337    /// Get the size in bits.
338    pub const fn size_bits(&self) -> usize {
339        match self {
340            OpaqueType::Barrier(_) => 64,
341        }
342    }
343}
344
345impl StorageType {
346    pub fn elem_type(&self) -> ElemType {
347        match self {
348            StorageType::Scalar(ty) | StorageType::Packed(ty, _) | StorageType::Atomic(ty) => *ty,
349            StorageType::Opaque(_) => unimplemented!("Can't get elem type for opaque type"),
350        }
351    }
352
353    pub fn packing_factor(&self) -> usize {
354        match self {
355            StorageType::Packed(_, factor) => *factor,
356            _ => 1,
357        }
358    }
359
360    pub fn is_atomic(&self) -> bool {
361        matches!(self, StorageType::Atomic(_))
362    }
363
364    pub fn size(&self) -> usize {
365        self.size_bits().div_ceil(8)
366    }
367
368    pub fn size_bits(&self) -> usize {
369        match self {
370            StorageType::Packed(ty, factor) => ty.size_bits() * *factor,
371            StorageType::Scalar(ty) | StorageType::Atomic(ty) => ty.size_bits(),
372            StorageType::Opaque(ty) => ty.size_bits(),
373        }
374    }
375
376    pub fn is_int(&self) -> bool {
377        self.elem_type().is_int()
378    }
379
380    pub fn is_signed_int(&self) -> bool {
381        self.elem_type().is_signed_int()
382    }
383
384    pub fn is_unsigned_int(&self) -> bool {
385        self.elem_type().is_unsigned_int()
386    }
387
388    pub fn is_float(&self) -> bool {
389        self.elem_type().is_float()
390    }
391
392    pub fn is_bool(&self) -> bool {
393        self.elem_type().is_bool()
394    }
395
396    /// Returns an empirical epsilon for this storage type, taking quantization into account.
397    pub fn epsilon(&self) -> f64 {
398        match self {
399            StorageType::Scalar(ty) | StorageType::Atomic(ty) => ty.epsilon(),
400            StorageType::Packed(ty, factor) => {
401                // For packed types, we can conservatively scale epsilon by the number of packed elements
402                ty.epsilon() * (*factor as f64)
403            }
404            StorageType::Opaque(_) => panic!("Opaque type does not have an epsilon"),
405        }
406    }
407
408    pub fn constant(&self, value: ConstantValue) -> Variable {
409        Variable::constant(value, Type::new(*self))
410    }
411}
412
413macro_rules! storage_from_elem {
414    ($($ty: ty),*) => {
415        $(impl From<$ty> for StorageType {
416            fn from(value: $ty) -> Self {
417                StorageType::Scalar(value.into())
418            }
419        })*
420    };
421}
422
423storage_from_elem!(FloatKind, IntKind, UIntKind, ElemType);
424
425impl From<OpaqueType> for StorageType {
426    fn from(val: OpaqueType) -> Self {
427        StorageType::Opaque(val)
428    }
429}
430
431impl<T: Into<StorageType>> From<T> for Type {
432    fn from(val: T) -> Self {
433        Type::new(val.into())
434    }
435}
436
437impl From<SemanticType> for Type {
438    fn from(val: SemanticType) -> Self {
439        Type::semantic(val)
440    }
441}
442
443#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
444#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
445pub enum Type {
446    /// Scalar type containing a single storage element
447    Scalar(StorageType),
448    /// Line wrapping `n` storage elements
449    Line(StorageType, LineSize),
450    /// No defined physical representation, purely semantic. i.e. barrier, pipeline
451    Semantic(SemanticType),
452}
453
454pub type LineSize = usize;
455
456impl Type {
457    /// Fetch the elem of the item.
458    pub fn elem_type(&self) -> ElemType {
459        self.storage_type().elem_type()
460    }
461
462    /// Create a new item
463    pub fn new(storage: StorageType) -> Self {
464        Type::Scalar(storage)
465    }
466
467    pub fn scalar(elem: ElemType) -> Self {
468        Self::new(StorageType::Scalar(elem))
469    }
470
471    pub fn semantic(ty: SemanticType) -> Self {
472        Self::Semantic(ty)
473    }
474
475    pub fn line(self, line_size: LineSize) -> Type {
476        match line_size > 1 {
477            true => Type::Line(self.storage_type(), line_size),
478            false => Type::Scalar(self.storage_type()),
479        }
480    }
481
482    pub fn line_size(&self) -> LineSize {
483        match self {
484            Type::Scalar(_) => 1,
485            Type::Line(_, line_size) => *line_size,
486            Type::Semantic(_) => 0,
487        }
488    }
489
490    pub fn size(&self) -> usize {
491        match self {
492            Type::Scalar(ty) => ty.size(),
493            Type::Line(ty, line_size) => ty.size() * *line_size,
494            Type::Semantic(_) => 0,
495        }
496    }
497
498    pub fn size_bits(&self) -> usize {
499        match self {
500            Type::Scalar(ty) => ty.size_bits(),
501            Type::Line(ty, line_size) => ty.size_bits() * *line_size,
502            Type::Semantic(_) => 0,
503        }
504    }
505
506    pub fn is_atomic(&self) -> bool {
507        !self.is_semantic() && self.storage_type().is_atomic()
508    }
509
510    pub fn is_int(&self) -> bool {
511        !self.is_semantic() && self.storage_type().is_int()
512    }
513
514    pub fn is_signed_int(&self) -> bool {
515        !self.is_semantic() && self.storage_type().is_signed_int()
516    }
517
518    pub fn is_unsigned_int(&self) -> bool {
519        !self.is_semantic() && self.storage_type().is_unsigned_int()
520    }
521
522    pub fn is_float(&self) -> bool {
523        !self.is_semantic() && self.storage_type().is_float()
524    }
525
526    pub fn is_bool(&self) -> bool {
527        !self.is_semantic() && self.storage_type().is_bool()
528    }
529
530    pub fn storage_type(&self) -> StorageType {
531        match self {
532            Type::Scalar(ty) | Type::Line(ty, _) => *ty,
533            Type::Semantic(_) => unimplemented!("Can't get storage for semantic type"),
534        }
535    }
536
537    pub fn is_semantic(&self) -> bool {
538        match self {
539            Type::Scalar(_) | Type::Line(_, _) => false,
540            Type::Semantic(_) => true,
541        }
542    }
543
544    pub fn constant(&self, value: ConstantValue) -> Variable {
545        Variable::constant(value, *self)
546    }
547}
548
549impl Display for Type {
550    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
551        match self {
552            Type::Scalar(ty) => write!(f, "{ty}"),
553            Type::Line(ty, line_size) => write!(f, "line<{ty}, {line_size}>"),
554            Type::Semantic(ty) => write!(f, "{ty}"),
555        }
556    }
557}
558
559impl Display for StorageType {
560    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
561        match self {
562            StorageType::Scalar(ty) => write!(f, "{ty}"),
563            StorageType::Packed(ty, factor) => write!(f, "packed<{ty}, {factor}>"),
564            StorageType::Atomic(ty) => write!(f, "atomic<{ty}>"),
565            StorageType::Opaque(ty) => write!(f, "{ty}"),
566        }
567    }
568}
569
570impl Display for ElemType {
571    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
572        match self {
573            Self::Float(kind) => match kind {
574                FloatKind::E2M1 => f.write_str("e2m1"),
575                FloatKind::E2M3 => f.write_str("e2m3"),
576                FloatKind::E3M2 => f.write_str("e3m2"),
577                FloatKind::E4M3 => f.write_str("e4m3"),
578                FloatKind::E5M2 => f.write_str("e5m2"),
579                FloatKind::UE8M0 => f.write_str("ue8m0"),
580                FloatKind::F16 => f.write_str("f16"),
581                FloatKind::BF16 => f.write_str("bf16"),
582                FloatKind::Flex32 => f.write_str("flex32"),
583                FloatKind::TF32 => f.write_str("tf32"),
584                FloatKind::F32 => f.write_str("f32"),
585                FloatKind::F64 => f.write_str("f64"),
586            },
587            Self::Int(kind) => match kind {
588                IntKind::I8 => f.write_str("i8"),
589                IntKind::I16 => f.write_str("i16"),
590                IntKind::I32 => f.write_str("i32"),
591                IntKind::I64 => f.write_str("i64"),
592            },
593            Self::UInt(kind) => match kind {
594                UIntKind::U8 => f.write_str("u8"),
595                UIntKind::U16 => f.write_str("u16"),
596                UIntKind::U32 => f.write_str("u32"),
597                UIntKind::U64 => f.write_str("u64"),
598            },
599            Self::Bool => f.write_str("bool"),
600        }
601    }
602}
603
604impl Display for SemanticType {
605    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
606        match self {
607            SemanticType::BarrierToken => f.write_str("barrier_token"),
608            SemanticType::Pipeline => f.write_str("pipeline"),
609            SemanticType::TensorMap => f.write_str("tensor_map"),
610        }
611    }
612}
613
614impl Display for OpaqueType {
615    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
616        match self {
617            OpaqueType::Barrier(level) => write!(f, "barrier<{level}>"),
618        }
619    }
620}
621
622impl From<e2m1x2> for Variable {
623    fn from(_value: e2m1x2) -> Self {
624        unimplemented!("Can't currently construct e2m1x2")
625    }
626}
627
628impl From<e2m3> for Variable {
629    fn from(_value: e2m3) -> Self {
630        unimplemented!("Can't currently construct fp6")
631    }
632}
633
634impl From<e3m2> for Variable {
635    fn from(_value: e3m2) -> Self {
636        unimplemented!("Can't currently construct fp6")
637    }
638}
639
640impl From<i8> for ConstantValue {
641    fn from(value: i8) -> Self {
642        ConstantValue::Int(value as i64)
643    }
644}
645
646impl From<i16> for ConstantValue {
647    fn from(value: i16) -> Self {
648        ConstantValue::Int(value as i64)
649    }
650}
651
652impl From<i32> for ConstantValue {
653    fn from(value: i32) -> Self {
654        ConstantValue::Int(value as i64)
655    }
656}
657
658impl From<isize> for ConstantValue {
659    fn from(value: isize) -> Self {
660        ConstantValue::Int(value as i64)
661    }
662}
663
664impl From<u8> for ConstantValue {
665    fn from(value: u8) -> Self {
666        ConstantValue::UInt(value as u64)
667    }
668}
669
670impl From<u16> for ConstantValue {
671    fn from(value: u16) -> Self {
672        ConstantValue::UInt(value as u64)
673    }
674}
675
676impl From<u32> for ConstantValue {
677    fn from(value: u32) -> Self {
678        ConstantValue::UInt(value as u64)
679    }
680}
681
682impl From<usize> for ConstantValue {
683    fn from(value: usize) -> Self {
684        ConstantValue::UInt(value as u64)
685    }
686}
687
688impl From<e2m1> for ConstantValue {
689    fn from(value: e2m1) -> Self {
690        ConstantValue::Float(value.to_f64())
691    }
692}
693
694impl From<e4m3> for ConstantValue {
695    fn from(value: e4m3) -> Self {
696        ConstantValue::Float(value.to_f64())
697    }
698}
699
700impl From<e5m2> for ConstantValue {
701    fn from(value: e5m2) -> Self {
702        ConstantValue::Float(value.to_f64())
703    }
704}
705
706impl From<ue8m0> for ConstantValue {
707    fn from(value: ue8m0) -> Self {
708        ConstantValue::Float(value.to_f64())
709    }
710}
711
712impl From<half::f16> for ConstantValue {
713    fn from(value: half::f16) -> Self {
714        ConstantValue::Float(value.to_f64())
715    }
716}
717
718impl From<half::bf16> for ConstantValue {
719    fn from(value: half::bf16) -> Self {
720        ConstantValue::Float(value.to_f64())
721    }
722}
723
724impl From<flex32> for ConstantValue {
725    fn from(value: flex32) -> Self {
726        ConstantValue::Float(value.to_f64())
727    }
728}
729
730impl From<tf32> for ConstantValue {
731    fn from(value: tf32) -> Self {
732        ConstantValue::Float(value.to_f64())
733    }
734}
735
736impl From<f32> for ConstantValue {
737    fn from(value: f32) -> Self {
738        ConstantValue::Float(value as f64)
739    }
740}
741
742macro_rules! impl_into_variable {
743    ($($ty: ty => $kind: path,)*) => {
744        $(
745            impl From<$ty> for Variable {
746                fn from(value: $ty) -> Self {
747                    Variable::new(VariableKind::Constant(value.into()), $kind.into())
748                }
749            }
750        )*
751    };
752}
753
754impl_into_variable!(
755    bool => ElemType::Bool,
756
757    i8 => IntKind::I8,
758    i16 => IntKind::I16,
759    i32 => IntKind::I32,
760    i64 => IntKind::I64,
761
762    u8 => UIntKind::U8,
763    u16 => UIntKind::U16,
764    u32 => UIntKind::U32,
765    u64 => UIntKind::U64,
766
767    e2m1 => FloatKind::E2M1,
768    e4m3 => FloatKind::E4M3,
769    e5m2 => FloatKind::E5M2,
770    ue8m0 => FloatKind::UE8M0,
771    f16 => FloatKind::F16,
772    bf16 => FloatKind::BF16,
773    f32 => FloatKind::F32,
774    flex32 => FloatKind::Flex32,
775    tf32 => FloatKind::TF32,
776    f64 => FloatKind::F64,
777
778    usize => UIntKind::U32,
779    isize => IntKind::I32,
780);