Skip to main content

cubecl_ir/
variable.rs

1use core::{fmt::Display, hash::Hash};
2
3use crate::{BarrierLevel, FloatKind, IntKind, StorageType, TypeHash};
4
5use super::{ElemType, Matrix, Type, UIntKind};
6use cubecl_common::{e2m1, e4m3, e5m2, ue8m0};
7use derive_more::From;
8use float_ord::FloatOrd;
9
10#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
12#[allow(missing_docs)]
13pub struct Variable {
14    pub kind: VariableKind,
15    pub ty: Type,
16}
17
18impl Variable {
19    pub fn new(kind: VariableKind, item: Type) -> Self {
20        Self { kind, ty: item }
21    }
22
23    pub fn builtin(builtin: Builtin, ty: StorageType) -> Self {
24        Self::new(VariableKind::Builtin(builtin), Type::new(ty))
25    }
26
27    pub fn constant(value: ConstantValue, ty: impl Into<Type>) -> Self {
28        let ty = ty.into();
29        let value = value.cast_to(ty);
30        Self::new(VariableKind::Constant(value), ty)
31    }
32
33    pub fn elem_type(&self) -> ElemType {
34        self.ty.elem_type()
35    }
36
37    pub fn storage_type(&self) -> StorageType {
38        self.ty.storage_type()
39    }
40}
41
42pub type Id = u32;
43
44#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
45#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
46pub enum VariableKind {
47    GlobalInputArray(Id),
48    GlobalOutputArray(Id),
49    GlobalScalar(Id),
50    TensorMapInput(Id),
51    TensorMapOutput(Id),
52    LocalArray {
53        id: Id,
54        length: usize,
55        unroll_factor: usize,
56    },
57    LocalMut {
58        id: Id,
59    },
60    LocalConst {
61        id: Id,
62    },
63    Versioned {
64        id: Id,
65        version: u16,
66    },
67    Constant(ConstantValue),
68    ConstantArray {
69        id: Id,
70        length: usize,
71        unroll_factor: usize,
72    },
73    SharedArray {
74        id: Id,
75        length: usize,
76        unroll_factor: usize,
77        alignment: Option<usize>,
78    },
79    Shared {
80        id: Id,
81    },
82    Matrix {
83        id: Id,
84        mat: Matrix,
85    },
86    Builtin(Builtin),
87    Pipeline {
88        id: Id,
89        num_stages: u8,
90    },
91    BarrierToken {
92        id: Id,
93        level: BarrierLevel,
94    },
95}
96
97#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
98#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash, PartialOrd, Ord)]
99#[repr(u8)]
100pub enum Builtin {
101    UnitPos,
102    UnitPosX,
103    UnitPosY,
104    UnitPosZ,
105    CubePosCluster,
106    CubePosClusterX,
107    CubePosClusterY,
108    CubePosClusterZ,
109    CubePos,
110    CubePosX,
111    CubePosY,
112    CubePosZ,
113    CubeDim,
114    CubeDimX,
115    CubeDimY,
116    CubeDimZ,
117    CubeClusterDim,
118    CubeClusterDimX,
119    CubeClusterDimY,
120    CubeClusterDimZ,
121    CubeCount,
122    CubeCountX,
123    CubeCountY,
124    CubeCountZ,
125    PlaneDim,
126    PlanePos,
127    UnitPosPlane,
128    AbsolutePos,
129    AbsolutePosX,
130    AbsolutePosY,
131    AbsolutePosZ,
132}
133
134impl Variable {
135    /// Whether a variable is always immutable. Used for optimizations to determine whether it's
136    /// safe to inline/merge
137    pub fn is_immutable(&self) -> bool {
138        match self.kind {
139            VariableKind::GlobalOutputArray { .. } => false,
140            VariableKind::TensorMapInput(_) => true,
141            VariableKind::TensorMapOutput(_) => false,
142            VariableKind::LocalMut { .. } => false,
143            VariableKind::SharedArray { .. } => false,
144            VariableKind::Shared { .. } => false,
145            VariableKind::Matrix { .. } => false,
146            VariableKind::LocalArray { .. } => false,
147            VariableKind::GlobalInputArray { .. } => false,
148            VariableKind::GlobalScalar { .. } => true,
149            VariableKind::Versioned { .. } => true,
150            VariableKind::LocalConst { .. } => true,
151            VariableKind::Constant(_) => true,
152            VariableKind::ConstantArray { .. } => true,
153            VariableKind::Builtin(_) => true,
154            VariableKind::Pipeline { .. } => false,
155            VariableKind::BarrierToken { .. } => false,
156        }
157    }
158
159    /// Is this an array type that yields items when indexed,
160    /// or a scalar/vector that yields elems/slices when indexed?
161    pub fn is_array(&self) -> bool {
162        matches!(
163            self.kind,
164            VariableKind::GlobalInputArray { .. }
165                | VariableKind::GlobalOutputArray { .. }
166                | VariableKind::ConstantArray { .. }
167                | VariableKind::SharedArray { .. }
168                | VariableKind::LocalArray { .. }
169                | VariableKind::Matrix { .. }
170        )
171    }
172
173    /// Is this an array type that is contained in concrete memory,
174    /// or a local array/scalar/vector?
175    pub fn is_memory(&self) -> bool {
176        matches!(
177            self.kind,
178            VariableKind::GlobalInputArray { .. }
179                | VariableKind::GlobalOutputArray { .. }
180                | VariableKind::SharedArray { .. }
181        )
182    }
183
184    pub fn has_length(&self) -> bool {
185        matches!(
186            self.kind,
187            VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
188        )
189    }
190
191    pub fn has_buffer_length(&self) -> bool {
192        matches!(
193            self.kind,
194            VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
195        )
196    }
197
198    /// Determines if the value is a constant with the specified value (converted if necessary)
199    pub fn is_constant(&self, value: i64) -> bool {
200        match self.kind {
201            VariableKind::Constant(ConstantValue::Int(val)) => val == value,
202            VariableKind::Constant(ConstantValue::UInt(val)) => val as i64 == value,
203            VariableKind::Constant(ConstantValue::Float(val)) => val == value as f64,
204            _ => false,
205        }
206    }
207
208    /// Determines if the value is a boolean constant with the `true` value
209    pub fn is_true(&self) -> bool {
210        match self.kind {
211            VariableKind::Constant(ConstantValue::Bool(val)) => val,
212            _ => false,
213        }
214    }
215
216    /// Determines if the value is a boolean constant with the `false` value
217    pub fn is_false(&self) -> bool {
218        match self.kind {
219            VariableKind::Constant(ConstantValue::Bool(val)) => !val,
220            _ => false,
221        }
222    }
223}
224
225/// The scalars are stored with the highest precision possible, but they might get reduced during
226/// compilation. For constant propagation, casts are always executed before converting back to the
227/// larger type to ensure deterministic output.
228#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
229#[derive(Debug, Clone, Copy, TypeHash, PartialEq, PartialOrd, From)]
230#[allow(missing_docs, clippy::derive_ord_xor_partial_ord)]
231pub enum ConstantValue {
232    Int(i64),
233    Float(f64),
234    UInt(u64),
235    Bool(bool),
236}
237
238impl Ord for ConstantValue {
239    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
240        // Override float-float comparison with `FloatOrd` since `f64` isn't `Ord`. All other
241        // comparisons are safe to unwrap since they're either `Ord` or only compare discriminants.
242        match (self, other) {
243            (ConstantValue::Float(this), ConstantValue::Float(other)) => {
244                FloatOrd(*this).cmp(&FloatOrd(*other))
245            }
246            _ => self.partial_cmp(other).unwrap(),
247        }
248    }
249}
250
251impl Eq for ConstantValue {}
252impl Hash for ConstantValue {
253    fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
254        core::mem::discriminant(self).hash(ra_expand_state);
255        match self {
256            ConstantValue::Int(f0) => {
257                f0.hash(ra_expand_state);
258            }
259            ConstantValue::Float(f0) => {
260                FloatOrd(*f0).hash(ra_expand_state);
261            }
262            ConstantValue::UInt(f0) => {
263                f0.hash(ra_expand_state);
264            }
265            ConstantValue::Bool(f0) => {
266                f0.hash(ra_expand_state);
267            }
268        }
269    }
270}
271
272impl ConstantValue {
273    /// Returns the value of the constant as a usize.
274    ///
275    /// It will return [None] if the constant type is a float or a bool.
276    pub fn try_as_usize(&self) -> Option<usize> {
277        match self {
278            ConstantValue::UInt(val) => Some(*val as usize),
279            ConstantValue::Int(val) => Some(*val as usize),
280            ConstantValue::Float(_) => None,
281            ConstantValue::Bool(_) => None,
282        }
283    }
284
285    /// Returns the value of the constant as a usize.
286    pub fn as_usize(&self) -> usize {
287        match self {
288            ConstantValue::UInt(val) => *val as usize,
289            ConstantValue::Int(val) => *val as usize,
290            ConstantValue::Float(val) => *val as usize,
291            ConstantValue::Bool(val) => *val as usize,
292        }
293    }
294
295    /// Returns the value of the scalar as a u32.
296    ///
297    /// It will return [None] if the scalar type is a float or a bool.
298    pub fn try_as_u32(&self) -> Option<u32> {
299        self.try_as_u64().map(|it| it as u32)
300    }
301
302    /// Returns the value of the scalar as a u32.
303    ///
304    /// It will panic if the scalar type is a float or a bool.
305    pub fn as_u32(&self) -> u32 {
306        self.as_u64() as u32
307    }
308
309    /// Returns the value of the scalar as a u64.
310    ///
311    /// It will return [None] if the scalar type is a float or a bool.
312    pub fn try_as_u64(&self) -> Option<u64> {
313        match self {
314            ConstantValue::UInt(val) => Some(*val),
315            ConstantValue::Int(val) => Some(*val as u64),
316            ConstantValue::Float(_) => None,
317            ConstantValue::Bool(_) => None,
318        }
319    }
320
321    /// Returns the value of the scalar as a u64.
322    pub fn as_u64(&self) -> u64 {
323        match self {
324            ConstantValue::UInt(val) => *val,
325            ConstantValue::Int(val) => *val as u64,
326            ConstantValue::Float(val) => *val as u64,
327            ConstantValue::Bool(val) => *val as u64,
328        }
329    }
330
331    /// Returns the value of the scalar as a i64.
332    ///
333    /// It will return [None] if the scalar type is a float or a bool.
334    pub fn try_as_i64(&self) -> Option<i64> {
335        match self {
336            ConstantValue::UInt(val) => Some(*val as i64),
337            ConstantValue::Int(val) => Some(*val),
338            ConstantValue::Float(_) => None,
339            ConstantValue::Bool(_) => None,
340        }
341    }
342
343    /// Returns the value of the scalar as a i128.
344    pub fn as_i128(&self) -> i128 {
345        match self {
346            ConstantValue::UInt(val) => *val as i128,
347            ConstantValue::Int(val) => *val as i128,
348            ConstantValue::Float(val) => *val as i128,
349            ConstantValue::Bool(val) => *val as i128,
350        }
351    }
352
353    /// Returns the value of the scalar as a i64.
354    pub fn as_i64(&self) -> i64 {
355        match self {
356            ConstantValue::UInt(val) => *val as i64,
357            ConstantValue::Int(val) => *val,
358            ConstantValue::Float(val) => *val as i64,
359            ConstantValue::Bool(val) => *val as i64,
360        }
361    }
362
363    /// Returns the value of the scalar as a i64.
364    pub fn as_i32(&self) -> i32 {
365        match self {
366            ConstantValue::UInt(val) => *val as i32,
367            ConstantValue::Int(val) => *val as i32,
368            ConstantValue::Float(val) => *val as i32,
369            ConstantValue::Bool(val) => *val as i32,
370        }
371    }
372
373    /// Returns the value of the scalar as a f64.
374    ///
375    /// It will return [None] if the scalar type is an int or a bool.
376    pub fn try_as_f64(&self) -> Option<f64> {
377        match self {
378            ConstantValue::Float(val) => Some(*val),
379            _ => None,
380        }
381    }
382
383    /// Returns the value of the scalar as a f64.
384    pub fn as_f64(&self) -> f64 {
385        match self {
386            ConstantValue::UInt(val) => *val as f64,
387            ConstantValue::Int(val) => *val as f64,
388            ConstantValue::Float(val) => *val,
389            ConstantValue::Bool(val) => *val as u8 as f64,
390        }
391    }
392
393    /// Returns the value of the variable as a bool if it actually is a bool.
394    pub fn try_as_bool(&self) -> Option<bool> {
395        match self {
396            ConstantValue::Bool(val) => Some(*val),
397            _ => None,
398        }
399    }
400
401    /// Returns the value of the variable as a bool.
402    ///
403    /// It will panic if the scalar isn't a bool.
404    pub fn as_bool(&self) -> bool {
405        match self {
406            ConstantValue::UInt(val) => *val != 0,
407            ConstantValue::Int(val) => *val != 0,
408            ConstantValue::Float(val) => *val != 0.,
409            ConstantValue::Bool(val) => *val,
410        }
411    }
412
413    pub fn is_zero(&self) -> bool {
414        match self {
415            ConstantValue::Int(val) => *val == 0,
416            ConstantValue::Float(val) => *val == 0.0,
417            ConstantValue::UInt(val) => *val == 0,
418            ConstantValue::Bool(val) => !*val,
419        }
420    }
421
422    pub fn is_one(&self) -> bool {
423        match self {
424            ConstantValue::Int(val) => *val == 1,
425            ConstantValue::Float(val) => *val == 1.0,
426            ConstantValue::UInt(val) => *val == 1,
427            ConstantValue::Bool(val) => *val,
428        }
429    }
430
431    pub fn cast_to(&self, other: impl Into<Type>) -> ConstantValue {
432        match other.into().storage_type() {
433            StorageType::Scalar(elem_type) => match elem_type {
434                ElemType::Float(kind) => match kind {
435                    FloatKind::E2M1 => e2m1::from_f64(self.as_f64()).to_f64(),
436                    FloatKind::E2M3 | FloatKind::E3M2 => {
437                        unimplemented!("FP6 constants not yet supported")
438                    }
439                    FloatKind::E4M3 => e4m3::from_f64(self.as_f64()).to_f64(),
440                    FloatKind::E5M2 => e5m2::from_f64(self.as_f64()).to_f64(),
441                    FloatKind::UE8M0 => ue8m0::from_f64(self.as_f64()).to_f64(),
442                    FloatKind::F16 => half::f16::from_f64(self.as_f64()).to_f64(),
443                    FloatKind::BF16 => half::bf16::from_f64(self.as_f64()).to_f64(),
444                    FloatKind::Flex32 | FloatKind::TF32 | FloatKind::F32 => {
445                        self.as_f64() as f32 as f64
446                    }
447                    FloatKind::F64 => self.as_f64(),
448                }
449                .into(),
450                ElemType::Int(kind) => match kind {
451                    IntKind::I8 => self.as_i64() as i8 as i64,
452                    IntKind::I16 => self.as_i64() as i16 as i64,
453                    IntKind::I32 => self.as_i64() as i32 as i64,
454                    IntKind::I64 => self.as_i64(),
455                }
456                .into(),
457                ElemType::UInt(kind) => match kind {
458                    UIntKind::U8 => self.as_u64() as u8 as u64,
459                    UIntKind::U16 => self.as_u64() as u16 as u64,
460                    UIntKind::U32 => self.as_u64() as u32 as u64,
461                    UIntKind::U64 => self.as_u64(),
462                }
463                .into(),
464                ElemType::Bool => self.as_bool().into(),
465            },
466            StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2) => {
467                e2m1::from_f64(self.as_f64()).to_f64().into()
468            }
469            StorageType::Packed(..) => unimplemented!("Unsupported packed type"),
470            StorageType::Atomic(_) => unimplemented!("Atomic constants aren't supported"),
471            StorageType::Opaque(_) => unimplemented!("Opaque constants aren't supported"),
472        }
473    }
474}
475
476impl Display for ConstantValue {
477    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
478        match self {
479            ConstantValue::Int(val) => write!(f, "{val}"),
480            ConstantValue::Float(val) => write!(f, "{val:?}"),
481            ConstantValue::UInt(val) => write!(f, "{val}"),
482            ConstantValue::Bool(val) => write!(f, "{val}"),
483        }
484    }
485}
486
487impl Variable {
488    pub fn vector_size(&self) -> usize {
489        self.ty.vector_size()
490    }
491
492    pub fn index(&self) -> Option<Id> {
493        match self.kind {
494            VariableKind::GlobalInputArray(id)
495            | VariableKind::GlobalOutputArray(id)
496            | VariableKind::TensorMapInput(id)
497            | VariableKind::TensorMapOutput(id)
498            | VariableKind::GlobalScalar(id)
499            | VariableKind::LocalMut { id, .. }
500            | VariableKind::Versioned { id, .. }
501            | VariableKind::LocalConst { id, .. }
502            | VariableKind::ConstantArray { id, .. }
503            | VariableKind::SharedArray { id, .. }
504            | VariableKind::Shared { id, .. }
505            | VariableKind::LocalArray { id, .. }
506            | VariableKind::Matrix { id, .. } => Some(id),
507            _ => None,
508        }
509    }
510
511    pub fn as_const(&self) -> Option<ConstantValue> {
512        match self.kind {
513            VariableKind::Constant(constant) => Some(constant),
514            _ => None,
515        }
516    }
517}
518
519impl Display for Variable {
520    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
521        match self.kind {
522            VariableKind::GlobalInputArray(id) => write!(f, "input({id})"),
523            VariableKind::GlobalOutputArray(id) => write!(f, "output({id})"),
524            VariableKind::GlobalScalar(id) => write!(f, "scalar<{}>({id})", self.ty),
525            VariableKind::TensorMapInput(id) => write!(f, "tensor_map({id})"),
526            VariableKind::TensorMapOutput(id) => write!(f, "tensor_map({id})"),
527            VariableKind::Constant(constant) => write!(f, "{}({constant})", self.ty),
528            VariableKind::LocalMut { id } => write!(f, "local({id})"),
529            VariableKind::Versioned { id, version } => {
530                write!(f, "local({id}).v{version}")
531            }
532            VariableKind::LocalConst { id } => write!(f, "binding({id})"),
533            VariableKind::ConstantArray { id, .. } => write!(f, "const_array({id})"),
534            VariableKind::SharedArray { id, .. } => write!(f, "shared_array({id})"),
535            VariableKind::Shared { id } => write!(f, "shared({id})"),
536            VariableKind::LocalArray { id, .. } => write!(f, "array({id})"),
537            VariableKind::Matrix { id, .. } => write!(f, "matrix({id})"),
538            VariableKind::Builtin(builtin) => write!(f, "{builtin:?}"),
539            VariableKind::Pipeline { id, .. } => write!(f, "pipeline({id})"),
540            VariableKind::BarrierToken { id, .. } => write!(f, "barrier_token({id})"),
541        }
542    }
543}
544
545// Useful with the cube_inline macro.
546impl From<&Variable> for Variable {
547    fn from(value: &Variable) -> Self {
548        *value
549    }
550}