cubecl_ir/
variable.rs

1use core::num::NonZero;
2use core::{fmt::Display, hash::Hash};
3
4use crate::{BarrierLevel, TypeHash};
5
6use super::{Elem, FloatKind, IntKind, Item, Matrix, UIntKind};
7use float_ord::FloatOrd;
8
9#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
10#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
11#[allow(missing_docs)]
12pub struct Variable {
13    pub kind: VariableKind,
14    pub item: Item,
15}
16
17impl Variable {
18    pub fn new(kind: VariableKind, item: Item) -> Self {
19        Self { kind, item }
20    }
21
22    pub fn builtin(builtin: Builtin) -> Self {
23        Self::new(
24            VariableKind::Builtin(builtin),
25            Item::new(Elem::UInt(UIntKind::U32)),
26        )
27    }
28
29    pub fn constant(scalar: ConstantScalarValue) -> Self {
30        let elem = match scalar {
31            ConstantScalarValue::Int(_, int_kind) => Elem::Int(int_kind),
32            ConstantScalarValue::Float(_, float_kind) => Elem::Float(float_kind),
33            ConstantScalarValue::UInt(_, kind) => Elem::UInt(kind),
34            ConstantScalarValue::Bool(_) => Elem::Bool,
35        };
36        Self::new(VariableKind::ConstantScalar(scalar), Item::new(elem))
37    }
38
39    pub fn elem(&self) -> Elem {
40        self.item.elem
41    }
42}
43
44pub type Id = u32;
45
46#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
47#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
48pub enum VariableKind {
49    GlobalInputArray(Id),
50    GlobalOutputArray(Id),
51    GlobalScalar(Id),
52    TensorMap(Id),
53    LocalArray {
54        id: Id,
55        length: u32,
56    },
57    LocalMut {
58        id: Id,
59    },
60    LocalConst {
61        id: Id,
62    },
63    Versioned {
64        id: Id,
65        version: u16,
66    },
67    ConstantScalar(ConstantScalarValue),
68    ConstantArray {
69        id: Id,
70        length: u32,
71    },
72    SharedMemory {
73        id: Id,
74        length: u32,
75        alignment: Option<u32>,
76    },
77    Matrix {
78        id: Id,
79        mat: Matrix,
80    },
81    Builtin(Builtin),
82    Pipeline {
83        id: Id,
84        item: Item,
85        num_stages: u8,
86    },
87    Barrier {
88        id: Id,
89        item: Item,
90        level: BarrierLevel,
91    },
92}
93
94#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
95#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash, PartialOrd, Ord)]
96pub enum Builtin {
97    UnitPos,
98    UnitPosX,
99    UnitPosY,
100    UnitPosZ,
101    CubePosCluster,
102    CubePosClusterX,
103    CubePosClusterY,
104    CubePosClusterZ,
105    CubePos,
106    CubePosX,
107    CubePosY,
108    CubePosZ,
109    CubeDim,
110    CubeDimX,
111    CubeDimY,
112    CubeDimZ,
113    CubeClusterDim,
114    CubeClusterDimX,
115    CubeClusterDimY,
116    CubeClusterDimZ,
117    CubeCount,
118    CubeCountX,
119    CubeCountY,
120    CubeCountZ,
121    PlaneDim,
122    UnitPosPlane,
123    AbsolutePos,
124    AbsolutePosX,
125    AbsolutePosY,
126    AbsolutePosZ,
127}
128
129impl Variable {
130    /// Whether a variable is always immutable. Used for optimizations to determine whether it's
131    /// safe to inline/merge
132    pub fn is_immutable(&self) -> bool {
133        match self.kind {
134            VariableKind::GlobalOutputArray { .. } => false,
135            VariableKind::TensorMap(_) => false,
136            VariableKind::LocalMut { .. } => false,
137            VariableKind::SharedMemory { .. } => false,
138            VariableKind::Matrix { .. } => false,
139            VariableKind::LocalArray { .. } => false,
140            VariableKind::GlobalInputArray { .. } => false,
141            VariableKind::GlobalScalar { .. } => true,
142            VariableKind::Versioned { .. } => true,
143            VariableKind::LocalConst { .. } => true,
144            VariableKind::ConstantScalar(_) => true,
145            VariableKind::ConstantArray { .. } => true,
146            VariableKind::Builtin(_) => true,
147            VariableKind::Pipeline { .. } => false,
148            VariableKind::Barrier { .. } => false,
149        }
150    }
151
152    /// Is this an array type that yields [`Item`]s when indexed, or a scalar/vector that yields
153    /// [`Elem`]s when indexed?
154    pub fn is_array(&self) -> bool {
155        matches!(
156            self.kind,
157            VariableKind::GlobalInputArray { .. }
158                | VariableKind::GlobalOutputArray { .. }
159                | VariableKind::ConstantArray { .. }
160                | VariableKind::SharedMemory { .. }
161                | VariableKind::LocalArray { .. }
162                | VariableKind::Matrix { .. }
163        )
164    }
165
166    pub fn has_length(&self) -> bool {
167        matches!(
168            self.kind,
169            VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
170        )
171    }
172
173    pub fn has_buffer_length(&self) -> bool {
174        matches!(
175            self.kind,
176            VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
177        )
178    }
179
180    /// Determines if the value is a constant with the specified value (converted if necessary)
181    pub fn is_constant(&self, value: i64) -> bool {
182        match self.kind {
183            VariableKind::ConstantScalar(ConstantScalarValue::Int(val, _)) => val == value,
184            VariableKind::ConstantScalar(ConstantScalarValue::UInt(val, _)) => val as i64 == value,
185            VariableKind::ConstantScalar(ConstantScalarValue::Float(val, _)) => val == value as f64,
186            _ => false,
187        }
188    }
189
190    /// Determines if the value is a boolean constant with the `true` value
191    pub fn is_true(&self) -> bool {
192        match self.kind {
193            VariableKind::ConstantScalar(ConstantScalarValue::Bool(val)) => val,
194            _ => false,
195        }
196    }
197
198    /// Determines if the value is a boolean constant with the `false` value
199    pub fn is_false(&self) -> bool {
200        match self.kind {
201            VariableKind::ConstantScalar(ConstantScalarValue::Bool(val)) => !val,
202            _ => false,
203        }
204    }
205}
206
207/// The scalars are stored with the highest precision possible, but they might get reduced during
208/// compilation.
209#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
210#[derive(Debug, Clone, Copy, TypeHash, PartialEq, PartialOrd)]
211#[allow(missing_docs)]
212pub enum ConstantScalarValue {
213    Int(i64, IntKind),
214    Float(f64, FloatKind),
215    UInt(u64, UIntKind),
216    Bool(bool),
217}
218
219impl Eq for ConstantScalarValue {}
220impl Hash for ConstantScalarValue {
221    fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
222        core::mem::discriminant(self).hash(ra_expand_state);
223        match self {
224            ConstantScalarValue::Int(f0, f1) => {
225                f0.hash(ra_expand_state);
226                f1.hash(ra_expand_state);
227            }
228            ConstantScalarValue::Float(f0, f1) => {
229                FloatOrd(*f0).hash(ra_expand_state);
230                f1.hash(ra_expand_state);
231            }
232            ConstantScalarValue::UInt(f0, f1) => {
233                f0.hash(ra_expand_state);
234                f1.hash(ra_expand_state);
235            }
236            ConstantScalarValue::Bool(f0) => {
237                f0.hash(ra_expand_state);
238            }
239        }
240    }
241}
242
243impl ConstantScalarValue {
244    /// Returns the element type of the scalar.
245    pub fn elem(&self) -> Elem {
246        match self {
247            ConstantScalarValue::Int(_, kind) => Elem::Int(*kind),
248            ConstantScalarValue::Float(_, kind) => Elem::Float(*kind),
249            ConstantScalarValue::UInt(_, kind) => Elem::UInt(*kind),
250            ConstantScalarValue::Bool(_) => Elem::Bool,
251        }
252    }
253
254    /// Returns the value of the scalar as a usize.
255    ///
256    /// It will return [None] if the scalar type is a float or a bool.
257    pub fn try_as_usize(&self) -> Option<usize> {
258        match self {
259            ConstantScalarValue::UInt(val, _) => Some(*val as usize),
260            ConstantScalarValue::Int(val, _) => Some(*val as usize),
261            ConstantScalarValue::Float(_, _) => None,
262            ConstantScalarValue::Bool(_) => None,
263        }
264    }
265
266    /// Returns the value of the scalar as a usize.
267    ///
268    /// It will panic if the scalar type is a float or a bool.
269    pub fn as_usize(&self) -> usize {
270        self.try_as_usize()
271            .expect("Only Int and UInt kind can be made into usize.")
272    }
273
274    /// Returns the value of the scalar as a u32.
275    ///
276    /// It will return [None] if the scalar type is a float or a bool.
277    pub fn try_as_u32(&self) -> Option<u32> {
278        match self {
279            ConstantScalarValue::UInt(val, _) => Some(*val as u32),
280            ConstantScalarValue::Int(val, _) => Some(*val as u32),
281            ConstantScalarValue::Float(_, _) => None,
282            ConstantScalarValue::Bool(_) => None,
283        }
284    }
285
286    /// Returns the value of the scalar as a u32.
287    ///
288    /// It will panic if the scalar type is a float or a bool.
289    pub fn as_u32(&self) -> u32 {
290        self.try_as_u32()
291            .expect("Only Int and UInt kind can be made into u32.")
292    }
293
294    /// Returns the value of the scalar as a u64.
295    ///
296    /// It will return [None] if the scalar type is a float or a bool.
297    pub fn try_as_u64(&self) -> Option<u64> {
298        match self {
299            ConstantScalarValue::UInt(val, _) => Some(*val),
300            ConstantScalarValue::Int(val, _) => Some(*val as u64),
301            ConstantScalarValue::Float(_, _) => None,
302            ConstantScalarValue::Bool(_) => None,
303        }
304    }
305
306    /// Returns the value of the scalar as a u64.
307    ///
308    /// It will panic if the scalar type is a float or a bool.
309    pub fn as_u64(&self) -> u64 {
310        self.try_as_u64()
311            .expect("Only Int and UInt kind can be made into u64.")
312    }
313
314    /// Returns the value of the scalar as a i64.
315    ///
316    /// It will return [None] if the scalar type is a float or a bool.
317    pub fn try_as_i64(&self) -> Option<i64> {
318        match self {
319            ConstantScalarValue::UInt(val, _) => Some(*val as i64),
320            ConstantScalarValue::Int(val, _) => Some(*val),
321            ConstantScalarValue::Float(_, _) => None,
322            ConstantScalarValue::Bool(_) => None,
323        }
324    }
325
326    /// Returns the value of the scalar as a u32.
327    ///
328    /// It will panic if the scalar type is a float or a bool.
329    pub fn as_i64(&self) -> i64 {
330        self.try_as_i64()
331            .expect("Only Int and UInt kind can be made into i64.")
332    }
333
334    /// Returns the value of the variable as a bool if it actually is a bool.
335    pub fn try_as_bool(&self) -> Option<bool> {
336        match self {
337            ConstantScalarValue::Bool(val) => Some(*val),
338            _ => None,
339        }
340    }
341
342    /// Returns the value of the variable as a bool.
343    ///
344    /// It will panic if the scalar isn't a bool.
345    pub fn as_bool(&self) -> bool {
346        self.try_as_bool()
347            .expect("Only bool can be made into a bool")
348    }
349
350    pub fn is_zero(&self) -> bool {
351        match self {
352            ConstantScalarValue::Int(val, _) => *val == 0,
353            ConstantScalarValue::Float(val, _) => *val == 0.0,
354            ConstantScalarValue::UInt(val, _) => *val == 0,
355            ConstantScalarValue::Bool(_) => false,
356        }
357    }
358
359    pub fn is_one(&self) -> bool {
360        match self {
361            ConstantScalarValue::Int(val, _) => *val == 1,
362            ConstantScalarValue::Float(val, _) => *val == 1.0,
363            ConstantScalarValue::UInt(val, _) => *val == 1,
364            ConstantScalarValue::Bool(_) => false,
365        }
366    }
367
368    pub fn cast_to(&self, other: Elem) -> ConstantScalarValue {
369        match (self, other) {
370            (ConstantScalarValue::Int(val, _), Elem::Float(float_kind)) => {
371                ConstantScalarValue::Float(*val as f64, float_kind)
372            }
373            (ConstantScalarValue::Int(val, _), Elem::Int(int_kind)) => {
374                ConstantScalarValue::Int(*val, int_kind)
375            }
376            (ConstantScalarValue::Int(val, _), Elem::UInt(kind)) => {
377                ConstantScalarValue::UInt(*val as u64, kind)
378            }
379            (ConstantScalarValue::Int(val, _), Elem::Bool) => ConstantScalarValue::Bool(*val == 1),
380            (ConstantScalarValue::Float(val, _), Elem::Float(float_kind)) => {
381                ConstantScalarValue::Float(*val, float_kind)
382            }
383            (ConstantScalarValue::Float(val, _), Elem::Int(int_kind)) => {
384                ConstantScalarValue::Int(*val as i64, int_kind)
385            }
386            (ConstantScalarValue::Float(val, _), Elem::UInt(kind)) => {
387                ConstantScalarValue::UInt(*val as u64, kind)
388            }
389            (ConstantScalarValue::Float(val, _), Elem::Bool) => {
390                ConstantScalarValue::Bool(*val == 0.0)
391            }
392            (ConstantScalarValue::UInt(val, _), Elem::Float(float_kind)) => {
393                ConstantScalarValue::Float(*val as f64, float_kind)
394            }
395            (ConstantScalarValue::UInt(val, _), Elem::Int(int_kind)) => {
396                ConstantScalarValue::Int(*val as i64, int_kind)
397            }
398            (ConstantScalarValue::UInt(val, _), Elem::UInt(kind)) => {
399                ConstantScalarValue::UInt(*val, kind)
400            }
401            (ConstantScalarValue::UInt(val, _), Elem::Bool) => ConstantScalarValue::Bool(*val == 1),
402            (ConstantScalarValue::Bool(val), Elem::Float(float_kind)) => {
403                ConstantScalarValue::Float(*val as u32 as f64, float_kind)
404            }
405            (ConstantScalarValue::Bool(val), Elem::Int(int_kind)) => {
406                ConstantScalarValue::Int(*val as i64, int_kind)
407            }
408            (ConstantScalarValue::Bool(val), Elem::UInt(kind)) => {
409                ConstantScalarValue::UInt(*val as u64, kind)
410            }
411            (ConstantScalarValue::Bool(val), Elem::Bool) => ConstantScalarValue::Bool(*val),
412            _ => unreachable!(),
413        }
414    }
415}
416
417impl Display for ConstantScalarValue {
418    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
419        match self {
420            ConstantScalarValue::Int(val, IntKind::I8) => write!(f, "{val}i8"),
421            ConstantScalarValue::Int(val, IntKind::I16) => write!(f, "{val}i16"),
422            ConstantScalarValue::Int(val, IntKind::I32) => write!(f, "{val}i32"),
423            ConstantScalarValue::Int(val, IntKind::I64) => write!(f, "{val}i64"),
424            ConstantScalarValue::Float(val, FloatKind::E2M1) => write!(f, "{val}e2m1"),
425            ConstantScalarValue::Float(val, FloatKind::E2M3) => write!(f, "{val}e2m3"),
426            ConstantScalarValue::Float(val, FloatKind::E3M2) => write!(f, "{val}e3m2"),
427            ConstantScalarValue::Float(val, FloatKind::E4M3) => write!(f, "{val}e4m3"),
428            ConstantScalarValue::Float(val, FloatKind::E5M2) => write!(f, "{val}e5m2"),
429            ConstantScalarValue::Float(val, FloatKind::UE8M0) => write!(f, "{val}ue8m0"),
430            ConstantScalarValue::Float(val, FloatKind::BF16) => write!(f, "{val}bf16"),
431            ConstantScalarValue::Float(val, FloatKind::F16) => write!(f, "{val}f16"),
432            ConstantScalarValue::Float(val, FloatKind::TF32) => write!(f, "{val}tf32"),
433            ConstantScalarValue::Float(val, FloatKind::Flex32) => write!(f, "{val}flex32"),
434            ConstantScalarValue::Float(val, FloatKind::F32) => write!(f, "{val}f32"),
435            ConstantScalarValue::Float(val, FloatKind::F64) => write!(f, "{val}f64"),
436            ConstantScalarValue::UInt(val, UIntKind::U8) => write!(f, "{val}u8"),
437            ConstantScalarValue::UInt(val, UIntKind::U16) => write!(f, "{val}u16"),
438            ConstantScalarValue::UInt(val, UIntKind::U32) => write!(f, "{val}u32"),
439            ConstantScalarValue::UInt(val, UIntKind::U64) => write!(f, "{val}u64"),
440            ConstantScalarValue::Bool(val) => write!(f, "{val}"),
441        }
442    }
443}
444
445impl Variable {
446    pub fn vectorization_factor(&self) -> u8 {
447        self.item.vectorization.map(NonZero::get).unwrap_or(1u8)
448    }
449
450    pub fn index(&self) -> Option<Id> {
451        match self.kind {
452            VariableKind::GlobalInputArray(id)
453            | VariableKind::GlobalOutputArray(id)
454            | VariableKind::TensorMap(id)
455            | VariableKind::GlobalScalar(id)
456            | VariableKind::LocalMut { id, .. }
457            | VariableKind::Versioned { id, .. }
458            | VariableKind::LocalConst { id, .. }
459            | VariableKind::ConstantArray { id, .. }
460            | VariableKind::SharedMemory { id, .. }
461            | VariableKind::LocalArray { id, .. }
462            | VariableKind::Matrix { id, .. } => Some(id),
463            _ => None,
464        }
465    }
466
467    pub fn as_const(&self) -> Option<ConstantScalarValue> {
468        match self.kind {
469            VariableKind::ConstantScalar(constant) => Some(constant),
470            _ => None,
471        }
472    }
473}
474
475impl Display for Variable {
476    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
477        match self.kind {
478            VariableKind::GlobalInputArray(id) => write!(f, "input({id})"),
479            VariableKind::GlobalOutputArray(id) => write!(f, "output({id})"),
480            VariableKind::GlobalScalar(id) => write!(f, "scalar({id})"),
481            VariableKind::TensorMap(id) => write!(f, "tensor_map({id})"),
482            VariableKind::ConstantScalar(constant) => write!(f, "{constant}"),
483            VariableKind::LocalMut { id } => write!(f, "local({id})"),
484            VariableKind::Versioned { id, version } => {
485                write!(f, "local({id}).v{version}")
486            }
487            VariableKind::LocalConst { id } => write!(f, "binding({id})"),
488            VariableKind::ConstantArray { id, .. } => write!(f, "const_array({id})"),
489            VariableKind::SharedMemory { id, .. } => write!(f, "shared({id})"),
490            VariableKind::LocalArray { id, .. } => write!(f, "array({id})"),
491            VariableKind::Matrix { id, .. } => write!(f, "matrix({id})"),
492            VariableKind::Builtin(builtin) => write!(f, "{builtin:?}"),
493            VariableKind::Pipeline { id, .. } => write!(f, "pipeline({id})"),
494            VariableKind::Barrier { id, .. } => write!(f, "barrier({id})"),
495        }
496    }
497}
498
499// Useful with the cube_inline macro.
500impl From<&Variable> for Variable {
501    fn from(value: &Variable) -> Self {
502        *value
503    }
504}