1use super::{ConstantScalarValue, Variable, VariableKind};
2use crate::TypeHash;
3use core::fmt::Display;
4use cubecl_common::{
5 e2m1, e2m1x2, e2m3, e3m2, e4m3, e5m2, flex32,
6 quant::scheme::{QuantParam, QuantValue},
7 tf32, ue8m0,
8};
9
10#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
12#[allow(missing_docs)]
13pub enum FloatKind {
14 E2M1,
16 E2M3,
19 E3M2,
22 E4M3,
24 E5M2,
26 UE8M0,
28 F16,
29 BF16,
30 Flex32,
31 F32,
32 TF32,
33 F64,
34}
35
36#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
37#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
38#[allow(missing_docs)]
39pub enum IntKind {
40 I8,
41 I16,
42 I32,
43 I64,
44}
45
46#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
47#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
48#[allow(missing_docs)]
49pub enum UIntKind {
50 U8,
51 U16,
52 U32,
53 U64,
54}
55
56#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
58#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
59#[allow(missing_docs)]
60pub enum ElemType {
61 Float(FloatKind),
62 Int(IntKind),
63 UInt(UIntKind),
64 Bool,
65}
66
67#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
68#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
69pub enum SemanticType {
70 Barrier,
71 BarrierToken,
72 Pipeline,
73 TensorMap,
74}
75
76#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
78#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
79pub enum StorageType {
80 Scalar(ElemType),
82 Packed(ElemType, u32),
84 Atomic(ElemType),
86}
87
88impl ElemType {
89 pub fn from_quant_param(quant_param: QuantParam) -> Self {
91 match quant_param {
92 QuantParam::F32 => Self::Float(FloatKind::F32),
93 QuantParam::F16 => Self::Float(FloatKind::F16),
94 QuantParam::BF16 => Self::Float(FloatKind::BF16),
95 QuantParam::UE8M0 => Self::Float(FloatKind::UE8M0),
96 QuantParam::UE4M3 => Self::Float(FloatKind::UE8M0),
97 }
98 }
99
100 pub fn from_quant_value(quant_value: QuantValue) -> Self {
102 match quant_value {
103 QuantValue::E5M2 => Self::Float(FloatKind::E5M2),
104 QuantValue::E4M3 => Self::Float(FloatKind::E4M3),
105 QuantValue::E2M1 => Self::Float(FloatKind::E2M1),
106 QuantValue::Q8F | QuantValue::Q8S => Self::Int(IntKind::I8),
107 other => panic!("Unsupported quant value {other:?}"),
108 }
109 }
110 pub fn constant_from_f64(&self, val: f64) -> Variable {
114 Variable::constant(match self {
115 ElemType::Float(kind) => ConstantScalarValue::Float(val, *kind),
116 ElemType::Int(kind) => ConstantScalarValue::Int(val as i64, *kind),
117 ElemType::UInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
118 ElemType::Bool => ConstantScalarValue::Bool(val > 0.0),
119 })
120 }
121 pub fn constant_from_i64(&self, val: i64) -> Variable {
125 Variable::constant(match self {
126 ElemType::Float(kind) => ConstantScalarValue::Float(val as f64, *kind),
127 ElemType::Int(kind) => ConstantScalarValue::Int(val, *kind),
128 ElemType::UInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
129 ElemType::Bool => ConstantScalarValue::Bool(val > 0),
130 })
131 }
132 pub fn constant_from_u64(&self, val: u64) -> Variable {
136 Variable::constant(match self {
137 ElemType::Float(kind) => ConstantScalarValue::Float(val as f64, *kind),
138 ElemType::Int(kind) => ConstantScalarValue::Int(val as i64, *kind),
139 ElemType::UInt(kind) => ConstantScalarValue::UInt(val, *kind),
140 ElemType::Bool => ConstantScalarValue::Bool(val > 0),
141 })
142 }
143 pub fn constant_from_bool(&self, val: bool) -> Variable {
147 Variable::constant(match self {
148 ElemType::Float(kind) => ConstantScalarValue::Float(val as u32 as f64, *kind),
149 ElemType::Int(kind) => ConstantScalarValue::Int(val as i64, *kind),
150 ElemType::UInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
151 ElemType::Bool => ConstantScalarValue::Bool(val),
152 })
153 }
154
155 pub fn from_constant(&self, constant: Variable) -> Variable {
157 let value = match constant.kind {
158 VariableKind::ConstantScalar(value) => value,
159 _ => return constant,
160 };
161
162 match value {
163 ConstantScalarValue::Int(val, _) => self.constant_from_i64(val),
164 ConstantScalarValue::Float(val, _) => self.constant_from_f64(val),
165 ConstantScalarValue::UInt(val, _) => self.constant_from_u64(val),
166 ConstantScalarValue::Bool(val) => self.constant_from_bool(val),
167 }
168 }
169 pub const fn size(&self) -> usize {
171 match self {
172 ElemType::Float(kind) => match kind {
173 FloatKind::E2M1
174 | FloatKind::E2M3
175 | FloatKind::E3M2
176 | FloatKind::E4M3
177 | FloatKind::E5M2
178 | FloatKind::UE8M0 => core::mem::size_of::<u8>(),
179 FloatKind::F16 => core::mem::size_of::<half::f16>(),
180 FloatKind::BF16 => core::mem::size_of::<half::bf16>(),
181 FloatKind::F32 => core::mem::size_of::<f32>(),
182 FloatKind::F64 => core::mem::size_of::<f64>(),
183 FloatKind::Flex32 => core::mem::size_of::<f32>(),
184 FloatKind::TF32 => core::mem::size_of::<f32>(),
185 },
186 ElemType::Int(kind) => match kind {
187 IntKind::I8 => core::mem::size_of::<i8>(),
188 IntKind::I16 => core::mem::size_of::<i16>(),
189 IntKind::I32 => core::mem::size_of::<i32>(),
190 IntKind::I64 => core::mem::size_of::<i64>(),
191 },
192 ElemType::UInt(kind) => match kind {
193 UIntKind::U8 => core::mem::size_of::<u8>(),
194 UIntKind::U16 => core::mem::size_of::<u16>(),
195 UIntKind::U32 => core::mem::size_of::<u32>(),
196 UIntKind::U64 => core::mem::size_of::<u64>(),
197 },
198 ElemType::Bool => core::mem::size_of::<bool>(),
199 }
200 }
201
202 pub const fn size_bits(&self) -> usize {
204 match self {
205 ElemType::Float(kind) => match kind {
206 FloatKind::E2M3
207 | FloatKind::E3M2
208 | FloatKind::E4M3
209 | FloatKind::E5M2
210 | FloatKind::UE8M0
211 | FloatKind::F16
212 | FloatKind::BF16
213 | FloatKind::F32
214 | FloatKind::F64
215 | FloatKind::Flex32
216 | FloatKind::TF32 => self.size() * 8,
217 FloatKind::E2M1 => 4,
218 },
219 ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool => self.size() * 8,
220 }
221 }
222
223 pub const fn min_line_size(&self) -> u8 {
224 match self {
225 ElemType::Float(FloatKind::E2M1) => 2,
226 _ => 1,
227 }
228 }
229
230 pub fn is_int(&self) -> bool {
231 matches!(self, ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool)
232 }
233
234 pub fn is_signed_int(&self) -> bool {
235 matches!(self, ElemType::Int(_))
236 }
237
238 pub fn is_unsigned_int(&self) -> bool {
239 matches!(self, ElemType::UInt(_) | ElemType::Bool)
240 }
241
242 pub fn is_float(&self) -> bool {
243 matches!(self, ElemType::Float(_))
244 }
245
246 pub fn max_variable(&self) -> Variable {
247 let value = match self {
248 ElemType::Float(kind) => match kind {
249 FloatKind::E2M1 => ConstantScalarValue::Float(e2m1::MAX, FloatKind::E2M1),
250 FloatKind::E2M3 => ConstantScalarValue::Float(e2m3::MAX, FloatKind::E2M3),
251 FloatKind::E3M2 => ConstantScalarValue::Float(e3m2::MAX, FloatKind::E3M2),
252 FloatKind::E4M3 => ConstantScalarValue::Float(e4m3::MAX, FloatKind::E4M3),
253 FloatKind::E5M2 => ConstantScalarValue::Float(e5m2::MAX, FloatKind::E5M2),
254 FloatKind::UE8M0 => ConstantScalarValue::Float(ue8m0::MAX, FloatKind::UE8M0),
255 FloatKind::F16 => {
256 ConstantScalarValue::Float(half::f16::MAX.to_f64(), FloatKind::F16)
257 }
258 FloatKind::BF16 => {
259 ConstantScalarValue::Float(half::bf16::MAX.to_f64(), FloatKind::BF16)
260 }
261 FloatKind::Flex32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::Flex32),
262 FloatKind::F32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::F32),
263 FloatKind::TF32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::TF32),
264 FloatKind::F64 => ConstantScalarValue::Float(f64::MAX, FloatKind::F64),
265 },
266 ElemType::Int(kind) => match kind {
267 IntKind::I8 => ConstantScalarValue::Int(i8::MAX.into(), IntKind::I8),
268 IntKind::I16 => ConstantScalarValue::Int(i16::MAX.into(), IntKind::I16),
269 IntKind::I32 => ConstantScalarValue::Int(i32::MAX.into(), IntKind::I32),
270 IntKind::I64 => ConstantScalarValue::Int(i64::MAX, IntKind::I64),
271 },
272 ElemType::UInt(kind) => match kind {
273 UIntKind::U8 => ConstantScalarValue::UInt(u8::MAX.into(), UIntKind::U8),
274 UIntKind::U16 => ConstantScalarValue::UInt(u16::MAX.into(), UIntKind::U16),
275 UIntKind::U32 => ConstantScalarValue::UInt(u32::MAX.into(), UIntKind::U32),
276 UIntKind::U64 => ConstantScalarValue::UInt(u64::MAX, UIntKind::U64),
277 },
278 ElemType::Bool => ConstantScalarValue::Bool(true),
279 };
280
281 Variable::new(VariableKind::ConstantScalar(value), Type::scalar(*self))
282 }
283
284 pub fn min_variable(&self) -> Variable {
285 let value = match self {
286 ElemType::Float(kind) => match kind {
287 FloatKind::E2M1 => ConstantScalarValue::Float(e2m1::MIN, FloatKind::E2M1),
288 FloatKind::E2M3 => ConstantScalarValue::Float(e2m3::MIN, FloatKind::E2M3),
289 FloatKind::E3M2 => ConstantScalarValue::Float(e3m2::MIN, FloatKind::E3M2),
290 FloatKind::E4M3 => ConstantScalarValue::Float(e4m3::MIN, FloatKind::E4M3),
291 FloatKind::E5M2 => ConstantScalarValue::Float(e5m2::MIN, FloatKind::E5M2),
292 FloatKind::UE8M0 => ConstantScalarValue::Float(ue8m0::MIN, FloatKind::UE8M0),
293 FloatKind::F16 => {
294 ConstantScalarValue::Float(half::f16::MIN.to_f64(), FloatKind::F16)
295 }
296 FloatKind::BF16 => {
297 ConstantScalarValue::Float(half::bf16::MIN.to_f64(), FloatKind::BF16)
298 }
299 FloatKind::Flex32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::Flex32),
300 FloatKind::F32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::F32),
301 FloatKind::TF32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::TF32),
302 FloatKind::F64 => ConstantScalarValue::Float(f64::MIN, FloatKind::F64),
303 },
304 ElemType::Int(kind) => match kind {
305 IntKind::I8 => ConstantScalarValue::Int(i8::MIN.into(), IntKind::I8),
306 IntKind::I16 => ConstantScalarValue::Int(i16::MIN.into(), IntKind::I16),
307 IntKind::I32 => ConstantScalarValue::Int(i32::MIN.into(), IntKind::I32),
308 IntKind::I64 => ConstantScalarValue::Int(i64::MIN, IntKind::I64),
309 },
310 ElemType::UInt(kind) => match kind {
311 UIntKind::U8 => ConstantScalarValue::UInt(u8::MIN.into(), UIntKind::U8),
312 UIntKind::U16 => ConstantScalarValue::UInt(u16::MIN.into(), UIntKind::U16),
313 UIntKind::U32 => ConstantScalarValue::UInt(u32::MIN.into(), UIntKind::U32),
314 UIntKind::U64 => ConstantScalarValue::UInt(u64::MIN, UIntKind::U64),
315 },
316 ElemType::Bool => ConstantScalarValue::Bool(false),
317 };
318
319 Variable::new(VariableKind::ConstantScalar(value), Type::scalar(*self))
320 }
321}
322
323impl StorageType {
324 pub fn elem_type(&self) -> ElemType {
325 match self {
326 StorageType::Scalar(ty) | StorageType::Packed(ty, _) | StorageType::Atomic(ty) => *ty,
327 }
328 }
329
330 pub fn packing_factor(&self) -> u32 {
331 match self {
332 StorageType::Packed(_, factor) => *factor,
333 _ => 1,
334 }
335 }
336
337 pub fn is_atomic(&self) -> bool {
338 matches!(self, StorageType::Atomic(_))
339 }
340
341 pub fn size(&self) -> usize {
342 self.size_bits().div_ceil(8)
343 }
344
345 pub fn size_bits(&self) -> usize {
346 match self {
347 StorageType::Packed(ty, factor) => ty.size_bits() * *factor as usize,
348 StorageType::Scalar(ty) | StorageType::Atomic(ty) => ty.size_bits(),
349 }
350 }
351
352 pub fn from_constant(&self, constant: Variable) -> Variable {
354 self.elem_type().from_constant(constant)
355 }
356
357 pub fn is_int(&self) -> bool {
358 self.elem_type().is_int()
359 }
360
361 pub fn is_signed_int(&self) -> bool {
362 self.elem_type().is_signed_int()
363 }
364
365 pub fn is_unsigned_int(&self) -> bool {
366 self.elem_type().is_unsigned_int()
367 }
368
369 pub fn is_float(&self) -> bool {
370 self.elem_type().is_float()
371 }
372}
373
374impl From<ElemType> for Type {
375 fn from(val: ElemType) -> Self {
376 Type::scalar(val)
377 }
378}
379
380impl From<ElemType> for StorageType {
381 fn from(val: ElemType) -> Self {
382 StorageType::Scalar(val)
383 }
384}
385
386impl From<StorageType> for Type {
387 fn from(val: StorageType) -> Self {
388 Type::new(val)
389 }
390}
391
392impl From<SemanticType> for Type {
393 fn from(val: SemanticType) -> Self {
394 Type::semantic(val)
395 }
396}
397
398#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
399#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
400pub enum Type {
401 Scalar(StorageType),
403 Line(StorageType, u32),
405 Semantic(SemanticType),
407}
408
409pub type LineSize = u32;
410
411impl Type {
412 pub fn elem_type(&self) -> ElemType {
414 self.storage_type().elem_type()
415 }
416
417 pub fn new(storage: StorageType) -> Self {
419 Type::Scalar(storage)
420 }
421
422 pub fn scalar(elem: ElemType) -> Self {
423 Self::new(StorageType::Scalar(elem))
424 }
425
426 pub fn semantic(ty: SemanticType) -> Self {
427 Self::Semantic(ty)
428 }
429
430 pub fn line(self, line_size: LineSize) -> Type {
431 match line_size > 1 {
432 true => Type::Line(self.storage_type(), line_size),
433 false => Type::Scalar(self.storage_type()),
434 }
435 }
436
437 pub fn line_size(&self) -> u32 {
438 match self {
439 Type::Scalar(_) => 1,
440 Type::Line(_, line_size) => *line_size,
441 Type::Semantic(_) => 0,
442 }
443 }
444
445 pub fn size(&self) -> usize {
446 match self {
447 Type::Scalar(ty) => ty.size(),
448 Type::Line(ty, line_size) => ty.size() * *line_size as usize,
449 Type::Semantic(_) => 0,
450 }
451 }
452
453 pub fn size_bits(&self) -> usize {
454 match self {
455 Type::Scalar(ty) => ty.size_bits(),
456 Type::Line(ty, line_size) => ty.size_bits() * *line_size as usize,
457 Type::Semantic(_) => 0,
458 }
459 }
460
461 pub fn is_atomic(&self) -> bool {
462 !self.is_semantic() && self.storage_type().is_atomic()
463 }
464
465 pub fn is_int(&self) -> bool {
466 !self.is_semantic() && self.storage_type().is_int()
467 }
468
469 pub fn is_signed_int(&self) -> bool {
470 !self.is_semantic() && self.storage_type().is_signed_int()
471 }
472
473 pub fn is_unsigned_int(&self) -> bool {
474 !self.is_semantic() && self.storage_type().is_unsigned_int()
475 }
476
477 pub fn is_float(&self) -> bool {
478 !self.is_semantic() && self.storage_type().is_float()
479 }
480
481 pub fn storage_type(&self) -> StorageType {
482 match self {
483 Type::Scalar(ty) | Type::Line(ty, _) => *ty,
484 Type::Semantic(_) => unimplemented!("Can't get storage for semantic type"),
485 }
486 }
487
488 pub fn is_semantic(&self) -> bool {
489 match self {
490 Type::Scalar(_) | Type::Line(_, _) => false,
491 Type::Semantic(_) => true,
492 }
493 }
494}
495
496impl Display for Type {
497 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
498 match self {
499 Type::Scalar(ty) => write!(f, "{ty}"),
500 Type::Line(ty, line_size) => write!(f, "line<{ty}, {line_size}>"),
501 Type::Semantic(ty) => write!(f, "{ty}"),
502 }
503 }
504}
505
506impl Display for StorageType {
507 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
508 match self {
509 StorageType::Scalar(ty) => write!(f, "{ty}"),
510 StorageType::Packed(ty, factor) => write!(f, "packed<{ty}, {factor}>"),
511 StorageType::Atomic(ty) => write!(f, "atomic<{ty}>"),
512 }
513 }
514}
515
516impl Display for ElemType {
517 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
518 match self {
519 Self::Float(kind) => match kind {
520 FloatKind::E2M1 => f.write_str("e2m1"),
521 FloatKind::E2M3 => f.write_str("e2m3"),
522 FloatKind::E3M2 => f.write_str("e3m2"),
523 FloatKind::E4M3 => f.write_str("e4m3"),
524 FloatKind::E5M2 => f.write_str("e5m2"),
525 FloatKind::UE8M0 => f.write_str("ue8m0"),
526 FloatKind::F16 => f.write_str("f16"),
527 FloatKind::BF16 => f.write_str("bf16"),
528 FloatKind::Flex32 => f.write_str("flex32"),
529 FloatKind::TF32 => f.write_str("tf32"),
530 FloatKind::F32 => f.write_str("f32"),
531 FloatKind::F64 => f.write_str("f64"),
532 },
533 Self::Int(kind) => match kind {
534 IntKind::I8 => f.write_str("i8"),
535 IntKind::I16 => f.write_str("i16"),
536 IntKind::I32 => f.write_str("i32"),
537 IntKind::I64 => f.write_str("i64"),
538 },
539 Self::UInt(kind) => match kind {
540 UIntKind::U8 => f.write_str("u8"),
541 UIntKind::U16 => f.write_str("u16"),
542 UIntKind::U32 => f.write_str("u32"),
543 UIntKind::U64 => f.write_str("u64"),
544 },
545 Self::Bool => f.write_str("bool"),
546 }
547 }
548}
549
550impl Display for SemanticType {
551 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
552 match self {
553 SemanticType::Barrier => f.write_str("barrier"),
554 SemanticType::BarrierToken => f.write_str("barrier_token"),
555 SemanticType::Pipeline => f.write_str("pipeline"),
556 SemanticType::TensorMap => f.write_str("tensor_map"),
557 }
558 }
559}
560
561impl From<bool> for Variable {
562 fn from(value: bool) -> Self {
563 Variable::constant(ConstantScalarValue::Bool(value))
564 }
565}
566
567impl From<i8> for Variable {
568 fn from(value: i8) -> Self {
569 Variable::constant(ConstantScalarValue::Int(value as i64, IntKind::I8))
570 }
571}
572
573impl From<i16> for Variable {
574 fn from(value: i16) -> Self {
575 Variable::constant(ConstantScalarValue::Int(value as i64, IntKind::I16))
576 }
577}
578
579impl From<i32> for Variable {
580 fn from(value: i32) -> Self {
581 Variable::constant(ConstantScalarValue::Int(value as i64, IntKind::I32))
582 }
583}
584
585impl From<i64> for Variable {
586 fn from(value: i64) -> Self {
587 Variable::constant(ConstantScalarValue::Int(value, IntKind::I64))
588 }
589}
590
591impl From<e2m1> for Variable {
592 fn from(_value: e2m1) -> Self {
593 unimplemented!("Can't currently construct minifloats")
594 }
595}
596
597impl From<e2m1x2> for Variable {
598 fn from(_value: e2m1x2) -> Self {
599 unimplemented!("Can't currently construct minifloats")
600 }
601}
602
603impl From<e2m3> for Variable {
604 fn from(_value: e2m3) -> Self {
605 unimplemented!("Can't currently construct minifloats")
606 }
607}
608
609impl From<e3m2> for Variable {
610 fn from(_value: e3m2) -> Self {
611 unimplemented!("Can't currently construct minifloats")
612 }
613}
614
615impl From<e4m3> for Variable {
616 fn from(_value: e4m3) -> Self {
617 unimplemented!("Can't currently construct minifloats")
618 }
619}
620
621impl From<e5m2> for Variable {
622 fn from(_value: e5m2) -> Self {
623 unimplemented!("Can't currently construct minifloats")
624 }
625}
626
627impl From<ue8m0> for Variable {
628 fn from(_value: ue8m0) -> Self {
629 unimplemented!("Can't currently construct minifloats")
630 }
631}
632
633impl From<half::f16> for Variable {
634 fn from(value: half::f16) -> Self {
635 Variable::constant(ConstantScalarValue::Float(value.to_f64(), FloatKind::F16))
636 }
637}
638
639impl From<half::bf16> for Variable {
640 fn from(value: half::bf16) -> Self {
641 Variable::constant(ConstantScalarValue::Float(value.to_f64(), FloatKind::BF16))
642 }
643}
644
645impl From<flex32> for Variable {
646 fn from(value: flex32) -> Self {
647 Variable::constant(ConstantScalarValue::Float(
648 value.to_f64(),
649 FloatKind::Flex32,
650 ))
651 }
652}
653
654impl From<tf32> for Variable {
655 fn from(value: tf32) -> Self {
656 Variable::constant(ConstantScalarValue::Float(value.to_f64(), FloatKind::TF32))
657 }
658}
659
660impl From<f32> for Variable {
661 fn from(value: f32) -> Self {
662 Variable::constant(ConstantScalarValue::Float(value as f64, FloatKind::F32))
663 }
664}
665
666impl From<f64> for Variable {
667 fn from(value: f64) -> Self {
668 Variable::constant(ConstantScalarValue::Float(value, FloatKind::F64))
669 }
670}
671
672impl From<u8> for Variable {
673 fn from(value: u8) -> Self {
674 Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U8))
675 }
676}
677
678impl From<u16> for Variable {
679 fn from(value: u16) -> Self {
680 Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U16))
681 }
682}
683
684impl From<u32> for Variable {
685 fn from(value: u32) -> Self {
686 Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U32))
687 }
688}
689
690impl From<u64> for Variable {
691 fn from(value: u64) -> Self {
692 Variable::constant(ConstantScalarValue::UInt(value, UIntKind::U64))
693 }
694}
695
696impl From<usize> for Variable {
697 fn from(value: usize) -> Self {
698 Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U32))
699 }
700}