Skip to main content

cubecl_ir/
variable.rs

1use core::{fmt::Display, hash::Hash};
2
3use crate::{BarrierLevel, FloatKind, IntKind, StorageType, TypeHash};
4
5use super::{ElemType, Matrix, Type, UIntKind};
6use cubecl_common::{e2m1, e4m3, e5m2, ue8m0};
7use derive_more::From;
8use float_ord::FloatOrd;
9
10#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
12#[allow(missing_docs)]
13pub struct Variable {
14    pub kind: VariableKind,
15    pub ty: Type,
16}
17
18impl Variable {
19    pub fn new(kind: VariableKind, item: Type) -> Self {
20        Self { kind, ty: item }
21    }
22
23    pub fn builtin(builtin: Builtin, ty: StorageType) -> Self {
24        Self::new(VariableKind::Builtin(builtin), Type::new(ty))
25    }
26
27    pub fn constant(value: ConstantValue, ty: impl Into<Type>) -> Self {
28        let ty = ty.into();
29        let value = value.cast_to(ty);
30        Self::new(VariableKind::Constant(value), ty)
31    }
32
33    pub fn elem_type(&self) -> ElemType {
34        self.ty.elem_type()
35    }
36
37    pub fn storage_type(&self) -> StorageType {
38        self.ty.storage_type()
39    }
40}
41
42pub type Id = u32;
43
44#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
45#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
46pub enum VariableKind {
47    GlobalInputArray(Id),
48    GlobalOutputArray(Id),
49    GlobalScalar(Id),
50    TensorMapInput(Id),
51    TensorMapOutput(Id),
52    LocalArray {
53        id: Id,
54        length: usize,
55        unroll_factor: usize,
56    },
57    LocalMut {
58        id: Id,
59    },
60    LocalConst {
61        id: Id,
62    },
63    Versioned {
64        id: Id,
65        version: u16,
66    },
67    Constant(ConstantValue),
68    ConstantArray {
69        id: Id,
70        length: usize,
71        unroll_factor: usize,
72    },
73    SharedArray {
74        id: Id,
75        length: usize,
76        unroll_factor: usize,
77        alignment: Option<usize>,
78    },
79    Shared {
80        id: Id,
81    },
82    Matrix {
83        id: Id,
84        mat: Matrix,
85    },
86    Builtin(Builtin),
87    Pipeline {
88        id: Id,
89        num_stages: u8,
90    },
91    BarrierToken {
92        id: Id,
93        level: BarrierLevel,
94    },
95}
96
97#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
98#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash, PartialOrd, Ord)]
99#[repr(u8)]
100pub enum Builtin {
101    UnitPos,
102    UnitPosX,
103    UnitPosY,
104    UnitPosZ,
105    CubePosCluster,
106    CubePosClusterX,
107    CubePosClusterY,
108    CubePosClusterZ,
109    CubePos,
110    CubePosX,
111    CubePosY,
112    CubePosZ,
113    CubeDim,
114    CubeDimX,
115    CubeDimY,
116    CubeDimZ,
117    CubeClusterDim,
118    CubeClusterDimX,
119    CubeClusterDimY,
120    CubeClusterDimZ,
121    CubeCount,
122    CubeCountX,
123    CubeCountY,
124    CubeCountZ,
125    PlaneDim,
126    PlanePos,
127    UnitPosPlane,
128    AbsolutePos,
129    AbsolutePosX,
130    AbsolutePosY,
131    AbsolutePosZ,
132}
133
134impl Variable {
135    /// Whether a variable is always immutable. Used for optimizations to determine whether it's
136    /// safe to inline/merge
137    pub fn is_immutable(&self) -> bool {
138        match self.kind {
139            VariableKind::GlobalOutputArray { .. } => false,
140            VariableKind::TensorMapInput(_) => true,
141            VariableKind::TensorMapOutput(_) => false,
142            VariableKind::LocalMut { .. } => false,
143            VariableKind::SharedArray { .. } => false,
144            VariableKind::Shared { .. } => false,
145            VariableKind::Matrix { .. } => false,
146            VariableKind::LocalArray { .. } => false,
147            VariableKind::GlobalInputArray { .. } => false,
148            VariableKind::GlobalScalar { .. } => true,
149            VariableKind::Versioned { .. } => true,
150            VariableKind::LocalConst { .. } => true,
151            VariableKind::Constant(_) => true,
152            VariableKind::ConstantArray { .. } => true,
153            VariableKind::Builtin(_) => true,
154            VariableKind::Pipeline { .. } => false,
155            VariableKind::BarrierToken { .. } => false,
156        }
157    }
158
159    /// Is this an array type that yields items when indexed,
160    /// or a scalar/vector that yields elems/slices when indexed?
161    pub fn is_array(&self) -> bool {
162        matches!(
163            self.kind,
164            VariableKind::GlobalInputArray { .. }
165                | VariableKind::GlobalOutputArray { .. }
166                | VariableKind::ConstantArray { .. }
167                | VariableKind::SharedArray { .. }
168                | VariableKind::LocalArray { .. }
169                | VariableKind::Matrix { .. }
170        )
171    }
172
173    /// Is this an array type that is contained in concrete memory,
174    /// or a local array/scalar/vector?
175    pub fn is_memory(&self) -> bool {
176        matches!(
177            self.kind,
178            VariableKind::GlobalInputArray { .. }
179                | VariableKind::GlobalOutputArray { .. }
180                | VariableKind::SharedArray { .. }
181        )
182    }
183
184    pub fn has_length(&self) -> bool {
185        matches!(
186            self.kind,
187            VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
188        )
189    }
190
191    pub fn has_buffer_length(&self) -> bool {
192        matches!(
193            self.kind,
194            VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
195        )
196    }
197
198    /// Determines if the value is a constant with the specified value (converted if necessary)
199    pub fn is_constant(&self, value: i64) -> bool {
200        match self.kind {
201            VariableKind::Constant(ConstantValue::Int(val)) => val == value,
202            VariableKind::Constant(ConstantValue::UInt(val)) => val as i64 == value,
203            VariableKind::Constant(ConstantValue::Float(val)) => val == value as f64,
204            _ => false,
205        }
206    }
207
208    /// Determines if the value is a boolean constant with the `true` value
209    pub fn is_true(&self) -> bool {
210        match self.kind {
211            VariableKind::Constant(ConstantValue::Bool(val)) => val,
212            _ => false,
213        }
214    }
215
216    /// Determines if the value is a boolean constant with the `false` value
217    pub fn is_false(&self) -> bool {
218        match self.kind {
219            VariableKind::Constant(ConstantValue::Bool(val)) => !val,
220            _ => false,
221        }
222    }
223}
224
225/// The scalars are stored with the highest precision possible, but they might get reduced during
226/// compilation. For constant propagation, casts are always executed before converting back to the
227/// larger type to ensure deterministic output.
228#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
229#[derive(Debug, Clone, Copy, TypeHash, PartialEq, PartialOrd, From)]
230#[allow(missing_docs, clippy::derive_ord_xor_partial_ord)]
231pub enum ConstantValue {
232    Int(i64),
233    Float(f64),
234    UInt(u64),
235    Bool(bool),
236}
237
238impl Ord for ConstantValue {
239    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
240        // Override float-float comparison with `FloatOrd` since `f64` isn't `Ord`. All other
241        // comparisons are safe to unwrap since they're either `Ord` or only compare discriminants.
242        match (self, other) {
243            (ConstantValue::Float(this), ConstantValue::Float(other)) => {
244                FloatOrd(*this).cmp(&FloatOrd(*other))
245            }
246            _ => self.partial_cmp(other).unwrap(),
247        }
248    }
249}
250
251impl Eq for ConstantValue {}
252impl Hash for ConstantValue {
253    fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
254        core::mem::discriminant(self).hash(ra_expand_state);
255        match self {
256            ConstantValue::Int(f0) => {
257                f0.hash(ra_expand_state);
258            }
259            ConstantValue::Float(f0) => {
260                FloatOrd(*f0).hash(ra_expand_state);
261            }
262            ConstantValue::UInt(f0) => {
263                f0.hash(ra_expand_state);
264            }
265            ConstantValue::Bool(f0) => {
266                f0.hash(ra_expand_state);
267            }
268        }
269    }
270}
271
272impl ConstantValue {
273    /// Returns the value of the constant as a usize.
274    ///
275    /// It will return [None] if the constant type is a float or a bool.
276    pub fn try_as_usize(&self) -> Option<usize> {
277        match self {
278            ConstantValue::UInt(val) => Some(*val as usize),
279            ConstantValue::Int(val) => Some(*val as usize),
280            ConstantValue::Float(_) => None,
281            ConstantValue::Bool(_) => None,
282        }
283    }
284
285    /// Returns the value of the constant as a usize.
286    pub fn as_usize(&self) -> usize {
287        match self {
288            ConstantValue::UInt(val) => *val as usize,
289            ConstantValue::Int(val) => *val as usize,
290            ConstantValue::Float(val) => *val as usize,
291            ConstantValue::Bool(val) => *val as usize,
292        }
293    }
294
295    /// Returns the value of the scalar as a u32.
296    ///
297    /// It will return [None] if the scalar type is a float or a bool.
298    pub fn try_as_u32(&self) -> Option<u32> {
299        self.try_as_u64().map(|it| it as u32)
300    }
301
302    /// Returns the value of the scalar as a u32.
303    ///
304    /// It will panic if the scalar type is a float or a bool.
305    pub fn as_u32(&self) -> u32 {
306        self.as_u64() as u32
307    }
308
309    /// Returns the value of the scalar as a u64.
310    ///
311    /// It will return [None] if the scalar type is a float or a bool.
312    pub fn try_as_u64(&self) -> Option<u64> {
313        match self {
314            ConstantValue::UInt(val) => Some(*val),
315            ConstantValue::Int(val) => Some(*val as u64),
316            ConstantValue::Float(_) => None,
317            ConstantValue::Bool(_) => None,
318        }
319    }
320
321    /// Returns the value of the scalar as a u64.
322    pub fn as_u64(&self) -> u64 {
323        match self {
324            ConstantValue::UInt(val) => *val,
325            ConstantValue::Int(val) => *val as u64,
326            ConstantValue::Float(val) => *val as u64,
327            ConstantValue::Bool(val) => *val as u64,
328        }
329    }
330
331    /// Returns the value of the scalar as a i64.
332    ///
333    /// It will return [None] if the scalar type is a float or a bool.
334    pub fn try_as_i64(&self) -> Option<i64> {
335        match self {
336            ConstantValue::UInt(val) => Some(*val as i64),
337            ConstantValue::Int(val) => Some(*val),
338            ConstantValue::Float(_) => None,
339            ConstantValue::Bool(_) => None,
340        }
341    }
342
343    /// Returns the value of the scalar as a i128.
344    pub fn as_i128(&self) -> i128 {
345        match self {
346            ConstantValue::UInt(val) => *val as i128,
347            ConstantValue::Int(val) => *val as i128,
348            ConstantValue::Float(val) => *val as i128,
349            ConstantValue::Bool(val) => *val as i128,
350        }
351    }
352
353    /// Returns the value of the scalar as a i64.
354    pub fn as_i64(&self) -> i64 {
355        match self {
356            ConstantValue::UInt(val) => *val as i64,
357            ConstantValue::Int(val) => *val,
358            ConstantValue::Float(val) => *val as i64,
359            ConstantValue::Bool(val) => *val as i64,
360        }
361    }
362
363    /// Returns the value of the scalar as a f64.
364    ///
365    /// It will return [None] if the scalar type is an int or a bool.
366    pub fn try_as_f64(&self) -> Option<f64> {
367        match self {
368            ConstantValue::Float(val) => Some(*val),
369            _ => None,
370        }
371    }
372
373    /// Returns the value of the scalar as a f64.
374    pub fn as_f64(&self) -> f64 {
375        match self {
376            ConstantValue::UInt(val) => *val as f64,
377            ConstantValue::Int(val) => *val as f64,
378            ConstantValue::Float(val) => *val,
379            ConstantValue::Bool(val) => *val as u8 as f64,
380        }
381    }
382
383    /// Returns the value of the variable as a bool if it actually is a bool.
384    pub fn try_as_bool(&self) -> Option<bool> {
385        match self {
386            ConstantValue::Bool(val) => Some(*val),
387            _ => None,
388        }
389    }
390
391    /// Returns the value of the variable as a bool.
392    ///
393    /// It will panic if the scalar isn't a bool.
394    pub fn as_bool(&self) -> bool {
395        match self {
396            ConstantValue::UInt(val) => *val != 0,
397            ConstantValue::Int(val) => *val != 0,
398            ConstantValue::Float(val) => *val != 0.,
399            ConstantValue::Bool(val) => *val,
400        }
401    }
402
403    pub fn is_zero(&self) -> bool {
404        match self {
405            ConstantValue::Int(val) => *val == 0,
406            ConstantValue::Float(val) => *val == 0.0,
407            ConstantValue::UInt(val) => *val == 0,
408            ConstantValue::Bool(val) => !*val,
409        }
410    }
411
412    pub fn is_one(&self) -> bool {
413        match self {
414            ConstantValue::Int(val) => *val == 1,
415            ConstantValue::Float(val) => *val == 1.0,
416            ConstantValue::UInt(val) => *val == 1,
417            ConstantValue::Bool(val) => *val,
418        }
419    }
420
421    pub fn cast_to(&self, other: impl Into<Type>) -> ConstantValue {
422        match other.into().storage_type() {
423            StorageType::Scalar(elem_type) => match elem_type {
424                ElemType::Float(kind) => match kind {
425                    FloatKind::E2M1 => e2m1::from_f64(self.as_f64()).to_f64(),
426                    FloatKind::E2M3 | FloatKind::E3M2 => {
427                        unimplemented!("FP6 constants not yet supported")
428                    }
429                    FloatKind::E4M3 => e4m3::from_f64(self.as_f64()).to_f64(),
430                    FloatKind::E5M2 => e5m2::from_f64(self.as_f64()).to_f64(),
431                    FloatKind::UE8M0 => ue8m0::from_f64(self.as_f64()).to_f64(),
432                    FloatKind::F16 => half::f16::from_f64(self.as_f64()).to_f64(),
433                    FloatKind::BF16 => half::bf16::from_f64(self.as_f64()).to_f64(),
434                    FloatKind::Flex32 | FloatKind::TF32 | FloatKind::F32 => {
435                        self.as_f64() as f32 as f64
436                    }
437                    FloatKind::F64 => self.as_f64(),
438                }
439                .into(),
440                ElemType::Int(kind) => match kind {
441                    IntKind::I8 => self.as_i64() as i8 as i64,
442                    IntKind::I16 => self.as_i64() as i16 as i64,
443                    IntKind::I32 => self.as_i64() as i32 as i64,
444                    IntKind::I64 => self.as_i64(),
445                }
446                .into(),
447                ElemType::UInt(kind) => match kind {
448                    UIntKind::U8 => self.as_u64() as u8 as u64,
449                    UIntKind::U16 => self.as_u64() as u16 as u64,
450                    UIntKind::U32 => self.as_u64() as u32 as u64,
451                    UIntKind::U64 => self.as_u64(),
452                }
453                .into(),
454                ElemType::Bool => self.as_bool().into(),
455            },
456            StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2) => {
457                e2m1::from_f64(self.as_f64()).to_f64().into()
458            }
459            StorageType::Packed(..) => unimplemented!("Unsupported packed type"),
460            StorageType::Atomic(_) => unimplemented!("Atomic constants aren't supported"),
461            StorageType::Opaque(_) => unimplemented!("Opaque constants aren't supported"),
462        }
463    }
464}
465
466impl Display for ConstantValue {
467    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
468        match self {
469            ConstantValue::Int(val) => write!(f, "{val}"),
470            ConstantValue::Float(val) => write!(f, "{val:?}"),
471            ConstantValue::UInt(val) => write!(f, "{val}"),
472            ConstantValue::Bool(val) => write!(f, "{val}"),
473        }
474    }
475}
476
477impl Variable {
478    pub fn line_size(&self) -> usize {
479        self.ty.line_size()
480    }
481
482    pub fn index(&self) -> Option<Id> {
483        match self.kind {
484            VariableKind::GlobalInputArray(id)
485            | VariableKind::GlobalOutputArray(id)
486            | VariableKind::TensorMapInput(id)
487            | VariableKind::TensorMapOutput(id)
488            | VariableKind::GlobalScalar(id)
489            | VariableKind::LocalMut { id, .. }
490            | VariableKind::Versioned { id, .. }
491            | VariableKind::LocalConst { id, .. }
492            | VariableKind::ConstantArray { id, .. }
493            | VariableKind::SharedArray { id, .. }
494            | VariableKind::Shared { id, .. }
495            | VariableKind::LocalArray { id, .. }
496            | VariableKind::Matrix { id, .. } => Some(id),
497            _ => None,
498        }
499    }
500
501    pub fn as_const(&self) -> Option<ConstantValue> {
502        match self.kind {
503            VariableKind::Constant(constant) => Some(constant),
504            _ => None,
505        }
506    }
507}
508
509impl Display for Variable {
510    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
511        match self.kind {
512            VariableKind::GlobalInputArray(id) => write!(f, "input({id})"),
513            VariableKind::GlobalOutputArray(id) => write!(f, "output({id})"),
514            VariableKind::GlobalScalar(id) => write!(f, "scalar({id})"),
515            VariableKind::TensorMapInput(id) => write!(f, "tensor_map({id})"),
516            VariableKind::TensorMapOutput(id) => write!(f, "tensor_map({id})"),
517            VariableKind::Constant(constant) => write!(f, "{}({constant})", self.ty),
518            VariableKind::LocalMut { id } => write!(f, "local({id})"),
519            VariableKind::Versioned { id, version } => {
520                write!(f, "local({id}).v{version}")
521            }
522            VariableKind::LocalConst { id } => write!(f, "binding({id})"),
523            VariableKind::ConstantArray { id, .. } => write!(f, "const_array({id})"),
524            VariableKind::SharedArray { id, .. } => write!(f, "shared_array({id})"),
525            VariableKind::Shared { id } => write!(f, "shared({id})"),
526            VariableKind::LocalArray { id, .. } => write!(f, "array({id})"),
527            VariableKind::Matrix { id, .. } => write!(f, "matrix({id})"),
528            VariableKind::Builtin(builtin) => write!(f, "{builtin:?}"),
529            VariableKind::Pipeline { id, .. } => write!(f, "pipeline({id})"),
530            VariableKind::BarrierToken { id, .. } => write!(f, "barrier_token({id})"),
531        }
532    }
533}
534
535// Useful with the cube_inline macro.
536impl From<&Variable> for Variable {
537    fn from(value: &Variable) -> Self {
538        *value
539    }
540}