cubecl_ir/
type.rs

1use super::{ConstantScalarValue, Variable, VariableKind};
2use crate::TypeHash;
3use core::fmt::Display;
4use cubecl_common::{
5    e2m1, e2m1x2, e2m3, e3m2, e4m3, e5m2, flex32,
6    quant::scheme::{QuantParam, QuantValue},
7    tf32, ue8m0,
8};
9
10#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
12#[allow(missing_docs)]
13pub enum FloatKind {
14    /// FP4, 2 bit exponent, 1 bit mantissa
15    E2M1,
16    /// FP6, 2 bit exponent, 3 bit mantissa
17    /// Note: represented by an 8-bit value, with the upper two bits being insignificant
18    E2M3,
19    /// FP6, 3 bit exponent, 2 bit mantissa
20    /// Note: represented by an 8-bit value, with the upper two bits being insignificant
21    E3M2,
22    /// FP8, 4 bit exponent, 3 bit mantissa
23    E4M3,
24    /// FP8, 5 bit exponent, 2 bit mantissa
25    E5M2,
26    /// FP8, unsigned, 8 bit exponent, 0 bit mantissa
27    UE8M0,
28    F16,
29    BF16,
30    Flex32,
31    F32,
32    TF32,
33    F64,
34}
35
36#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
37#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
38#[allow(missing_docs)]
39pub enum IntKind {
40    I8,
41    I16,
42    I32,
43    I64,
44}
45
46#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
47#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
48#[allow(missing_docs)]
49pub enum UIntKind {
50    U8,
51    U16,
52    U32,
53    U64,
54}
55
56/// Conceptual element type, not necessarily the physical type used in the code
57#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
58#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
59#[allow(missing_docs)]
60pub enum ElemType {
61    Float(FloatKind),
62    Int(IntKind),
63    UInt(UIntKind),
64    Bool,
65}
66
67#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
68#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
69pub enum SemanticType {
70    Barrier,
71    BarrierToken,
72    Pipeline,
73    TensorMap,
74}
75
76/// Physical type containing one or more elements
77#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
78#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
79pub enum StorageType {
80    /// `ElemType` is the same as the physical type
81    Scalar(ElemType),
82    /// Packed values of type `ElemType`
83    Packed(ElemType, u32),
84    /// Atomically accessed version of `ElemType`
85    Atomic(ElemType),
86}
87
88impl ElemType {
89    /// Creates an elem type that correspond to the given [QuantParam].
90    pub fn from_quant_param(quant_param: QuantParam) -> Self {
91        match quant_param {
92            QuantParam::F32 => Self::Float(FloatKind::F32),
93            QuantParam::F16 => Self::Float(FloatKind::F16),
94            QuantParam::BF16 => Self::Float(FloatKind::BF16),
95            QuantParam::UE8M0 => Self::Float(FloatKind::UE8M0),
96            QuantParam::UE4M3 => Self::Float(FloatKind::UE8M0),
97        }
98    }
99
100    /// Creates an elem type that correspond to the given [QuantValue].
101    pub fn from_quant_value(quant_value: QuantValue) -> Self {
102        match quant_value {
103            QuantValue::E5M2 => Self::Float(FloatKind::E5M2),
104            QuantValue::E4M3 => Self::Float(FloatKind::E4M3),
105            QuantValue::E2M1 => Self::Float(FloatKind::E2M1),
106            QuantValue::Q8F | QuantValue::Q8S => Self::Int(IntKind::I8),
107            other => panic!("Unsupported quant value {other:?}"),
108        }
109    }
110    /// Create a constant scalar from a float.
111    ///
112    /// The output will have the same type as the element.
113    pub fn constant_from_f64(&self, val: f64) -> Variable {
114        Variable::constant(match self {
115            ElemType::Float(kind) => ConstantScalarValue::Float(val, *kind),
116            ElemType::Int(kind) => ConstantScalarValue::Int(val as i64, *kind),
117            ElemType::UInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
118            ElemType::Bool => ConstantScalarValue::Bool(val > 0.0),
119        })
120    }
121    /// Create a constant scalar from a signed integer.
122    ///
123    /// The output will have the same type as the element.
124    pub fn constant_from_i64(&self, val: i64) -> Variable {
125        Variable::constant(match self {
126            ElemType::Float(kind) => ConstantScalarValue::Float(val as f64, *kind),
127            ElemType::Int(kind) => ConstantScalarValue::Int(val, *kind),
128            ElemType::UInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
129            ElemType::Bool => ConstantScalarValue::Bool(val > 0),
130        })
131    }
132    /// Create a constant scalar from a unsigned integer.
133    ///
134    /// The output will have the same type as the element.
135    pub fn constant_from_u64(&self, val: u64) -> Variable {
136        Variable::constant(match self {
137            ElemType::Float(kind) => ConstantScalarValue::Float(val as f64, *kind),
138            ElemType::Int(kind) => ConstantScalarValue::Int(val as i64, *kind),
139            ElemType::UInt(kind) => ConstantScalarValue::UInt(val, *kind),
140            ElemType::Bool => ConstantScalarValue::Bool(val > 0),
141        })
142    }
143    /// Create a constant scalar from a boolean.
144    ///
145    /// The output will have the same type as the element.
146    pub fn constant_from_bool(&self, val: bool) -> Variable {
147        Variable::constant(match self {
148            ElemType::Float(kind) => ConstantScalarValue::Float(val as u32 as f64, *kind),
149            ElemType::Int(kind) => ConstantScalarValue::Int(val as i64, *kind),
150            ElemType::UInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
151            ElemType::Bool => ConstantScalarValue::Bool(val),
152        })
153    }
154
155    /// Ensure that the variable provided, when a constant, is the same type as elem.
156    pub fn from_constant(&self, constant: Variable) -> Variable {
157        let value = match constant.kind {
158            VariableKind::ConstantScalar(value) => value,
159            _ => return constant,
160        };
161
162        match value {
163            ConstantScalarValue::Int(val, _) => self.constant_from_i64(val),
164            ConstantScalarValue::Float(val, _) => self.constant_from_f64(val),
165            ConstantScalarValue::UInt(val, _) => self.constant_from_u64(val),
166            ConstantScalarValue::Bool(val) => self.constant_from_bool(val),
167        }
168    }
169    /// Get the size in bytes.
170    pub const fn size(&self) -> usize {
171        match self {
172            ElemType::Float(kind) => match kind {
173                FloatKind::E2M1
174                | FloatKind::E2M3
175                | FloatKind::E3M2
176                | FloatKind::E4M3
177                | FloatKind::E5M2
178                | FloatKind::UE8M0 => core::mem::size_of::<u8>(),
179                FloatKind::F16 => core::mem::size_of::<half::f16>(),
180                FloatKind::BF16 => core::mem::size_of::<half::bf16>(),
181                FloatKind::F32 => core::mem::size_of::<f32>(),
182                FloatKind::F64 => core::mem::size_of::<f64>(),
183                FloatKind::Flex32 => core::mem::size_of::<f32>(),
184                FloatKind::TF32 => core::mem::size_of::<f32>(),
185            },
186            ElemType::Int(kind) => match kind {
187                IntKind::I8 => core::mem::size_of::<i8>(),
188                IntKind::I16 => core::mem::size_of::<i16>(),
189                IntKind::I32 => core::mem::size_of::<i32>(),
190                IntKind::I64 => core::mem::size_of::<i64>(),
191            },
192            ElemType::UInt(kind) => match kind {
193                UIntKind::U8 => core::mem::size_of::<u8>(),
194                UIntKind::U16 => core::mem::size_of::<u16>(),
195                UIntKind::U32 => core::mem::size_of::<u32>(),
196                UIntKind::U64 => core::mem::size_of::<u64>(),
197            },
198            ElemType::Bool => core::mem::size_of::<bool>(),
199        }
200    }
201
202    /// Get the size in bits.
203    pub const fn size_bits(&self) -> usize {
204        match self {
205            ElemType::Float(kind) => match kind {
206                FloatKind::E2M3
207                | FloatKind::E3M2
208                | FloatKind::E4M3
209                | FloatKind::E5M2
210                | FloatKind::UE8M0
211                | FloatKind::F16
212                | FloatKind::BF16
213                | FloatKind::F32
214                | FloatKind::F64
215                | FloatKind::Flex32
216                | FloatKind::TF32 => self.size() * 8,
217                FloatKind::E2M1 => 4,
218            },
219            ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool => self.size() * 8,
220        }
221    }
222
223    pub const fn min_line_size(&self) -> u8 {
224        match self {
225            ElemType::Float(FloatKind::E2M1) => 2,
226            _ => 1,
227        }
228    }
229
230    pub fn is_int(&self) -> bool {
231        matches!(self, ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool)
232    }
233
234    pub fn is_signed_int(&self) -> bool {
235        matches!(self, ElemType::Int(_))
236    }
237
238    pub fn is_unsigned_int(&self) -> bool {
239        matches!(self, ElemType::UInt(_) | ElemType::Bool)
240    }
241
242    pub fn is_float(&self) -> bool {
243        matches!(self, ElemType::Float(_))
244    }
245
246    pub fn max_variable(&self) -> Variable {
247        let value = match self {
248            ElemType::Float(kind) => match kind {
249                FloatKind::E2M1 => ConstantScalarValue::Float(e2m1::MAX, FloatKind::E2M1),
250                FloatKind::E2M3 => ConstantScalarValue::Float(e2m3::MAX, FloatKind::E2M3),
251                FloatKind::E3M2 => ConstantScalarValue::Float(e3m2::MAX, FloatKind::E3M2),
252                FloatKind::E4M3 => ConstantScalarValue::Float(e4m3::MAX, FloatKind::E4M3),
253                FloatKind::E5M2 => ConstantScalarValue::Float(e5m2::MAX, FloatKind::E5M2),
254                FloatKind::UE8M0 => ConstantScalarValue::Float(ue8m0::MAX, FloatKind::UE8M0),
255                FloatKind::F16 => {
256                    ConstantScalarValue::Float(half::f16::MAX.to_f64(), FloatKind::F16)
257                }
258                FloatKind::BF16 => {
259                    ConstantScalarValue::Float(half::bf16::MAX.to_f64(), FloatKind::BF16)
260                }
261                FloatKind::Flex32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::Flex32),
262                FloatKind::F32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::F32),
263                FloatKind::TF32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::TF32),
264                FloatKind::F64 => ConstantScalarValue::Float(f64::MAX, FloatKind::F64),
265            },
266            ElemType::Int(kind) => match kind {
267                IntKind::I8 => ConstantScalarValue::Int(i8::MAX.into(), IntKind::I8),
268                IntKind::I16 => ConstantScalarValue::Int(i16::MAX.into(), IntKind::I16),
269                IntKind::I32 => ConstantScalarValue::Int(i32::MAX.into(), IntKind::I32),
270                IntKind::I64 => ConstantScalarValue::Int(i64::MAX, IntKind::I64),
271            },
272            ElemType::UInt(kind) => match kind {
273                UIntKind::U8 => ConstantScalarValue::UInt(u8::MAX.into(), UIntKind::U8),
274                UIntKind::U16 => ConstantScalarValue::UInt(u16::MAX.into(), UIntKind::U16),
275                UIntKind::U32 => ConstantScalarValue::UInt(u32::MAX.into(), UIntKind::U32),
276                UIntKind::U64 => ConstantScalarValue::UInt(u64::MAX, UIntKind::U64),
277            },
278            ElemType::Bool => ConstantScalarValue::Bool(true),
279        };
280
281        Variable::new(VariableKind::ConstantScalar(value), Type::scalar(*self))
282    }
283
284    pub fn min_variable(&self) -> Variable {
285        let value = match self {
286            ElemType::Float(kind) => match kind {
287                FloatKind::E2M1 => ConstantScalarValue::Float(e2m1::MIN, FloatKind::E2M1),
288                FloatKind::E2M3 => ConstantScalarValue::Float(e2m3::MIN, FloatKind::E2M3),
289                FloatKind::E3M2 => ConstantScalarValue::Float(e3m2::MIN, FloatKind::E3M2),
290                FloatKind::E4M3 => ConstantScalarValue::Float(e4m3::MIN, FloatKind::E4M3),
291                FloatKind::E5M2 => ConstantScalarValue::Float(e5m2::MIN, FloatKind::E5M2),
292                FloatKind::UE8M0 => ConstantScalarValue::Float(ue8m0::MIN, FloatKind::UE8M0),
293                FloatKind::F16 => {
294                    ConstantScalarValue::Float(half::f16::MIN.to_f64(), FloatKind::F16)
295                }
296                FloatKind::BF16 => {
297                    ConstantScalarValue::Float(half::bf16::MIN.to_f64(), FloatKind::BF16)
298                }
299                FloatKind::Flex32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::Flex32),
300                FloatKind::F32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::F32),
301                FloatKind::TF32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::TF32),
302                FloatKind::F64 => ConstantScalarValue::Float(f64::MIN, FloatKind::F64),
303            },
304            ElemType::Int(kind) => match kind {
305                IntKind::I8 => ConstantScalarValue::Int(i8::MIN.into(), IntKind::I8),
306                IntKind::I16 => ConstantScalarValue::Int(i16::MIN.into(), IntKind::I16),
307                IntKind::I32 => ConstantScalarValue::Int(i32::MIN.into(), IntKind::I32),
308                IntKind::I64 => ConstantScalarValue::Int(i64::MIN, IntKind::I64),
309            },
310            ElemType::UInt(kind) => match kind {
311                UIntKind::U8 => ConstantScalarValue::UInt(u8::MIN.into(), UIntKind::U8),
312                UIntKind::U16 => ConstantScalarValue::UInt(u16::MIN.into(), UIntKind::U16),
313                UIntKind::U32 => ConstantScalarValue::UInt(u32::MIN.into(), UIntKind::U32),
314                UIntKind::U64 => ConstantScalarValue::UInt(u64::MIN, UIntKind::U64),
315            },
316            ElemType::Bool => ConstantScalarValue::Bool(false),
317        };
318
319        Variable::new(VariableKind::ConstantScalar(value), Type::scalar(*self))
320    }
321}
322
323impl StorageType {
324    pub fn elem_type(&self) -> ElemType {
325        match self {
326            StorageType::Scalar(ty) | StorageType::Packed(ty, _) | StorageType::Atomic(ty) => *ty,
327        }
328    }
329
330    pub fn packing_factor(&self) -> u32 {
331        match self {
332            StorageType::Packed(_, factor) => *factor,
333            _ => 1,
334        }
335    }
336
337    pub fn is_atomic(&self) -> bool {
338        matches!(self, StorageType::Atomic(_))
339    }
340
341    pub fn size(&self) -> usize {
342        self.size_bits().div_ceil(8)
343    }
344
345    pub fn size_bits(&self) -> usize {
346        match self {
347            StorageType::Packed(ty, factor) => ty.size_bits() * *factor as usize,
348            StorageType::Scalar(ty) | StorageType::Atomic(ty) => ty.size_bits(),
349        }
350    }
351
352    /// Ensure that the variable provided, when a constant, is the same type as elem.
353    pub fn from_constant(&self, constant: Variable) -> Variable {
354        self.elem_type().from_constant(constant)
355    }
356
357    pub fn is_int(&self) -> bool {
358        self.elem_type().is_int()
359    }
360
361    pub fn is_signed_int(&self) -> bool {
362        self.elem_type().is_signed_int()
363    }
364
365    pub fn is_unsigned_int(&self) -> bool {
366        self.elem_type().is_unsigned_int()
367    }
368
369    pub fn is_float(&self) -> bool {
370        self.elem_type().is_float()
371    }
372}
373
374impl From<ElemType> for Type {
375    fn from(val: ElemType) -> Self {
376        Type::scalar(val)
377    }
378}
379
380impl From<ElemType> for StorageType {
381    fn from(val: ElemType) -> Self {
382        StorageType::Scalar(val)
383    }
384}
385
386impl From<StorageType> for Type {
387    fn from(val: StorageType) -> Self {
388        Type::new(val)
389    }
390}
391
392impl From<SemanticType> for Type {
393    fn from(val: SemanticType) -> Self {
394        Type::semantic(val)
395    }
396}
397
398#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
399#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
400pub enum Type {
401    /// Scalar type containing a single storage element
402    Scalar(StorageType),
403    /// Line wrapping `n` storage elements
404    Line(StorageType, u32),
405    /// No defined physical representation, purely semantic. i.e. barrier, pipeline
406    Semantic(SemanticType),
407}
408
409pub type LineSize = u32;
410
411impl Type {
412    /// Fetch the elem of the item.
413    pub fn elem_type(&self) -> ElemType {
414        self.storage_type().elem_type()
415    }
416
417    /// Create a new item
418    pub fn new(storage: StorageType) -> Self {
419        Type::Scalar(storage)
420    }
421
422    pub fn scalar(elem: ElemType) -> Self {
423        Self::new(StorageType::Scalar(elem))
424    }
425
426    pub fn semantic(ty: SemanticType) -> Self {
427        Self::Semantic(ty)
428    }
429
430    pub fn line(self, line_size: LineSize) -> Type {
431        match line_size > 1 {
432            true => Type::Line(self.storage_type(), line_size),
433            false => Type::Scalar(self.storage_type()),
434        }
435    }
436
437    pub fn line_size(&self) -> u32 {
438        match self {
439            Type::Scalar(_) => 1,
440            Type::Line(_, line_size) => *line_size,
441            Type::Semantic(_) => 0,
442        }
443    }
444
445    pub fn size(&self) -> usize {
446        match self {
447            Type::Scalar(ty) => ty.size(),
448            Type::Line(ty, line_size) => ty.size() * *line_size as usize,
449            Type::Semantic(_) => 0,
450        }
451    }
452
453    pub fn size_bits(&self) -> usize {
454        match self {
455            Type::Scalar(ty) => ty.size_bits(),
456            Type::Line(ty, line_size) => ty.size_bits() * *line_size as usize,
457            Type::Semantic(_) => 0,
458        }
459    }
460
461    pub fn is_atomic(&self) -> bool {
462        !self.is_semantic() && self.storage_type().is_atomic()
463    }
464
465    pub fn is_int(&self) -> bool {
466        !self.is_semantic() && self.storage_type().is_int()
467    }
468
469    pub fn is_signed_int(&self) -> bool {
470        !self.is_semantic() && self.storage_type().is_signed_int()
471    }
472
473    pub fn is_unsigned_int(&self) -> bool {
474        !self.is_semantic() && self.storage_type().is_unsigned_int()
475    }
476
477    pub fn is_float(&self) -> bool {
478        !self.is_semantic() && self.storage_type().is_float()
479    }
480
481    pub fn storage_type(&self) -> StorageType {
482        match self {
483            Type::Scalar(ty) | Type::Line(ty, _) => *ty,
484            Type::Semantic(_) => unimplemented!("Can't get storage for semantic type"),
485        }
486    }
487
488    pub fn is_semantic(&self) -> bool {
489        match self {
490            Type::Scalar(_) | Type::Line(_, _) => false,
491            Type::Semantic(_) => true,
492        }
493    }
494}
495
496impl Display for Type {
497    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
498        match self {
499            Type::Scalar(ty) => write!(f, "{ty}"),
500            Type::Line(ty, line_size) => write!(f, "line<{ty}, {line_size}>"),
501            Type::Semantic(ty) => write!(f, "{ty}"),
502        }
503    }
504}
505
506impl Display for StorageType {
507    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
508        match self {
509            StorageType::Scalar(ty) => write!(f, "{ty}"),
510            StorageType::Packed(ty, factor) => write!(f, "packed<{ty}, {factor}>"),
511            StorageType::Atomic(ty) => write!(f, "atomic<{ty}>"),
512        }
513    }
514}
515
516impl Display for ElemType {
517    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
518        match self {
519            Self::Float(kind) => match kind {
520                FloatKind::E2M1 => f.write_str("e2m1"),
521                FloatKind::E2M3 => f.write_str("e2m3"),
522                FloatKind::E3M2 => f.write_str("e3m2"),
523                FloatKind::E4M3 => f.write_str("e4m3"),
524                FloatKind::E5M2 => f.write_str("e5m2"),
525                FloatKind::UE8M0 => f.write_str("ue8m0"),
526                FloatKind::F16 => f.write_str("f16"),
527                FloatKind::BF16 => f.write_str("bf16"),
528                FloatKind::Flex32 => f.write_str("flex32"),
529                FloatKind::TF32 => f.write_str("tf32"),
530                FloatKind::F32 => f.write_str("f32"),
531                FloatKind::F64 => f.write_str("f64"),
532            },
533            Self::Int(kind) => match kind {
534                IntKind::I8 => f.write_str("i8"),
535                IntKind::I16 => f.write_str("i16"),
536                IntKind::I32 => f.write_str("i32"),
537                IntKind::I64 => f.write_str("i64"),
538            },
539            Self::UInt(kind) => match kind {
540                UIntKind::U8 => f.write_str("u8"),
541                UIntKind::U16 => f.write_str("u16"),
542                UIntKind::U32 => f.write_str("u32"),
543                UIntKind::U64 => f.write_str("u64"),
544            },
545            Self::Bool => f.write_str("bool"),
546        }
547    }
548}
549
550impl Display for SemanticType {
551    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
552        match self {
553            SemanticType::Barrier => f.write_str("barrier"),
554            SemanticType::BarrierToken => f.write_str("barrier_token"),
555            SemanticType::Pipeline => f.write_str("pipeline"),
556            SemanticType::TensorMap => f.write_str("tensor_map"),
557        }
558    }
559}
560
561impl From<bool> for Variable {
562    fn from(value: bool) -> Self {
563        Variable::constant(ConstantScalarValue::Bool(value))
564    }
565}
566
567impl From<i8> for Variable {
568    fn from(value: i8) -> Self {
569        Variable::constant(ConstantScalarValue::Int(value as i64, IntKind::I8))
570    }
571}
572
573impl From<i16> for Variable {
574    fn from(value: i16) -> Self {
575        Variable::constant(ConstantScalarValue::Int(value as i64, IntKind::I16))
576    }
577}
578
579impl From<i32> for Variable {
580    fn from(value: i32) -> Self {
581        Variable::constant(ConstantScalarValue::Int(value as i64, IntKind::I32))
582    }
583}
584
585impl From<i64> for Variable {
586    fn from(value: i64) -> Self {
587        Variable::constant(ConstantScalarValue::Int(value, IntKind::I64))
588    }
589}
590
591impl From<e2m1> for Variable {
592    fn from(_value: e2m1) -> Self {
593        unimplemented!("Can't currently construct minifloats")
594    }
595}
596
597impl From<e2m1x2> for Variable {
598    fn from(_value: e2m1x2) -> Self {
599        unimplemented!("Can't currently construct minifloats")
600    }
601}
602
603impl From<e2m3> for Variable {
604    fn from(_value: e2m3) -> Self {
605        unimplemented!("Can't currently construct minifloats")
606    }
607}
608
609impl From<e3m2> for Variable {
610    fn from(_value: e3m2) -> Self {
611        unimplemented!("Can't currently construct minifloats")
612    }
613}
614
615impl From<e4m3> for Variable {
616    fn from(_value: e4m3) -> Self {
617        unimplemented!("Can't currently construct minifloats")
618    }
619}
620
621impl From<e5m2> for Variable {
622    fn from(_value: e5m2) -> Self {
623        unimplemented!("Can't currently construct minifloats")
624    }
625}
626
627impl From<ue8m0> for Variable {
628    fn from(_value: ue8m0) -> Self {
629        unimplemented!("Can't currently construct minifloats")
630    }
631}
632
633impl From<half::f16> for Variable {
634    fn from(value: half::f16) -> Self {
635        Variable::constant(ConstantScalarValue::Float(value.to_f64(), FloatKind::F16))
636    }
637}
638
639impl From<half::bf16> for Variable {
640    fn from(value: half::bf16) -> Self {
641        Variable::constant(ConstantScalarValue::Float(value.to_f64(), FloatKind::BF16))
642    }
643}
644
645impl From<flex32> for Variable {
646    fn from(value: flex32) -> Self {
647        Variable::constant(ConstantScalarValue::Float(
648            value.to_f64(),
649            FloatKind::Flex32,
650        ))
651    }
652}
653
654impl From<tf32> for Variable {
655    fn from(value: tf32) -> Self {
656        Variable::constant(ConstantScalarValue::Float(value.to_f64(), FloatKind::TF32))
657    }
658}
659
660impl From<f32> for Variable {
661    fn from(value: f32) -> Self {
662        Variable::constant(ConstantScalarValue::Float(value as f64, FloatKind::F32))
663    }
664}
665
666impl From<f64> for Variable {
667    fn from(value: f64) -> Self {
668        Variable::constant(ConstantScalarValue::Float(value, FloatKind::F64))
669    }
670}
671
672impl From<u8> for Variable {
673    fn from(value: u8) -> Self {
674        Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U8))
675    }
676}
677
678impl From<u16> for Variable {
679    fn from(value: u16) -> Self {
680        Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U16))
681    }
682}
683
684impl From<u32> for Variable {
685    fn from(value: u32) -> Self {
686        Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U32))
687    }
688}
689
690impl From<u64> for Variable {
691    fn from(value: u64) -> Self {
692        Variable::constant(ConstantScalarValue::UInt(value, UIntKind::U64))
693    }
694}
695
696impl From<usize> for Variable {
697    fn from(value: usize) -> Self {
698        Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U32))
699    }
700}