1use core::{fmt::Display, hash::Hash};
2
3use crate::{BarrierLevel, FloatKind, IntKind, StorageType, TypeHash};
4
5use super::{ComplexKind, ElemType, Matrix, Type, UIntKind};
6use cubecl_common::{e2m1, e4m3, e5m2, ue8m0};
7use derive_more::From;
8use float_ord::FloatOrd;
9
10#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
12#[allow(missing_docs)]
13pub struct Variable {
14 pub kind: VariableKind,
15 pub ty: Type,
16}
17
18impl Variable {
19 pub fn new(kind: VariableKind, item: Type) -> Self {
20 Self { kind, ty: item }
21 }
22
23 pub fn builtin(builtin: Builtin, ty: StorageType) -> Self {
24 Self::new(VariableKind::Builtin(builtin), Type::new(ty))
25 }
26
27 pub fn constant(value: ConstantValue, ty: impl Into<Type>) -> Self {
28 let ty = ty.into();
29 let value = value.cast_to(ty);
30 Self::new(VariableKind::Constant(value), ty)
31 }
32
33 pub fn elem_type(&self) -> ElemType {
34 self.ty.elem_type()
35 }
36
37 pub fn storage_type(&self) -> StorageType {
38 self.ty.storage_type()
39 }
40}
41
42pub type Id = u32;
43
44#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
45#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
46pub enum VariableKind {
47 GlobalInputArray(Id),
48 GlobalOutputArray(Id),
49 GlobalScalar(Id),
50 TensorMapInput(Id),
51 TensorMapOutput(Id),
52 LocalArray {
53 id: Id,
54 length: usize,
55 unroll_factor: usize,
56 },
57 LocalMut {
58 id: Id,
59 },
60 LocalConst {
61 id: Id,
62 },
63 Versioned {
64 id: Id,
65 version: u16,
66 },
67 Constant(ConstantValue),
68 ConstantArray {
69 id: Id,
70 length: usize,
71 unroll_factor: usize,
72 },
73 SharedArray {
74 id: Id,
75 length: usize,
76 unroll_factor: usize,
77 alignment: Option<usize>,
78 },
79 Shared {
80 id: Id,
81 },
82 Matrix {
83 id: Id,
84 mat: Matrix,
85 },
86 Builtin(Builtin),
87 Pipeline {
88 id: Id,
89 num_stages: u8,
90 },
91 BarrierToken {
92 id: Id,
93 level: BarrierLevel,
94 },
95}
96
97#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
98#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash, PartialOrd, Ord)]
99#[repr(u8)]
100pub enum Builtin {
101 UnitPos,
102 UnitPosX,
103 UnitPosY,
104 UnitPosZ,
105 CubePosCluster,
106 CubePosClusterX,
107 CubePosClusterY,
108 CubePosClusterZ,
109 CubePos,
110 CubePosX,
111 CubePosY,
112 CubePosZ,
113 CubeDim,
114 CubeDimX,
115 CubeDimY,
116 CubeDimZ,
117 CubeClusterDim,
118 CubeClusterDimX,
119 CubeClusterDimY,
120 CubeClusterDimZ,
121 CubeCount,
122 CubeCountX,
123 CubeCountY,
124 CubeCountZ,
125 PlaneDim,
126 PlanePos,
127 UnitPosPlane,
128 AbsolutePos,
129 AbsolutePosX,
130 AbsolutePosY,
131 AbsolutePosZ,
132}
133
134impl Variable {
135 pub fn is_immutable(&self) -> bool {
138 match self.kind {
139 VariableKind::GlobalOutputArray { .. } => false,
140 VariableKind::TensorMapInput(_) => true,
141 VariableKind::TensorMapOutput(_) => false,
142 VariableKind::LocalMut { .. } => false,
143 VariableKind::SharedArray { .. } => false,
144 VariableKind::Shared { .. } => false,
145 VariableKind::Matrix { .. } => false,
146 VariableKind::LocalArray { .. } => false,
147 VariableKind::GlobalInputArray { .. } => false,
148 VariableKind::GlobalScalar { .. } => true,
149 VariableKind::Versioned { .. } => true,
150 VariableKind::LocalConst { .. } => true,
151 VariableKind::Constant(_) => true,
152 VariableKind::ConstantArray { .. } => true,
153 VariableKind::Builtin(_) => true,
154 VariableKind::Pipeline { .. } => false,
155 VariableKind::BarrierToken { .. } => false,
156 }
157 }
158
159 pub fn is_array(&self) -> bool {
162 matches!(
163 self.kind,
164 VariableKind::GlobalInputArray { .. }
165 | VariableKind::GlobalOutputArray { .. }
166 | VariableKind::ConstantArray { .. }
167 | VariableKind::SharedArray { .. }
168 | VariableKind::LocalArray { .. }
169 | VariableKind::Matrix { .. }
170 )
171 }
172
173 pub fn is_memory(&self) -> bool {
176 matches!(
177 self.kind,
178 VariableKind::GlobalInputArray { .. }
179 | VariableKind::GlobalOutputArray { .. }
180 | VariableKind::SharedArray { .. }
181 )
182 }
183
184 pub fn has_length(&self) -> bool {
185 matches!(
186 self.kind,
187 VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
188 )
189 }
190
191 pub fn has_buffer_length(&self) -> bool {
192 matches!(
193 self.kind,
194 VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
195 )
196 }
197
198 pub fn is_constant(&self, value: i64) -> bool {
200 match self.kind {
201 VariableKind::Constant(ConstantValue::Int(val)) => val == value,
202 VariableKind::Constant(ConstantValue::UInt(val)) => val as i64 == value,
203 VariableKind::Constant(ConstantValue::Float(val)) => val == value as f64,
204 _ => false,
205 }
206 }
207
208 pub fn is_true(&self) -> bool {
210 match self.kind {
211 VariableKind::Constant(ConstantValue::Bool(val)) => val,
212 _ => false,
213 }
214 }
215
216 pub fn is_false(&self) -> bool {
218 match self.kind {
219 VariableKind::Constant(ConstantValue::Bool(val)) => !val,
220 _ => false,
221 }
222 }
223}
224
225#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
229#[derive(Debug, Clone, Copy, TypeHash, PartialEq, PartialOrd, From)]
230#[allow(missing_docs, clippy::derive_ord_xor_partial_ord)]
231pub enum ConstantValue {
232 Int(i64),
233 Float(f64),
234 UInt(u64),
235 Bool(bool),
236 Complex(f64, f64),
237}
238
239impl Ord for ConstantValue {
240 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
241 match (self, other) {
244 (ConstantValue::Float(this), ConstantValue::Float(other)) => {
245 FloatOrd(*this).cmp(&FloatOrd(*other))
246 }
247 (
248 ConstantValue::Complex(this_re, this_im),
249 ConstantValue::Complex(other_re, other_im),
250 ) => FloatOrd(*this_re)
251 .cmp(&FloatOrd(*other_re))
252 .then_with(|| FloatOrd(*this_im).cmp(&FloatOrd(*other_im))),
253 _ => self.partial_cmp(other).unwrap(),
254 }
255 }
256}
257
258impl Eq for ConstantValue {}
259impl Hash for ConstantValue {
260 fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
261 core::mem::discriminant(self).hash(ra_expand_state);
262 match self {
263 ConstantValue::Int(f0) => {
264 f0.hash(ra_expand_state);
265 }
266 ConstantValue::Float(f0) => {
267 FloatOrd(*f0).hash(ra_expand_state);
268 }
269 ConstantValue::UInt(f0) => {
270 f0.hash(ra_expand_state);
271 }
272 ConstantValue::Bool(f0) => {
273 f0.hash(ra_expand_state);
274 }
275 ConstantValue::Complex(f0, f1) => {
276 FloatOrd(*f0).hash(ra_expand_state);
277 FloatOrd(*f1).hash(ra_expand_state);
278 }
279 }
280 }
281}
282
283impl ConstantValue {
284 pub fn try_as_usize(&self) -> Option<usize> {
288 match self {
289 ConstantValue::UInt(val) => Some(*val as usize),
290 ConstantValue::Int(val) => Some(*val as usize),
291 ConstantValue::Float(_) => None,
292 ConstantValue::Bool(_) => None,
293 ConstantValue::Complex(_, _) => None,
294 }
295 }
296
297 pub fn as_usize(&self) -> usize {
299 match self {
300 ConstantValue::UInt(val) => *val as usize,
301 ConstantValue::Int(val) => *val as usize,
302 ConstantValue::Float(val) => *val as usize,
303 ConstantValue::Bool(val) => *val as usize,
304 ConstantValue::Complex(_, _) => {
305 panic!("Complex constants can't be converted to usize")
306 }
307 }
308 }
309
310 pub fn try_as_u32(&self) -> Option<u32> {
314 self.try_as_u64().map(|it| it as u32)
315 }
316
317 pub fn as_u32(&self) -> u32 {
321 self.as_u64() as u32
322 }
323
324 pub fn try_as_u64(&self) -> Option<u64> {
328 match self {
329 ConstantValue::UInt(val) => Some(*val),
330 ConstantValue::Int(val) => Some(*val as u64),
331 ConstantValue::Float(_) => None,
332 ConstantValue::Bool(_) => None,
333 ConstantValue::Complex(_, _) => None,
334 }
335 }
336
337 pub fn as_u64(&self) -> u64 {
339 match self {
340 ConstantValue::UInt(val) => *val,
341 ConstantValue::Int(val) => *val as u64,
342 ConstantValue::Float(val) => *val as u64,
343 ConstantValue::Bool(val) => *val as u64,
344 ConstantValue::Complex(_, _) => panic!("Complex constants can't be converted to u64"),
345 }
346 }
347
348 pub fn try_as_i64(&self) -> Option<i64> {
352 match self {
353 ConstantValue::UInt(val) => Some(*val as i64),
354 ConstantValue::Int(val) => Some(*val),
355 ConstantValue::Float(_) => None,
356 ConstantValue::Bool(_) => None,
357 ConstantValue::Complex(_, _) => None,
358 }
359 }
360
361 pub fn as_i128(&self) -> i128 {
363 match self {
364 ConstantValue::UInt(val) => *val as i128,
365 ConstantValue::Int(val) => *val as i128,
366 ConstantValue::Float(val) => *val as i128,
367 ConstantValue::Bool(val) => *val as i128,
368 ConstantValue::Complex(_, _) => {
369 panic!("Complex constants can't be converted to i128")
370 }
371 }
372 }
373
374 pub fn as_i64(&self) -> i64 {
376 match self {
377 ConstantValue::UInt(val) => *val as i64,
378 ConstantValue::Int(val) => *val,
379 ConstantValue::Float(val) => *val as i64,
380 ConstantValue::Bool(val) => *val as i64,
381 ConstantValue::Complex(_, _) => panic!("Complex constants can't be converted to i64"),
382 }
383 }
384
385 pub fn as_i32(&self) -> i32 {
387 match self {
388 ConstantValue::UInt(val) => *val as i32,
389 ConstantValue::Int(val) => *val as i32,
390 ConstantValue::Float(val) => *val as i32,
391 ConstantValue::Bool(val) => *val as i32,
392 ConstantValue::Complex(_, _) => panic!("Complex constants can't be converted to i32"),
393 }
394 }
395
396 pub fn try_as_f64(&self) -> Option<f64> {
400 match self {
401 ConstantValue::Float(val) => Some(*val),
402 ConstantValue::Complex(re, _) => Some(*re),
403 _ => None,
404 }
405 }
406
407 pub fn as_f64(&self) -> f64 {
409 match self {
410 ConstantValue::UInt(val) => *val as f64,
411 ConstantValue::Int(val) => *val as f64,
412 ConstantValue::Float(val) => *val,
413 ConstantValue::Bool(val) => *val as u8 as f64,
414 ConstantValue::Complex(re, _) => *re,
415 }
416 }
417
418 pub fn try_as_bool(&self) -> Option<bool> {
420 match self {
421 ConstantValue::Bool(val) => Some(*val),
422 _ => None,
423 }
424 }
425
426 pub fn as_bool(&self) -> bool {
430 match self {
431 ConstantValue::UInt(val) => *val != 0,
432 ConstantValue::Int(val) => *val != 0,
433 ConstantValue::Float(val) => *val != 0.,
434 ConstantValue::Bool(val) => *val,
435 ConstantValue::Complex(_, _) => {
436 panic!("Complex constants can't be converted to bool")
437 }
438 }
439 }
440
441 pub fn is_zero(&self) -> bool {
442 match self {
443 ConstantValue::Int(val) => *val == 0,
444 ConstantValue::Float(val) => *val == 0.0,
445 ConstantValue::UInt(val) => *val == 0,
446 ConstantValue::Bool(val) => !*val,
447 ConstantValue::Complex(re, im) => *re == 0.0 && *im == 0.0,
448 }
449 }
450
451 pub fn is_one(&self) -> bool {
452 match self {
453 ConstantValue::Int(val) => *val == 1,
454 ConstantValue::Float(val) => *val == 1.0,
455 ConstantValue::UInt(val) => *val == 1,
456 ConstantValue::Bool(val) => *val,
457 ConstantValue::Complex(re, im) => *re == 1.0 && *im == 0.0,
458 }
459 }
460
461 pub fn cast_to(&self, other: impl Into<Type>) -> ConstantValue {
462 match other.into().storage_type() {
463 StorageType::Scalar(elem_type) => match elem_type {
464 ElemType::Float(kind) => match kind {
465 FloatKind::E2M1 => e2m1::from_f64(self.as_f64()).to_f64(),
466 FloatKind::E2M3 | FloatKind::E3M2 => {
467 unimplemented!("FP6 constants not yet supported")
468 }
469 FloatKind::E4M3 => e4m3::from_f64(self.as_f64()).to_f64(),
470 FloatKind::E5M2 => e5m2::from_f64(self.as_f64()).to_f64(),
471 FloatKind::UE8M0 => ue8m0::from_f64(self.as_f64()).to_f64(),
472 FloatKind::F16 => half::f16::from_f64(self.as_f64()).to_f64(),
473 FloatKind::BF16 => half::bf16::from_f64(self.as_f64()).to_f64(),
474 FloatKind::Flex32 | FloatKind::TF32 | FloatKind::F32 => {
475 self.as_f64() as f32 as f64
476 }
477 FloatKind::F64 => self.as_f64(),
478 }
479 .into(),
480 ElemType::Int(kind) => {
481 let value = match self {
482 ConstantValue::Complex(re, _) => *re as i64,
483 _ => self.as_i64(),
484 };
485
486 match kind {
487 IntKind::I8 => value as i8 as i64,
488 IntKind::I16 => value as i16 as i64,
489 IntKind::I32 => value as i32 as i64,
490 IntKind::I64 => value,
491 }
492 }
493 .into(),
494 ElemType::UInt(kind) => {
495 let value = match self {
496 ConstantValue::Complex(re, _) => *re as u64,
497 _ => self.as_u64(),
498 };
499
500 match kind {
501 UIntKind::U8 => value as u8 as u64,
502 UIntKind::U16 => value as u16 as u64,
503 UIntKind::U32 => value as u32 as u64,
504 UIntKind::U64 => value,
505 }
506 }
507 .into(),
508 ElemType::Bool => self.as_bool().into(),
509 ElemType::Complex(kind) => match (self, kind) {
510 (ConstantValue::Complex(re, im), ComplexKind::C32) => {
511 ConstantValue::Complex(*re as f32 as f64, *im as f32 as f64)
512 }
513 (ConstantValue::Complex(re, im), ComplexKind::C64) => {
514 ConstantValue::Complex(*re, *im)
515 }
516 (_, ComplexKind::C32) => {
517 let re = self.as_f64() as f32 as f64;
518 ConstantValue::Complex(re, 0.0)
519 }
520 (_, ComplexKind::C64) => ConstantValue::Complex(self.as_f64(), 0.0),
521 },
522 },
523 StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2) => {
524 e2m1::from_f64(self.as_f64()).to_f64().into()
525 }
526 StorageType::Packed(..) => unimplemented!("Unsupported packed type"),
527 StorageType::Atomic(_) => unimplemented!("Atomic constants aren't supported"),
528 StorageType::Opaque(_) => unimplemented!("Opaque constants aren't supported"),
529 }
530 }
531}
532
533impl Display for ConstantValue {
534 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
535 match self {
536 ConstantValue::Int(val) => write!(f, "{val}"),
537 ConstantValue::Float(val) => write!(f, "{val:?}"),
538 ConstantValue::UInt(val) => write!(f, "{val}"),
539 ConstantValue::Bool(val) => write!(f, "{val}"),
540 ConstantValue::Complex(re, im) => write!(f, "({re:?}, {im:?})"),
541 }
542 }
543}
544
545impl Variable {
546 pub fn vector_size(&self) -> usize {
547 self.ty.vector_size()
548 }
549
550 pub fn index(&self) -> Option<Id> {
551 match self.kind {
552 VariableKind::GlobalInputArray(id)
553 | VariableKind::GlobalOutputArray(id)
554 | VariableKind::TensorMapInput(id)
555 | VariableKind::TensorMapOutput(id)
556 | VariableKind::GlobalScalar(id)
557 | VariableKind::LocalMut { id, .. }
558 | VariableKind::Versioned { id, .. }
559 | VariableKind::LocalConst { id, .. }
560 | VariableKind::ConstantArray { id, .. }
561 | VariableKind::SharedArray { id, .. }
562 | VariableKind::Shared { id, .. }
563 | VariableKind::LocalArray { id, .. }
564 | VariableKind::Matrix { id, .. } => Some(id),
565 _ => None,
566 }
567 }
568
569 pub fn as_const(&self) -> Option<ConstantValue> {
570 match self.kind {
571 VariableKind::Constant(constant) => Some(constant),
572 _ => None,
573 }
574 }
575}
576
577impl Display for Variable {
578 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
579 match self.kind {
580 VariableKind::GlobalInputArray(id) => write!(f, "input({id})"),
581 VariableKind::GlobalOutputArray(id) => write!(f, "output({id})"),
582 VariableKind::GlobalScalar(id) => write!(f, "scalar<{}>({id})", self.ty),
583 VariableKind::TensorMapInput(id) => write!(f, "tensor_map({id})"),
584 VariableKind::TensorMapOutput(id) => write!(f, "tensor_map({id})"),
585 VariableKind::Constant(constant) => write!(f, "{}({constant})", self.ty),
586 VariableKind::LocalMut { id } => write!(f, "local({id})"),
587 VariableKind::Versioned { id, version } => {
588 write!(f, "local({id}).v{version}")
589 }
590 VariableKind::LocalConst { id } => write!(f, "binding({id})"),
591 VariableKind::ConstantArray { id, .. } => write!(f, "const_array({id})"),
592 VariableKind::SharedArray { id, .. } => write!(f, "shared_array({id})"),
593 VariableKind::Shared { id } => write!(f, "shared({id})"),
594 VariableKind::LocalArray { id, .. } => write!(f, "array({id})"),
595 VariableKind::Matrix { id, .. } => write!(f, "matrix({id})"),
596 VariableKind::Builtin(builtin) => write!(f, "{builtin:?}"),
597 VariableKind::Pipeline { id, .. } => write!(f, "pipeline({id})"),
598 VariableKind::BarrierToken { id, .. } => write!(f, "barrier_token({id})"),
599 }
600 }
601}
602
603impl From<&Variable> for Variable {
605 fn from(value: &Variable) -> Self {
606 *value
607 }
608}