1use super::{ConstantValue, Variable, VariableKind};
2use crate::{BarrierLevel, TypeHash};
3use core::fmt::Display;
4use cubecl_common::{
5 e2m1, e2m1x2, e2m3, e3m2, e4m3, e5m2, flex32,
6 quant::scheme::{QuantParam, QuantValue},
7 tf32, ue8m0,
8};
9use derive_more::From;
10use half::{bf16, f16};
11
12#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
14#[allow(missing_docs)]
15pub enum FloatKind {
16 E2M1,
18 E2M3,
21 E3M2,
24 E4M3,
26 E5M2,
28 UE8M0,
30 F16,
31 BF16,
32 Flex32,
33 F32,
34 TF32,
35 F64,
36}
37
38#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
39#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
40#[allow(missing_docs)]
41pub enum IntKind {
42 I8,
43 I16,
44 I32,
45 I64,
46}
47
48#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
49#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
50#[allow(missing_docs)]
51pub enum UIntKind {
52 U8,
53 U16,
54 U32,
55 U64,
56}
57
58#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
60#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord, From)]
61#[allow(missing_docs)]
62pub enum ElemType {
63 Float(FloatKind),
64 Int(IntKind),
65 UInt(UIntKind),
66 Bool,
67}
68
69#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
70#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
71pub enum OpaqueType {
72 Barrier(BarrierLevel),
73}
74
75#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
76#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
77pub enum SemanticType {
78 BarrierToken,
79 Pipeline,
80 TensorMap,
81}
82
83#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
85#[derive(Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
86pub enum StorageType {
87 Scalar(ElemType),
89 Packed(ElemType, usize),
91 Atomic(ElemType),
93 Opaque(OpaqueType),
96}
97
98impl core::fmt::Debug for StorageType {
99 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
100 struct Dummy<'a>(&'a StorageType);
103
104 impl<'a> core::fmt::Debug for Dummy<'a> {
105 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
106 match self.0 {
107 StorageType::Scalar(f0) => f.debug_tuple("Scalar").field(&f0).finish(),
108 StorageType::Packed(f0, f1) => {
109 f.debug_tuple("Packed").field(&f0).field(&f1).finish()
110 }
111 StorageType::Atomic(f0) => f.debug_tuple("Atomic").field(&f0).finish(),
112 StorageType::Opaque(f0) => f.debug_tuple("Opaque").field(&f0).finish(),
113 }
114 }
115 }
116
117 write!(f, "{:?}", Dummy(self))
118 }
119}
120
121impl ElemType {
122 pub fn from_quant_param(quant_param: QuantParam) -> Self {
124 match quant_param {
125 QuantParam::F32 => Self::Float(FloatKind::F32),
126 QuantParam::F16 => Self::Float(FloatKind::F16),
127 QuantParam::BF16 => Self::Float(FloatKind::BF16),
128 QuantParam::UE8M0 => Self::Float(FloatKind::UE8M0),
129 QuantParam::UE4M3 => Self::Float(FloatKind::UE8M0),
130 }
131 }
132
133 pub fn from_quant_value(quant_value: QuantValue) -> Self {
135 match quant_value {
136 QuantValue::E5M2 => Self::Float(FloatKind::E5M2),
137 QuantValue::E4M3 => Self::Float(FloatKind::E4M3),
138 QuantValue::E2M1 => Self::Float(FloatKind::E2M1),
139 QuantValue::Q8F | QuantValue::Q8S => Self::Int(IntKind::I8),
140 other => panic!("Unsupported quant value {other:?}"),
141 }
142 }
143
144 pub fn constant(&self, val: ConstantValue) -> Variable {
148 Variable::constant(val, Type::scalar(*self))
149 }
150
151 pub const fn size(&self) -> usize {
153 match self {
154 ElemType::Float(kind) => match kind {
155 FloatKind::E2M1
156 | FloatKind::E2M3
157 | FloatKind::E3M2
158 | FloatKind::E4M3
159 | FloatKind::E5M2
160 | FloatKind::UE8M0 => core::mem::size_of::<u8>(),
161 FloatKind::F16 => core::mem::size_of::<half::f16>(),
162 FloatKind::BF16 => core::mem::size_of::<half::bf16>(),
163 FloatKind::F32 => core::mem::size_of::<f32>(),
164 FloatKind::F64 => core::mem::size_of::<f64>(),
165 FloatKind::Flex32 => core::mem::size_of::<f32>(),
166 FloatKind::TF32 => core::mem::size_of::<f32>(),
167 },
168 ElemType::Int(kind) => match kind {
169 IntKind::I8 => core::mem::size_of::<i8>(),
170 IntKind::I16 => core::mem::size_of::<i16>(),
171 IntKind::I32 => core::mem::size_of::<i32>(),
172 IntKind::I64 => core::mem::size_of::<i64>(),
173 },
174 ElemType::UInt(kind) => match kind {
175 UIntKind::U8 => core::mem::size_of::<u8>(),
176 UIntKind::U16 => core::mem::size_of::<u16>(),
177 UIntKind::U32 => core::mem::size_of::<u32>(),
178 UIntKind::U64 => core::mem::size_of::<u64>(),
179 },
180 ElemType::Bool => core::mem::size_of::<bool>(),
181 }
182 }
183
184 pub const fn size_bits(&self) -> usize {
186 match self {
187 ElemType::Float(kind) => match kind {
188 FloatKind::E2M3
189 | FloatKind::E3M2
190 | FloatKind::E4M3
191 | FloatKind::E5M2
192 | FloatKind::UE8M0
193 | FloatKind::F16
194 | FloatKind::BF16
195 | FloatKind::F32
196 | FloatKind::F64
197 | FloatKind::Flex32
198 | FloatKind::TF32 => self.size() * 8,
199 FloatKind::E2M1 => 4,
200 },
201 ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool => self.size() * 8,
202 }
203 }
204
205 pub const fn min_line_size(&self) -> u8 {
206 match self {
207 ElemType::Float(FloatKind::E2M1) => 2,
208 _ => 1,
209 }
210 }
211
212 pub fn is_int(&self) -> bool {
213 matches!(self, ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool)
214 }
215
216 pub fn is_signed_int(&self) -> bool {
217 matches!(self, ElemType::Int(_))
218 }
219
220 pub fn is_unsigned_int(&self) -> bool {
221 matches!(self, ElemType::UInt(_) | ElemType::Bool)
222 }
223
224 pub fn is_float(&self) -> bool {
225 matches!(self, ElemType::Float(_))
226 }
227
228 pub fn is_bool(&self) -> bool {
229 matches!(self, ElemType::Bool)
230 }
231
232 pub fn as_float(&self) -> Option<FloatKind> {
233 match self {
234 ElemType::Float(kind) => Some(*kind),
235 _ => None,
236 }
237 }
238
239 pub fn max_variable(&self) -> Variable {
240 let value = match self {
241 ElemType::Float(kind) => match kind {
242 FloatKind::E2M1 => e2m1::MAX,
243 FloatKind::E2M3 => e2m3::MAX,
244 FloatKind::E3M2 => e3m2::MAX,
245 FloatKind::E4M3 => e4m3::MAX,
246 FloatKind::E5M2 => e5m2::MAX,
247 FloatKind::UE8M0 => ue8m0::MAX,
248 FloatKind::F16 => half::f16::MAX.to_f64(),
249 FloatKind::BF16 => half::bf16::MAX.to_f64(),
250 FloatKind::Flex32 | FloatKind::TF32 | FloatKind::F32 => f32::MAX as f64,
251 FloatKind::F64 => f64::MAX,
252 }
253 .into(),
254 ElemType::Int(kind) => match kind {
255 IntKind::I8 => i8::MAX as i64,
256 IntKind::I16 => i16::MAX as i64,
257 IntKind::I32 => i32::MAX as i64,
258 IntKind::I64 => i64::MAX,
259 }
260 .into(),
261 ElemType::UInt(kind) => match kind {
262 UIntKind::U8 => u8::MAX as u64,
263 UIntKind::U16 => u16::MAX as u64,
264 UIntKind::U32 => u32::MAX as u64,
265 UIntKind::U64 => u64::MAX,
266 }
267 .into(),
268 ElemType::Bool => true.into(),
269 };
270
271 Variable::new(VariableKind::Constant(value), Type::scalar(*self))
272 }
273
274 pub fn min_variable(&self) -> Variable {
275 let value = match self {
276 ElemType::Float(kind) => match kind {
277 FloatKind::E2M1 => e2m1::MIN,
278 FloatKind::E2M3 => e2m3::MIN,
279 FloatKind::E3M2 => e3m2::MIN,
280 FloatKind::E4M3 => e4m3::MIN,
281 FloatKind::E5M2 => e5m2::MIN,
282 FloatKind::UE8M0 => ue8m0::MIN,
283 FloatKind::F16 => half::f16::MIN.to_f64(),
284 FloatKind::BF16 => half::bf16::MIN.to_f64(),
285 FloatKind::Flex32 | FloatKind::TF32 | FloatKind::F32 => f32::MIN as f64,
286 FloatKind::F64 => f64::MIN,
287 }
288 .into(),
289 ElemType::Int(kind) => match kind {
290 IntKind::I8 => i8::MIN as i64,
291 IntKind::I16 => i16::MIN as i64,
292 IntKind::I32 => i32::MIN as i64,
293 IntKind::I64 => i64::MIN,
294 }
295 .into(),
296 ElemType::UInt(kind) => match kind {
297 UIntKind::U8 => u8::MIN as u64,
298 UIntKind::U16 => u16::MIN as u64,
299 UIntKind::U32 => u32::MIN as u64,
300 UIntKind::U64 => u64::MIN,
301 }
302 .into(),
303 ElemType::Bool => false.into(),
304 };
305
306 Variable::new(VariableKind::Constant(value), Type::scalar(*self))
307 }
308
309 pub fn epsilon(&self) -> f64 {
310 match self {
311 ElemType::Float(kind) => match kind {
312 FloatKind::E2M1 => 0.5 * (e2m1::MAX - e2m1::MIN),
313 FloatKind::E2M3 => 0.5 * (e2m3::MAX - e2m3::MIN),
314 FloatKind::E3M2 => 0.5 * (e3m2::MAX - e3m2::MIN),
315 FloatKind::E4M3 => 0.5 * (e4m3::MAX - e4m3::MIN),
316 FloatKind::E5M2 => 0.5 * (e5m2::MAX - e5m2::MIN),
317 FloatKind::UE8M0 => 0.5 * (ue8m0::MAX - ue8m0::MIN),
318 FloatKind::F16 => half::f16::EPSILON.to_f64(),
319 FloatKind::BF16 => 0.0078125, FloatKind::Flex32 | FloatKind::F32 | FloatKind::TF32 => f32::EPSILON.into(),
321 FloatKind::F64 => f64::EPSILON,
322 },
323 ElemType::Int(_) | ElemType::UInt(_) => 1.0, ElemType::Bool => 1.0,
325 }
326 }
327}
328
329impl OpaqueType {
330 pub const fn size(&self) -> usize {
332 match self {
333 OpaqueType::Barrier(_) => 8,
334 }
335 }
336
337 pub const fn size_bits(&self) -> usize {
339 match self {
340 OpaqueType::Barrier(_) => 64,
341 }
342 }
343}
344
345impl StorageType {
346 pub fn elem_type(&self) -> ElemType {
347 match self {
348 StorageType::Scalar(ty) | StorageType::Packed(ty, _) | StorageType::Atomic(ty) => *ty,
349 StorageType::Opaque(_) => unimplemented!("Can't get elem type for opaque type"),
350 }
351 }
352
353 pub fn packing_factor(&self) -> usize {
354 match self {
355 StorageType::Packed(_, factor) => *factor,
356 _ => 1,
357 }
358 }
359
360 pub fn is_atomic(&self) -> bool {
361 matches!(self, StorageType::Atomic(_))
362 }
363
364 pub fn size(&self) -> usize {
365 self.size_bits().div_ceil(8)
366 }
367
368 pub fn size_bits(&self) -> usize {
369 match self {
370 StorageType::Packed(ty, factor) => ty.size_bits() * *factor,
371 StorageType::Scalar(ty) | StorageType::Atomic(ty) => ty.size_bits(),
372 StorageType::Opaque(ty) => ty.size_bits(),
373 }
374 }
375
376 pub fn is_int(&self) -> bool {
377 self.elem_type().is_int()
378 }
379
380 pub fn is_signed_int(&self) -> bool {
381 self.elem_type().is_signed_int()
382 }
383
384 pub fn is_unsigned_int(&self) -> bool {
385 self.elem_type().is_unsigned_int()
386 }
387
388 pub fn is_float(&self) -> bool {
389 self.elem_type().is_float()
390 }
391
392 pub fn is_bool(&self) -> bool {
393 self.elem_type().is_bool()
394 }
395
396 pub fn epsilon(&self) -> f64 {
398 match self {
399 StorageType::Scalar(ty) | StorageType::Atomic(ty) => ty.epsilon(),
400 StorageType::Packed(ty, factor) => {
401 ty.epsilon() * (*factor as f64)
403 }
404 StorageType::Opaque(_) => panic!("Opaque type does not have an epsilon"),
405 }
406 }
407
408 pub fn constant(&self, value: ConstantValue) -> Variable {
409 Variable::constant(value, Type::new(*self))
410 }
411}
412
413macro_rules! storage_from_elem {
414 ($($ty: ty),*) => {
415 $(impl From<$ty> for StorageType {
416 fn from(value: $ty) -> Self {
417 StorageType::Scalar(value.into())
418 }
419 })*
420 };
421}
422
423storage_from_elem!(FloatKind, IntKind, UIntKind, ElemType);
424
425impl From<OpaqueType> for StorageType {
426 fn from(val: OpaqueType) -> Self {
427 StorageType::Opaque(val)
428 }
429}
430
431impl<T: Into<StorageType>> From<T> for Type {
432 fn from(val: T) -> Self {
433 Type::new(val.into())
434 }
435}
436
437impl From<SemanticType> for Type {
438 fn from(val: SemanticType) -> Self {
439 Type::semantic(val)
440 }
441}
442
443#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
444#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
445pub enum Type {
446 Scalar(StorageType),
448 Line(StorageType, LineSize),
450 Semantic(SemanticType),
452}
453
454pub type LineSize = usize;
455
456impl Type {
457 pub fn elem_type(&self) -> ElemType {
459 self.storage_type().elem_type()
460 }
461
462 pub fn new(storage: StorageType) -> Self {
464 Type::Scalar(storage)
465 }
466
467 pub fn scalar(elem: ElemType) -> Self {
468 Self::new(StorageType::Scalar(elem))
469 }
470
471 pub fn semantic(ty: SemanticType) -> Self {
472 Self::Semantic(ty)
473 }
474
475 pub fn line(self, line_size: LineSize) -> Type {
476 match line_size > 1 {
477 true => Type::Line(self.storage_type(), line_size),
478 false => Type::Scalar(self.storage_type()),
479 }
480 }
481
482 pub fn line_size(&self) -> LineSize {
483 match self {
484 Type::Scalar(_) => 1,
485 Type::Line(_, line_size) => *line_size,
486 Type::Semantic(_) => 0,
487 }
488 }
489
490 pub fn size(&self) -> usize {
491 match self {
492 Type::Scalar(ty) => ty.size(),
493 Type::Line(ty, line_size) => ty.size() * *line_size,
494 Type::Semantic(_) => 0,
495 }
496 }
497
498 pub fn size_bits(&self) -> usize {
499 match self {
500 Type::Scalar(ty) => ty.size_bits(),
501 Type::Line(ty, line_size) => ty.size_bits() * *line_size,
502 Type::Semantic(_) => 0,
503 }
504 }
505
506 pub fn is_atomic(&self) -> bool {
507 !self.is_semantic() && self.storage_type().is_atomic()
508 }
509
510 pub fn is_int(&self) -> bool {
511 !self.is_semantic() && self.storage_type().is_int()
512 }
513
514 pub fn is_signed_int(&self) -> bool {
515 !self.is_semantic() && self.storage_type().is_signed_int()
516 }
517
518 pub fn is_unsigned_int(&self) -> bool {
519 !self.is_semantic() && self.storage_type().is_unsigned_int()
520 }
521
522 pub fn is_float(&self) -> bool {
523 !self.is_semantic() && self.storage_type().is_float()
524 }
525
526 pub fn is_bool(&self) -> bool {
527 !self.is_semantic() && self.storage_type().is_bool()
528 }
529
530 pub fn storage_type(&self) -> StorageType {
531 match self {
532 Type::Scalar(ty) | Type::Line(ty, _) => *ty,
533 Type::Semantic(_) => unimplemented!("Can't get storage for semantic type"),
534 }
535 }
536
537 pub fn is_semantic(&self) -> bool {
538 match self {
539 Type::Scalar(_) | Type::Line(_, _) => false,
540 Type::Semantic(_) => true,
541 }
542 }
543
544 pub fn constant(&self, value: ConstantValue) -> Variable {
545 Variable::constant(value, *self)
546 }
547}
548
549impl Display for Type {
550 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
551 match self {
552 Type::Scalar(ty) => write!(f, "{ty}"),
553 Type::Line(ty, line_size) => write!(f, "line<{ty}, {line_size}>"),
554 Type::Semantic(ty) => write!(f, "{ty}"),
555 }
556 }
557}
558
559impl Display for StorageType {
560 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
561 match self {
562 StorageType::Scalar(ty) => write!(f, "{ty}"),
563 StorageType::Packed(ty, factor) => write!(f, "packed<{ty}, {factor}>"),
564 StorageType::Atomic(ty) => write!(f, "atomic<{ty}>"),
565 StorageType::Opaque(ty) => write!(f, "{ty}"),
566 }
567 }
568}
569
570impl Display for ElemType {
571 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
572 match self {
573 Self::Float(kind) => match kind {
574 FloatKind::E2M1 => f.write_str("e2m1"),
575 FloatKind::E2M3 => f.write_str("e2m3"),
576 FloatKind::E3M2 => f.write_str("e3m2"),
577 FloatKind::E4M3 => f.write_str("e4m3"),
578 FloatKind::E5M2 => f.write_str("e5m2"),
579 FloatKind::UE8M0 => f.write_str("ue8m0"),
580 FloatKind::F16 => f.write_str("f16"),
581 FloatKind::BF16 => f.write_str("bf16"),
582 FloatKind::Flex32 => f.write_str("flex32"),
583 FloatKind::TF32 => f.write_str("tf32"),
584 FloatKind::F32 => f.write_str("f32"),
585 FloatKind::F64 => f.write_str("f64"),
586 },
587 Self::Int(kind) => match kind {
588 IntKind::I8 => f.write_str("i8"),
589 IntKind::I16 => f.write_str("i16"),
590 IntKind::I32 => f.write_str("i32"),
591 IntKind::I64 => f.write_str("i64"),
592 },
593 Self::UInt(kind) => match kind {
594 UIntKind::U8 => f.write_str("u8"),
595 UIntKind::U16 => f.write_str("u16"),
596 UIntKind::U32 => f.write_str("u32"),
597 UIntKind::U64 => f.write_str("u64"),
598 },
599 Self::Bool => f.write_str("bool"),
600 }
601 }
602}
603
604impl Display for SemanticType {
605 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
606 match self {
607 SemanticType::BarrierToken => f.write_str("barrier_token"),
608 SemanticType::Pipeline => f.write_str("pipeline"),
609 SemanticType::TensorMap => f.write_str("tensor_map"),
610 }
611 }
612}
613
614impl Display for OpaqueType {
615 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
616 match self {
617 OpaqueType::Barrier(level) => write!(f, "barrier<{level}>"),
618 }
619 }
620}
621
622impl From<e2m1x2> for Variable {
623 fn from(_value: e2m1x2) -> Self {
624 unimplemented!("Can't currently construct e2m1x2")
625 }
626}
627
628impl From<e2m3> for Variable {
629 fn from(_value: e2m3) -> Self {
630 unimplemented!("Can't currently construct fp6")
631 }
632}
633
634impl From<e3m2> for Variable {
635 fn from(_value: e3m2) -> Self {
636 unimplemented!("Can't currently construct fp6")
637 }
638}
639
640impl From<i8> for ConstantValue {
641 fn from(value: i8) -> Self {
642 ConstantValue::Int(value as i64)
643 }
644}
645
646impl From<i16> for ConstantValue {
647 fn from(value: i16) -> Self {
648 ConstantValue::Int(value as i64)
649 }
650}
651
652impl From<i32> for ConstantValue {
653 fn from(value: i32) -> Self {
654 ConstantValue::Int(value as i64)
655 }
656}
657
658impl From<isize> for ConstantValue {
659 fn from(value: isize) -> Self {
660 ConstantValue::Int(value as i64)
661 }
662}
663
664impl From<u8> for ConstantValue {
665 fn from(value: u8) -> Self {
666 ConstantValue::UInt(value as u64)
667 }
668}
669
670impl From<u16> for ConstantValue {
671 fn from(value: u16) -> Self {
672 ConstantValue::UInt(value as u64)
673 }
674}
675
676impl From<u32> for ConstantValue {
677 fn from(value: u32) -> Self {
678 ConstantValue::UInt(value as u64)
679 }
680}
681
682impl From<usize> for ConstantValue {
683 fn from(value: usize) -> Self {
684 ConstantValue::UInt(value as u64)
685 }
686}
687
688impl From<e2m1> for ConstantValue {
689 fn from(value: e2m1) -> Self {
690 ConstantValue::Float(value.to_f64())
691 }
692}
693
694impl From<e4m3> for ConstantValue {
695 fn from(value: e4m3) -> Self {
696 ConstantValue::Float(value.to_f64())
697 }
698}
699
700impl From<e5m2> for ConstantValue {
701 fn from(value: e5m2) -> Self {
702 ConstantValue::Float(value.to_f64())
703 }
704}
705
706impl From<ue8m0> for ConstantValue {
707 fn from(value: ue8m0) -> Self {
708 ConstantValue::Float(value.to_f64())
709 }
710}
711
712impl From<half::f16> for ConstantValue {
713 fn from(value: half::f16) -> Self {
714 ConstantValue::Float(value.to_f64())
715 }
716}
717
718impl From<half::bf16> for ConstantValue {
719 fn from(value: half::bf16) -> Self {
720 ConstantValue::Float(value.to_f64())
721 }
722}
723
724impl From<flex32> for ConstantValue {
725 fn from(value: flex32) -> Self {
726 ConstantValue::Float(value.to_f64())
727 }
728}
729
730impl From<tf32> for ConstantValue {
731 fn from(value: tf32) -> Self {
732 ConstantValue::Float(value.to_f64())
733 }
734}
735
736impl From<f32> for ConstantValue {
737 fn from(value: f32) -> Self {
738 ConstantValue::Float(value as f64)
739 }
740}
741
742macro_rules! impl_into_variable {
743 ($($ty: ty => $kind: path,)*) => {
744 $(
745 impl From<$ty> for Variable {
746 fn from(value: $ty) -> Self {
747 Variable::new(VariableKind::Constant(value.into()), $kind.into())
748 }
749 }
750 )*
751 };
752}
753
754impl_into_variable!(
755 bool => ElemType::Bool,
756
757 i8 => IntKind::I8,
758 i16 => IntKind::I16,
759 i32 => IntKind::I32,
760 i64 => IntKind::I64,
761
762 u8 => UIntKind::U8,
763 u16 => UIntKind::U16,
764 u32 => UIntKind::U32,
765 u64 => UIntKind::U64,
766
767 e2m1 => FloatKind::E2M1,
768 e4m3 => FloatKind::E4M3,
769 e5m2 => FloatKind::E5M2,
770 ue8m0 => FloatKind::UE8M0,
771 f16 => FloatKind::F16,
772 bf16 => FloatKind::BF16,
773 f32 => FloatKind::F32,
774 flex32 => FloatKind::Flex32,
775 tf32 => FloatKind::TF32,
776 f64 => FloatKind::F64,
777
778 usize => UIntKind::U32,
779 isize => IntKind::I32,
780);