1use core::{fmt::Display, hash::Hash};
2
3use crate::{BarrierLevel, FloatKind, IntKind, StorageType, TypeHash};
4
5use super::{ElemType, Matrix, Type, UIntKind};
6use cubecl_common::{e2m1, e4m3, e5m2, ue8m0};
7use derive_more::From;
8use float_ord::FloatOrd;
9
10#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
12#[allow(missing_docs)]
13pub struct Variable {
14 pub kind: VariableKind,
15 pub ty: Type,
16}
17
18impl Variable {
19 pub fn new(kind: VariableKind, item: Type) -> Self {
20 Self { kind, ty: item }
21 }
22
23 pub fn builtin(builtin: Builtin, ty: StorageType) -> Self {
24 Self::new(VariableKind::Builtin(builtin), Type::new(ty))
25 }
26
27 pub fn constant(value: ConstantValue, ty: impl Into<Type>) -> Self {
28 let ty = ty.into();
29 let value = value.cast_to(ty);
30 Self::new(VariableKind::Constant(value), ty)
31 }
32
33 pub fn elem_type(&self) -> ElemType {
34 self.ty.elem_type()
35 }
36
37 pub fn storage_type(&self) -> StorageType {
38 self.ty.storage_type()
39 }
40}
41
42pub type Id = u32;
43
44#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
45#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
46pub enum VariableKind {
47 GlobalInputArray(Id),
48 GlobalOutputArray(Id),
49 GlobalScalar(Id),
50 TensorMapInput(Id),
51 TensorMapOutput(Id),
52 LocalArray {
53 id: Id,
54 length: usize,
55 unroll_factor: usize,
56 },
57 LocalMut {
58 id: Id,
59 },
60 LocalConst {
61 id: Id,
62 },
63 Versioned {
64 id: Id,
65 version: u16,
66 },
67 Constant(ConstantValue),
68 ConstantArray {
69 id: Id,
70 length: usize,
71 unroll_factor: usize,
72 },
73 SharedArray {
74 id: Id,
75 length: usize,
76 unroll_factor: usize,
77 alignment: Option<usize>,
78 },
79 Shared {
80 id: Id,
81 },
82 Matrix {
83 id: Id,
84 mat: Matrix,
85 },
86 Builtin(Builtin),
87 Pipeline {
88 id: Id,
89 num_stages: u8,
90 },
91 BarrierToken {
92 id: Id,
93 level: BarrierLevel,
94 },
95}
96
97#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
98#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash, PartialOrd, Ord)]
99#[repr(u8)]
100pub enum Builtin {
101 UnitPos,
102 UnitPosX,
103 UnitPosY,
104 UnitPosZ,
105 CubePosCluster,
106 CubePosClusterX,
107 CubePosClusterY,
108 CubePosClusterZ,
109 CubePos,
110 CubePosX,
111 CubePosY,
112 CubePosZ,
113 CubeDim,
114 CubeDimX,
115 CubeDimY,
116 CubeDimZ,
117 CubeClusterDim,
118 CubeClusterDimX,
119 CubeClusterDimY,
120 CubeClusterDimZ,
121 CubeCount,
122 CubeCountX,
123 CubeCountY,
124 CubeCountZ,
125 PlaneDim,
126 UnitPosPlane,
127 AbsolutePos,
128 AbsolutePosX,
129 AbsolutePosY,
130 AbsolutePosZ,
131}
132
133impl Variable {
134 pub fn is_immutable(&self) -> bool {
137 match self.kind {
138 VariableKind::GlobalOutputArray { .. } => false,
139 VariableKind::TensorMapInput(_) => true,
140 VariableKind::TensorMapOutput(_) => false,
141 VariableKind::LocalMut { .. } => false,
142 VariableKind::SharedArray { .. } => false,
143 VariableKind::Shared { .. } => false,
144 VariableKind::Matrix { .. } => false,
145 VariableKind::LocalArray { .. } => false,
146 VariableKind::GlobalInputArray { .. } => false,
147 VariableKind::GlobalScalar { .. } => true,
148 VariableKind::Versioned { .. } => true,
149 VariableKind::LocalConst { .. } => true,
150 VariableKind::Constant(_) => true,
151 VariableKind::ConstantArray { .. } => true,
152 VariableKind::Builtin(_) => true,
153 VariableKind::Pipeline { .. } => false,
154 VariableKind::BarrierToken { .. } => false,
155 }
156 }
157
158 pub fn is_array(&self) -> bool {
161 matches!(
162 self.kind,
163 VariableKind::GlobalInputArray { .. }
164 | VariableKind::GlobalOutputArray { .. }
165 | VariableKind::ConstantArray { .. }
166 | VariableKind::SharedArray { .. }
167 | VariableKind::LocalArray { .. }
168 | VariableKind::Matrix { .. }
169 )
170 }
171
172 pub fn has_length(&self) -> bool {
173 matches!(
174 self.kind,
175 VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
176 )
177 }
178
179 pub fn has_buffer_length(&self) -> bool {
180 matches!(
181 self.kind,
182 VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
183 )
184 }
185
186 pub fn is_constant(&self, value: i64) -> bool {
188 match self.kind {
189 VariableKind::Constant(ConstantValue::Int(val)) => val == value,
190 VariableKind::Constant(ConstantValue::UInt(val)) => val as i64 == value,
191 VariableKind::Constant(ConstantValue::Float(val)) => val == value as f64,
192 _ => false,
193 }
194 }
195
196 pub fn is_true(&self) -> bool {
198 match self.kind {
199 VariableKind::Constant(ConstantValue::Bool(val)) => val,
200 _ => false,
201 }
202 }
203
204 pub fn is_false(&self) -> bool {
206 match self.kind {
207 VariableKind::Constant(ConstantValue::Bool(val)) => !val,
208 _ => false,
209 }
210 }
211}
212
213#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
217#[derive(Debug, Clone, Copy, TypeHash, PartialEq, PartialOrd, From)]
218#[allow(missing_docs, clippy::derive_ord_xor_partial_ord)]
219pub enum ConstantValue {
220 Int(i64),
221 Float(f64),
222 UInt(u64),
223 Bool(bool),
224}
225
226impl Ord for ConstantValue {
227 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
228 match (self, other) {
231 (ConstantValue::Float(this), ConstantValue::Float(other)) => {
232 FloatOrd(*this).cmp(&FloatOrd(*other))
233 }
234 _ => self.partial_cmp(other).unwrap(),
235 }
236 }
237}
238
239impl Eq for ConstantValue {}
240impl Hash for ConstantValue {
241 fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
242 core::mem::discriminant(self).hash(ra_expand_state);
243 match self {
244 ConstantValue::Int(f0) => {
245 f0.hash(ra_expand_state);
246 }
247 ConstantValue::Float(f0) => {
248 FloatOrd(*f0).hash(ra_expand_state);
249 }
250 ConstantValue::UInt(f0) => {
251 f0.hash(ra_expand_state);
252 }
253 ConstantValue::Bool(f0) => {
254 f0.hash(ra_expand_state);
255 }
256 }
257 }
258}
259
260impl ConstantValue {
261 pub fn try_as_usize(&self) -> Option<usize> {
265 match self {
266 ConstantValue::UInt(val) => Some(*val as usize),
267 ConstantValue::Int(val) => Some(*val as usize),
268 ConstantValue::Float(_) => None,
269 ConstantValue::Bool(_) => None,
270 }
271 }
272
273 pub fn as_usize(&self) -> usize {
275 match self {
276 ConstantValue::UInt(val) => *val as usize,
277 ConstantValue::Int(val) => *val as usize,
278 ConstantValue::Float(val) => *val as usize,
279 ConstantValue::Bool(val) => *val as usize,
280 }
281 }
282
283 pub fn try_as_u32(&self) -> Option<u32> {
287 self.try_as_u64().map(|it| it as u32)
288 }
289
290 pub fn as_u32(&self) -> u32 {
294 self.as_u64() as u32
295 }
296
297 pub fn try_as_u64(&self) -> Option<u64> {
301 match self {
302 ConstantValue::UInt(val) => Some(*val),
303 ConstantValue::Int(val) => Some(*val as u64),
304 ConstantValue::Float(_) => None,
305 ConstantValue::Bool(_) => None,
306 }
307 }
308
309 pub fn as_u64(&self) -> u64 {
311 match self {
312 ConstantValue::UInt(val) => *val,
313 ConstantValue::Int(val) => *val as u64,
314 ConstantValue::Float(val) => *val as u64,
315 ConstantValue::Bool(val) => *val as u64,
316 }
317 }
318
319 pub fn try_as_i64(&self) -> Option<i64> {
323 match self {
324 ConstantValue::UInt(val) => Some(*val as i64),
325 ConstantValue::Int(val) => Some(*val),
326 ConstantValue::Float(_) => None,
327 ConstantValue::Bool(_) => None,
328 }
329 }
330
331 pub fn as_i128(&self) -> i128 {
333 match self {
334 ConstantValue::UInt(val) => *val as i128,
335 ConstantValue::Int(val) => *val as i128,
336 ConstantValue::Float(val) => *val as i128,
337 ConstantValue::Bool(val) => *val as i128,
338 }
339 }
340
341 pub fn as_i64(&self) -> i64 {
343 match self {
344 ConstantValue::UInt(val) => *val as i64,
345 ConstantValue::Int(val) => *val,
346 ConstantValue::Float(val) => *val as i64,
347 ConstantValue::Bool(val) => *val as i64,
348 }
349 }
350
351 pub fn try_as_f64(&self) -> Option<f64> {
355 match self {
356 ConstantValue::Float(val) => Some(*val),
357 _ => None,
358 }
359 }
360
361 pub fn as_f64(&self) -> f64 {
363 match self {
364 ConstantValue::UInt(val) => *val as f64,
365 ConstantValue::Int(val) => *val as f64,
366 ConstantValue::Float(val) => *val,
367 ConstantValue::Bool(val) => *val as u8 as f64,
368 }
369 }
370
371 pub fn try_as_bool(&self) -> Option<bool> {
373 match self {
374 ConstantValue::Bool(val) => Some(*val),
375 _ => None,
376 }
377 }
378
379 pub fn as_bool(&self) -> bool {
383 match self {
384 ConstantValue::UInt(val) => *val != 0,
385 ConstantValue::Int(val) => *val != 0,
386 ConstantValue::Float(val) => *val != 0.,
387 ConstantValue::Bool(val) => *val,
388 }
389 }
390
391 pub fn is_zero(&self) -> bool {
392 match self {
393 ConstantValue::Int(val) => *val == 0,
394 ConstantValue::Float(val) => *val == 0.0,
395 ConstantValue::UInt(val) => *val == 0,
396 ConstantValue::Bool(val) => !*val,
397 }
398 }
399
400 pub fn is_one(&self) -> bool {
401 match self {
402 ConstantValue::Int(val) => *val == 1,
403 ConstantValue::Float(val) => *val == 1.0,
404 ConstantValue::UInt(val) => *val == 1,
405 ConstantValue::Bool(val) => *val,
406 }
407 }
408
409 pub fn cast_to(&self, other: impl Into<Type>) -> ConstantValue {
410 match other.into().storage_type() {
411 StorageType::Scalar(elem_type) => match elem_type {
412 ElemType::Float(kind) => match kind {
413 FloatKind::E2M1 => e2m1::from_f64(self.as_f64()).to_f64(),
414 FloatKind::E2M3 | FloatKind::E3M2 => {
415 unimplemented!("FP6 constants not yet supported")
416 }
417 FloatKind::E4M3 => e4m3::from_f64(self.as_f64()).to_f64(),
418 FloatKind::E5M2 => e5m2::from_f64(self.as_f64()).to_f64(),
419 FloatKind::UE8M0 => ue8m0::from_f64(self.as_f64()).to_f64(),
420 FloatKind::F16 => half::f16::from_f64(self.as_f64()).to_f64(),
421 FloatKind::BF16 => half::bf16::from_f64(self.as_f64()).to_f64(),
422 FloatKind::Flex32 | FloatKind::TF32 | FloatKind::F32 => {
423 self.as_f64() as f32 as f64
424 }
425 FloatKind::F64 => self.as_f64(),
426 }
427 .into(),
428 ElemType::Int(kind) => match kind {
429 IntKind::I8 => self.as_i64() as i8 as i64,
430 IntKind::I16 => self.as_i64() as i16 as i64,
431 IntKind::I32 => self.as_i64() as i32 as i64,
432 IntKind::I64 => self.as_i64(),
433 }
434 .into(),
435 ElemType::UInt(kind) => match kind {
436 UIntKind::U8 => self.as_u64() as u8 as u64,
437 UIntKind::U16 => self.as_u64() as u16 as u64,
438 UIntKind::U32 => self.as_u64() as u32 as u64,
439 UIntKind::U64 => self.as_u64(),
440 }
441 .into(),
442 ElemType::Bool => self.as_bool().into(),
443 },
444 StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2) => {
445 e2m1::from_f64(self.as_f64()).to_f64().into()
446 }
447 StorageType::Packed(..) => unimplemented!("Unsupported packed type"),
448 StorageType::Atomic(_) => unimplemented!("Atomic constants aren't supported"),
449 StorageType::Opaque(_) => unimplemented!("Opaque constants aren't supported"),
450 }
451 }
452}
453
454impl Display for ConstantValue {
455 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
456 match self {
457 ConstantValue::Int(val) => write!(f, "{val}"),
458 ConstantValue::Float(val) => write!(f, "{val:?}"),
459 ConstantValue::UInt(val) => write!(f, "{val}"),
460 ConstantValue::Bool(val) => write!(f, "{val}"),
461 }
462 }
463}
464
465impl Variable {
466 pub fn line_size(&self) -> usize {
467 self.ty.line_size()
468 }
469
470 pub fn index(&self) -> Option<Id> {
471 match self.kind {
472 VariableKind::GlobalInputArray(id)
473 | VariableKind::GlobalOutputArray(id)
474 | VariableKind::TensorMapInput(id)
475 | VariableKind::TensorMapOutput(id)
476 | VariableKind::GlobalScalar(id)
477 | VariableKind::LocalMut { id, .. }
478 | VariableKind::Versioned { id, .. }
479 | VariableKind::LocalConst { id, .. }
480 | VariableKind::ConstantArray { id, .. }
481 | VariableKind::SharedArray { id, .. }
482 | VariableKind::Shared { id, .. }
483 | VariableKind::LocalArray { id, .. }
484 | VariableKind::Matrix { id, .. } => Some(id),
485 _ => None,
486 }
487 }
488
489 pub fn as_const(&self) -> Option<ConstantValue> {
490 match self.kind {
491 VariableKind::Constant(constant) => Some(constant),
492 _ => None,
493 }
494 }
495}
496
497impl Display for Variable {
498 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
499 match self.kind {
500 VariableKind::GlobalInputArray(id) => write!(f, "input({id})"),
501 VariableKind::GlobalOutputArray(id) => write!(f, "output({id})"),
502 VariableKind::GlobalScalar(id) => write!(f, "scalar({id})"),
503 VariableKind::TensorMapInput(id) => write!(f, "tensor_map({id})"),
504 VariableKind::TensorMapOutput(id) => write!(f, "tensor_map({id})"),
505 VariableKind::Constant(constant) => write!(f, "{}({constant})", self.ty),
506 VariableKind::LocalMut { id } => write!(f, "local({id})"),
507 VariableKind::Versioned { id, version } => {
508 write!(f, "local({id}).v{version}")
509 }
510 VariableKind::LocalConst { id } => write!(f, "binding({id})"),
511 VariableKind::ConstantArray { id, .. } => write!(f, "const_array({id})"),
512 VariableKind::SharedArray { id, .. } => write!(f, "shared_array({id})"),
513 VariableKind::Shared { id } => write!(f, "shared({id})"),
514 VariableKind::LocalArray { id, .. } => write!(f, "array({id})"),
515 VariableKind::Matrix { id, .. } => write!(f, "matrix({id})"),
516 VariableKind::Builtin(builtin) => write!(f, "{builtin:?}"),
517 VariableKind::Pipeline { id, .. } => write!(f, "pipeline({id})"),
518 VariableKind::BarrierToken { id, .. } => write!(f, "barrier_token({id})"),
519 }
520 }
521}
522
523impl From<&Variable> for Variable {
525 fn from(value: &Variable) -> Self {
526 *value
527 }
528}