1use core::{fmt::Display, hash::Hash};
2
3use crate::{BarrierLevel, FloatKind, IntKind, StorageType, TypeHash};
4
5use super::{ElemType, Matrix, Type, UIntKind};
6use cubecl_common::{e2m1, e4m3, e5m2, ue8m0};
7use derive_more::From;
8use float_ord::FloatOrd;
9
10#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
12#[allow(missing_docs)]
13pub struct Variable {
14 pub kind: VariableKind,
15 pub ty: Type,
16}
17
18impl Variable {
19 pub fn new(kind: VariableKind, item: Type) -> Self {
20 Self { kind, ty: item }
21 }
22
23 pub fn builtin(builtin: Builtin, ty: StorageType) -> Self {
24 Self::new(VariableKind::Builtin(builtin), Type::new(ty))
25 }
26
27 pub fn constant(value: ConstantValue, ty: impl Into<Type>) -> Self {
28 let ty = ty.into();
29 let value = value.cast_to(ty);
30 Self::new(VariableKind::Constant(value), ty)
31 }
32
33 pub fn elem_type(&self) -> ElemType {
34 self.ty.elem_type()
35 }
36
37 pub fn storage_type(&self) -> StorageType {
38 self.ty.storage_type()
39 }
40}
41
42pub type Id = u32;
43
44#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
45#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
46pub enum VariableKind {
47 GlobalInputArray(Id),
48 GlobalOutputArray(Id),
49 GlobalScalar(Id),
50 TensorMapInput(Id),
51 TensorMapOutput(Id),
52 LocalArray {
53 id: Id,
54 length: usize,
55 unroll_factor: usize,
56 },
57 LocalMut {
58 id: Id,
59 },
60 LocalConst {
61 id: Id,
62 },
63 Versioned {
64 id: Id,
65 version: u16,
66 },
67 Constant(ConstantValue),
68 ConstantArray {
69 id: Id,
70 length: usize,
71 unroll_factor: usize,
72 },
73 SharedArray {
74 id: Id,
75 length: usize,
76 unroll_factor: usize,
77 alignment: Option<usize>,
78 },
79 Shared {
80 id: Id,
81 },
82 Matrix {
83 id: Id,
84 mat: Matrix,
85 },
86 Builtin(Builtin),
87 Pipeline {
88 id: Id,
89 num_stages: u8,
90 },
91 BarrierToken {
92 id: Id,
93 level: BarrierLevel,
94 },
95}
96
97#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
98#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash, PartialOrd, Ord)]
99#[repr(u8)]
100pub enum Builtin {
101 UnitPos,
102 UnitPosX,
103 UnitPosY,
104 UnitPosZ,
105 CubePosCluster,
106 CubePosClusterX,
107 CubePosClusterY,
108 CubePosClusterZ,
109 CubePos,
110 CubePosX,
111 CubePosY,
112 CubePosZ,
113 CubeDim,
114 CubeDimX,
115 CubeDimY,
116 CubeDimZ,
117 CubeClusterDim,
118 CubeClusterDimX,
119 CubeClusterDimY,
120 CubeClusterDimZ,
121 CubeCount,
122 CubeCountX,
123 CubeCountY,
124 CubeCountZ,
125 PlaneDim,
126 PlanePos,
127 UnitPosPlane,
128 AbsolutePos,
129 AbsolutePosX,
130 AbsolutePosY,
131 AbsolutePosZ,
132}
133
134impl Variable {
135 pub fn is_immutable(&self) -> bool {
138 match self.kind {
139 VariableKind::GlobalOutputArray { .. } => false,
140 VariableKind::TensorMapInput(_) => true,
141 VariableKind::TensorMapOutput(_) => false,
142 VariableKind::LocalMut { .. } => false,
143 VariableKind::SharedArray { .. } => false,
144 VariableKind::Shared { .. } => false,
145 VariableKind::Matrix { .. } => false,
146 VariableKind::LocalArray { .. } => false,
147 VariableKind::GlobalInputArray { .. } => false,
148 VariableKind::GlobalScalar { .. } => true,
149 VariableKind::Versioned { .. } => true,
150 VariableKind::LocalConst { .. } => true,
151 VariableKind::Constant(_) => true,
152 VariableKind::ConstantArray { .. } => true,
153 VariableKind::Builtin(_) => true,
154 VariableKind::Pipeline { .. } => false,
155 VariableKind::BarrierToken { .. } => false,
156 }
157 }
158
159 pub fn is_array(&self) -> bool {
162 matches!(
163 self.kind,
164 VariableKind::GlobalInputArray { .. }
165 | VariableKind::GlobalOutputArray { .. }
166 | VariableKind::ConstantArray { .. }
167 | VariableKind::SharedArray { .. }
168 | VariableKind::LocalArray { .. }
169 | VariableKind::Matrix { .. }
170 )
171 }
172
173 pub fn is_memory(&self) -> bool {
176 matches!(
177 self.kind,
178 VariableKind::GlobalInputArray { .. }
179 | VariableKind::GlobalOutputArray { .. }
180 | VariableKind::SharedArray { .. }
181 )
182 }
183
184 pub fn has_length(&self) -> bool {
185 matches!(
186 self.kind,
187 VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
188 )
189 }
190
191 pub fn has_buffer_length(&self) -> bool {
192 matches!(
193 self.kind,
194 VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
195 )
196 }
197
198 pub fn is_constant(&self, value: i64) -> bool {
200 match self.kind {
201 VariableKind::Constant(ConstantValue::Int(val)) => val == value,
202 VariableKind::Constant(ConstantValue::UInt(val)) => val as i64 == value,
203 VariableKind::Constant(ConstantValue::Float(val)) => val == value as f64,
204 _ => false,
205 }
206 }
207
208 pub fn is_true(&self) -> bool {
210 match self.kind {
211 VariableKind::Constant(ConstantValue::Bool(val)) => val,
212 _ => false,
213 }
214 }
215
216 pub fn is_false(&self) -> bool {
218 match self.kind {
219 VariableKind::Constant(ConstantValue::Bool(val)) => !val,
220 _ => false,
221 }
222 }
223}
224
225#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
229#[derive(Debug, Clone, Copy, TypeHash, PartialEq, PartialOrd, From)]
230#[allow(missing_docs, clippy::derive_ord_xor_partial_ord)]
231pub enum ConstantValue {
232 Int(i64),
233 Float(f64),
234 UInt(u64),
235 Bool(bool),
236}
237
238impl Ord for ConstantValue {
239 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
240 match (self, other) {
243 (ConstantValue::Float(this), ConstantValue::Float(other)) => {
244 FloatOrd(*this).cmp(&FloatOrd(*other))
245 }
246 _ => self.partial_cmp(other).unwrap(),
247 }
248 }
249}
250
251impl Eq for ConstantValue {}
252impl Hash for ConstantValue {
253 fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
254 core::mem::discriminant(self).hash(ra_expand_state);
255 match self {
256 ConstantValue::Int(f0) => {
257 f0.hash(ra_expand_state);
258 }
259 ConstantValue::Float(f0) => {
260 FloatOrd(*f0).hash(ra_expand_state);
261 }
262 ConstantValue::UInt(f0) => {
263 f0.hash(ra_expand_state);
264 }
265 ConstantValue::Bool(f0) => {
266 f0.hash(ra_expand_state);
267 }
268 }
269 }
270}
271
272impl ConstantValue {
273 pub fn try_as_usize(&self) -> Option<usize> {
277 match self {
278 ConstantValue::UInt(val) => Some(*val as usize),
279 ConstantValue::Int(val) => Some(*val as usize),
280 ConstantValue::Float(_) => None,
281 ConstantValue::Bool(_) => None,
282 }
283 }
284
285 pub fn as_usize(&self) -> usize {
287 match self {
288 ConstantValue::UInt(val) => *val as usize,
289 ConstantValue::Int(val) => *val as usize,
290 ConstantValue::Float(val) => *val as usize,
291 ConstantValue::Bool(val) => *val as usize,
292 }
293 }
294
295 pub fn try_as_u32(&self) -> Option<u32> {
299 self.try_as_u64().map(|it| it as u32)
300 }
301
302 pub fn as_u32(&self) -> u32 {
306 self.as_u64() as u32
307 }
308
309 pub fn try_as_u64(&self) -> Option<u64> {
313 match self {
314 ConstantValue::UInt(val) => Some(*val),
315 ConstantValue::Int(val) => Some(*val as u64),
316 ConstantValue::Float(_) => None,
317 ConstantValue::Bool(_) => None,
318 }
319 }
320
321 pub fn as_u64(&self) -> u64 {
323 match self {
324 ConstantValue::UInt(val) => *val,
325 ConstantValue::Int(val) => *val as u64,
326 ConstantValue::Float(val) => *val as u64,
327 ConstantValue::Bool(val) => *val as u64,
328 }
329 }
330
331 pub fn try_as_i64(&self) -> Option<i64> {
335 match self {
336 ConstantValue::UInt(val) => Some(*val as i64),
337 ConstantValue::Int(val) => Some(*val),
338 ConstantValue::Float(_) => None,
339 ConstantValue::Bool(_) => None,
340 }
341 }
342
343 pub fn as_i128(&self) -> i128 {
345 match self {
346 ConstantValue::UInt(val) => *val as i128,
347 ConstantValue::Int(val) => *val as i128,
348 ConstantValue::Float(val) => *val as i128,
349 ConstantValue::Bool(val) => *val as i128,
350 }
351 }
352
353 pub fn as_i64(&self) -> i64 {
355 match self {
356 ConstantValue::UInt(val) => *val as i64,
357 ConstantValue::Int(val) => *val,
358 ConstantValue::Float(val) => *val as i64,
359 ConstantValue::Bool(val) => *val as i64,
360 }
361 }
362
363 pub fn try_as_f64(&self) -> Option<f64> {
367 match self {
368 ConstantValue::Float(val) => Some(*val),
369 _ => None,
370 }
371 }
372
373 pub fn as_f64(&self) -> f64 {
375 match self {
376 ConstantValue::UInt(val) => *val as f64,
377 ConstantValue::Int(val) => *val as f64,
378 ConstantValue::Float(val) => *val,
379 ConstantValue::Bool(val) => *val as u8 as f64,
380 }
381 }
382
383 pub fn try_as_bool(&self) -> Option<bool> {
385 match self {
386 ConstantValue::Bool(val) => Some(*val),
387 _ => None,
388 }
389 }
390
391 pub fn as_bool(&self) -> bool {
395 match self {
396 ConstantValue::UInt(val) => *val != 0,
397 ConstantValue::Int(val) => *val != 0,
398 ConstantValue::Float(val) => *val != 0.,
399 ConstantValue::Bool(val) => *val,
400 }
401 }
402
403 pub fn is_zero(&self) -> bool {
404 match self {
405 ConstantValue::Int(val) => *val == 0,
406 ConstantValue::Float(val) => *val == 0.0,
407 ConstantValue::UInt(val) => *val == 0,
408 ConstantValue::Bool(val) => !*val,
409 }
410 }
411
412 pub fn is_one(&self) -> bool {
413 match self {
414 ConstantValue::Int(val) => *val == 1,
415 ConstantValue::Float(val) => *val == 1.0,
416 ConstantValue::UInt(val) => *val == 1,
417 ConstantValue::Bool(val) => *val,
418 }
419 }
420
421 pub fn cast_to(&self, other: impl Into<Type>) -> ConstantValue {
422 match other.into().storage_type() {
423 StorageType::Scalar(elem_type) => match elem_type {
424 ElemType::Float(kind) => match kind {
425 FloatKind::E2M1 => e2m1::from_f64(self.as_f64()).to_f64(),
426 FloatKind::E2M3 | FloatKind::E3M2 => {
427 unimplemented!("FP6 constants not yet supported")
428 }
429 FloatKind::E4M3 => e4m3::from_f64(self.as_f64()).to_f64(),
430 FloatKind::E5M2 => e5m2::from_f64(self.as_f64()).to_f64(),
431 FloatKind::UE8M0 => ue8m0::from_f64(self.as_f64()).to_f64(),
432 FloatKind::F16 => half::f16::from_f64(self.as_f64()).to_f64(),
433 FloatKind::BF16 => half::bf16::from_f64(self.as_f64()).to_f64(),
434 FloatKind::Flex32 | FloatKind::TF32 | FloatKind::F32 => {
435 self.as_f64() as f32 as f64
436 }
437 FloatKind::F64 => self.as_f64(),
438 }
439 .into(),
440 ElemType::Int(kind) => match kind {
441 IntKind::I8 => self.as_i64() as i8 as i64,
442 IntKind::I16 => self.as_i64() as i16 as i64,
443 IntKind::I32 => self.as_i64() as i32 as i64,
444 IntKind::I64 => self.as_i64(),
445 }
446 .into(),
447 ElemType::UInt(kind) => match kind {
448 UIntKind::U8 => self.as_u64() as u8 as u64,
449 UIntKind::U16 => self.as_u64() as u16 as u64,
450 UIntKind::U32 => self.as_u64() as u32 as u64,
451 UIntKind::U64 => self.as_u64(),
452 }
453 .into(),
454 ElemType::Bool => self.as_bool().into(),
455 },
456 StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2) => {
457 e2m1::from_f64(self.as_f64()).to_f64().into()
458 }
459 StorageType::Packed(..) => unimplemented!("Unsupported packed type"),
460 StorageType::Atomic(_) => unimplemented!("Atomic constants aren't supported"),
461 StorageType::Opaque(_) => unimplemented!("Opaque constants aren't supported"),
462 }
463 }
464}
465
466impl Display for ConstantValue {
467 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
468 match self {
469 ConstantValue::Int(val) => write!(f, "{val}"),
470 ConstantValue::Float(val) => write!(f, "{val:?}"),
471 ConstantValue::UInt(val) => write!(f, "{val}"),
472 ConstantValue::Bool(val) => write!(f, "{val}"),
473 }
474 }
475}
476
477impl Variable {
478 pub fn line_size(&self) -> usize {
479 self.ty.line_size()
480 }
481
482 pub fn index(&self) -> Option<Id> {
483 match self.kind {
484 VariableKind::GlobalInputArray(id)
485 | VariableKind::GlobalOutputArray(id)
486 | VariableKind::TensorMapInput(id)
487 | VariableKind::TensorMapOutput(id)
488 | VariableKind::GlobalScalar(id)
489 | VariableKind::LocalMut { id, .. }
490 | VariableKind::Versioned { id, .. }
491 | VariableKind::LocalConst { id, .. }
492 | VariableKind::ConstantArray { id, .. }
493 | VariableKind::SharedArray { id, .. }
494 | VariableKind::Shared { id, .. }
495 | VariableKind::LocalArray { id, .. }
496 | VariableKind::Matrix { id, .. } => Some(id),
497 _ => None,
498 }
499 }
500
501 pub fn as_const(&self) -> Option<ConstantValue> {
502 match self.kind {
503 VariableKind::Constant(constant) => Some(constant),
504 _ => None,
505 }
506 }
507}
508
509impl Display for Variable {
510 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
511 match self.kind {
512 VariableKind::GlobalInputArray(id) => write!(f, "input({id})"),
513 VariableKind::GlobalOutputArray(id) => write!(f, "output({id})"),
514 VariableKind::GlobalScalar(id) => write!(f, "scalar({id})"),
515 VariableKind::TensorMapInput(id) => write!(f, "tensor_map({id})"),
516 VariableKind::TensorMapOutput(id) => write!(f, "tensor_map({id})"),
517 VariableKind::Constant(constant) => write!(f, "{}({constant})", self.ty),
518 VariableKind::LocalMut { id } => write!(f, "local({id})"),
519 VariableKind::Versioned { id, version } => {
520 write!(f, "local({id}).v{version}")
521 }
522 VariableKind::LocalConst { id } => write!(f, "binding({id})"),
523 VariableKind::ConstantArray { id, .. } => write!(f, "const_array({id})"),
524 VariableKind::SharedArray { id, .. } => write!(f, "shared_array({id})"),
525 VariableKind::Shared { id } => write!(f, "shared({id})"),
526 VariableKind::LocalArray { id, .. } => write!(f, "array({id})"),
527 VariableKind::Matrix { id, .. } => write!(f, "matrix({id})"),
528 VariableKind::Builtin(builtin) => write!(f, "{builtin:?}"),
529 VariableKind::Pipeline { id, .. } => write!(f, "pipeline({id})"),
530 VariableKind::BarrierToken { id, .. } => write!(f, "barrier_token({id})"),
531 }
532 }
533}
534
535impl From<&Variable> for Variable {
537 fn from(value: &Variable) -> Self {
538 *value
539 }
540}