cubecl_ir/
type.rs

1use super::{ConstantScalarValue, Variable, VariableKind};
2use crate::TypeHash;
3use core::fmt::Display;
4use cubecl_common::{e2m1, e2m1x2, e2m3, e3m2, e4m3, e5m2, flex32, tf32, ue8m0};
5
6#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
7#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
8#[allow(missing_docs)]
9pub enum FloatKind {
10    /// FP4, 2 bit exponent, 1 bit mantissa
11    E2M1,
12    /// FP6, 2 bit exponent, 3 bit mantissa
13    /// Note: represented by an 8-bit value, with the upper two bits being insignificant
14    E2M3,
15    /// FP6, 3 bit exponent, 2 bit mantissa
16    /// Note: represented by an 8-bit value, with the upper two bits being insignificant
17    E3M2,
18    /// FP8, 4 bit exponent, 3 bit mantissa
19    E4M3,
20    /// FP8, 5 bit exponent, 2 bit mantissa
21    E5M2,
22    /// FP8, unsigned, 8 bit exponent, 0 bit mantissa
23    UE8M0,
24    F16,
25    BF16,
26    Flex32,
27    F32,
28    TF32,
29    F64,
30}
31
32#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
33#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
34#[allow(missing_docs)]
35pub enum IntKind {
36    I8,
37    I16,
38    I32,
39    I64,
40}
41
42#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
43#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
44#[allow(missing_docs)]
45pub enum UIntKind {
46    U8,
47    U16,
48    U32,
49    U64,
50}
51
52/// Conceptual element type, not necessarily the physical type used in the code
53#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
54#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
55#[allow(missing_docs)]
56pub enum ElemType {
57    Float(FloatKind),
58    Int(IntKind),
59    UInt(UIntKind),
60    Bool,
61}
62
63#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
64#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
65pub enum SemanticType {
66    Barrier,
67    BarrierToken,
68    Pipeline,
69    TensorMap,
70}
71
72/// Physical type containing one or more elements
73#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
74#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
75pub enum StorageType {
76    /// `ElemType` is the same as the physical type
77    Scalar(ElemType),
78    /// Packed values of type `ElemType`
79    Packed(ElemType, u32),
80    /// Atomically accessed version of `ElemType`
81    Atomic(ElemType),
82}
83
84impl ElemType {
85    /// Create a constant scalar from a float.
86    ///
87    /// The output will have the same type as the element.
88    pub fn constant_from_f64(&self, val: f64) -> Variable {
89        Variable::constant(match self {
90            ElemType::Float(kind) => ConstantScalarValue::Float(val, *kind),
91            ElemType::Int(kind) => ConstantScalarValue::Int(val as i64, *kind),
92            ElemType::UInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
93            ElemType::Bool => ConstantScalarValue::Bool(val > 0.0),
94        })
95    }
96    /// Create a constant scalar from a signed integer.
97    ///
98    /// The output will have the same type as the element.
99    pub fn constant_from_i64(&self, val: i64) -> Variable {
100        Variable::constant(match self {
101            ElemType::Float(kind) => ConstantScalarValue::Float(val as f64, *kind),
102            ElemType::Int(kind) => ConstantScalarValue::Int(val, *kind),
103            ElemType::UInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
104            ElemType::Bool => ConstantScalarValue::Bool(val > 0),
105        })
106    }
107    /// Create a constant scalar from a unsigned integer.
108    ///
109    /// The output will have the same type as the element.
110    pub fn constant_from_u64(&self, val: u64) -> Variable {
111        Variable::constant(match self {
112            ElemType::Float(kind) => ConstantScalarValue::Float(val as f64, *kind),
113            ElemType::Int(kind) => ConstantScalarValue::Int(val as i64, *kind),
114            ElemType::UInt(kind) => ConstantScalarValue::UInt(val, *kind),
115            ElemType::Bool => ConstantScalarValue::Bool(val > 0),
116        })
117    }
118    /// Create a constant scalar from a boolean.
119    ///
120    /// The output will have the same type as the element.
121    pub fn constant_from_bool(&self, val: bool) -> Variable {
122        Variable::constant(match self {
123            ElemType::Float(kind) => ConstantScalarValue::Float(val as u32 as f64, *kind),
124            ElemType::Int(kind) => ConstantScalarValue::Int(val as i64, *kind),
125            ElemType::UInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
126            ElemType::Bool => ConstantScalarValue::Bool(val),
127        })
128    }
129
130    /// Ensure that the variable provided, when a constant, is the same type as elem.
131    pub fn from_constant(&self, constant: Variable) -> Variable {
132        let value = match constant.kind {
133            VariableKind::ConstantScalar(value) => value,
134            _ => return constant,
135        };
136
137        match value {
138            ConstantScalarValue::Int(val, _) => self.constant_from_i64(val),
139            ConstantScalarValue::Float(val, _) => self.constant_from_f64(val),
140            ConstantScalarValue::UInt(val, _) => self.constant_from_u64(val),
141            ConstantScalarValue::Bool(val) => self.constant_from_bool(val),
142        }
143    }
144    /// Get the size in bytes.
145    pub const fn size(&self) -> usize {
146        match self {
147            ElemType::Float(kind) => match kind {
148                FloatKind::E2M1
149                | FloatKind::E2M3
150                | FloatKind::E3M2
151                | FloatKind::E4M3
152                | FloatKind::E5M2
153                | FloatKind::UE8M0 => core::mem::size_of::<u8>(),
154                FloatKind::F16 => core::mem::size_of::<half::f16>(),
155                FloatKind::BF16 => core::mem::size_of::<half::bf16>(),
156                FloatKind::F32 => core::mem::size_of::<f32>(),
157                FloatKind::F64 => core::mem::size_of::<f64>(),
158                FloatKind::Flex32 => core::mem::size_of::<f32>(),
159                FloatKind::TF32 => core::mem::size_of::<f32>(),
160            },
161            ElemType::Int(kind) => match kind {
162                IntKind::I8 => core::mem::size_of::<i8>(),
163                IntKind::I16 => core::mem::size_of::<i16>(),
164                IntKind::I32 => core::mem::size_of::<i32>(),
165                IntKind::I64 => core::mem::size_of::<i64>(),
166            },
167            ElemType::UInt(kind) => match kind {
168                UIntKind::U8 => core::mem::size_of::<u8>(),
169                UIntKind::U16 => core::mem::size_of::<u16>(),
170                UIntKind::U32 => core::mem::size_of::<u32>(),
171                UIntKind::U64 => core::mem::size_of::<u64>(),
172            },
173            ElemType::Bool => core::mem::size_of::<bool>(),
174        }
175    }
176
177    /// Get the size in bits.
178    pub const fn size_bits(&self) -> usize {
179        match self {
180            ElemType::Float(kind) => match kind {
181                FloatKind::E2M3
182                | FloatKind::E3M2
183                | FloatKind::E4M3
184                | FloatKind::E5M2
185                | FloatKind::UE8M0
186                | FloatKind::F16
187                | FloatKind::BF16
188                | FloatKind::F32
189                | FloatKind::F64
190                | FloatKind::Flex32
191                | FloatKind::TF32 => self.size() * 8,
192                FloatKind::E2M1 => 4,
193            },
194            ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool => self.size() * 8,
195        }
196    }
197
198    pub const fn min_line_size(&self) -> u8 {
199        match self {
200            ElemType::Float(FloatKind::E2M1) => 2,
201            _ => 1,
202        }
203    }
204
205    pub fn is_int(&self) -> bool {
206        matches!(self, ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool)
207    }
208
209    pub fn is_signed_int(&self) -> bool {
210        matches!(self, ElemType::Int(_))
211    }
212
213    pub fn is_unsigned_int(&self) -> bool {
214        matches!(self, ElemType::UInt(_) | ElemType::Bool)
215    }
216
217    pub fn is_float(&self) -> bool {
218        matches!(self, ElemType::Float(_))
219    }
220
221    pub fn max_variable(&self) -> Variable {
222        let value = match self {
223            ElemType::Float(kind) => match kind {
224                FloatKind::E2M1 => ConstantScalarValue::Float(e2m1::MAX, FloatKind::E2M1),
225                FloatKind::E2M3 => ConstantScalarValue::Float(e2m3::MAX, FloatKind::E2M3),
226                FloatKind::E3M2 => ConstantScalarValue::Float(e3m2::MAX, FloatKind::E3M2),
227                FloatKind::E4M3 => ConstantScalarValue::Float(e4m3::MAX, FloatKind::E4M3),
228                FloatKind::E5M2 => ConstantScalarValue::Float(e5m2::MAX, FloatKind::E5M2),
229                FloatKind::UE8M0 => ConstantScalarValue::Float(ue8m0::MAX, FloatKind::UE8M0),
230                FloatKind::F16 => {
231                    ConstantScalarValue::Float(half::f16::MAX.to_f64(), FloatKind::F16)
232                }
233                FloatKind::BF16 => {
234                    ConstantScalarValue::Float(half::bf16::MAX.to_f64(), FloatKind::BF16)
235                }
236                FloatKind::Flex32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::Flex32),
237                FloatKind::F32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::F32),
238                FloatKind::TF32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::TF32),
239                FloatKind::F64 => ConstantScalarValue::Float(f64::MAX, FloatKind::F64),
240            },
241            ElemType::Int(kind) => match kind {
242                IntKind::I8 => ConstantScalarValue::Int(i8::MAX.into(), IntKind::I8),
243                IntKind::I16 => ConstantScalarValue::Int(i16::MAX.into(), IntKind::I16),
244                IntKind::I32 => ConstantScalarValue::Int(i32::MAX.into(), IntKind::I32),
245                IntKind::I64 => ConstantScalarValue::Int(i64::MAX, IntKind::I64),
246            },
247            ElemType::UInt(kind) => match kind {
248                UIntKind::U8 => ConstantScalarValue::UInt(u8::MAX.into(), UIntKind::U8),
249                UIntKind::U16 => ConstantScalarValue::UInt(u16::MAX.into(), UIntKind::U16),
250                UIntKind::U32 => ConstantScalarValue::UInt(u32::MAX.into(), UIntKind::U32),
251                UIntKind::U64 => ConstantScalarValue::UInt(u64::MAX, UIntKind::U64),
252            },
253            ElemType::Bool => ConstantScalarValue::Bool(true),
254        };
255
256        Variable::new(VariableKind::ConstantScalar(value), Type::scalar(*self))
257    }
258
259    pub fn min_variable(&self) -> Variable {
260        let value = match self {
261            ElemType::Float(kind) => match kind {
262                FloatKind::E2M1 => ConstantScalarValue::Float(e2m1::MIN, FloatKind::E2M1),
263                FloatKind::E2M3 => ConstantScalarValue::Float(e2m3::MIN, FloatKind::E2M3),
264                FloatKind::E3M2 => ConstantScalarValue::Float(e3m2::MIN, FloatKind::E3M2),
265                FloatKind::E4M3 => ConstantScalarValue::Float(e4m3::MIN, FloatKind::E4M3),
266                FloatKind::E5M2 => ConstantScalarValue::Float(e5m2::MIN, FloatKind::E5M2),
267                FloatKind::UE8M0 => ConstantScalarValue::Float(ue8m0::MIN, FloatKind::UE8M0),
268                FloatKind::F16 => {
269                    ConstantScalarValue::Float(half::f16::MIN.to_f64(), FloatKind::F16)
270                }
271                FloatKind::BF16 => {
272                    ConstantScalarValue::Float(half::bf16::MIN.to_f64(), FloatKind::BF16)
273                }
274                FloatKind::Flex32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::Flex32),
275                FloatKind::F32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::F32),
276                FloatKind::TF32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::TF32),
277                FloatKind::F64 => ConstantScalarValue::Float(f64::MIN, FloatKind::F64),
278            },
279            ElemType::Int(kind) => match kind {
280                IntKind::I8 => ConstantScalarValue::Int(i8::MIN.into(), IntKind::I8),
281                IntKind::I16 => ConstantScalarValue::Int(i16::MIN.into(), IntKind::I16),
282                IntKind::I32 => ConstantScalarValue::Int(i32::MIN.into(), IntKind::I32),
283                IntKind::I64 => ConstantScalarValue::Int(i64::MIN, IntKind::I64),
284            },
285            ElemType::UInt(kind) => match kind {
286                UIntKind::U8 => ConstantScalarValue::UInt(u8::MIN.into(), UIntKind::U8),
287                UIntKind::U16 => ConstantScalarValue::UInt(u16::MIN.into(), UIntKind::U16),
288                UIntKind::U32 => ConstantScalarValue::UInt(u32::MIN.into(), UIntKind::U32),
289                UIntKind::U64 => ConstantScalarValue::UInt(u64::MIN, UIntKind::U64),
290            },
291            ElemType::Bool => ConstantScalarValue::Bool(false),
292        };
293
294        Variable::new(VariableKind::ConstantScalar(value), Type::scalar(*self))
295    }
296}
297
298impl StorageType {
299    pub fn elem_type(&self) -> ElemType {
300        match self {
301            StorageType::Scalar(ty) | StorageType::Packed(ty, _) | StorageType::Atomic(ty) => *ty,
302        }
303    }
304
305    pub fn packing_factor(&self) -> u32 {
306        match self {
307            StorageType::Packed(_, factor) => *factor,
308            _ => 1,
309        }
310    }
311
312    pub fn is_atomic(&self) -> bool {
313        matches!(self, StorageType::Atomic(_))
314    }
315
316    pub fn size(&self) -> usize {
317        self.size_bits().div_ceil(8)
318    }
319
320    pub fn size_bits(&self) -> usize {
321        match self {
322            StorageType::Packed(ty, factor) => ty.size_bits() * *factor as usize,
323            StorageType::Scalar(ty) | StorageType::Atomic(ty) => ty.size_bits(),
324        }
325    }
326
327    /// Ensure that the variable provided, when a constant, is the same type as elem.
328    pub fn from_constant(&self, constant: Variable) -> Variable {
329        self.elem_type().from_constant(constant)
330    }
331
332    pub fn is_int(&self) -> bool {
333        self.elem_type().is_int()
334    }
335
336    pub fn is_signed_int(&self) -> bool {
337        self.elem_type().is_signed_int()
338    }
339
340    pub fn is_unsigned_int(&self) -> bool {
341        self.elem_type().is_unsigned_int()
342    }
343
344    pub fn is_float(&self) -> bool {
345        self.elem_type().is_float()
346    }
347}
348
349impl From<ElemType> for Type {
350    fn from(val: ElemType) -> Self {
351        Type::scalar(val)
352    }
353}
354
355impl From<ElemType> for StorageType {
356    fn from(val: ElemType) -> Self {
357        StorageType::Scalar(val)
358    }
359}
360
361impl From<StorageType> for Type {
362    fn from(val: StorageType) -> Self {
363        Type::new(val)
364    }
365}
366
367impl From<SemanticType> for Type {
368    fn from(val: SemanticType) -> Self {
369        Type::semantic(val)
370    }
371}
372
373#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
374#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
375pub enum Type {
376    /// Scalar type containing a single storage element
377    Scalar(StorageType),
378    /// Line wrapping `n` storage elements
379    Line(StorageType, u32),
380    /// No defined physical representation, purely semantic. i.e. barrier, pipeline
381    Semantic(SemanticType),
382}
383
384pub type LineSize = u32;
385
386impl Type {
387    /// Fetch the elem of the item.
388    pub fn elem_type(&self) -> ElemType {
389        self.storage_type().elem_type()
390    }
391
392    /// Create a new item
393    pub fn new(storage: StorageType) -> Self {
394        Type::Scalar(storage)
395    }
396
397    pub fn scalar(elem: ElemType) -> Self {
398        Self::new(StorageType::Scalar(elem))
399    }
400
401    pub fn semantic(ty: SemanticType) -> Self {
402        Self::Semantic(ty)
403    }
404
405    pub fn line(self, line_size: LineSize) -> Type {
406        match line_size > 1 {
407            true => Type::Line(self.storage_type(), line_size),
408            false => Type::Scalar(self.storage_type()),
409        }
410    }
411
412    pub fn line_size(&self) -> u32 {
413        match self {
414            Type::Scalar(_) => 1,
415            Type::Line(_, line_size) => *line_size,
416            Type::Semantic(_) => 0,
417        }
418    }
419
420    pub fn size(&self) -> usize {
421        match self {
422            Type::Scalar(ty) => ty.size(),
423            Type::Line(ty, line_size) => ty.size() * *line_size as usize,
424            Type::Semantic(_) => 0,
425        }
426    }
427
428    pub fn size_bits(&self) -> usize {
429        match self {
430            Type::Scalar(ty) => ty.size_bits(),
431            Type::Line(ty, line_size) => ty.size_bits() * *line_size as usize,
432            Type::Semantic(_) => 0,
433        }
434    }
435
436    pub fn is_atomic(&self) -> bool {
437        !self.is_semantic() && self.storage_type().is_atomic()
438    }
439
440    pub fn is_int(&self) -> bool {
441        !self.is_semantic() && self.storage_type().is_int()
442    }
443
444    pub fn is_signed_int(&self) -> bool {
445        !self.is_semantic() && self.storage_type().is_signed_int()
446    }
447
448    pub fn is_unsigned_int(&self) -> bool {
449        !self.is_semantic() && self.storage_type().is_unsigned_int()
450    }
451
452    pub fn is_float(&self) -> bool {
453        !self.is_semantic() && self.storage_type().is_float()
454    }
455
456    pub fn storage_type(&self) -> StorageType {
457        match self {
458            Type::Scalar(ty) | Type::Line(ty, _) => *ty,
459            Type::Semantic(_) => unimplemented!("Can't get storage for semantic type"),
460        }
461    }
462
463    pub fn is_semantic(&self) -> bool {
464        match self {
465            Type::Scalar(_) | Type::Line(_, _) => false,
466            Type::Semantic(_) => true,
467        }
468    }
469}
470
471impl Display for Type {
472    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
473        match self {
474            Type::Scalar(ty) => write!(f, "{ty}"),
475            Type::Line(ty, line_size) => write!(f, "line<{ty}, {line_size}>"),
476            Type::Semantic(ty) => write!(f, "{ty}"),
477        }
478    }
479}
480
481impl Display for StorageType {
482    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
483        match self {
484            StorageType::Scalar(ty) => write!(f, "{ty}"),
485            StorageType::Packed(ty, factor) => write!(f, "packed<{ty}, {factor}>"),
486            StorageType::Atomic(ty) => write!(f, "atomic<{ty}>"),
487        }
488    }
489}
490
491impl Display for ElemType {
492    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
493        match self {
494            Self::Float(kind) => match kind {
495                FloatKind::E2M1 => f.write_str("e2m1"),
496                FloatKind::E2M3 => f.write_str("e2m3"),
497                FloatKind::E3M2 => f.write_str("e3m2"),
498                FloatKind::E4M3 => f.write_str("e4m3"),
499                FloatKind::E5M2 => f.write_str("e5m2"),
500                FloatKind::UE8M0 => f.write_str("ue8m0"),
501                FloatKind::F16 => f.write_str("f16"),
502                FloatKind::BF16 => f.write_str("bf16"),
503                FloatKind::Flex32 => f.write_str("flex32"),
504                FloatKind::TF32 => f.write_str("tf32"),
505                FloatKind::F32 => f.write_str("f32"),
506                FloatKind::F64 => f.write_str("f64"),
507            },
508            Self::Int(kind) => match kind {
509                IntKind::I8 => f.write_str("i8"),
510                IntKind::I16 => f.write_str("i16"),
511                IntKind::I32 => f.write_str("i32"),
512                IntKind::I64 => f.write_str("i64"),
513            },
514            Self::UInt(kind) => match kind {
515                UIntKind::U8 => f.write_str("u8"),
516                UIntKind::U16 => f.write_str("u16"),
517                UIntKind::U32 => f.write_str("u32"),
518                UIntKind::U64 => f.write_str("u64"),
519            },
520            Self::Bool => f.write_str("bool"),
521        }
522    }
523}
524
525impl Display for SemanticType {
526    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
527        match self {
528            SemanticType::Barrier => f.write_str("barrier"),
529            SemanticType::BarrierToken => f.write_str("barrier_token"),
530            SemanticType::Pipeline => f.write_str("pipeline"),
531            SemanticType::TensorMap => f.write_str("tensor_map"),
532        }
533    }
534}
535
536impl From<bool> for Variable {
537    fn from(value: bool) -> Self {
538        Variable::constant(ConstantScalarValue::Bool(value))
539    }
540}
541
542impl From<i8> for Variable {
543    fn from(value: i8) -> Self {
544        Variable::constant(ConstantScalarValue::Int(value as i64, IntKind::I8))
545    }
546}
547
548impl From<i16> for Variable {
549    fn from(value: i16) -> Self {
550        Variable::constant(ConstantScalarValue::Int(value as i64, IntKind::I16))
551    }
552}
553
554impl From<i32> for Variable {
555    fn from(value: i32) -> Self {
556        Variable::constant(ConstantScalarValue::Int(value as i64, IntKind::I32))
557    }
558}
559
560impl From<i64> for Variable {
561    fn from(value: i64) -> Self {
562        Variable::constant(ConstantScalarValue::Int(value, IntKind::I64))
563    }
564}
565
566impl From<e2m1> for Variable {
567    fn from(_value: e2m1) -> Self {
568        unimplemented!("Can't currently construct minifloats")
569    }
570}
571
572impl From<e2m1x2> for Variable {
573    fn from(_value: e2m1x2) -> Self {
574        unimplemented!("Can't currently construct minifloats")
575    }
576}
577
578impl From<e2m3> for Variable {
579    fn from(_value: e2m3) -> Self {
580        unimplemented!("Can't currently construct minifloats")
581    }
582}
583
584impl From<e3m2> for Variable {
585    fn from(_value: e3m2) -> Self {
586        unimplemented!("Can't currently construct minifloats")
587    }
588}
589
590impl From<e4m3> for Variable {
591    fn from(_value: e4m3) -> Self {
592        unimplemented!("Can't currently construct minifloats")
593    }
594}
595
596impl From<e5m2> for Variable {
597    fn from(_value: e5m2) -> Self {
598        unimplemented!("Can't currently construct minifloats")
599    }
600}
601
602impl From<ue8m0> for Variable {
603    fn from(_value: ue8m0) -> Self {
604        unimplemented!("Can't currently construct minifloats")
605    }
606}
607
608impl From<half::f16> for Variable {
609    fn from(value: half::f16) -> Self {
610        Variable::constant(ConstantScalarValue::Float(value.to_f64(), FloatKind::F16))
611    }
612}
613
614impl From<half::bf16> for Variable {
615    fn from(value: half::bf16) -> Self {
616        Variable::constant(ConstantScalarValue::Float(value.to_f64(), FloatKind::BF16))
617    }
618}
619
620impl From<flex32> for Variable {
621    fn from(value: flex32) -> Self {
622        Variable::constant(ConstantScalarValue::Float(
623            value.to_f64(),
624            FloatKind::Flex32,
625        ))
626    }
627}
628
629impl From<tf32> for Variable {
630    fn from(value: tf32) -> Self {
631        Variable::constant(ConstantScalarValue::Float(value.to_f64(), FloatKind::TF32))
632    }
633}
634
635impl From<f32> for Variable {
636    fn from(value: f32) -> Self {
637        Variable::constant(ConstantScalarValue::Float(value as f64, FloatKind::F32))
638    }
639}
640
641impl From<f64> for Variable {
642    fn from(value: f64) -> Self {
643        Variable::constant(ConstantScalarValue::Float(value, FloatKind::F64))
644    }
645}
646
647impl From<u8> for Variable {
648    fn from(value: u8) -> Self {
649        Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U8))
650    }
651}
652
653impl From<u16> for Variable {
654    fn from(value: u16) -> Self {
655        Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U16))
656    }
657}
658
659impl From<u32> for Variable {
660    fn from(value: u32) -> Self {
661        Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U32))
662    }
663}
664
665impl From<u64> for Variable {
666    fn from(value: u64) -> Self {
667        Variable::constant(ConstantScalarValue::UInt(value, UIntKind::U64))
668    }
669}
670
671impl From<usize> for Variable {
672    fn from(value: usize) -> Self {
673        Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U32))
674    }
675}