cubecl_ir/
variable.rs

1use core::num::NonZero;
2use core::{fmt::Display, hash::Hash};
3
4use crate::{BarrierLevel, TypeHash};
5
6use super::{Elem, FloatKind, IntKind, Item, Matrix, UIntKind};
7use float_ord::FloatOrd;
8
9#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
10#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
11#[allow(missing_docs)]
12pub struct Variable {
13    pub kind: VariableKind,
14    pub item: Item,
15}
16
17impl Variable {
18    pub fn new(kind: VariableKind, item: Item) -> Self {
19        Self { kind, item }
20    }
21
22    pub fn builtin(builtin: Builtin) -> Self {
23        Self::new(
24            VariableKind::Builtin(builtin),
25            Item::new(Elem::UInt(UIntKind::U32)),
26        )
27    }
28
29    pub fn constant(scalar: ConstantScalarValue) -> Self {
30        let elem = match scalar {
31            ConstantScalarValue::Int(_, int_kind) => Elem::Int(int_kind),
32            ConstantScalarValue::Float(_, float_kind) => Elem::Float(float_kind),
33            ConstantScalarValue::UInt(_, kind) => Elem::UInt(kind),
34            ConstantScalarValue::Bool(_) => Elem::Bool,
35        };
36        Self::new(VariableKind::ConstantScalar(scalar), Item::new(elem))
37    }
38
39    pub fn elem(&self) -> Elem {
40        self.item.elem
41    }
42}
43
44pub type Id = u32;
45
46#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
47#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
48pub enum VariableKind {
49    GlobalInputArray(Id),
50    GlobalOutputArray(Id),
51    GlobalScalar(Id),
52    TensorMap(Id),
53    LocalArray {
54        id: Id,
55        length: u32,
56    },
57    LocalMut {
58        id: Id,
59    },
60    LocalConst {
61        id: Id,
62    },
63    Versioned {
64        id: Id,
65        version: u16,
66    },
67    ConstantScalar(ConstantScalarValue),
68    ConstantArray {
69        id: Id,
70        length: u32,
71    },
72    SharedMemory {
73        id: Id,
74        length: u32,
75        alignment: Option<u32>,
76    },
77    Matrix {
78        id: Id,
79        mat: Matrix,
80    },
81    Slice {
82        id: Id,
83    },
84    Builtin(Builtin),
85    Pipeline {
86        id: Id,
87        item: Item,
88        num_stages: u8,
89    },
90    Barrier {
91        id: Id,
92        item: Item,
93        level: BarrierLevel,
94    },
95}
96
97#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
98#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash, PartialOrd, Ord)]
99pub enum Builtin {
100    UnitPos,
101    UnitPosX,
102    UnitPosY,
103    UnitPosZ,
104    CubePosCluster,
105    CubePosClusterX,
106    CubePosClusterY,
107    CubePosClusterZ,
108    CubePos,
109    CubePosX,
110    CubePosY,
111    CubePosZ,
112    CubeDim,
113    CubeDimX,
114    CubeDimY,
115    CubeDimZ,
116    CubeClusterDim,
117    CubeClusterDimX,
118    CubeClusterDimY,
119    CubeClusterDimZ,
120    CubeCount,
121    CubeCountX,
122    CubeCountY,
123    CubeCountZ,
124    PlaneDim,
125    UnitPosPlane,
126    AbsolutePos,
127    AbsolutePosX,
128    AbsolutePosY,
129    AbsolutePosZ,
130}
131
132impl Variable {
133    /// Whether a variable is always immutable. Used for optimizations to determine whether it's
134    /// safe to inline/merge
135    pub fn is_immutable(&self) -> bool {
136        match self.kind {
137            VariableKind::GlobalOutputArray { .. } => false,
138            VariableKind::TensorMap(_) => false,
139            VariableKind::LocalMut { .. } => false,
140            VariableKind::SharedMemory { .. } => false,
141            VariableKind::Matrix { .. } => false,
142            VariableKind::Slice { .. } => false,
143            VariableKind::LocalArray { .. } => false,
144            VariableKind::GlobalInputArray { .. } => false,
145            VariableKind::GlobalScalar { .. } => true,
146            VariableKind::Versioned { .. } => true,
147            VariableKind::LocalConst { .. } => true,
148            VariableKind::ConstantScalar(_) => true,
149            VariableKind::ConstantArray { .. } => true,
150            VariableKind::Builtin(_) => true,
151            VariableKind::Pipeline { .. } => false,
152            VariableKind::Barrier { .. } => false,
153        }
154    }
155
156    /// Is this an array type that yields [`Item`]s when indexed, or a scalar/vector that yields
157    /// [`Elem`]s when indexed?
158    pub fn is_array(&self) -> bool {
159        matches!(
160            self.kind,
161            VariableKind::GlobalInputArray { .. }
162                | VariableKind::GlobalOutputArray { .. }
163                | VariableKind::ConstantArray { .. }
164                | VariableKind::SharedMemory { .. }
165                | VariableKind::LocalArray { .. }
166                | VariableKind::Matrix { .. }
167                | VariableKind::Slice { .. }
168        )
169    }
170
171    pub fn has_length(&self) -> bool {
172        matches!(
173            self.kind,
174            VariableKind::GlobalInputArray { .. }
175                | VariableKind::GlobalOutputArray { .. }
176                | VariableKind::Slice { .. }
177        )
178    }
179
180    pub fn has_buffer_length(&self) -> bool {
181        matches!(
182            self.kind,
183            VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
184        )
185    }
186
187    /// Determines if the value is a constant with the specified value (converted if necessary)
188    pub fn is_constant(&self, value: i64) -> bool {
189        match self.kind {
190            VariableKind::ConstantScalar(ConstantScalarValue::Int(val, _)) => val == value,
191            VariableKind::ConstantScalar(ConstantScalarValue::UInt(val, _)) => val as i64 == value,
192            VariableKind::ConstantScalar(ConstantScalarValue::Float(val, _)) => val == value as f64,
193            _ => false,
194        }
195    }
196
197    /// Determines if the value is a boolean constant with the `true` value
198    pub fn is_true(&self) -> bool {
199        match self.kind {
200            VariableKind::ConstantScalar(ConstantScalarValue::Bool(val)) => val,
201            _ => false,
202        }
203    }
204
205    /// Determines if the value is a boolean constant with the `false` value
206    pub fn is_false(&self) -> bool {
207        match self.kind {
208            VariableKind::ConstantScalar(ConstantScalarValue::Bool(val)) => !val,
209            _ => false,
210        }
211    }
212}
213
214/// The scalars are stored with the highest precision possible, but they might get reduced during
215/// compilation.
216#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
217#[derive(Debug, Clone, Copy, TypeHash, PartialEq, PartialOrd)]
218#[allow(missing_docs)]
219pub enum ConstantScalarValue {
220    Int(i64, IntKind),
221    Float(f64, FloatKind),
222    UInt(u64, UIntKind),
223    Bool(bool),
224}
225
226impl Eq for ConstantScalarValue {}
227impl Hash for ConstantScalarValue {
228    fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
229        core::mem::discriminant(self).hash(ra_expand_state);
230        match self {
231            ConstantScalarValue::Int(f0, f1) => {
232                f0.hash(ra_expand_state);
233                f1.hash(ra_expand_state);
234            }
235            ConstantScalarValue::Float(f0, f1) => {
236                FloatOrd(*f0).hash(ra_expand_state);
237                f1.hash(ra_expand_state);
238            }
239            ConstantScalarValue::UInt(f0, f1) => {
240                f0.hash(ra_expand_state);
241                f1.hash(ra_expand_state);
242            }
243            ConstantScalarValue::Bool(f0) => {
244                f0.hash(ra_expand_state);
245            }
246        }
247    }
248}
249
250impl ConstantScalarValue {
251    /// Returns the element type of the scalar.
252    pub fn elem(&self) -> Elem {
253        match self {
254            ConstantScalarValue::Int(_, kind) => Elem::Int(*kind),
255            ConstantScalarValue::Float(_, kind) => Elem::Float(*kind),
256            ConstantScalarValue::UInt(_, kind) => Elem::UInt(*kind),
257            ConstantScalarValue::Bool(_) => Elem::Bool,
258        }
259    }
260
261    /// Returns the value of the scalar as a usize.
262    ///
263    /// It will return [None] if the scalar type is a float or a bool.
264    pub fn try_as_usize(&self) -> Option<usize> {
265        match self {
266            ConstantScalarValue::UInt(val, _) => Some(*val as usize),
267            ConstantScalarValue::Int(val, _) => Some(*val as usize),
268            ConstantScalarValue::Float(_, _) => None,
269            ConstantScalarValue::Bool(_) => None,
270        }
271    }
272
273    /// Returns the value of the scalar as a usize.
274    ///
275    /// It will panic if the scalar type is a float or a bool.
276    pub fn as_usize(&self) -> usize {
277        self.try_as_usize()
278            .expect("Only Int and UInt kind can be made into usize.")
279    }
280
281    /// Returns the value of the scalar as a u32.
282    ///
283    /// It will return [None] if the scalar type is a float or a bool.
284    pub fn try_as_u32(&self) -> Option<u32> {
285        match self {
286            ConstantScalarValue::UInt(val, _) => Some(*val as u32),
287            ConstantScalarValue::Int(val, _) => Some(*val as u32),
288            ConstantScalarValue::Float(_, _) => None,
289            ConstantScalarValue::Bool(_) => None,
290        }
291    }
292
293    /// Returns the value of the scalar as a u32.
294    ///
295    /// It will panic if the scalar type is a float or a bool.
296    pub fn as_u32(&self) -> u32 {
297        self.try_as_u32()
298            .expect("Only Int and UInt kind can be made into u32.")
299    }
300
301    /// Returns the value of the scalar as a u64.
302    ///
303    /// It will return [None] if the scalar type is a float or a bool.
304    pub fn try_as_u64(&self) -> Option<u64> {
305        match self {
306            ConstantScalarValue::UInt(val, _) => Some(*val),
307            ConstantScalarValue::Int(val, _) => Some(*val as u64),
308            ConstantScalarValue::Float(_, _) => None,
309            ConstantScalarValue::Bool(_) => None,
310        }
311    }
312
313    /// Returns the value of the scalar as a u64.
314    ///
315    /// It will panic if the scalar type is a float or a bool.
316    pub fn as_u64(&self) -> u64 {
317        self.try_as_u64()
318            .expect("Only Int and UInt kind can be made into u64.")
319    }
320
321    /// Returns the value of the scalar as a i64.
322    ///
323    /// It will return [None] if the scalar type is a float or a bool.
324    pub fn try_as_i64(&self) -> Option<i64> {
325        match self {
326            ConstantScalarValue::UInt(val, _) => Some(*val as i64),
327            ConstantScalarValue::Int(val, _) => Some(*val),
328            ConstantScalarValue::Float(_, _) => None,
329            ConstantScalarValue::Bool(_) => None,
330        }
331    }
332
333    /// Returns the value of the scalar as a u32.
334    ///
335    /// It will panic if the scalar type is a float or a bool.
336    pub fn as_i64(&self) -> i64 {
337        self.try_as_i64()
338            .expect("Only Int and UInt kind can be made into i64.")
339    }
340
341    /// Returns the value of the variable as a bool if it actually is a bool.
342    pub fn try_as_bool(&self) -> Option<bool> {
343        match self {
344            ConstantScalarValue::Bool(val) => Some(*val),
345            _ => None,
346        }
347    }
348
349    /// Returns the value of the variable as a bool.
350    ///
351    /// It will panic if the scalar isn't a bool.
352    pub fn as_bool(&self) -> bool {
353        self.try_as_bool()
354            .expect("Only bool can be made into a bool")
355    }
356
357    pub fn is_zero(&self) -> bool {
358        match self {
359            ConstantScalarValue::Int(val, _) => *val == 0,
360            ConstantScalarValue::Float(val, _) => *val == 0.0,
361            ConstantScalarValue::UInt(val, _) => *val == 0,
362            ConstantScalarValue::Bool(_) => false,
363        }
364    }
365
366    pub fn is_one(&self) -> bool {
367        match self {
368            ConstantScalarValue::Int(val, _) => *val == 1,
369            ConstantScalarValue::Float(val, _) => *val == 1.0,
370            ConstantScalarValue::UInt(val, _) => *val == 1,
371            ConstantScalarValue::Bool(_) => false,
372        }
373    }
374
375    pub fn cast_to(&self, other: Elem) -> ConstantScalarValue {
376        match (self, other) {
377            (ConstantScalarValue::Int(val, _), Elem::Float(float_kind)) => {
378                ConstantScalarValue::Float(*val as f64, float_kind)
379            }
380            (ConstantScalarValue::Int(val, _), Elem::Int(int_kind)) => {
381                ConstantScalarValue::Int(*val, int_kind)
382            }
383            (ConstantScalarValue::Int(val, _), Elem::UInt(kind)) => {
384                ConstantScalarValue::UInt(*val as u64, kind)
385            }
386            (ConstantScalarValue::Int(val, _), Elem::Bool) => ConstantScalarValue::Bool(*val == 1),
387            (ConstantScalarValue::Float(val, _), Elem::Float(float_kind)) => {
388                ConstantScalarValue::Float(*val, float_kind)
389            }
390            (ConstantScalarValue::Float(val, _), Elem::Int(int_kind)) => {
391                ConstantScalarValue::Int(*val as i64, int_kind)
392            }
393            (ConstantScalarValue::Float(val, _), Elem::UInt(kind)) => {
394                ConstantScalarValue::UInt(*val as u64, kind)
395            }
396            (ConstantScalarValue::Float(val, _), Elem::Bool) => {
397                ConstantScalarValue::Bool(*val == 0.0)
398            }
399            (ConstantScalarValue::UInt(val, _), Elem::Float(float_kind)) => {
400                ConstantScalarValue::Float(*val as f64, float_kind)
401            }
402            (ConstantScalarValue::UInt(val, _), Elem::Int(int_kind)) => {
403                ConstantScalarValue::Int(*val as i64, int_kind)
404            }
405            (ConstantScalarValue::UInt(val, _), Elem::UInt(kind)) => {
406                ConstantScalarValue::UInt(*val, kind)
407            }
408            (ConstantScalarValue::UInt(val, _), Elem::Bool) => ConstantScalarValue::Bool(*val == 1),
409            (ConstantScalarValue::Bool(val), Elem::Float(float_kind)) => {
410                ConstantScalarValue::Float(*val as u32 as f64, float_kind)
411            }
412            (ConstantScalarValue::Bool(val), Elem::Int(int_kind)) => {
413                ConstantScalarValue::Int(*val as i64, int_kind)
414            }
415            (ConstantScalarValue::Bool(val), Elem::UInt(kind)) => {
416                ConstantScalarValue::UInt(*val as u64, kind)
417            }
418            (ConstantScalarValue::Bool(val), Elem::Bool) => ConstantScalarValue::Bool(*val),
419            _ => unreachable!(),
420        }
421    }
422}
423
424impl Display for ConstantScalarValue {
425    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
426        match self {
427            ConstantScalarValue::Int(val, IntKind::I8) => write!(f, "{val}i8"),
428            ConstantScalarValue::Int(val, IntKind::I16) => write!(f, "{val}i16"),
429            ConstantScalarValue::Int(val, IntKind::I32) => write!(f, "{val}i32"),
430            ConstantScalarValue::Int(val, IntKind::I64) => write!(f, "{val}i64"),
431            ConstantScalarValue::Float(val, FloatKind::BF16) => write!(f, "{val}bf16"),
432            ConstantScalarValue::Float(val, FloatKind::F16) => write!(f, "{val}f16"),
433            ConstantScalarValue::Float(val, FloatKind::TF32) => write!(f, "{val}tf32"),
434            ConstantScalarValue::Float(val, FloatKind::Flex32) => write!(f, "{val}flex32"),
435            ConstantScalarValue::Float(val, FloatKind::F32) => write!(f, "{val}f32"),
436            ConstantScalarValue::Float(val, FloatKind::F64) => write!(f, "{val}f64"),
437            ConstantScalarValue::UInt(val, UIntKind::U8) => write!(f, "{val}u8"),
438            ConstantScalarValue::UInt(val, UIntKind::U16) => write!(f, "{val}u16"),
439            ConstantScalarValue::UInt(val, UIntKind::U32) => write!(f, "{val}u32"),
440            ConstantScalarValue::UInt(val, UIntKind::U64) => write!(f, "{val}u64"),
441            ConstantScalarValue::Bool(val) => write!(f, "{val}"),
442        }
443    }
444}
445
446impl Variable {
447    pub fn vectorization_factor(&self) -> u8 {
448        self.item.vectorization.map(NonZero::get).unwrap_or(1u8)
449    }
450
451    pub fn index(&self) -> Option<Id> {
452        match self.kind {
453            VariableKind::GlobalInputArray(id)
454            | VariableKind::GlobalOutputArray(id)
455            | VariableKind::TensorMap(id)
456            | VariableKind::GlobalScalar(id)
457            | VariableKind::LocalMut { id, .. }
458            | VariableKind::Versioned { id, .. }
459            | VariableKind::LocalConst { id, .. }
460            | VariableKind::Slice { id, .. }
461            | VariableKind::ConstantArray { id, .. }
462            | VariableKind::SharedMemory { id, .. }
463            | VariableKind::LocalArray { id, .. }
464            | VariableKind::Matrix { id, .. } => Some(id),
465            _ => None,
466        }
467    }
468
469    pub fn as_const(&self) -> Option<ConstantScalarValue> {
470        match self.kind {
471            VariableKind::ConstantScalar(constant) => Some(constant),
472            _ => None,
473        }
474    }
475}
476
477impl Display for Variable {
478    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
479        match self.kind {
480            VariableKind::GlobalInputArray(id) => write!(f, "input({id})"),
481            VariableKind::GlobalOutputArray(id) => write!(f, "output({id})"),
482            VariableKind::GlobalScalar(id) => write!(f, "scalar({id})"),
483            VariableKind::TensorMap(id) => write!(f, "tensor_map({id})"),
484            VariableKind::ConstantScalar(constant) => write!(f, "{constant}"),
485            VariableKind::LocalMut { id } => write!(f, "local({id})"),
486            VariableKind::Versioned { id, version } => {
487                write!(f, "local({id}).v{version}")
488            }
489            VariableKind::LocalConst { id } => write!(f, "binding({id})"),
490            VariableKind::ConstantArray { id, .. } => write!(f, "const_array({id})"),
491            VariableKind::SharedMemory { id, .. } => write!(f, "shared({id})"),
492            VariableKind::LocalArray { id, .. } => write!(f, "array({id})"),
493            VariableKind::Matrix { id, .. } => write!(f, "matrix({id})"),
494            VariableKind::Slice { id } => write!(f, "slice({id})"),
495            VariableKind::Builtin(builtin) => write!(f, "{builtin:?}"),
496            VariableKind::Pipeline { id, .. } => write!(f, "pipeline({id})"),
497            VariableKind::Barrier { id, .. } => write!(f, "barrier({id})"),
498        }
499    }
500}
501
502// Useful with the cube_inline macro.
503impl From<&Variable> for Variable {
504    fn from(value: &Variable) -> Self {
505        *value
506    }
507}