1use super::{ConstantScalarValue, Variable, VariableKind};
2use crate::{BarrierLevel, TypeHash};
3use core::fmt::Display;
4use cubecl_common::{
5 e2m1, e2m1x2, e2m3, e3m2, e4m3, e5m2, flex32,
6 quant::scheme::{QuantParam, QuantValue},
7 tf32, ue8m0,
8};
9
10#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
12#[allow(missing_docs)]
13pub enum FloatKind {
14 E2M1,
16 E2M3,
19 E3M2,
22 E4M3,
24 E5M2,
26 UE8M0,
28 F16,
29 BF16,
30 Flex32,
31 F32,
32 TF32,
33 F64,
34}
35
36#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
37#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
38#[allow(missing_docs)]
39pub enum IntKind {
40 I8,
41 I16,
42 I32,
43 I64,
44}
45
46#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
47#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
48#[allow(missing_docs)]
49pub enum UIntKind {
50 U8,
51 U16,
52 U32,
53 U64,
54}
55
56#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
58#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
59#[allow(missing_docs)]
60pub enum ElemType {
61 Float(FloatKind),
62 Int(IntKind),
63 UInt(UIntKind),
64 Bool,
65}
66
67#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
68#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
69pub enum OpaqueType {
70 Barrier(BarrierLevel),
71}
72
73#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
74#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
75pub enum SemanticType {
76 BarrierToken,
77 Pipeline,
78 TensorMap,
79}
80
81#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
83#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
84pub enum StorageType {
85 Scalar(ElemType),
87 Packed(ElemType, u32),
89 Atomic(ElemType),
91 Opaque(OpaqueType),
94}
95
96impl ElemType {
97 pub fn from_quant_param(quant_param: QuantParam) -> Self {
99 match quant_param {
100 QuantParam::F32 => Self::Float(FloatKind::F32),
101 QuantParam::F16 => Self::Float(FloatKind::F16),
102 QuantParam::BF16 => Self::Float(FloatKind::BF16),
103 QuantParam::UE8M0 => Self::Float(FloatKind::UE8M0),
104 QuantParam::UE4M3 => Self::Float(FloatKind::UE8M0),
105 }
106 }
107
108 pub fn from_quant_value(quant_value: QuantValue) -> Self {
110 match quant_value {
111 QuantValue::E5M2 => Self::Float(FloatKind::E5M2),
112 QuantValue::E4M3 => Self::Float(FloatKind::E4M3),
113 QuantValue::E2M1 => Self::Float(FloatKind::E2M1),
114 QuantValue::Q8F | QuantValue::Q8S => Self::Int(IntKind::I8),
115 other => panic!("Unsupported quant value {other:?}"),
116 }
117 }
118 pub fn constant_from_f64(&self, val: f64) -> Variable {
122 Variable::constant(match self {
123 ElemType::Float(kind) => ConstantScalarValue::Float(val, *kind),
124 ElemType::Int(kind) => ConstantScalarValue::Int(val as i64, *kind),
125 ElemType::UInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
126 ElemType::Bool => ConstantScalarValue::Bool(val > 0.0),
127 })
128 }
129 pub fn constant_from_i64(&self, val: i64) -> Variable {
133 Variable::constant(match self {
134 ElemType::Float(kind) => ConstantScalarValue::Float(val as f64, *kind),
135 ElemType::Int(kind) => ConstantScalarValue::Int(val, *kind),
136 ElemType::UInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
137 ElemType::Bool => ConstantScalarValue::Bool(val > 0),
138 })
139 }
140 pub fn constant_from_u64(&self, val: u64) -> Variable {
144 Variable::constant(match self {
145 ElemType::Float(kind) => ConstantScalarValue::Float(val as f64, *kind),
146 ElemType::Int(kind) => ConstantScalarValue::Int(val as i64, *kind),
147 ElemType::UInt(kind) => ConstantScalarValue::UInt(val, *kind),
148 ElemType::Bool => ConstantScalarValue::Bool(val > 0),
149 })
150 }
151 pub fn constant_from_bool(&self, val: bool) -> Variable {
155 Variable::constant(match self {
156 ElemType::Float(kind) => ConstantScalarValue::Float(val as u32 as f64, *kind),
157 ElemType::Int(kind) => ConstantScalarValue::Int(val as i64, *kind),
158 ElemType::UInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
159 ElemType::Bool => ConstantScalarValue::Bool(val),
160 })
161 }
162
163 pub fn from_constant(&self, constant: Variable) -> Variable {
165 let value = match constant.kind {
166 VariableKind::ConstantScalar(value) => value,
167 _ => return constant,
168 };
169
170 match value {
171 ConstantScalarValue::Int(val, _) => self.constant_from_i64(val),
172 ConstantScalarValue::Float(val, _) => self.constant_from_f64(val),
173 ConstantScalarValue::UInt(val, _) => self.constant_from_u64(val),
174 ConstantScalarValue::Bool(val) => self.constant_from_bool(val),
175 }
176 }
177 pub const fn size(&self) -> usize {
179 match self {
180 ElemType::Float(kind) => match kind {
181 FloatKind::E2M1
182 | FloatKind::E2M3
183 | FloatKind::E3M2
184 | FloatKind::E4M3
185 | FloatKind::E5M2
186 | FloatKind::UE8M0 => core::mem::size_of::<u8>(),
187 FloatKind::F16 => core::mem::size_of::<half::f16>(),
188 FloatKind::BF16 => core::mem::size_of::<half::bf16>(),
189 FloatKind::F32 => core::mem::size_of::<f32>(),
190 FloatKind::F64 => core::mem::size_of::<f64>(),
191 FloatKind::Flex32 => core::mem::size_of::<f32>(),
192 FloatKind::TF32 => core::mem::size_of::<f32>(),
193 },
194 ElemType::Int(kind) => match kind {
195 IntKind::I8 => core::mem::size_of::<i8>(),
196 IntKind::I16 => core::mem::size_of::<i16>(),
197 IntKind::I32 => core::mem::size_of::<i32>(),
198 IntKind::I64 => core::mem::size_of::<i64>(),
199 },
200 ElemType::UInt(kind) => match kind {
201 UIntKind::U8 => core::mem::size_of::<u8>(),
202 UIntKind::U16 => core::mem::size_of::<u16>(),
203 UIntKind::U32 => core::mem::size_of::<u32>(),
204 UIntKind::U64 => core::mem::size_of::<u64>(),
205 },
206 ElemType::Bool => core::mem::size_of::<bool>(),
207 }
208 }
209
210 pub const fn size_bits(&self) -> usize {
212 match self {
213 ElemType::Float(kind) => match kind {
214 FloatKind::E2M3
215 | FloatKind::E3M2
216 | FloatKind::E4M3
217 | FloatKind::E5M2
218 | FloatKind::UE8M0
219 | FloatKind::F16
220 | FloatKind::BF16
221 | FloatKind::F32
222 | FloatKind::F64
223 | FloatKind::Flex32
224 | FloatKind::TF32 => self.size() * 8,
225 FloatKind::E2M1 => 4,
226 },
227 ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool => self.size() * 8,
228 }
229 }
230
231 pub const fn min_line_size(&self) -> u8 {
232 match self {
233 ElemType::Float(FloatKind::E2M1) => 2,
234 _ => 1,
235 }
236 }
237
238 pub fn is_int(&self) -> bool {
239 matches!(self, ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool)
240 }
241
242 pub fn is_signed_int(&self) -> bool {
243 matches!(self, ElemType::Int(_))
244 }
245
246 pub fn is_unsigned_int(&self) -> bool {
247 matches!(self, ElemType::UInt(_) | ElemType::Bool)
248 }
249
250 pub fn is_float(&self) -> bool {
251 matches!(self, ElemType::Float(_))
252 }
253
254 pub fn max_variable(&self) -> Variable {
255 let value = match self {
256 ElemType::Float(kind) => match kind {
257 FloatKind::E2M1 => ConstantScalarValue::Float(e2m1::MAX, FloatKind::E2M1),
258 FloatKind::E2M3 => ConstantScalarValue::Float(e2m3::MAX, FloatKind::E2M3),
259 FloatKind::E3M2 => ConstantScalarValue::Float(e3m2::MAX, FloatKind::E3M2),
260 FloatKind::E4M3 => ConstantScalarValue::Float(e4m3::MAX, FloatKind::E4M3),
261 FloatKind::E5M2 => ConstantScalarValue::Float(e5m2::MAX, FloatKind::E5M2),
262 FloatKind::UE8M0 => ConstantScalarValue::Float(ue8m0::MAX, FloatKind::UE8M0),
263 FloatKind::F16 => {
264 ConstantScalarValue::Float(half::f16::MAX.to_f64(), FloatKind::F16)
265 }
266 FloatKind::BF16 => {
267 ConstantScalarValue::Float(half::bf16::MAX.to_f64(), FloatKind::BF16)
268 }
269 FloatKind::Flex32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::Flex32),
270 FloatKind::F32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::F32),
271 FloatKind::TF32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::TF32),
272 FloatKind::F64 => ConstantScalarValue::Float(f64::MAX, FloatKind::F64),
273 },
274 ElemType::Int(kind) => match kind {
275 IntKind::I8 => ConstantScalarValue::Int(i8::MAX.into(), IntKind::I8),
276 IntKind::I16 => ConstantScalarValue::Int(i16::MAX.into(), IntKind::I16),
277 IntKind::I32 => ConstantScalarValue::Int(i32::MAX.into(), IntKind::I32),
278 IntKind::I64 => ConstantScalarValue::Int(i64::MAX, IntKind::I64),
279 },
280 ElemType::UInt(kind) => match kind {
281 UIntKind::U8 => ConstantScalarValue::UInt(u8::MAX.into(), UIntKind::U8),
282 UIntKind::U16 => ConstantScalarValue::UInt(u16::MAX.into(), UIntKind::U16),
283 UIntKind::U32 => ConstantScalarValue::UInt(u32::MAX.into(), UIntKind::U32),
284 UIntKind::U64 => ConstantScalarValue::UInt(u64::MAX, UIntKind::U64),
285 },
286 ElemType::Bool => ConstantScalarValue::Bool(true),
287 };
288
289 Variable::new(VariableKind::ConstantScalar(value), Type::scalar(*self))
290 }
291
292 pub fn min_variable(&self) -> Variable {
293 let value = match self {
294 ElemType::Float(kind) => match kind {
295 FloatKind::E2M1 => ConstantScalarValue::Float(e2m1::MIN, FloatKind::E2M1),
296 FloatKind::E2M3 => ConstantScalarValue::Float(e2m3::MIN, FloatKind::E2M3),
297 FloatKind::E3M2 => ConstantScalarValue::Float(e3m2::MIN, FloatKind::E3M2),
298 FloatKind::E4M3 => ConstantScalarValue::Float(e4m3::MIN, FloatKind::E4M3),
299 FloatKind::E5M2 => ConstantScalarValue::Float(e5m2::MIN, FloatKind::E5M2),
300 FloatKind::UE8M0 => ConstantScalarValue::Float(ue8m0::MIN, FloatKind::UE8M0),
301 FloatKind::F16 => {
302 ConstantScalarValue::Float(half::f16::MIN.to_f64(), FloatKind::F16)
303 }
304 FloatKind::BF16 => {
305 ConstantScalarValue::Float(half::bf16::MIN.to_f64(), FloatKind::BF16)
306 }
307 FloatKind::Flex32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::Flex32),
308 FloatKind::F32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::F32),
309 FloatKind::TF32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::TF32),
310 FloatKind::F64 => ConstantScalarValue::Float(f64::MIN, FloatKind::F64),
311 },
312 ElemType::Int(kind) => match kind {
313 IntKind::I8 => ConstantScalarValue::Int(i8::MIN.into(), IntKind::I8),
314 IntKind::I16 => ConstantScalarValue::Int(i16::MIN.into(), IntKind::I16),
315 IntKind::I32 => ConstantScalarValue::Int(i32::MIN.into(), IntKind::I32),
316 IntKind::I64 => ConstantScalarValue::Int(i64::MIN, IntKind::I64),
317 },
318 ElemType::UInt(kind) => match kind {
319 UIntKind::U8 => ConstantScalarValue::UInt(u8::MIN.into(), UIntKind::U8),
320 UIntKind::U16 => ConstantScalarValue::UInt(u16::MIN.into(), UIntKind::U16),
321 UIntKind::U32 => ConstantScalarValue::UInt(u32::MIN.into(), UIntKind::U32),
322 UIntKind::U64 => ConstantScalarValue::UInt(u64::MIN, UIntKind::U64),
323 },
324 ElemType::Bool => ConstantScalarValue::Bool(false),
325 };
326
327 Variable::new(VariableKind::ConstantScalar(value), Type::scalar(*self))
328 }
329}
330
331impl OpaqueType {
332 pub const fn size(&self) -> usize {
334 match self {
335 OpaqueType::Barrier(_) => 8,
336 }
337 }
338
339 pub const fn size_bits(&self) -> usize {
341 match self {
342 OpaqueType::Barrier(_) => 64,
343 }
344 }
345}
346
347impl StorageType {
348 pub fn elem_type(&self) -> ElemType {
349 match self {
350 StorageType::Scalar(ty) | StorageType::Packed(ty, _) | StorageType::Atomic(ty) => *ty,
351 StorageType::Opaque(_) => unimplemented!("Can't get elem type for opaque type"),
352 }
353 }
354
355 pub fn packing_factor(&self) -> u32 {
356 match self {
357 StorageType::Packed(_, factor) => *factor,
358 _ => 1,
359 }
360 }
361
362 pub fn is_atomic(&self) -> bool {
363 matches!(self, StorageType::Atomic(_))
364 }
365
366 pub fn size(&self) -> usize {
367 self.size_bits().div_ceil(8)
368 }
369
370 pub fn size_bits(&self) -> usize {
371 match self {
372 StorageType::Packed(ty, factor) => ty.size_bits() * *factor as usize,
373 StorageType::Scalar(ty) | StorageType::Atomic(ty) => ty.size_bits(),
374 StorageType::Opaque(ty) => ty.size_bits(),
375 }
376 }
377
378 pub fn from_constant(&self, constant: Variable) -> Variable {
380 self.elem_type().from_constant(constant)
381 }
382
383 pub fn is_int(&self) -> bool {
384 self.elem_type().is_int()
385 }
386
387 pub fn is_signed_int(&self) -> bool {
388 self.elem_type().is_signed_int()
389 }
390
391 pub fn is_unsigned_int(&self) -> bool {
392 self.elem_type().is_unsigned_int()
393 }
394
395 pub fn is_float(&self) -> bool {
396 self.elem_type().is_float()
397 }
398}
399
400impl From<ElemType> for StorageType {
401 fn from(val: ElemType) -> Self {
402 StorageType::Scalar(val)
403 }
404}
405
406impl From<OpaqueType> for StorageType {
407 fn from(val: OpaqueType) -> Self {
408 StorageType::Opaque(val)
409 }
410}
411
412impl<T: Into<StorageType>> From<T> for Type {
413 fn from(val: T) -> Self {
414 Type::new(val.into())
415 }
416}
417
418impl From<SemanticType> for Type {
419 fn from(val: SemanticType) -> Self {
420 Type::semantic(val)
421 }
422}
423
424#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
425#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
426pub enum Type {
427 Scalar(StorageType),
429 Line(StorageType, u32),
431 Semantic(SemanticType),
433}
434
435pub type LineSize = u32;
436
437impl Type {
438 pub fn elem_type(&self) -> ElemType {
440 self.storage_type().elem_type()
441 }
442
443 pub fn new(storage: StorageType) -> Self {
445 Type::Scalar(storage)
446 }
447
448 pub fn scalar(elem: ElemType) -> Self {
449 Self::new(StorageType::Scalar(elem))
450 }
451
452 pub fn semantic(ty: SemanticType) -> Self {
453 Self::Semantic(ty)
454 }
455
456 pub fn line(self, line_size: LineSize) -> Type {
457 match line_size > 1 {
458 true => Type::Line(self.storage_type(), line_size),
459 false => Type::Scalar(self.storage_type()),
460 }
461 }
462
463 pub fn line_size(&self) -> u32 {
464 match self {
465 Type::Scalar(_) => 1,
466 Type::Line(_, line_size) => *line_size,
467 Type::Semantic(_) => 0,
468 }
469 }
470
471 pub fn size(&self) -> usize {
472 match self {
473 Type::Scalar(ty) => ty.size(),
474 Type::Line(ty, line_size) => ty.size() * *line_size as usize,
475 Type::Semantic(_) => 0,
476 }
477 }
478
479 pub fn size_bits(&self) -> usize {
480 match self {
481 Type::Scalar(ty) => ty.size_bits(),
482 Type::Line(ty, line_size) => ty.size_bits() * *line_size as usize,
483 Type::Semantic(_) => 0,
484 }
485 }
486
487 pub fn is_atomic(&self) -> bool {
488 !self.is_semantic() && self.storage_type().is_atomic()
489 }
490
491 pub fn is_int(&self) -> bool {
492 !self.is_semantic() && self.storage_type().is_int()
493 }
494
495 pub fn is_signed_int(&self) -> bool {
496 !self.is_semantic() && self.storage_type().is_signed_int()
497 }
498
499 pub fn is_unsigned_int(&self) -> bool {
500 !self.is_semantic() && self.storage_type().is_unsigned_int()
501 }
502
503 pub fn is_float(&self) -> bool {
504 !self.is_semantic() && self.storage_type().is_float()
505 }
506
507 pub fn storage_type(&self) -> StorageType {
508 match self {
509 Type::Scalar(ty) | Type::Line(ty, _) => *ty,
510 Type::Semantic(_) => unimplemented!("Can't get storage for semantic type"),
511 }
512 }
513
514 pub fn is_semantic(&self) -> bool {
515 match self {
516 Type::Scalar(_) | Type::Line(_, _) => false,
517 Type::Semantic(_) => true,
518 }
519 }
520}
521
522impl Display for Type {
523 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
524 match self {
525 Type::Scalar(ty) => write!(f, "{ty}"),
526 Type::Line(ty, line_size) => write!(f, "line<{ty}, {line_size}>"),
527 Type::Semantic(ty) => write!(f, "{ty}"),
528 }
529 }
530}
531
532impl Display for StorageType {
533 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
534 match self {
535 StorageType::Scalar(ty) => write!(f, "{ty}"),
536 StorageType::Packed(ty, factor) => write!(f, "packed<{ty}, {factor}>"),
537 StorageType::Atomic(ty) => write!(f, "atomic<{ty}>"),
538 StorageType::Opaque(ty) => write!(f, "{ty}"),
539 }
540 }
541}
542
543impl Display for ElemType {
544 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
545 match self {
546 Self::Float(kind) => match kind {
547 FloatKind::E2M1 => f.write_str("e2m1"),
548 FloatKind::E2M3 => f.write_str("e2m3"),
549 FloatKind::E3M2 => f.write_str("e3m2"),
550 FloatKind::E4M3 => f.write_str("e4m3"),
551 FloatKind::E5M2 => f.write_str("e5m2"),
552 FloatKind::UE8M0 => f.write_str("ue8m0"),
553 FloatKind::F16 => f.write_str("f16"),
554 FloatKind::BF16 => f.write_str("bf16"),
555 FloatKind::Flex32 => f.write_str("flex32"),
556 FloatKind::TF32 => f.write_str("tf32"),
557 FloatKind::F32 => f.write_str("f32"),
558 FloatKind::F64 => f.write_str("f64"),
559 },
560 Self::Int(kind) => match kind {
561 IntKind::I8 => f.write_str("i8"),
562 IntKind::I16 => f.write_str("i16"),
563 IntKind::I32 => f.write_str("i32"),
564 IntKind::I64 => f.write_str("i64"),
565 },
566 Self::UInt(kind) => match kind {
567 UIntKind::U8 => f.write_str("u8"),
568 UIntKind::U16 => f.write_str("u16"),
569 UIntKind::U32 => f.write_str("u32"),
570 UIntKind::U64 => f.write_str("u64"),
571 },
572 Self::Bool => f.write_str("bool"),
573 }
574 }
575}
576
577impl Display for SemanticType {
578 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
579 match self {
580 SemanticType::BarrierToken => f.write_str("barrier_token"),
581 SemanticType::Pipeline => f.write_str("pipeline"),
582 SemanticType::TensorMap => f.write_str("tensor_map"),
583 }
584 }
585}
586
587impl Display for OpaqueType {
588 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
589 match self {
590 OpaqueType::Barrier(level) => write!(f, "barrier<{level}>"),
591 }
592 }
593}
594
595impl From<bool> for Variable {
596 fn from(value: bool) -> Self {
597 Variable::constant(ConstantScalarValue::Bool(value))
598 }
599}
600
601impl From<i8> for Variable {
602 fn from(value: i8) -> Self {
603 Variable::constant(ConstantScalarValue::Int(value as i64, IntKind::I8))
604 }
605}
606
607impl From<i16> for Variable {
608 fn from(value: i16) -> Self {
609 Variable::constant(ConstantScalarValue::Int(value as i64, IntKind::I16))
610 }
611}
612
613impl From<i32> for Variable {
614 fn from(value: i32) -> Self {
615 Variable::constant(ConstantScalarValue::Int(value as i64, IntKind::I32))
616 }
617}
618
619impl From<i64> for Variable {
620 fn from(value: i64) -> Self {
621 Variable::constant(ConstantScalarValue::Int(value, IntKind::I64))
622 }
623}
624
625impl From<e2m1> for Variable {
626 fn from(_value: e2m1) -> Self {
627 unimplemented!("Can't currently construct minifloats")
628 }
629}
630
631impl From<e2m1x2> for Variable {
632 fn from(_value: e2m1x2) -> Self {
633 unimplemented!("Can't currently construct minifloats")
634 }
635}
636
637impl From<e2m3> for Variable {
638 fn from(_value: e2m3) -> Self {
639 unimplemented!("Can't currently construct minifloats")
640 }
641}
642
643impl From<e3m2> for Variable {
644 fn from(_value: e3m2) -> Self {
645 unimplemented!("Can't currently construct minifloats")
646 }
647}
648
649impl From<e4m3> for Variable {
650 fn from(_value: e4m3) -> Self {
651 unimplemented!("Can't currently construct minifloats")
652 }
653}
654
655impl From<e5m2> for Variable {
656 fn from(_value: e5m2) -> Self {
657 unimplemented!("Can't currently construct minifloats")
658 }
659}
660
661impl From<ue8m0> for Variable {
662 fn from(_value: ue8m0) -> Self {
663 unimplemented!("Can't currently construct minifloats")
664 }
665}
666
667impl From<half::f16> for Variable {
668 fn from(value: half::f16) -> Self {
669 Variable::constant(ConstantScalarValue::Float(value.to_f64(), FloatKind::F16))
670 }
671}
672
673impl From<half::bf16> for Variable {
674 fn from(value: half::bf16) -> Self {
675 Variable::constant(ConstantScalarValue::Float(value.to_f64(), FloatKind::BF16))
676 }
677}
678
679impl From<flex32> for Variable {
680 fn from(value: flex32) -> Self {
681 Variable::constant(ConstantScalarValue::Float(
682 value.to_f64(),
683 FloatKind::Flex32,
684 ))
685 }
686}
687
688impl From<tf32> for Variable {
689 fn from(value: tf32) -> Self {
690 Variable::constant(ConstantScalarValue::Float(value.to_f64(), FloatKind::TF32))
691 }
692}
693
694impl From<f32> for Variable {
695 fn from(value: f32) -> Self {
696 Variable::constant(ConstantScalarValue::Float(value as f64, FloatKind::F32))
697 }
698}
699
700impl From<f64> for Variable {
701 fn from(value: f64) -> Self {
702 Variable::constant(ConstantScalarValue::Float(value, FloatKind::F64))
703 }
704}
705
706impl From<u8> for Variable {
707 fn from(value: u8) -> Self {
708 Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U8))
709 }
710}
711
712impl From<u16> for Variable {
713 fn from(value: u16) -> Self {
714 Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U16))
715 }
716}
717
718impl From<u32> for Variable {
719 fn from(value: u32) -> Self {
720 Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U32))
721 }
722}
723
724impl From<u64> for Variable {
725 fn from(value: u64) -> Self {
726 Variable::constant(ConstantScalarValue::UInt(value, UIntKind::U64))
727 }
728}
729
730impl From<usize> for Variable {
731 fn from(value: usize) -> Self {
732 Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U32))
733 }
734}