Skip to main content

cubecl_ir/
type.rs

1use super::{ConstantValue, Variable, VariableKind};
2use crate::{BarrierLevel, TypeHash};
3use core::fmt::Display;
4use cubecl_common::{
5    e2m1, e2m1x2, e2m3, e3m2, e4m3, e5m2, flex32,
6    quant::scheme::{QuantParam, QuantValue},
7    tf32, ue8m0,
8};
9use derive_more::From;
10use half::{bf16, f16};
11
12#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
14#[allow(missing_docs)]
15pub enum FloatKind {
16    /// FP4, 2 bit exponent, 1 bit mantissa
17    E2M1,
18    /// FP6, 2 bit exponent, 3 bit mantissa
19    /// Note: represented by an 8-bit value, with the upper two bits being insignificant
20    E2M3,
21    /// FP6, 3 bit exponent, 2 bit mantissa
22    /// Note: represented by an 8-bit value, with the upper two bits being insignificant
23    E3M2,
24    /// FP8, 4 bit exponent, 3 bit mantissa
25    E4M3,
26    /// FP8, 5 bit exponent, 2 bit mantissa
27    E5M2,
28    /// FP8, unsigned, 8 bit exponent, 0 bit mantissa
29    UE8M0,
30    F16,
31    BF16,
32    Flex32,
33    F32,
34    TF32,
35    F64,
36}
37
38#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
39#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
40#[allow(missing_docs)]
41pub enum IntKind {
42    I8,
43    I16,
44    I32,
45    I64,
46}
47
48#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
49#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
50#[allow(missing_docs)]
51pub enum UIntKind {
52    U8,
53    U16,
54    U32,
55    U64,
56}
57
58/// Conceptual element type, not necessarily the physical type used in the code
59#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
60#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord, From)]
61#[allow(missing_docs)]
62pub enum ElemType {
63    Float(FloatKind),
64    Int(IntKind),
65    UInt(UIntKind),
66    Bool,
67}
68
69#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
70#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
71pub enum OpaqueType {
72    Barrier(BarrierLevel),
73}
74
75#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
76#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
77pub enum SemanticType {
78    BarrierToken,
79    Pipeline,
80    TensorMap,
81}
82
83/// Physical type containing one or more elements
84#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
85#[derive(Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
86pub enum StorageType {
87    /// `ElemType` is the same as the physical type
88    Scalar(ElemType),
89    /// Packed values of type `ElemType`
90    Packed(ElemType, usize),
91    /// Atomically accessed version of `ElemType`
92    Atomic(ElemType),
93    /// Opaque types that can be stored but not interacted with normally. Currently only barrier,
94    /// but may be used for arrival tokens and tensor map descriptors, for example.
95    Opaque(OpaqueType),
96}
97
98impl core::fmt::Debug for StorageType {
99    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
100        // Ensure debug is not spread into multiple lines because it makes kernel ids very hard
101        // to read.
102        struct Dummy<'a>(&'a StorageType);
103
104        impl<'a> core::fmt::Debug for Dummy<'a> {
105            fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
106                match self.0 {
107                    StorageType::Scalar(f0) => f.debug_tuple("Scalar").field(&f0).finish(),
108                    StorageType::Packed(f0, f1) => {
109                        f.debug_tuple("Packed").field(&f0).field(&f1).finish()
110                    }
111                    StorageType::Atomic(f0) => f.debug_tuple("Atomic").field(&f0).finish(),
112                    StorageType::Opaque(f0) => f.debug_tuple("Opaque").field(&f0).finish(),
113                }
114            }
115        }
116
117        write!(f, "{:?}", Dummy(self))
118    }
119}
120
121impl ElemType {
122    /// Creates an elem type that correspond to the given [`QuantParam`].
123    pub fn from_quant_param(quant_param: QuantParam) -> Self {
124        match quant_param {
125            QuantParam::F32 => Self::Float(FloatKind::F32),
126            QuantParam::F16 => Self::Float(FloatKind::F16),
127            QuantParam::BF16 => Self::Float(FloatKind::BF16),
128            QuantParam::UE8M0 => Self::Float(FloatKind::UE8M0),
129            QuantParam::UE4M3 => Self::Float(FloatKind::UE8M0),
130        }
131    }
132
133    /// Creates an elem type that correspond to the given [`QuantValue`].
134    pub fn from_quant_value(quant_value: QuantValue) -> Self {
135        match quant_value {
136            QuantValue::E5M2 => Self::Float(FloatKind::E5M2),
137            QuantValue::E4M3 => Self::Float(FloatKind::E4M3),
138            QuantValue::E2M1 => Self::Float(FloatKind::E2M1),
139            QuantValue::Q8F | QuantValue::Q8S => Self::Int(IntKind::I8),
140            other => panic!("Unsupported quant value {other:?}"),
141        }
142    }
143
144    /// Create a constant from a constant value.
145    ///
146    /// The output will have the same type as the element.
147    pub fn constant(&self, val: ConstantValue) -> Variable {
148        Variable::constant(val, Type::scalar(*self))
149    }
150
151    /// Get the size in bytes.
152    pub const fn size(&self) -> usize {
153        match self {
154            ElemType::Float(kind) => match kind {
155                FloatKind::E2M1
156                | FloatKind::E2M3
157                | FloatKind::E3M2
158                | FloatKind::E4M3
159                | FloatKind::E5M2
160                | FloatKind::UE8M0 => core::mem::size_of::<u8>(),
161                FloatKind::F16 => core::mem::size_of::<half::f16>(),
162                FloatKind::BF16 => core::mem::size_of::<half::bf16>(),
163                FloatKind::F32 => core::mem::size_of::<f32>(),
164                FloatKind::F64 => core::mem::size_of::<f64>(),
165                FloatKind::Flex32 => core::mem::size_of::<f32>(),
166                FloatKind::TF32 => core::mem::size_of::<f32>(),
167            },
168            ElemType::Int(kind) => match kind {
169                IntKind::I8 => core::mem::size_of::<i8>(),
170                IntKind::I16 => core::mem::size_of::<i16>(),
171                IntKind::I32 => core::mem::size_of::<i32>(),
172                IntKind::I64 => core::mem::size_of::<i64>(),
173            },
174            ElemType::UInt(kind) => match kind {
175                UIntKind::U8 => core::mem::size_of::<u8>(),
176                UIntKind::U16 => core::mem::size_of::<u16>(),
177                UIntKind::U32 => core::mem::size_of::<u32>(),
178                UIntKind::U64 => core::mem::size_of::<u64>(),
179            },
180            ElemType::Bool => core::mem::size_of::<bool>(),
181        }
182    }
183
184    /// Get the size in bits.
185    pub const fn size_bits(&self) -> usize {
186        match self {
187            ElemType::Float(kind) => match kind {
188                FloatKind::E2M3
189                | FloatKind::E3M2
190                | FloatKind::E4M3
191                | FloatKind::E5M2
192                | FloatKind::UE8M0
193                | FloatKind::F16
194                | FloatKind::BF16
195                | FloatKind::F32
196                | FloatKind::F64
197                | FloatKind::Flex32
198                | FloatKind::TF32 => self.size() * 8,
199                FloatKind::E2M1 => 4,
200            },
201            ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool => self.size() * 8,
202        }
203    }
204
205    pub const fn min_vector_size(&self) -> u8 {
206        match self {
207            ElemType::Float(FloatKind::E2M1) => 2,
208            _ => 1,
209        }
210    }
211
212    pub fn is_int(&self) -> bool {
213        matches!(self, ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool)
214    }
215
216    pub fn is_signed_int(&self) -> bool {
217        matches!(self, ElemType::Int(_))
218    }
219
220    pub fn is_unsigned_int(&self) -> bool {
221        matches!(self, ElemType::UInt(_) | ElemType::Bool)
222    }
223
224    pub fn is_float(&self) -> bool {
225        matches!(self, ElemType::Float(_))
226    }
227
228    pub fn is_bool(&self) -> bool {
229        matches!(self, ElemType::Bool)
230    }
231
232    pub fn as_float(&self) -> Option<FloatKind> {
233        match self {
234            ElemType::Float(kind) => Some(*kind),
235            _ => None,
236        }
237    }
238
239    pub fn max_variable(&self) -> Variable {
240        let value = match self {
241            ElemType::Float(kind) => match kind {
242                FloatKind::E2M1 => e2m1::MAX,
243                FloatKind::E2M3 => e2m3::MAX,
244                FloatKind::E3M2 => e3m2::MAX,
245                FloatKind::E4M3 => e4m3::MAX,
246                FloatKind::E5M2 => e5m2::MAX,
247                FloatKind::UE8M0 => ue8m0::MAX,
248                FloatKind::F16 => half::f16::MAX.to_f64(),
249                FloatKind::BF16 => half::bf16::MAX.to_f64(),
250                FloatKind::Flex32 | FloatKind::TF32 | FloatKind::F32 => f32::MAX as f64,
251                FloatKind::F64 => f64::MAX,
252            }
253            .into(),
254            ElemType::Int(kind) => match kind {
255                IntKind::I8 => i8::MAX as i64,
256                IntKind::I16 => i16::MAX as i64,
257                IntKind::I32 => i32::MAX as i64,
258                IntKind::I64 => i64::MAX,
259            }
260            .into(),
261            ElemType::UInt(kind) => match kind {
262                UIntKind::U8 => u8::MAX as u64,
263                UIntKind::U16 => u16::MAX as u64,
264                UIntKind::U32 => u32::MAX as u64,
265                UIntKind::U64 => u64::MAX,
266            }
267            .into(),
268            ElemType::Bool => true.into(),
269        };
270
271        Variable::new(VariableKind::Constant(value), Type::scalar(*self))
272    }
273
274    pub fn min_variable(&self) -> Variable {
275        let value = match self {
276            ElemType::Float(kind) => match kind {
277                FloatKind::E2M1 => e2m1::MIN,
278                FloatKind::E2M3 => e2m3::MIN,
279                FloatKind::E3M2 => e3m2::MIN,
280                FloatKind::E4M3 => e4m3::MIN,
281                FloatKind::E5M2 => e5m2::MIN,
282                FloatKind::UE8M0 => ue8m0::MIN,
283                FloatKind::F16 => half::f16::MIN.to_f64(),
284                FloatKind::BF16 => half::bf16::MIN.to_f64(),
285                FloatKind::Flex32 | FloatKind::TF32 | FloatKind::F32 => f32::MIN as f64,
286                FloatKind::F64 => f64::MIN,
287            }
288            .into(),
289            ElemType::Int(kind) => match kind {
290                IntKind::I8 => i8::MIN as i64,
291                IntKind::I16 => i16::MIN as i64,
292                IntKind::I32 => i32::MIN as i64,
293                IntKind::I64 => i64::MIN,
294            }
295            .into(),
296            ElemType::UInt(kind) => match kind {
297                UIntKind::U8 => u8::MIN as u64,
298                UIntKind::U16 => u16::MIN as u64,
299                UIntKind::U32 => u32::MIN as u64,
300                UIntKind::U64 => u64::MIN,
301            }
302            .into(),
303            ElemType::Bool => false.into(),
304        };
305
306        Variable::new(VariableKind::Constant(value), Type::scalar(*self))
307    }
308
309    pub fn epsilon(&self) -> f64 {
310        match self {
311            ElemType::Float(kind) => match kind {
312                FloatKind::E2M1 => 0.5 * (e2m1::MAX - e2m1::MIN),
313                FloatKind::E2M3 => 0.5 * (e2m3::MAX - e2m3::MIN),
314                FloatKind::E3M2 => 0.5 * (e3m2::MAX - e3m2::MIN),
315                FloatKind::E4M3 => 0.5 * (e4m3::MAX - e4m3::MIN),
316                FloatKind::E5M2 => 0.5 * (e5m2::MAX - e5m2::MIN),
317                FloatKind::UE8M0 => 0.5 * (ue8m0::MAX - ue8m0::MIN),
318                FloatKind::F16 => half::f16::EPSILON.to_f64(),
319                FloatKind::BF16 => 0.0078125, // bf16 epsilon ≈ 2^-7
320                FloatKind::Flex32 | FloatKind::F32 | FloatKind::TF32 => f32::EPSILON.into(),
321                FloatKind::F64 => f64::EPSILON,
322            },
323            ElemType::Int(_) | ElemType::UInt(_) => 1.0, // step of 1
324            ElemType::Bool => 1.0,
325        }
326    }
327}
328
329impl OpaqueType {
330    /// Get the size in bytes.
331    pub const fn size(&self) -> usize {
332        match self {
333            OpaqueType::Barrier(_) => 8,
334        }
335    }
336
337    /// Get the size in bits.
338    pub const fn size_bits(&self) -> usize {
339        match self {
340            OpaqueType::Barrier(_) => 64,
341        }
342    }
343}
344
345impl StorageType {
346    pub fn elem_type(&self) -> ElemType {
347        match self {
348            StorageType::Scalar(ty) | StorageType::Packed(ty, _) | StorageType::Atomic(ty) => *ty,
349            StorageType::Opaque(_) => unimplemented!("Can't get elem type for opaque type"),
350        }
351    }
352
353    pub fn packing_factor(&self) -> usize {
354        match self {
355            StorageType::Packed(_, factor) => *factor,
356            _ => 1,
357        }
358    }
359
360    pub fn is_atomic(&self) -> bool {
361        matches!(self, StorageType::Atomic(_))
362    }
363
364    pub fn size(&self) -> usize {
365        self.size_bits().div_ceil(8)
366    }
367
368    pub fn size_bits(&self) -> usize {
369        match self {
370            StorageType::Packed(ty, factor) => ty.size_bits() * *factor,
371            StorageType::Scalar(ty) | StorageType::Atomic(ty) => ty.size_bits(),
372            StorageType::Opaque(ty) => ty.size_bits(),
373        }
374    }
375
376    pub fn is_int(&self) -> bool {
377        self.elem_type().is_int()
378    }
379
380    pub fn is_signed_int(&self) -> bool {
381        self.elem_type().is_signed_int()
382    }
383
384    pub fn is_unsigned_int(&self) -> bool {
385        self.elem_type().is_unsigned_int()
386    }
387
388    pub fn is_float(&self) -> bool {
389        self.elem_type().is_float()
390    }
391
392    pub fn is_bool(&self) -> bool {
393        self.elem_type().is_bool()
394    }
395
396    /// Returns an empirical epsilon for this storage type, taking quantization into account.
397    pub fn epsilon(&self) -> f64 {
398        match self {
399            StorageType::Scalar(ty) | StorageType::Atomic(ty) => ty.epsilon(),
400            StorageType::Packed(ty, factor) => {
401                // For packed types, we can conservatively scale epsilon by the number of packed elements
402                ty.epsilon() * (*factor as f64)
403            }
404            StorageType::Opaque(_) => panic!("Opaque type does not have an epsilon"),
405        }
406    }
407
408    pub fn constant(&self, value: ConstantValue) -> Variable {
409        Variable::constant(value, Type::new(*self))
410    }
411}
412
413macro_rules! storage_from_elem {
414    ($($ty: ty),*) => {
415        $(impl From<$ty> for StorageType {
416            fn from(value: $ty) -> Self {
417                StorageType::Scalar(value.into())
418            }
419        })*
420    };
421}
422
423storage_from_elem!(FloatKind, IntKind, UIntKind, ElemType);
424
425impl From<OpaqueType> for StorageType {
426    fn from(val: OpaqueType) -> Self {
427        StorageType::Opaque(val)
428    }
429}
430
431impl<T: Into<StorageType>> From<T> for Type {
432    fn from(val: T) -> Self {
433        Type::new(val.into())
434    }
435}
436
437impl From<SemanticType> for Type {
438    fn from(val: SemanticType) -> Self {
439        Type::semantic(val)
440    }
441}
442
443#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
444#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
445pub enum Type {
446    /// Scalar type containing a single storage element
447    Scalar(StorageType),
448    /// Vector wrapping `n` storage elements
449    Vector(StorageType, VectorSize),
450    /// No defined physical representation, purely semantic. i.e. barrier, pipeline
451    Semantic(SemanticType),
452}
453
454pub type VectorSize = usize;
455
456impl Type {
457    /// Fetch the elem of the item.
458    pub fn elem_type(&self) -> ElemType {
459        self.storage_type().elem_type()
460    }
461
462    /// Create a new item
463    pub fn new(storage: StorageType) -> Self {
464        Type::Scalar(storage)
465    }
466
467    pub fn scalar(elem: ElemType) -> Self {
468        Self::new(StorageType::Scalar(elem))
469    }
470
471    pub fn semantic(ty: SemanticType) -> Self {
472        Self::Semantic(ty)
473    }
474
475    pub fn with_vector_size(self, vector_size: VectorSize) -> Type {
476        match vector_size > 1 {
477            true => Type::Vector(self.storage_type(), vector_size),
478            false => Type::Scalar(self.storage_type()),
479        }
480    }
481
482    pub fn with_storage_type(self, storage: StorageType) -> Type {
483        let vector_size = self.vector_size();
484        Type::new(storage).with_vector_size(vector_size)
485    }
486
487    pub fn vector_size(&self) -> VectorSize {
488        match self {
489            Type::Scalar(_) => 1,
490            Type::Vector(_, vector_size) => *vector_size,
491            Type::Semantic(_) => 0,
492        }
493    }
494
495    pub fn size(&self) -> usize {
496        match self {
497            Type::Scalar(ty) => ty.size(),
498            Type::Vector(ty, vector_size) => ty.size() * *vector_size,
499            Type::Semantic(_) => 0,
500        }
501    }
502
503    pub fn size_bits(&self) -> usize {
504        match self {
505            Type::Scalar(ty) => ty.size_bits(),
506            Type::Vector(ty, vector_size) => ty.size_bits() * *vector_size,
507            Type::Semantic(_) => 0,
508        }
509    }
510
511    pub fn packing_factor(&self) -> usize {
512        match self {
513            Type::Scalar(ty) => ty.packing_factor(),
514            Type::Vector(ty, _) => ty.packing_factor(),
515            Type::Semantic(_) => 1,
516        }
517    }
518
519    pub fn is_atomic(&self) -> bool {
520        !self.is_semantic() && self.storage_type().is_atomic()
521    }
522
523    pub fn is_int(&self) -> bool {
524        !self.is_semantic() && self.storage_type().is_int()
525    }
526
527    pub fn is_signed_int(&self) -> bool {
528        !self.is_semantic() && self.storage_type().is_signed_int()
529    }
530
531    pub fn is_unsigned_int(&self) -> bool {
532        !self.is_semantic() && self.storage_type().is_unsigned_int()
533    }
534
535    pub fn is_float(&self) -> bool {
536        !self.is_semantic() && self.storage_type().is_float()
537    }
538
539    pub fn is_bool(&self) -> bool {
540        !self.is_semantic() && self.storage_type().is_bool()
541    }
542
543    pub fn storage_type(&self) -> StorageType {
544        match self {
545            Type::Scalar(ty) | Type::Vector(ty, _) => *ty,
546            Type::Semantic(_) => unimplemented!("Can't get storage for semantic type"),
547        }
548    }
549
550    pub fn is_semantic(&self) -> bool {
551        match self {
552            Type::Scalar(_) | Type::Vector(_, _) => false,
553            Type::Semantic(_) => true,
554        }
555    }
556
557    pub fn constant(&self, value: ConstantValue) -> Variable {
558        Variable::constant(value, *self)
559    }
560}
561
562impl Display for Type {
563    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
564        match self {
565            Type::Scalar(ty) => write!(f, "{ty}"),
566            Type::Vector(ty, vector_size) => write!(f, "vector<{ty}, {vector_size}>"),
567            Type::Semantic(ty) => write!(f, "{ty}"),
568        }
569    }
570}
571
572impl Display for StorageType {
573    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
574        match self {
575            StorageType::Scalar(ty) => write!(f, "{ty}"),
576            StorageType::Packed(ty, factor) => write!(f, "packed<{ty}, {factor}>"),
577            StorageType::Atomic(ty) => write!(f, "atomic<{ty}>"),
578            StorageType::Opaque(ty) => write!(f, "{ty}"),
579        }
580    }
581}
582
583impl Display for ElemType {
584    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
585        match self {
586            Self::Float(kind) => match kind {
587                FloatKind::E2M1 => f.write_str("e2m1"),
588                FloatKind::E2M3 => f.write_str("e2m3"),
589                FloatKind::E3M2 => f.write_str("e3m2"),
590                FloatKind::E4M3 => f.write_str("e4m3"),
591                FloatKind::E5M2 => f.write_str("e5m2"),
592                FloatKind::UE8M0 => f.write_str("ue8m0"),
593                FloatKind::F16 => f.write_str("f16"),
594                FloatKind::BF16 => f.write_str("bf16"),
595                FloatKind::Flex32 => f.write_str("flex32"),
596                FloatKind::TF32 => f.write_str("tf32"),
597                FloatKind::F32 => f.write_str("f32"),
598                FloatKind::F64 => f.write_str("f64"),
599            },
600            Self::Int(kind) => match kind {
601                IntKind::I8 => f.write_str("i8"),
602                IntKind::I16 => f.write_str("i16"),
603                IntKind::I32 => f.write_str("i32"),
604                IntKind::I64 => f.write_str("i64"),
605            },
606            Self::UInt(kind) => match kind {
607                UIntKind::U8 => f.write_str("u8"),
608                UIntKind::U16 => f.write_str("u16"),
609                UIntKind::U32 => f.write_str("u32"),
610                UIntKind::U64 => f.write_str("u64"),
611            },
612            Self::Bool => f.write_str("bool"),
613        }
614    }
615}
616
617impl Display for SemanticType {
618    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
619        match self {
620            SemanticType::BarrierToken => f.write_str("barrier_token"),
621            SemanticType::Pipeline => f.write_str("pipeline"),
622            SemanticType::TensorMap => f.write_str("tensor_map"),
623        }
624    }
625}
626
627impl Display for OpaqueType {
628    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
629        match self {
630            OpaqueType::Barrier(level) => write!(f, "barrier<{level}>"),
631        }
632    }
633}
634
635impl From<e2m1x2> for Variable {
636    fn from(_value: e2m1x2) -> Self {
637        unimplemented!("Can't currently construct e2m1x2")
638    }
639}
640
641impl From<e2m3> for Variable {
642    fn from(_value: e2m3) -> Self {
643        unimplemented!("Can't currently construct fp6")
644    }
645}
646
647impl From<e3m2> for Variable {
648    fn from(_value: e3m2) -> Self {
649        unimplemented!("Can't currently construct fp6")
650    }
651}
652
653impl From<i8> for ConstantValue {
654    fn from(value: i8) -> Self {
655        ConstantValue::Int(value as i64)
656    }
657}
658
659impl From<i16> for ConstantValue {
660    fn from(value: i16) -> Self {
661        ConstantValue::Int(value as i64)
662    }
663}
664
665impl From<i32> for ConstantValue {
666    fn from(value: i32) -> Self {
667        ConstantValue::Int(value as i64)
668    }
669}
670
671impl From<isize> for ConstantValue {
672    fn from(value: isize) -> Self {
673        ConstantValue::Int(value as i64)
674    }
675}
676
677impl From<u8> for ConstantValue {
678    fn from(value: u8) -> Self {
679        ConstantValue::UInt(value as u64)
680    }
681}
682
683impl From<u16> for ConstantValue {
684    fn from(value: u16) -> Self {
685        ConstantValue::UInt(value as u64)
686    }
687}
688
689impl From<u32> for ConstantValue {
690    fn from(value: u32) -> Self {
691        ConstantValue::UInt(value as u64)
692    }
693}
694
695impl From<usize> for ConstantValue {
696    fn from(value: usize) -> Self {
697        ConstantValue::UInt(value as u64)
698    }
699}
700
701impl From<e2m1> for ConstantValue {
702    fn from(value: e2m1) -> Self {
703        ConstantValue::Float(value.to_f64())
704    }
705}
706
707impl From<e4m3> for ConstantValue {
708    fn from(value: e4m3) -> Self {
709        ConstantValue::Float(value.to_f64())
710    }
711}
712
713impl From<e5m2> for ConstantValue {
714    fn from(value: e5m2) -> Self {
715        ConstantValue::Float(value.to_f64())
716    }
717}
718
719impl From<ue8m0> for ConstantValue {
720    fn from(value: ue8m0) -> Self {
721        ConstantValue::Float(value.to_f64())
722    }
723}
724
725impl From<half::f16> for ConstantValue {
726    fn from(value: half::f16) -> Self {
727        ConstantValue::Float(value.to_f64())
728    }
729}
730
731impl From<half::bf16> for ConstantValue {
732    fn from(value: half::bf16) -> Self {
733        ConstantValue::Float(value.to_f64())
734    }
735}
736
737impl From<flex32> for ConstantValue {
738    fn from(value: flex32) -> Self {
739        ConstantValue::Float(value.to_f64())
740    }
741}
742
743impl From<tf32> for ConstantValue {
744    fn from(value: tf32) -> Self {
745        ConstantValue::Float(value.to_f64())
746    }
747}
748
749impl From<f32> for ConstantValue {
750    fn from(value: f32) -> Self {
751        ConstantValue::Float(value as f64)
752    }
753}
754
755macro_rules! impl_into_variable {
756    ($($ty: ty => $kind: path,)*) => {
757        $(
758            impl From<$ty> for Variable {
759                fn from(value: $ty) -> Self {
760                    Variable::new(VariableKind::Constant(value.into()), $kind.into())
761                }
762            }
763        )*
764    };
765}
766
767impl_into_variable!(
768    bool => ElemType::Bool,
769
770    i8 => IntKind::I8,
771    i16 => IntKind::I16,
772    i32 => IntKind::I32,
773    i64 => IntKind::I64,
774
775    u8 => UIntKind::U8,
776    u16 => UIntKind::U16,
777    u32 => UIntKind::U32,
778    u64 => UIntKind::U64,
779
780    e2m1 => FloatKind::E2M1,
781    e4m3 => FloatKind::E4M3,
782    e5m2 => FloatKind::E5M2,
783    ue8m0 => FloatKind::UE8M0,
784    f16 => FloatKind::F16,
785    bf16 => FloatKind::BF16,
786    f32 => FloatKind::F32,
787    flex32 => FloatKind::Flex32,
788    tf32 => FloatKind::TF32,
789    f64 => FloatKind::F64,
790
791    usize => UIntKind::U32,
792    isize => IntKind::I32,
793);