cubecl_ir/
type.rs

1use super::{ConstantScalarValue, Variable, VariableKind};
2use crate::TypeHash;
3use core::fmt::Display;
4use cubecl_common::{e2m1, e2m1x2, e2m3, e3m2, e4m3, e5m2, flex32, tf32, ue8m0};
5
6#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
7#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
8#[allow(missing_docs)]
9pub enum FloatKind {
10    /// FP4, 2 bit exponent, 1 bit mantissa
11    E2M1,
12    /// FP6, 2 bit exponent, 3 bit mantissa
13    /// Note: represented by an 8-bit value, with the upper two bits being insignificant
14    E2M3,
15    /// FP6, 3 bit exponent, 2 bit mantissa
16    /// Note: represented by an 8-bit value, with the upper two bits being insignificant
17    E3M2,
18    /// FP8, 4 bit exponent, 3 bit mantissa
19    E4M3,
20    /// FP8, 5 bit exponent, 2 bit mantissa
21    E5M2,
22    /// FP8, unsigned, 8 bit exponent, 0 bit mantissa
23    UE8M0,
24    F16,
25    BF16,
26    Flex32,
27    F32,
28    TF32,
29    F64,
30}
31
32#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
33#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
34#[allow(missing_docs)]
35pub enum IntKind {
36    I8,
37    I16,
38    I32,
39    I64,
40}
41
42#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
43#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
44#[allow(missing_docs)]
45pub enum UIntKind {
46    U8,
47    U16,
48    U32,
49    U64,
50}
51
52/// Conceptual element type, not necessarily the physical type used in the code
53#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
54#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
55#[allow(missing_docs)]
56pub enum ElemType {
57    Float(FloatKind),
58    Int(IntKind),
59    UInt(UIntKind),
60    Bool,
61}
62
63#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
64#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
65pub enum SemanticType {
66    Barrier,
67    Pipeline,
68    TensorMap,
69}
70
71/// Physical type containing one or more elements
72#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
73#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
74pub enum StorageType {
75    /// `ElemType` is the same as the physical type
76    Scalar(ElemType),
77    /// Packed values of type `ElemType`
78    Packed(ElemType, u32),
79    /// Atomically accessed version of `ElemType`
80    Atomic(ElemType),
81}
82
83impl ElemType {
84    /// Create a constant scalar from a float.
85    ///
86    /// The output will have the same type as the element.
87    pub fn constant_from_f64(&self, val: f64) -> Variable {
88        Variable::constant(match self {
89            ElemType::Float(kind) => ConstantScalarValue::Float(val, *kind),
90            ElemType::Int(kind) => ConstantScalarValue::Int(val as i64, *kind),
91            ElemType::UInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
92            ElemType::Bool => ConstantScalarValue::Bool(val > 0.0),
93        })
94    }
95    /// Create a constant scalar from a signed integer.
96    ///
97    /// The output will have the same type as the element.
98    pub fn constant_from_i64(&self, val: i64) -> Variable {
99        Variable::constant(match self {
100            ElemType::Float(kind) => ConstantScalarValue::Float(val as f64, *kind),
101            ElemType::Int(kind) => ConstantScalarValue::Int(val, *kind),
102            ElemType::UInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
103            ElemType::Bool => ConstantScalarValue::Bool(val > 0),
104        })
105    }
106    /// Create a constant scalar from a unsigned integer.
107    ///
108    /// The output will have the same type as the element.
109    pub fn constant_from_u64(&self, val: u64) -> Variable {
110        Variable::constant(match self {
111            ElemType::Float(kind) => ConstantScalarValue::Float(val as f64, *kind),
112            ElemType::Int(kind) => ConstantScalarValue::Int(val as i64, *kind),
113            ElemType::UInt(kind) => ConstantScalarValue::UInt(val, *kind),
114            ElemType::Bool => ConstantScalarValue::Bool(val > 0),
115        })
116    }
117    /// Create a constant scalar from a boolean.
118    ///
119    /// The output will have the same type as the element.
120    pub fn constant_from_bool(&self, val: bool) -> Variable {
121        Variable::constant(match self {
122            ElemType::Float(kind) => ConstantScalarValue::Float(val as u32 as f64, *kind),
123            ElemType::Int(kind) => ConstantScalarValue::Int(val as i64, *kind),
124            ElemType::UInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
125            ElemType::Bool => ConstantScalarValue::Bool(val),
126        })
127    }
128
129    /// Ensure that the variable provided, when a constant, is the same type as elem.
130    pub fn from_constant(&self, constant: Variable) -> Variable {
131        let value = match constant.kind {
132            VariableKind::ConstantScalar(value) => value,
133            _ => return constant,
134        };
135
136        match value {
137            ConstantScalarValue::Int(val, _) => self.constant_from_i64(val),
138            ConstantScalarValue::Float(val, _) => self.constant_from_f64(val),
139            ConstantScalarValue::UInt(val, _) => self.constant_from_u64(val),
140            ConstantScalarValue::Bool(val) => self.constant_from_bool(val),
141        }
142    }
143    /// Get the size in bytes.
144    pub const fn size(&self) -> usize {
145        match self {
146            ElemType::Float(kind) => match kind {
147                FloatKind::E2M1
148                | FloatKind::E2M3
149                | FloatKind::E3M2
150                | FloatKind::E4M3
151                | FloatKind::E5M2
152                | FloatKind::UE8M0 => core::mem::size_of::<u8>(),
153                FloatKind::F16 => core::mem::size_of::<half::f16>(),
154                FloatKind::BF16 => core::mem::size_of::<half::bf16>(),
155                FloatKind::F32 => core::mem::size_of::<f32>(),
156                FloatKind::F64 => core::mem::size_of::<f64>(),
157                FloatKind::Flex32 => core::mem::size_of::<f32>(),
158                FloatKind::TF32 => core::mem::size_of::<f32>(),
159            },
160            ElemType::Int(kind) => match kind {
161                IntKind::I8 => core::mem::size_of::<i8>(),
162                IntKind::I16 => core::mem::size_of::<i16>(),
163                IntKind::I32 => core::mem::size_of::<i32>(),
164                IntKind::I64 => core::mem::size_of::<i64>(),
165            },
166            ElemType::UInt(kind) => match kind {
167                UIntKind::U8 => core::mem::size_of::<u8>(),
168                UIntKind::U16 => core::mem::size_of::<u16>(),
169                UIntKind::U32 => core::mem::size_of::<u32>(),
170                UIntKind::U64 => core::mem::size_of::<u64>(),
171            },
172            ElemType::Bool => core::mem::size_of::<bool>(),
173        }
174    }
175
176    /// Get the size in bits.
177    pub const fn size_bits(&self) -> usize {
178        match self {
179            ElemType::Float(kind) => match kind {
180                FloatKind::E2M3
181                | FloatKind::E3M2
182                | FloatKind::E4M3
183                | FloatKind::E5M2
184                | FloatKind::UE8M0
185                | FloatKind::F16
186                | FloatKind::BF16
187                | FloatKind::F32
188                | FloatKind::F64
189                | FloatKind::Flex32
190                | FloatKind::TF32 => self.size() * 8,
191                FloatKind::E2M1 => 4,
192            },
193            ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool => self.size() * 8,
194        }
195    }
196
197    pub const fn min_line_size(&self) -> u8 {
198        match self {
199            ElemType::Float(FloatKind::E2M1) => 2,
200            _ => 1,
201        }
202    }
203
204    pub fn is_int(&self) -> bool {
205        matches!(self, ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool)
206    }
207
208    pub fn is_signed_int(&self) -> bool {
209        matches!(self, ElemType::Int(_))
210    }
211
212    pub fn is_unsigned_int(&self) -> bool {
213        matches!(self, ElemType::UInt(_) | ElemType::Bool)
214    }
215
216    pub fn is_float(&self) -> bool {
217        matches!(self, ElemType::Float(_))
218    }
219
220    pub fn max_variable(&self) -> Variable {
221        let value = match self {
222            ElemType::Float(kind) => match kind {
223                FloatKind::E2M1 => ConstantScalarValue::Float(e2m1::MAX, FloatKind::E2M1),
224                FloatKind::E2M3 => ConstantScalarValue::Float(e2m3::MAX, FloatKind::E2M3),
225                FloatKind::E3M2 => ConstantScalarValue::Float(e3m2::MAX, FloatKind::E3M2),
226                FloatKind::E4M3 => ConstantScalarValue::Float(e4m3::MAX, FloatKind::E4M3),
227                FloatKind::E5M2 => ConstantScalarValue::Float(e5m2::MAX, FloatKind::E5M2),
228                FloatKind::UE8M0 => ConstantScalarValue::Float(ue8m0::MAX, FloatKind::UE8M0),
229                FloatKind::F16 => {
230                    ConstantScalarValue::Float(half::f16::MAX.to_f64(), FloatKind::F16)
231                }
232                FloatKind::BF16 => {
233                    ConstantScalarValue::Float(half::bf16::MAX.to_f64(), FloatKind::BF16)
234                }
235                FloatKind::Flex32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::Flex32),
236                FloatKind::F32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::F32),
237                FloatKind::TF32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::TF32),
238                FloatKind::F64 => ConstantScalarValue::Float(f64::MAX, FloatKind::F64),
239            },
240            ElemType::Int(kind) => match kind {
241                IntKind::I8 => ConstantScalarValue::Int(i8::MAX.into(), IntKind::I8),
242                IntKind::I16 => ConstantScalarValue::Int(i16::MAX.into(), IntKind::I16),
243                IntKind::I32 => ConstantScalarValue::Int(i32::MAX.into(), IntKind::I32),
244                IntKind::I64 => ConstantScalarValue::Int(i64::MAX, IntKind::I64),
245            },
246            ElemType::UInt(kind) => match kind {
247                UIntKind::U8 => ConstantScalarValue::UInt(u8::MAX.into(), UIntKind::U8),
248                UIntKind::U16 => ConstantScalarValue::UInt(u16::MAX.into(), UIntKind::U16),
249                UIntKind::U32 => ConstantScalarValue::UInt(u32::MAX.into(), UIntKind::U32),
250                UIntKind::U64 => ConstantScalarValue::UInt(u64::MAX, UIntKind::U64),
251            },
252            ElemType::Bool => ConstantScalarValue::Bool(true),
253        };
254
255        Variable::new(VariableKind::ConstantScalar(value), Type::scalar(*self))
256    }
257
258    pub fn min_variable(&self) -> Variable {
259        let value = match self {
260            ElemType::Float(kind) => match kind {
261                FloatKind::E2M1 => ConstantScalarValue::Float(e2m1::MIN, FloatKind::E2M1),
262                FloatKind::E2M3 => ConstantScalarValue::Float(e2m3::MIN, FloatKind::E2M3),
263                FloatKind::E3M2 => ConstantScalarValue::Float(e3m2::MIN, FloatKind::E3M2),
264                FloatKind::E4M3 => ConstantScalarValue::Float(e4m3::MIN, FloatKind::E4M3),
265                FloatKind::E5M2 => ConstantScalarValue::Float(e5m2::MIN, FloatKind::E5M2),
266                FloatKind::UE8M0 => ConstantScalarValue::Float(ue8m0::MIN, FloatKind::UE8M0),
267                FloatKind::F16 => {
268                    ConstantScalarValue::Float(half::f16::MIN.to_f64(), FloatKind::F16)
269                }
270                FloatKind::BF16 => {
271                    ConstantScalarValue::Float(half::bf16::MIN.to_f64(), FloatKind::BF16)
272                }
273                FloatKind::Flex32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::Flex32),
274                FloatKind::F32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::F32),
275                FloatKind::TF32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::TF32),
276                FloatKind::F64 => ConstantScalarValue::Float(f64::MIN, FloatKind::F64),
277            },
278            ElemType::Int(kind) => match kind {
279                IntKind::I8 => ConstantScalarValue::Int(i8::MIN.into(), IntKind::I8),
280                IntKind::I16 => ConstantScalarValue::Int(i16::MIN.into(), IntKind::I16),
281                IntKind::I32 => ConstantScalarValue::Int(i32::MIN.into(), IntKind::I32),
282                IntKind::I64 => ConstantScalarValue::Int(i64::MIN, IntKind::I64),
283            },
284            ElemType::UInt(kind) => match kind {
285                UIntKind::U8 => ConstantScalarValue::UInt(u8::MIN.into(), UIntKind::U8),
286                UIntKind::U16 => ConstantScalarValue::UInt(u16::MIN.into(), UIntKind::U16),
287                UIntKind::U32 => ConstantScalarValue::UInt(u32::MIN.into(), UIntKind::U32),
288                UIntKind::U64 => ConstantScalarValue::UInt(u64::MIN, UIntKind::U64),
289            },
290            ElemType::Bool => ConstantScalarValue::Bool(false),
291        };
292
293        Variable::new(VariableKind::ConstantScalar(value), Type::scalar(*self))
294    }
295}
296
297impl StorageType {
298    pub fn elem_type(&self) -> ElemType {
299        match self {
300            StorageType::Scalar(ty) | StorageType::Packed(ty, _) | StorageType::Atomic(ty) => *ty,
301        }
302    }
303
304    pub fn packing_factor(&self) -> u32 {
305        match self {
306            StorageType::Packed(_, factor) => *factor,
307            _ => 1,
308        }
309    }
310
311    pub fn is_atomic(&self) -> bool {
312        matches!(self, StorageType::Atomic(_))
313    }
314
315    pub fn size(&self) -> usize {
316        self.size_bits().div_ceil(8)
317    }
318
319    pub fn size_bits(&self) -> usize {
320        match self {
321            StorageType::Packed(ty, factor) => ty.size_bits() * *factor as usize,
322            StorageType::Scalar(ty) | StorageType::Atomic(ty) => ty.size_bits(),
323        }
324    }
325
326    /// Ensure that the variable provided, when a constant, is the same type as elem.
327    pub fn from_constant(&self, constant: Variable) -> Variable {
328        self.elem_type().from_constant(constant)
329    }
330
331    pub fn is_int(&self) -> bool {
332        self.elem_type().is_int()
333    }
334
335    pub fn is_signed_int(&self) -> bool {
336        self.elem_type().is_signed_int()
337    }
338
339    pub fn is_unsigned_int(&self) -> bool {
340        self.elem_type().is_unsigned_int()
341    }
342
343    pub fn is_float(&self) -> bool {
344        self.elem_type().is_float()
345    }
346}
347
348impl From<ElemType> for Type {
349    fn from(val: ElemType) -> Self {
350        Type::scalar(val)
351    }
352}
353
354impl From<ElemType> for StorageType {
355    fn from(val: ElemType) -> Self {
356        StorageType::Scalar(val)
357    }
358}
359
360impl From<StorageType> for Type {
361    fn from(val: StorageType) -> Self {
362        Type::new(val)
363    }
364}
365
366impl From<SemanticType> for Type {
367    fn from(val: SemanticType) -> Self {
368        Type::semantic(val)
369    }
370}
371
372#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
373#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
374pub enum Type {
375    /// Scalar type containing a single storage element
376    Scalar(StorageType),
377    /// Line wrapping `n` storage elements
378    Line(StorageType, u32),
379    /// No defined physical representation, purely semantic. i.e. barrier, pipeline
380    Semantic(SemanticType),
381}
382
383pub type LineSize = u32;
384
385impl Type {
386    /// Fetch the elem of the item.
387    pub fn elem_type(&self) -> ElemType {
388        self.storage_type().elem_type()
389    }
390
391    /// Create a new item
392    pub fn new(storage: StorageType) -> Self {
393        Type::Scalar(storage)
394    }
395
396    pub fn scalar(elem: ElemType) -> Self {
397        Self::new(StorageType::Scalar(elem))
398    }
399
400    pub fn semantic(ty: SemanticType) -> Self {
401        Self::Semantic(ty)
402    }
403
404    pub fn line(self, line_size: LineSize) -> Type {
405        match line_size > 1 {
406            true => Type::Line(self.storage_type(), line_size),
407            false => Type::Scalar(self.storage_type()),
408        }
409    }
410
411    pub fn line_size(&self) -> u32 {
412        match self {
413            Type::Scalar(_) => 1,
414            Type::Line(_, line_size) => *line_size,
415            Type::Semantic(_) => 0,
416        }
417    }
418
419    pub fn size(&self) -> usize {
420        match self {
421            Type::Scalar(ty) => ty.size(),
422            Type::Line(ty, line_size) => ty.size() * *line_size as usize,
423            Type::Semantic(_) => 0,
424        }
425    }
426
427    pub fn size_bits(&self) -> usize {
428        match self {
429            Type::Scalar(ty) => ty.size_bits(),
430            Type::Line(ty, line_size) => ty.size_bits() * *line_size as usize,
431            Type::Semantic(_) => 0,
432        }
433    }
434
435    pub fn is_atomic(&self) -> bool {
436        !self.is_semantic() && self.storage_type().is_atomic()
437    }
438
439    pub fn is_int(&self) -> bool {
440        !self.is_semantic() && self.storage_type().is_int()
441    }
442
443    pub fn is_signed_int(&self) -> bool {
444        !self.is_semantic() && self.storage_type().is_signed_int()
445    }
446
447    pub fn is_unsigned_int(&self) -> bool {
448        !self.is_semantic() && self.storage_type().is_unsigned_int()
449    }
450
451    pub fn is_float(&self) -> bool {
452        !self.is_semantic() && self.storage_type().is_float()
453    }
454
455    pub fn storage_type(&self) -> StorageType {
456        match self {
457            Type::Scalar(ty) | Type::Line(ty, _) => *ty,
458            Type::Semantic(_) => unimplemented!("Can't get storage for semantic type"),
459        }
460    }
461
462    pub fn is_semantic(&self) -> bool {
463        match self {
464            Type::Scalar(_) | Type::Line(_, _) => false,
465            Type::Semantic(_) => true,
466        }
467    }
468}
469
470impl Display for Type {
471    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
472        match self {
473            Type::Scalar(ty) => write!(f, "{ty}"),
474            Type::Line(ty, line_size) => write!(f, "line<{ty}, {line_size}>"),
475            Type::Semantic(ty) => write!(f, "{ty}"),
476        }
477    }
478}
479
480impl Display for StorageType {
481    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
482        match self {
483            StorageType::Scalar(ty) => write!(f, "{ty}"),
484            StorageType::Packed(ty, factor) => write!(f, "packed<{ty}, {factor}>"),
485            StorageType::Atomic(ty) => write!(f, "atomic<{ty}>"),
486        }
487    }
488}
489
490impl Display for ElemType {
491    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
492        match self {
493            Self::Float(kind) => match kind {
494                FloatKind::E2M1 => f.write_str("e2m1"),
495                FloatKind::E2M3 => f.write_str("e2m3"),
496                FloatKind::E3M2 => f.write_str("e3m2"),
497                FloatKind::E4M3 => f.write_str("e4m3"),
498                FloatKind::E5M2 => f.write_str("e5m2"),
499                FloatKind::UE8M0 => f.write_str("ue8m0"),
500                FloatKind::F16 => f.write_str("f16"),
501                FloatKind::BF16 => f.write_str("bf16"),
502                FloatKind::Flex32 => f.write_str("flex32"),
503                FloatKind::TF32 => f.write_str("tf32"),
504                FloatKind::F32 => f.write_str("f32"),
505                FloatKind::F64 => f.write_str("f64"),
506            },
507            Self::Int(kind) => match kind {
508                IntKind::I8 => f.write_str("i8"),
509                IntKind::I16 => f.write_str("i16"),
510                IntKind::I32 => f.write_str("i32"),
511                IntKind::I64 => f.write_str("i64"),
512            },
513            Self::UInt(kind) => match kind {
514                UIntKind::U8 => f.write_str("u8"),
515                UIntKind::U16 => f.write_str("u16"),
516                UIntKind::U32 => f.write_str("u32"),
517                UIntKind::U64 => f.write_str("u64"),
518            },
519            Self::Bool => f.write_str("bool"),
520        }
521    }
522}
523
524impl Display for SemanticType {
525    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
526        match self {
527            SemanticType::Barrier => f.write_str("barrier"),
528            SemanticType::Pipeline => f.write_str("pipeline"),
529            SemanticType::TensorMap => f.write_str("tensor_map"),
530        }
531    }
532}
533
534impl From<bool> for Variable {
535    fn from(value: bool) -> Self {
536        Variable::constant(ConstantScalarValue::Bool(value))
537    }
538}
539
540impl From<i8> for Variable {
541    fn from(value: i8) -> Self {
542        Variable::constant(ConstantScalarValue::Int(value as i64, IntKind::I8))
543    }
544}
545
546impl From<i16> for Variable {
547    fn from(value: i16) -> Self {
548        Variable::constant(ConstantScalarValue::Int(value as i64, IntKind::I16))
549    }
550}
551
552impl From<i32> for Variable {
553    fn from(value: i32) -> Self {
554        Variable::constant(ConstantScalarValue::Int(value as i64, IntKind::I32))
555    }
556}
557
558impl From<i64> for Variable {
559    fn from(value: i64) -> Self {
560        Variable::constant(ConstantScalarValue::Int(value, IntKind::I64))
561    }
562}
563
564impl From<e2m1> for Variable {
565    fn from(_value: e2m1) -> Self {
566        unimplemented!("Can't currently construct minifloats")
567    }
568}
569
570impl From<e2m1x2> for Variable {
571    fn from(_value: e2m1x2) -> Self {
572        unimplemented!("Can't currently construct minifloats")
573    }
574}
575
576impl From<e2m3> for Variable {
577    fn from(_value: e2m3) -> Self {
578        unimplemented!("Can't currently construct minifloats")
579    }
580}
581
582impl From<e3m2> for Variable {
583    fn from(_value: e3m2) -> Self {
584        unimplemented!("Can't currently construct minifloats")
585    }
586}
587
588impl From<e4m3> for Variable {
589    fn from(_value: e4m3) -> Self {
590        unimplemented!("Can't currently construct minifloats")
591    }
592}
593
594impl From<e5m2> for Variable {
595    fn from(_value: e5m2) -> Self {
596        unimplemented!("Can't currently construct minifloats")
597    }
598}
599
600impl From<ue8m0> for Variable {
601    fn from(_value: ue8m0) -> Self {
602        unimplemented!("Can't currently construct minifloats")
603    }
604}
605
606impl From<half::f16> for Variable {
607    fn from(value: half::f16) -> Self {
608        Variable::constant(ConstantScalarValue::Float(value.to_f64(), FloatKind::F16))
609    }
610}
611
612impl From<half::bf16> for Variable {
613    fn from(value: half::bf16) -> Self {
614        Variable::constant(ConstantScalarValue::Float(value.to_f64(), FloatKind::BF16))
615    }
616}
617
618impl From<flex32> for Variable {
619    fn from(value: flex32) -> Self {
620        Variable::constant(ConstantScalarValue::Float(
621            value.to_f64(),
622            FloatKind::Flex32,
623        ))
624    }
625}
626
627impl From<tf32> for Variable {
628    fn from(value: tf32) -> Self {
629        Variable::constant(ConstantScalarValue::Float(value.to_f64(), FloatKind::TF32))
630    }
631}
632
633impl From<f32> for Variable {
634    fn from(value: f32) -> Self {
635        Variable::constant(ConstantScalarValue::Float(value as f64, FloatKind::F32))
636    }
637}
638
639impl From<f64> for Variable {
640    fn from(value: f64) -> Self {
641        Variable::constant(ConstantScalarValue::Float(value, FloatKind::F64))
642    }
643}
644
645impl From<u8> for Variable {
646    fn from(value: u8) -> Self {
647        Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U8))
648    }
649}
650
651impl From<u16> for Variable {
652    fn from(value: u16) -> Self {
653        Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U16))
654    }
655}
656
657impl From<u32> for Variable {
658    fn from(value: u32) -> Self {
659        Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U32))
660    }
661}
662
663impl From<u64> for Variable {
664    fn from(value: u64) -> Self {
665        Variable::constant(ConstantScalarValue::UInt(value, UIntKind::U64))
666    }
667}
668
669impl From<usize> for Variable {
670    fn from(value: usize) -> Self {
671        Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U32))
672    }
673}