1use super::{ConstantScalarValue, Variable, VariableKind};
2use crate::{BarrierLevel, TypeHash};
3use core::fmt::Display;
4use cubecl_common::{
5 e2m1, e2m1x2, e2m3, e3m2, e4m3, e5m2, flex32,
6 quant::scheme::{QuantParam, QuantValue},
7 tf32, ue8m0,
8};
9
10#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
12#[allow(missing_docs)]
13pub enum FloatKind {
14 E2M1,
16 E2M3,
19 E3M2,
22 E4M3,
24 E5M2,
26 UE8M0,
28 F16,
29 BF16,
30 Flex32,
31 F32,
32 TF32,
33 F64,
34}
35
36#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
37#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
38#[allow(missing_docs)]
39pub enum IntKind {
40 I8,
41 I16,
42 I32,
43 I64,
44}
45
46#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
47#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
48#[allow(missing_docs)]
49pub enum UIntKind {
50 U8,
51 U16,
52 U32,
53 U64,
54}
55
56#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
58#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
59#[allow(missing_docs)]
60pub enum ElemType {
61 Float(FloatKind),
62 Int(IntKind),
63 UInt(UIntKind),
64 Bool,
65}
66
67#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
68#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
69pub enum OpaqueType {
70 Barrier(BarrierLevel),
71}
72
73#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
74#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
75pub enum SemanticType {
76 BarrierToken,
77 Pipeline,
78 TensorMap,
79}
80
81#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
83#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
84pub enum StorageType {
85 Scalar(ElemType),
87 Packed(ElemType, u32),
89 Atomic(ElemType),
91 Opaque(OpaqueType),
94}
95
96impl ElemType {
97 pub fn from_quant_param(quant_param: QuantParam) -> Self {
99 match quant_param {
100 QuantParam::F32 => Self::Float(FloatKind::F32),
101 QuantParam::F16 => Self::Float(FloatKind::F16),
102 QuantParam::BF16 => Self::Float(FloatKind::BF16),
103 QuantParam::UE8M0 => Self::Float(FloatKind::UE8M0),
104 QuantParam::UE4M3 => Self::Float(FloatKind::UE8M0),
105 }
106 }
107
108 pub fn from_quant_value(quant_value: QuantValue) -> Self {
110 match quant_value {
111 QuantValue::E5M2 => Self::Float(FloatKind::E5M2),
112 QuantValue::E4M3 => Self::Float(FloatKind::E4M3),
113 QuantValue::E2M1 => Self::Float(FloatKind::E2M1),
114 QuantValue::Q8F | QuantValue::Q8S => Self::Int(IntKind::I8),
115 other => panic!("Unsupported quant value {other:?}"),
116 }
117 }
118 pub fn constant_from_f64(&self, val: f64) -> Variable {
122 Variable::constant(match self {
123 ElemType::Float(kind) => ConstantScalarValue::Float(val, *kind),
124 ElemType::Int(kind) => ConstantScalarValue::Int(val as i64, *kind),
125 ElemType::UInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
126 ElemType::Bool => ConstantScalarValue::Bool(val > 0.0),
127 })
128 }
129 pub fn constant_from_i64(&self, val: i64) -> Variable {
133 Variable::constant(match self {
134 ElemType::Float(kind) => ConstantScalarValue::Float(val as f64, *kind),
135 ElemType::Int(kind) => ConstantScalarValue::Int(val, *kind),
136 ElemType::UInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
137 ElemType::Bool => ConstantScalarValue::Bool(val > 0),
138 })
139 }
140 pub fn constant_from_u64(&self, val: u64) -> Variable {
144 Variable::constant(match self {
145 ElemType::Float(kind) => ConstantScalarValue::Float(val as f64, *kind),
146 ElemType::Int(kind) => ConstantScalarValue::Int(val as i64, *kind),
147 ElemType::UInt(kind) => ConstantScalarValue::UInt(val, *kind),
148 ElemType::Bool => ConstantScalarValue::Bool(val > 0),
149 })
150 }
151 pub fn constant_from_bool(&self, val: bool) -> Variable {
155 Variable::constant(match self {
156 ElemType::Float(kind) => ConstantScalarValue::Float(val as u32 as f64, *kind),
157 ElemType::Int(kind) => ConstantScalarValue::Int(val as i64, *kind),
158 ElemType::UInt(kind) => ConstantScalarValue::UInt(val as u64, *kind),
159 ElemType::Bool => ConstantScalarValue::Bool(val),
160 })
161 }
162
163 pub fn from_constant(&self, constant: Variable) -> Variable {
165 let value = match constant.kind {
166 VariableKind::ConstantScalar(value) => value,
167 _ => return constant,
168 };
169
170 match value {
171 ConstantScalarValue::Int(val, _) => self.constant_from_i64(val),
172 ConstantScalarValue::Float(val, _) => self.constant_from_f64(val),
173 ConstantScalarValue::UInt(val, _) => self.constant_from_u64(val),
174 ConstantScalarValue::Bool(val) => self.constant_from_bool(val),
175 }
176 }
177 pub const fn size(&self) -> usize {
179 match self {
180 ElemType::Float(kind) => match kind {
181 FloatKind::E2M1
182 | FloatKind::E2M3
183 | FloatKind::E3M2
184 | FloatKind::E4M3
185 | FloatKind::E5M2
186 | FloatKind::UE8M0 => core::mem::size_of::<u8>(),
187 FloatKind::F16 => core::mem::size_of::<half::f16>(),
188 FloatKind::BF16 => core::mem::size_of::<half::bf16>(),
189 FloatKind::F32 => core::mem::size_of::<f32>(),
190 FloatKind::F64 => core::mem::size_of::<f64>(),
191 FloatKind::Flex32 => core::mem::size_of::<f32>(),
192 FloatKind::TF32 => core::mem::size_of::<f32>(),
193 },
194 ElemType::Int(kind) => match kind {
195 IntKind::I8 => core::mem::size_of::<i8>(),
196 IntKind::I16 => core::mem::size_of::<i16>(),
197 IntKind::I32 => core::mem::size_of::<i32>(),
198 IntKind::I64 => core::mem::size_of::<i64>(),
199 },
200 ElemType::UInt(kind) => match kind {
201 UIntKind::U8 => core::mem::size_of::<u8>(),
202 UIntKind::U16 => core::mem::size_of::<u16>(),
203 UIntKind::U32 => core::mem::size_of::<u32>(),
204 UIntKind::U64 => core::mem::size_of::<u64>(),
205 },
206 ElemType::Bool => core::mem::size_of::<bool>(),
207 }
208 }
209
210 pub const fn size_bits(&self) -> usize {
212 match self {
213 ElemType::Float(kind) => match kind {
214 FloatKind::E2M3
215 | FloatKind::E3M2
216 | FloatKind::E4M3
217 | FloatKind::E5M2
218 | FloatKind::UE8M0
219 | FloatKind::F16
220 | FloatKind::BF16
221 | FloatKind::F32
222 | FloatKind::F64
223 | FloatKind::Flex32
224 | FloatKind::TF32 => self.size() * 8,
225 FloatKind::E2M1 => 4,
226 },
227 ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool => self.size() * 8,
228 }
229 }
230
231 pub const fn min_line_size(&self) -> u8 {
232 match self {
233 ElemType::Float(FloatKind::E2M1) => 2,
234 _ => 1,
235 }
236 }
237
238 pub fn is_int(&self) -> bool {
239 matches!(self, ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool)
240 }
241
242 pub fn is_signed_int(&self) -> bool {
243 matches!(self, ElemType::Int(_))
244 }
245
246 pub fn is_unsigned_int(&self) -> bool {
247 matches!(self, ElemType::UInt(_) | ElemType::Bool)
248 }
249
250 pub fn is_float(&self) -> bool {
251 matches!(self, ElemType::Float(_))
252 }
253
254 pub fn max_variable(&self) -> Variable {
255 let value = match self {
256 ElemType::Float(kind) => match kind {
257 FloatKind::E2M1 => ConstantScalarValue::Float(e2m1::MAX, FloatKind::E2M1),
258 FloatKind::E2M3 => ConstantScalarValue::Float(e2m3::MAX, FloatKind::E2M3),
259 FloatKind::E3M2 => ConstantScalarValue::Float(e3m2::MAX, FloatKind::E3M2),
260 FloatKind::E4M3 => ConstantScalarValue::Float(e4m3::MAX, FloatKind::E4M3),
261 FloatKind::E5M2 => ConstantScalarValue::Float(e5m2::MAX, FloatKind::E5M2),
262 FloatKind::UE8M0 => ConstantScalarValue::Float(ue8m0::MAX, FloatKind::UE8M0),
263 FloatKind::F16 => {
264 ConstantScalarValue::Float(half::f16::MAX.to_f64(), FloatKind::F16)
265 }
266 FloatKind::BF16 => {
267 ConstantScalarValue::Float(half::bf16::MAX.to_f64(), FloatKind::BF16)
268 }
269 FloatKind::Flex32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::Flex32),
270 FloatKind::F32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::F32),
271 FloatKind::TF32 => ConstantScalarValue::Float(f32::MAX.into(), FloatKind::TF32),
272 FloatKind::F64 => ConstantScalarValue::Float(f64::MAX, FloatKind::F64),
273 },
274 ElemType::Int(kind) => match kind {
275 IntKind::I8 => ConstantScalarValue::Int(i8::MAX.into(), IntKind::I8),
276 IntKind::I16 => ConstantScalarValue::Int(i16::MAX.into(), IntKind::I16),
277 IntKind::I32 => ConstantScalarValue::Int(i32::MAX.into(), IntKind::I32),
278 IntKind::I64 => ConstantScalarValue::Int(i64::MAX, IntKind::I64),
279 },
280 ElemType::UInt(kind) => match kind {
281 UIntKind::U8 => ConstantScalarValue::UInt(u8::MAX.into(), UIntKind::U8),
282 UIntKind::U16 => ConstantScalarValue::UInt(u16::MAX.into(), UIntKind::U16),
283 UIntKind::U32 => ConstantScalarValue::UInt(u32::MAX.into(), UIntKind::U32),
284 UIntKind::U64 => ConstantScalarValue::UInt(u64::MAX, UIntKind::U64),
285 },
286 ElemType::Bool => ConstantScalarValue::Bool(true),
287 };
288
289 Variable::new(VariableKind::ConstantScalar(value), Type::scalar(*self))
290 }
291
292 pub fn min_variable(&self) -> Variable {
293 let value = match self {
294 ElemType::Float(kind) => match kind {
295 FloatKind::E2M1 => ConstantScalarValue::Float(e2m1::MIN, FloatKind::E2M1),
296 FloatKind::E2M3 => ConstantScalarValue::Float(e2m3::MIN, FloatKind::E2M3),
297 FloatKind::E3M2 => ConstantScalarValue::Float(e3m2::MIN, FloatKind::E3M2),
298 FloatKind::E4M3 => ConstantScalarValue::Float(e4m3::MIN, FloatKind::E4M3),
299 FloatKind::E5M2 => ConstantScalarValue::Float(e5m2::MIN, FloatKind::E5M2),
300 FloatKind::UE8M0 => ConstantScalarValue::Float(ue8m0::MIN, FloatKind::UE8M0),
301 FloatKind::F16 => {
302 ConstantScalarValue::Float(half::f16::MIN.to_f64(), FloatKind::F16)
303 }
304 FloatKind::BF16 => {
305 ConstantScalarValue::Float(half::bf16::MIN.to_f64(), FloatKind::BF16)
306 }
307 FloatKind::Flex32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::Flex32),
308 FloatKind::F32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::F32),
309 FloatKind::TF32 => ConstantScalarValue::Float(f32::MIN.into(), FloatKind::TF32),
310 FloatKind::F64 => ConstantScalarValue::Float(f64::MIN, FloatKind::F64),
311 },
312 ElemType::Int(kind) => match kind {
313 IntKind::I8 => ConstantScalarValue::Int(i8::MIN.into(), IntKind::I8),
314 IntKind::I16 => ConstantScalarValue::Int(i16::MIN.into(), IntKind::I16),
315 IntKind::I32 => ConstantScalarValue::Int(i32::MIN.into(), IntKind::I32),
316 IntKind::I64 => ConstantScalarValue::Int(i64::MIN, IntKind::I64),
317 },
318 ElemType::UInt(kind) => match kind {
319 UIntKind::U8 => ConstantScalarValue::UInt(u8::MIN.into(), UIntKind::U8),
320 UIntKind::U16 => ConstantScalarValue::UInt(u16::MIN.into(), UIntKind::U16),
321 UIntKind::U32 => ConstantScalarValue::UInt(u32::MIN.into(), UIntKind::U32),
322 UIntKind::U64 => ConstantScalarValue::UInt(u64::MIN, UIntKind::U64),
323 },
324 ElemType::Bool => ConstantScalarValue::Bool(false),
325 };
326
327 Variable::new(VariableKind::ConstantScalar(value), Type::scalar(*self))
328 }
329
330 pub fn epsilon(&self) -> f64 {
331 match self {
332 ElemType::Float(kind) => match kind {
333 FloatKind::E2M1 => 0.5 * (e2m1::MAX - e2m1::MIN),
334 FloatKind::E2M3 => 0.5 * (e2m3::MAX - e2m3::MIN),
335 FloatKind::E3M2 => 0.5 * (e3m2::MAX - e3m2::MIN),
336 FloatKind::E4M3 => 0.5 * (e4m3::MAX - e4m3::MIN),
337 FloatKind::E5M2 => 0.5 * (e5m2::MAX - e5m2::MIN),
338 FloatKind::UE8M0 => 0.5 * (ue8m0::MAX - ue8m0::MIN),
339 FloatKind::F16 => half::f16::EPSILON.to_f64(),
340 FloatKind::BF16 => 0.0078125, FloatKind::Flex32 | FloatKind::F32 | FloatKind::TF32 => f32::EPSILON.into(),
342 FloatKind::F64 => f64::EPSILON,
343 },
344 ElemType::Int(_) | ElemType::UInt(_) => 1.0, ElemType::Bool => 1.0,
346 }
347 }
348}
349
350impl OpaqueType {
351 pub const fn size(&self) -> usize {
353 match self {
354 OpaqueType::Barrier(_) => 8,
355 }
356 }
357
358 pub const fn size_bits(&self) -> usize {
360 match self {
361 OpaqueType::Barrier(_) => 64,
362 }
363 }
364}
365
366impl StorageType {
367 pub fn elem_type(&self) -> ElemType {
368 match self {
369 StorageType::Scalar(ty) | StorageType::Packed(ty, _) | StorageType::Atomic(ty) => *ty,
370 StorageType::Opaque(_) => unimplemented!("Can't get elem type for opaque type"),
371 }
372 }
373
374 pub fn packing_factor(&self) -> u32 {
375 match self {
376 StorageType::Packed(_, factor) => *factor,
377 _ => 1,
378 }
379 }
380
381 pub fn is_atomic(&self) -> bool {
382 matches!(self, StorageType::Atomic(_))
383 }
384
385 pub fn size(&self) -> usize {
386 self.size_bits().div_ceil(8)
387 }
388
389 pub fn size_bits(&self) -> usize {
390 match self {
391 StorageType::Packed(ty, factor) => ty.size_bits() * *factor as usize,
392 StorageType::Scalar(ty) | StorageType::Atomic(ty) => ty.size_bits(),
393 StorageType::Opaque(ty) => ty.size_bits(),
394 }
395 }
396
397 pub fn from_constant(&self, constant: Variable) -> Variable {
399 self.elem_type().from_constant(constant)
400 }
401
402 pub fn is_int(&self) -> bool {
403 self.elem_type().is_int()
404 }
405
406 pub fn is_signed_int(&self) -> bool {
407 self.elem_type().is_signed_int()
408 }
409
410 pub fn is_unsigned_int(&self) -> bool {
411 self.elem_type().is_unsigned_int()
412 }
413
414 pub fn is_float(&self) -> bool {
415 self.elem_type().is_float()
416 }
417
418 pub fn epsilon(&self) -> f64 {
420 match self {
421 StorageType::Scalar(ty) | StorageType::Atomic(ty) => ty.epsilon(),
422 StorageType::Packed(ty, factor) => {
423 ty.epsilon() * (*factor as f64)
425 }
426 StorageType::Opaque(_) => panic!("Opaque type does not have an epsilon"),
427 }
428 }
429}
430
431impl From<ElemType> for StorageType {
432 fn from(val: ElemType) -> Self {
433 StorageType::Scalar(val)
434 }
435}
436
437impl From<OpaqueType> for StorageType {
438 fn from(val: OpaqueType) -> Self {
439 StorageType::Opaque(val)
440 }
441}
442
443impl<T: Into<StorageType>> From<T> for Type {
444 fn from(val: T) -> Self {
445 Type::new(val.into())
446 }
447}
448
449impl From<SemanticType> for Type {
450 fn from(val: SemanticType) -> Self {
451 Type::semantic(val)
452 }
453}
454
455#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
456#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
457pub enum Type {
458 Scalar(StorageType),
460 Line(StorageType, u32),
462 Semantic(SemanticType),
464}
465
466pub type LineSize = u32;
467
468impl Type {
469 pub fn elem_type(&self) -> ElemType {
471 self.storage_type().elem_type()
472 }
473
474 pub fn new(storage: StorageType) -> Self {
476 Type::Scalar(storage)
477 }
478
479 pub fn scalar(elem: ElemType) -> Self {
480 Self::new(StorageType::Scalar(elem))
481 }
482
483 pub fn semantic(ty: SemanticType) -> Self {
484 Self::Semantic(ty)
485 }
486
487 pub fn line(self, line_size: LineSize) -> Type {
488 match line_size > 1 {
489 true => Type::Line(self.storage_type(), line_size),
490 false => Type::Scalar(self.storage_type()),
491 }
492 }
493
494 pub fn line_size(&self) -> u32 {
495 match self {
496 Type::Scalar(_) => 1,
497 Type::Line(_, line_size) => *line_size,
498 Type::Semantic(_) => 0,
499 }
500 }
501
502 pub fn size(&self) -> usize {
503 match self {
504 Type::Scalar(ty) => ty.size(),
505 Type::Line(ty, line_size) => ty.size() * *line_size as usize,
506 Type::Semantic(_) => 0,
507 }
508 }
509
510 pub fn size_bits(&self) -> usize {
511 match self {
512 Type::Scalar(ty) => ty.size_bits(),
513 Type::Line(ty, line_size) => ty.size_bits() * *line_size as usize,
514 Type::Semantic(_) => 0,
515 }
516 }
517
518 pub fn is_atomic(&self) -> bool {
519 !self.is_semantic() && self.storage_type().is_atomic()
520 }
521
522 pub fn is_int(&self) -> bool {
523 !self.is_semantic() && self.storage_type().is_int()
524 }
525
526 pub fn is_signed_int(&self) -> bool {
527 !self.is_semantic() && self.storage_type().is_signed_int()
528 }
529
530 pub fn is_unsigned_int(&self) -> bool {
531 !self.is_semantic() && self.storage_type().is_unsigned_int()
532 }
533
534 pub fn is_float(&self) -> bool {
535 !self.is_semantic() && self.storage_type().is_float()
536 }
537
538 pub fn storage_type(&self) -> StorageType {
539 match self {
540 Type::Scalar(ty) | Type::Line(ty, _) => *ty,
541 Type::Semantic(_) => unimplemented!("Can't get storage for semantic type"),
542 }
543 }
544
545 pub fn is_semantic(&self) -> bool {
546 match self {
547 Type::Scalar(_) | Type::Line(_, _) => false,
548 Type::Semantic(_) => true,
549 }
550 }
551}
552
553impl Display for Type {
554 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
555 match self {
556 Type::Scalar(ty) => write!(f, "{ty}"),
557 Type::Line(ty, line_size) => write!(f, "line<{ty}, {line_size}>"),
558 Type::Semantic(ty) => write!(f, "{ty}"),
559 }
560 }
561}
562
563impl Display for StorageType {
564 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
565 match self {
566 StorageType::Scalar(ty) => write!(f, "{ty}"),
567 StorageType::Packed(ty, factor) => write!(f, "packed<{ty}, {factor}>"),
568 StorageType::Atomic(ty) => write!(f, "atomic<{ty}>"),
569 StorageType::Opaque(ty) => write!(f, "{ty}"),
570 }
571 }
572}
573
574impl Display for ElemType {
575 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
576 match self {
577 Self::Float(kind) => match kind {
578 FloatKind::E2M1 => f.write_str("e2m1"),
579 FloatKind::E2M3 => f.write_str("e2m3"),
580 FloatKind::E3M2 => f.write_str("e3m2"),
581 FloatKind::E4M3 => f.write_str("e4m3"),
582 FloatKind::E5M2 => f.write_str("e5m2"),
583 FloatKind::UE8M0 => f.write_str("ue8m0"),
584 FloatKind::F16 => f.write_str("f16"),
585 FloatKind::BF16 => f.write_str("bf16"),
586 FloatKind::Flex32 => f.write_str("flex32"),
587 FloatKind::TF32 => f.write_str("tf32"),
588 FloatKind::F32 => f.write_str("f32"),
589 FloatKind::F64 => f.write_str("f64"),
590 },
591 Self::Int(kind) => match kind {
592 IntKind::I8 => f.write_str("i8"),
593 IntKind::I16 => f.write_str("i16"),
594 IntKind::I32 => f.write_str("i32"),
595 IntKind::I64 => f.write_str("i64"),
596 },
597 Self::UInt(kind) => match kind {
598 UIntKind::U8 => f.write_str("u8"),
599 UIntKind::U16 => f.write_str("u16"),
600 UIntKind::U32 => f.write_str("u32"),
601 UIntKind::U64 => f.write_str("u64"),
602 },
603 Self::Bool => f.write_str("bool"),
604 }
605 }
606}
607
608impl Display for SemanticType {
609 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
610 match self {
611 SemanticType::BarrierToken => f.write_str("barrier_token"),
612 SemanticType::Pipeline => f.write_str("pipeline"),
613 SemanticType::TensorMap => f.write_str("tensor_map"),
614 }
615 }
616}
617
618impl Display for OpaqueType {
619 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
620 match self {
621 OpaqueType::Barrier(level) => write!(f, "barrier<{level}>"),
622 }
623 }
624}
625
626impl From<bool> for Variable {
627 fn from(value: bool) -> Self {
628 Variable::constant(ConstantScalarValue::Bool(value))
629 }
630}
631
632impl From<i8> for Variable {
633 fn from(value: i8) -> Self {
634 Variable::constant(ConstantScalarValue::Int(value as i64, IntKind::I8))
635 }
636}
637
638impl From<i16> for Variable {
639 fn from(value: i16) -> Self {
640 Variable::constant(ConstantScalarValue::Int(value as i64, IntKind::I16))
641 }
642}
643
644impl From<i32> for Variable {
645 fn from(value: i32) -> Self {
646 Variable::constant(ConstantScalarValue::Int(value as i64, IntKind::I32))
647 }
648}
649
650impl From<i64> for Variable {
651 fn from(value: i64) -> Self {
652 Variable::constant(ConstantScalarValue::Int(value, IntKind::I64))
653 }
654}
655
656impl From<e2m1> for Variable {
657 fn from(_value: e2m1) -> Self {
658 unimplemented!("Can't currently construct minifloats")
659 }
660}
661
662impl From<e2m1x2> for Variable {
663 fn from(_value: e2m1x2) -> Self {
664 unimplemented!("Can't currently construct minifloats")
665 }
666}
667
668impl From<e2m3> for Variable {
669 fn from(_value: e2m3) -> Self {
670 unimplemented!("Can't currently construct minifloats")
671 }
672}
673
674impl From<e3m2> for Variable {
675 fn from(_value: e3m2) -> Self {
676 unimplemented!("Can't currently construct minifloats")
677 }
678}
679
680impl From<e4m3> for Variable {
681 fn from(_value: e4m3) -> Self {
682 unimplemented!("Can't currently construct minifloats")
683 }
684}
685
686impl From<e5m2> for Variable {
687 fn from(_value: e5m2) -> Self {
688 unimplemented!("Can't currently construct minifloats")
689 }
690}
691
692impl From<ue8m0> for Variable {
693 fn from(_value: ue8m0) -> Self {
694 unimplemented!("Can't currently construct minifloats")
695 }
696}
697
698impl From<half::f16> for Variable {
699 fn from(value: half::f16) -> Self {
700 Variable::constant(ConstantScalarValue::Float(value.to_f64(), FloatKind::F16))
701 }
702}
703
704impl From<half::bf16> for Variable {
705 fn from(value: half::bf16) -> Self {
706 Variable::constant(ConstantScalarValue::Float(value.to_f64(), FloatKind::BF16))
707 }
708}
709
710impl From<flex32> for Variable {
711 fn from(value: flex32) -> Self {
712 Variable::constant(ConstantScalarValue::Float(
713 value.to_f64(),
714 FloatKind::Flex32,
715 ))
716 }
717}
718
719impl From<tf32> for Variable {
720 fn from(value: tf32) -> Self {
721 Variable::constant(ConstantScalarValue::Float(value.to_f64(), FloatKind::TF32))
722 }
723}
724
725impl From<f32> for Variable {
726 fn from(value: f32) -> Self {
727 Variable::constant(ConstantScalarValue::Float(value as f64, FloatKind::F32))
728 }
729}
730
731impl From<f64> for Variable {
732 fn from(value: f64) -> Self {
733 Variable::constant(ConstantScalarValue::Float(value, FloatKind::F64))
734 }
735}
736
737impl From<u8> for Variable {
738 fn from(value: u8) -> Self {
739 Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U8))
740 }
741}
742
743impl From<u16> for Variable {
744 fn from(value: u16) -> Self {
745 Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U16))
746 }
747}
748
749impl From<u32> for Variable {
750 fn from(value: u32) -> Self {
751 Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U32))
752 }
753}
754
755impl From<u64> for Variable {
756 fn from(value: u64) -> Self {
757 Variable::constant(ConstantScalarValue::UInt(value, UIntKind::U64))
758 }
759}
760
761impl From<usize> for Variable {
762 fn from(value: usize) -> Self {
763 Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U32))
764 }
765}