1use core::{fmt::Display, hash::Hash};
2
3use crate::{BarrierLevel, FloatKind, IntKind, StorageType, TypeHash};
4
5use super::{ElemType, Matrix, Type, UIntKind};
6use cubecl_common::{e2m1, e4m3, e5m2, ue8m0};
7use derive_more::From;
8use float_ord::FloatOrd;
9
10#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
12#[allow(missing_docs)]
13pub struct Variable {
14 pub kind: VariableKind,
15 pub ty: Type,
16}
17
18impl Variable {
19 pub fn new(kind: VariableKind, item: Type) -> Self {
20 Self { kind, ty: item }
21 }
22
23 pub fn builtin(builtin: Builtin, ty: StorageType) -> Self {
24 Self::new(VariableKind::Builtin(builtin), Type::new(ty))
25 }
26
27 pub fn constant(value: ConstantValue, ty: impl Into<Type>) -> Self {
28 let ty = ty.into();
29 let value = value.cast_to(ty);
30 Self::new(VariableKind::Constant(value), ty)
31 }
32
33 pub fn elem_type(&self) -> ElemType {
34 self.ty.elem_type()
35 }
36
37 pub fn storage_type(&self) -> StorageType {
38 self.ty.storage_type()
39 }
40}
41
42pub type Id = u32;
43
44#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
45#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
46pub enum VariableKind {
47 GlobalInputArray(Id),
48 GlobalOutputArray(Id),
49 GlobalScalar(Id),
50 TensorMapInput(Id),
51 TensorMapOutput(Id),
52 LocalArray {
53 id: Id,
54 length: usize,
55 unroll_factor: usize,
56 },
57 LocalMut {
58 id: Id,
59 },
60 LocalConst {
61 id: Id,
62 },
63 Versioned {
64 id: Id,
65 version: u16,
66 },
67 Constant(ConstantValue),
68 ConstantArray {
69 id: Id,
70 length: usize,
71 unroll_factor: usize,
72 },
73 SharedArray {
74 id: Id,
75 length: usize,
76 unroll_factor: usize,
77 alignment: Option<usize>,
78 },
79 Shared {
80 id: Id,
81 },
82 Matrix {
83 id: Id,
84 mat: Matrix,
85 },
86 Builtin(Builtin),
87 Pipeline {
88 id: Id,
89 num_stages: u8,
90 },
91 BarrierToken {
92 id: Id,
93 level: BarrierLevel,
94 },
95}
96
97#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
98#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash, PartialOrd, Ord)]
99#[repr(u8)]
100pub enum Builtin {
101 UnitPos,
102 UnitPosX,
103 UnitPosY,
104 UnitPosZ,
105 CubePosCluster,
106 CubePosClusterX,
107 CubePosClusterY,
108 CubePosClusterZ,
109 CubePos,
110 CubePosX,
111 CubePosY,
112 CubePosZ,
113 CubeDim,
114 CubeDimX,
115 CubeDimY,
116 CubeDimZ,
117 CubeClusterDim,
118 CubeClusterDimX,
119 CubeClusterDimY,
120 CubeClusterDimZ,
121 CubeCount,
122 CubeCountX,
123 CubeCountY,
124 CubeCountZ,
125 PlaneDim,
126 PlanePos,
127 UnitPosPlane,
128 AbsolutePos,
129 AbsolutePosX,
130 AbsolutePosY,
131 AbsolutePosZ,
132}
133
134impl Variable {
135 pub fn is_immutable(&self) -> bool {
138 match self.kind {
139 VariableKind::GlobalOutputArray { .. } => false,
140 VariableKind::TensorMapInput(_) => true,
141 VariableKind::TensorMapOutput(_) => false,
142 VariableKind::LocalMut { .. } => false,
143 VariableKind::SharedArray { .. } => false,
144 VariableKind::Shared { .. } => false,
145 VariableKind::Matrix { .. } => false,
146 VariableKind::LocalArray { .. } => false,
147 VariableKind::GlobalInputArray { .. } => false,
148 VariableKind::GlobalScalar { .. } => true,
149 VariableKind::Versioned { .. } => true,
150 VariableKind::LocalConst { .. } => true,
151 VariableKind::Constant(_) => true,
152 VariableKind::ConstantArray { .. } => true,
153 VariableKind::Builtin(_) => true,
154 VariableKind::Pipeline { .. } => false,
155 VariableKind::BarrierToken { .. } => false,
156 }
157 }
158
159 pub fn is_array(&self) -> bool {
162 matches!(
163 self.kind,
164 VariableKind::GlobalInputArray { .. }
165 | VariableKind::GlobalOutputArray { .. }
166 | VariableKind::ConstantArray { .. }
167 | VariableKind::SharedArray { .. }
168 | VariableKind::LocalArray { .. }
169 | VariableKind::Matrix { .. }
170 )
171 }
172
173 pub fn is_memory(&self) -> bool {
176 matches!(
177 self.kind,
178 VariableKind::GlobalInputArray { .. }
179 | VariableKind::GlobalOutputArray { .. }
180 | VariableKind::SharedArray { .. }
181 )
182 }
183
184 pub fn has_length(&self) -> bool {
185 matches!(
186 self.kind,
187 VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
188 )
189 }
190
191 pub fn has_buffer_length(&self) -> bool {
192 matches!(
193 self.kind,
194 VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
195 )
196 }
197
198 pub fn is_constant(&self, value: i64) -> bool {
200 match self.kind {
201 VariableKind::Constant(ConstantValue::Int(val)) => val == value,
202 VariableKind::Constant(ConstantValue::UInt(val)) => val as i64 == value,
203 VariableKind::Constant(ConstantValue::Float(val)) => val == value as f64,
204 _ => false,
205 }
206 }
207
208 pub fn is_true(&self) -> bool {
210 match self.kind {
211 VariableKind::Constant(ConstantValue::Bool(val)) => val,
212 _ => false,
213 }
214 }
215
216 pub fn is_false(&self) -> bool {
218 match self.kind {
219 VariableKind::Constant(ConstantValue::Bool(val)) => !val,
220 _ => false,
221 }
222 }
223}
224
225#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
229#[derive(Debug, Clone, Copy, TypeHash, PartialEq, PartialOrd, From)]
230#[allow(missing_docs, clippy::derive_ord_xor_partial_ord)]
231pub enum ConstantValue {
232 Int(i64),
233 Float(f64),
234 UInt(u64),
235 Bool(bool),
236}
237
238impl Ord for ConstantValue {
239 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
240 match (self, other) {
243 (ConstantValue::Float(this), ConstantValue::Float(other)) => {
244 FloatOrd(*this).cmp(&FloatOrd(*other))
245 }
246 _ => self.partial_cmp(other).unwrap(),
247 }
248 }
249}
250
251impl Eq for ConstantValue {}
252impl Hash for ConstantValue {
253 fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
254 core::mem::discriminant(self).hash(ra_expand_state);
255 match self {
256 ConstantValue::Int(f0) => {
257 f0.hash(ra_expand_state);
258 }
259 ConstantValue::Float(f0) => {
260 FloatOrd(*f0).hash(ra_expand_state);
261 }
262 ConstantValue::UInt(f0) => {
263 f0.hash(ra_expand_state);
264 }
265 ConstantValue::Bool(f0) => {
266 f0.hash(ra_expand_state);
267 }
268 }
269 }
270}
271
272impl ConstantValue {
273 pub fn try_as_usize(&self) -> Option<usize> {
277 match self {
278 ConstantValue::UInt(val) => Some(*val as usize),
279 ConstantValue::Int(val) => Some(*val as usize),
280 ConstantValue::Float(_) => None,
281 ConstantValue::Bool(_) => None,
282 }
283 }
284
285 pub fn as_usize(&self) -> usize {
287 match self {
288 ConstantValue::UInt(val) => *val as usize,
289 ConstantValue::Int(val) => *val as usize,
290 ConstantValue::Float(val) => *val as usize,
291 ConstantValue::Bool(val) => *val as usize,
292 }
293 }
294
295 pub fn try_as_u32(&self) -> Option<u32> {
299 self.try_as_u64().map(|it| it as u32)
300 }
301
302 pub fn as_u32(&self) -> u32 {
306 self.as_u64() as u32
307 }
308
309 pub fn try_as_u64(&self) -> Option<u64> {
313 match self {
314 ConstantValue::UInt(val) => Some(*val),
315 ConstantValue::Int(val) => Some(*val as u64),
316 ConstantValue::Float(_) => None,
317 ConstantValue::Bool(_) => None,
318 }
319 }
320
321 pub fn as_u64(&self) -> u64 {
323 match self {
324 ConstantValue::UInt(val) => *val,
325 ConstantValue::Int(val) => *val as u64,
326 ConstantValue::Float(val) => *val as u64,
327 ConstantValue::Bool(val) => *val as u64,
328 }
329 }
330
331 pub fn try_as_i64(&self) -> Option<i64> {
335 match self {
336 ConstantValue::UInt(val) => Some(*val as i64),
337 ConstantValue::Int(val) => Some(*val),
338 ConstantValue::Float(_) => None,
339 ConstantValue::Bool(_) => None,
340 }
341 }
342
343 pub fn as_i128(&self) -> i128 {
345 match self {
346 ConstantValue::UInt(val) => *val as i128,
347 ConstantValue::Int(val) => *val as i128,
348 ConstantValue::Float(val) => *val as i128,
349 ConstantValue::Bool(val) => *val as i128,
350 }
351 }
352
353 pub fn as_i64(&self) -> i64 {
355 match self {
356 ConstantValue::UInt(val) => *val as i64,
357 ConstantValue::Int(val) => *val,
358 ConstantValue::Float(val) => *val as i64,
359 ConstantValue::Bool(val) => *val as i64,
360 }
361 }
362
363 pub fn as_i32(&self) -> i32 {
365 match self {
366 ConstantValue::UInt(val) => *val as i32,
367 ConstantValue::Int(val) => *val as i32,
368 ConstantValue::Float(val) => *val as i32,
369 ConstantValue::Bool(val) => *val as i32,
370 }
371 }
372
373 pub fn try_as_f64(&self) -> Option<f64> {
377 match self {
378 ConstantValue::Float(val) => Some(*val),
379 _ => None,
380 }
381 }
382
383 pub fn as_f64(&self) -> f64 {
385 match self {
386 ConstantValue::UInt(val) => *val as f64,
387 ConstantValue::Int(val) => *val as f64,
388 ConstantValue::Float(val) => *val,
389 ConstantValue::Bool(val) => *val as u8 as f64,
390 }
391 }
392
393 pub fn try_as_bool(&self) -> Option<bool> {
395 match self {
396 ConstantValue::Bool(val) => Some(*val),
397 _ => None,
398 }
399 }
400
401 pub fn as_bool(&self) -> bool {
405 match self {
406 ConstantValue::UInt(val) => *val != 0,
407 ConstantValue::Int(val) => *val != 0,
408 ConstantValue::Float(val) => *val != 0.,
409 ConstantValue::Bool(val) => *val,
410 }
411 }
412
413 pub fn is_zero(&self) -> bool {
414 match self {
415 ConstantValue::Int(val) => *val == 0,
416 ConstantValue::Float(val) => *val == 0.0,
417 ConstantValue::UInt(val) => *val == 0,
418 ConstantValue::Bool(val) => !*val,
419 }
420 }
421
422 pub fn is_one(&self) -> bool {
423 match self {
424 ConstantValue::Int(val) => *val == 1,
425 ConstantValue::Float(val) => *val == 1.0,
426 ConstantValue::UInt(val) => *val == 1,
427 ConstantValue::Bool(val) => *val,
428 }
429 }
430
431 pub fn cast_to(&self, other: impl Into<Type>) -> ConstantValue {
432 match other.into().storage_type() {
433 StorageType::Scalar(elem_type) => match elem_type {
434 ElemType::Float(kind) => match kind {
435 FloatKind::E2M1 => e2m1::from_f64(self.as_f64()).to_f64(),
436 FloatKind::E2M3 | FloatKind::E3M2 => {
437 unimplemented!("FP6 constants not yet supported")
438 }
439 FloatKind::E4M3 => e4m3::from_f64(self.as_f64()).to_f64(),
440 FloatKind::E5M2 => e5m2::from_f64(self.as_f64()).to_f64(),
441 FloatKind::UE8M0 => ue8m0::from_f64(self.as_f64()).to_f64(),
442 FloatKind::F16 => half::f16::from_f64(self.as_f64()).to_f64(),
443 FloatKind::BF16 => half::bf16::from_f64(self.as_f64()).to_f64(),
444 FloatKind::Flex32 | FloatKind::TF32 | FloatKind::F32 => {
445 self.as_f64() as f32 as f64
446 }
447 FloatKind::F64 => self.as_f64(),
448 }
449 .into(),
450 ElemType::Int(kind) => match kind {
451 IntKind::I8 => self.as_i64() as i8 as i64,
452 IntKind::I16 => self.as_i64() as i16 as i64,
453 IntKind::I32 => self.as_i64() as i32 as i64,
454 IntKind::I64 => self.as_i64(),
455 }
456 .into(),
457 ElemType::UInt(kind) => match kind {
458 UIntKind::U8 => self.as_u64() as u8 as u64,
459 UIntKind::U16 => self.as_u64() as u16 as u64,
460 UIntKind::U32 => self.as_u64() as u32 as u64,
461 UIntKind::U64 => self.as_u64(),
462 }
463 .into(),
464 ElemType::Bool => self.as_bool().into(),
465 },
466 StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2) => {
467 e2m1::from_f64(self.as_f64()).to_f64().into()
468 }
469 StorageType::Packed(..) => unimplemented!("Unsupported packed type"),
470 StorageType::Atomic(_) => unimplemented!("Atomic constants aren't supported"),
471 StorageType::Opaque(_) => unimplemented!("Opaque constants aren't supported"),
472 }
473 }
474}
475
476impl Display for ConstantValue {
477 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
478 match self {
479 ConstantValue::Int(val) => write!(f, "{val}"),
480 ConstantValue::Float(val) => write!(f, "{val:?}"),
481 ConstantValue::UInt(val) => write!(f, "{val}"),
482 ConstantValue::Bool(val) => write!(f, "{val}"),
483 }
484 }
485}
486
487impl Variable {
488 pub fn vector_size(&self) -> usize {
489 self.ty.vector_size()
490 }
491
492 pub fn index(&self) -> Option<Id> {
493 match self.kind {
494 VariableKind::GlobalInputArray(id)
495 | VariableKind::GlobalOutputArray(id)
496 | VariableKind::TensorMapInput(id)
497 | VariableKind::TensorMapOutput(id)
498 | VariableKind::GlobalScalar(id)
499 | VariableKind::LocalMut { id, .. }
500 | VariableKind::Versioned { id, .. }
501 | VariableKind::LocalConst { id, .. }
502 | VariableKind::ConstantArray { id, .. }
503 | VariableKind::SharedArray { id, .. }
504 | VariableKind::Shared { id, .. }
505 | VariableKind::LocalArray { id, .. }
506 | VariableKind::Matrix { id, .. } => Some(id),
507 _ => None,
508 }
509 }
510
511 pub fn as_const(&self) -> Option<ConstantValue> {
512 match self.kind {
513 VariableKind::Constant(constant) => Some(constant),
514 _ => None,
515 }
516 }
517}
518
519impl Display for Variable {
520 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
521 match self.kind {
522 VariableKind::GlobalInputArray(id) => write!(f, "input({id})"),
523 VariableKind::GlobalOutputArray(id) => write!(f, "output({id})"),
524 VariableKind::GlobalScalar(id) => write!(f, "scalar<{}>({id})", self.ty),
525 VariableKind::TensorMapInput(id) => write!(f, "tensor_map({id})"),
526 VariableKind::TensorMapOutput(id) => write!(f, "tensor_map({id})"),
527 VariableKind::Constant(constant) => write!(f, "{}({constant})", self.ty),
528 VariableKind::LocalMut { id } => write!(f, "local({id})"),
529 VariableKind::Versioned { id, version } => {
530 write!(f, "local({id}).v{version}")
531 }
532 VariableKind::LocalConst { id } => write!(f, "binding({id})"),
533 VariableKind::ConstantArray { id, .. } => write!(f, "const_array({id})"),
534 VariableKind::SharedArray { id, .. } => write!(f, "shared_array({id})"),
535 VariableKind::Shared { id } => write!(f, "shared({id})"),
536 VariableKind::LocalArray { id, .. } => write!(f, "array({id})"),
537 VariableKind::Matrix { id, .. } => write!(f, "matrix({id})"),
538 VariableKind::Builtin(builtin) => write!(f, "{builtin:?}"),
539 VariableKind::Pipeline { id, .. } => write!(f, "pipeline({id})"),
540 VariableKind::BarrierToken { id, .. } => write!(f, "barrier_token({id})"),
541 }
542 }
543}
544
545impl From<&Variable> for Variable {
547 fn from(value: &Variable) -> Self {
548 *value
549 }
550}