cubecl_core/ir/
variable.rs

1use std::fmt::Display;
2use std::num::NonZero;
3
4use super::{Elem, FloatKind, IntKind, Item, Matrix, UIntKind};
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
8#[allow(missing_docs)]
9pub struct Variable {
10    pub kind: VariableKind,
11    pub item: Item,
12}
13
14impl Variable {
15    pub fn new(kind: VariableKind, item: Item) -> Self {
16        Self { kind, item }
17    }
18
19    pub fn builtin(builtin: Builtin) -> Self {
20        Self::new(
21            VariableKind::Builtin(builtin),
22            Item::new(Elem::UInt(UIntKind::U32)),
23        )
24    }
25
26    pub fn constant(scalar: ConstantScalarValue) -> Self {
27        let elem = match scalar {
28            ConstantScalarValue::Int(_, int_kind) => Elem::Int(int_kind),
29            ConstantScalarValue::Float(_, float_kind) => Elem::Float(float_kind),
30            ConstantScalarValue::UInt(_, kind) => Elem::UInt(kind),
31            ConstantScalarValue::Bool(_) => Elem::Bool,
32        };
33        Self::new(VariableKind::ConstantScalar(scalar), Item::new(elem))
34    }
35}
36
37pub type Id = u32;
38
39#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
40pub enum VariableKind {
41    GlobalInputArray(Id),
42    GlobalOutputArray(Id),
43    GlobalScalar(Id),
44    LocalArray { id: Id, length: u32 },
45    LocalMut { id: Id },
46    LocalConst { id: Id },
47    Versioned { id: Id, version: u16 },
48    ConstantScalar(ConstantScalarValue),
49    ConstantArray { id: Id, length: u32 },
50    SharedMemory { id: Id, length: u32 },
51    Matrix { id: Id, mat: Matrix },
52    Slice { id: Id },
53    Builtin(Builtin),
54}
55
56#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord)]
57pub enum Builtin {
58    UnitPos,
59    UnitPosX,
60    UnitPosY,
61    UnitPosZ,
62    CubePos,
63    CubePosX,
64    CubePosY,
65    CubePosZ,
66    CubeDim,
67    CubeDimX,
68    CubeDimY,
69    CubeDimZ,
70    CubeCount,
71    CubeCountX,
72    CubeCountY,
73    CubeCountZ,
74    PlaneDim,
75    UnitPosPlane,
76    AbsolutePos,
77    AbsolutePosX,
78    AbsolutePosY,
79    AbsolutePosZ,
80}
81
82impl Variable {
83    /// Whether a variable is always immutable. Used for optimizations to determine whether it's
84    /// safe to inline/merge
85    pub fn is_immutable(&self) -> bool {
86        match self.kind {
87            VariableKind::GlobalOutputArray { .. } => false,
88            VariableKind::LocalMut { .. } => false,
89            VariableKind::SharedMemory { .. } => false,
90            VariableKind::Matrix { .. } => false,
91            VariableKind::Slice { .. } => false,
92            VariableKind::LocalArray { .. } => false,
93            VariableKind::GlobalInputArray { .. } => false,
94            VariableKind::GlobalScalar { .. } => true,
95            VariableKind::Versioned { .. } => true,
96            VariableKind::LocalConst { .. } => true,
97            VariableKind::ConstantScalar(_) => true,
98            VariableKind::ConstantArray { .. } => true,
99            VariableKind::Builtin(_) => true,
100        }
101    }
102
103    /// Is this an array type that yields [`Item`]s when indexed, or a scalar/vector that yields
104    /// [`Elem`]s when indexed?
105    pub fn is_array(&self) -> bool {
106        matches!(
107            self.kind,
108            VariableKind::GlobalInputArray { .. }
109                | VariableKind::GlobalOutputArray { .. }
110                | VariableKind::ConstantArray { .. }
111                | VariableKind::SharedMemory { .. }
112                | VariableKind::LocalArray { .. }
113                | VariableKind::Matrix { .. }
114                | VariableKind::Slice { .. }
115        )
116    }
117
118    pub fn has_length(&self) -> bool {
119        matches!(
120            self.kind,
121            VariableKind::GlobalInputArray { .. }
122                | VariableKind::GlobalOutputArray { .. }
123                | VariableKind::Slice { .. }
124        )
125    }
126
127    pub fn has_buffer_length(&self) -> bool {
128        matches!(
129            self.kind,
130            VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
131        )
132    }
133
134    /// Determines if the value is a constant with the specified value (converted if necessary)
135    pub fn is_constant(&self, value: i64) -> bool {
136        match self.kind {
137            VariableKind::ConstantScalar(ConstantScalarValue::Int(val, _)) => val == value,
138            VariableKind::ConstantScalar(ConstantScalarValue::UInt(val, _)) => val as i64 == value,
139            VariableKind::ConstantScalar(ConstantScalarValue::Float(val, _)) => val == value as f64,
140            _ => false,
141        }
142    }
143
144    /// Determines if the value is a boolean constant with the `true` value
145    pub fn is_true(&self) -> bool {
146        match self.kind {
147            VariableKind::ConstantScalar(ConstantScalarValue::Bool(val)) => val,
148            _ => false,
149        }
150    }
151
152    /// Determines if the value is a boolean constant with the `false` value
153    pub fn is_false(&self) -> bool {
154        match self.kind {
155            VariableKind::ConstantScalar(ConstantScalarValue::Bool(val)) => !val,
156            _ => false,
157        }
158    }
159}
160
161/// The scalars are stored with the highest precision possible, but they might get reduced during
162/// compilation.
163#[derive(Debug, Clone, PartialEq, Copy, Serialize, Deserialize, PartialOrd)]
164#[allow(missing_docs)]
165pub enum ConstantScalarValue {
166    Int(i64, IntKind),
167    Float(f64, FloatKind),
168    UInt(u64, UIntKind),
169    Bool(bool),
170}
171
172impl ConstantScalarValue {
173    /// Returns the element type of the scalar.
174    pub fn elem(&self) -> Elem {
175        match self {
176            ConstantScalarValue::Int(_, kind) => Elem::Int(*kind),
177            ConstantScalarValue::Float(_, kind) => Elem::Float(*kind),
178            ConstantScalarValue::UInt(_, kind) => Elem::UInt(*kind),
179            ConstantScalarValue::Bool(_) => Elem::Bool,
180        }
181    }
182
183    /// Returns the value of the scalar as a usize.
184    ///
185    /// It will return [None] if the scalar type is a float or a bool.
186    pub fn try_as_usize(&self) -> Option<usize> {
187        match self {
188            ConstantScalarValue::UInt(val, _) => Some(*val as usize),
189            ConstantScalarValue::Int(val, _) => Some(*val as usize),
190            ConstantScalarValue::Float(_, _) => None,
191            ConstantScalarValue::Bool(_) => None,
192        }
193    }
194
195    /// Returns the value of the scalar as a usize.
196    ///
197    /// It will panic if the scalar type is a float or a bool.
198    pub fn as_usize(&self) -> usize {
199        self.try_as_usize()
200            .expect("Only Int and UInt kind can be made into usize.")
201    }
202
203    /// Returns the value of the scalar as a u32.
204    ///
205    /// It will return [None] if the scalar type is a float or a bool.
206    pub fn try_as_u32(&self) -> Option<u32> {
207        match self {
208            ConstantScalarValue::UInt(val, _) => Some(*val as u32),
209            ConstantScalarValue::Int(val, _) => Some(*val as u32),
210            ConstantScalarValue::Float(_, _) => None,
211            ConstantScalarValue::Bool(_) => None,
212        }
213    }
214
215    /// Returns the value of the scalar as a u32.
216    ///
217    /// It will panic if the scalar type is a float or a bool.
218    pub fn as_u32(&self) -> u32 {
219        self.try_as_u32()
220            .expect("Only Int and UInt kind can be made into u32.")
221    }
222
223    /// Returns the value of the scalar as a u64.
224    ///
225    /// It will return [None] if the scalar type is a float or a bool.
226    pub fn try_as_u64(&self) -> Option<u64> {
227        match self {
228            ConstantScalarValue::UInt(val, _) => Some(*val),
229            ConstantScalarValue::Int(val, _) => Some(*val as u64),
230            ConstantScalarValue::Float(_, _) => None,
231            ConstantScalarValue::Bool(_) => None,
232        }
233    }
234
235    /// Returns the value of the scalar as a u64.
236    ///
237    /// It will panic if the scalar type is a float or a bool.
238    pub fn as_u64(&self) -> u64 {
239        self.try_as_u64()
240            .expect("Only Int and UInt kind can be made into u64.")
241    }
242
243    /// Returns the value of the scalar as a i64.
244    ///
245    /// It will return [None] if the scalar type is a float or a bool.
246    pub fn try_as_i64(&self) -> Option<i64> {
247        match self {
248            ConstantScalarValue::UInt(val, _) => Some(*val as i64),
249            ConstantScalarValue::Int(val, _) => Some(*val),
250            ConstantScalarValue::Float(_, _) => None,
251            ConstantScalarValue::Bool(_) => None,
252        }
253    }
254
255    /// Returns the value of the scalar as a u32.
256    ///
257    /// It will panic if the scalar type is a float or a bool.
258    pub fn as_i64(&self) -> i64 {
259        self.try_as_i64()
260            .expect("Only Int and UInt kind can be made into i64.")
261    }
262
263    /// Returns the value of the variable as a bool if it actually is a bool.
264    pub fn try_as_bool(&self) -> Option<bool> {
265        match self {
266            ConstantScalarValue::Bool(val) => Some(*val),
267            _ => None,
268        }
269    }
270
271    /// Returns the value of the variable as a bool.
272    ///
273    /// It will panic if the scalar isn't a bool.
274    pub fn as_bool(&self) -> bool {
275        self.try_as_bool()
276            .expect("Only bool can be made into a bool")
277    }
278
279    pub fn is_zero(&self) -> bool {
280        match self {
281            ConstantScalarValue::Int(val, _) => *val == 0,
282            ConstantScalarValue::Float(val, _) => *val == 0.0,
283            ConstantScalarValue::UInt(val, _) => *val == 0,
284            ConstantScalarValue::Bool(_) => false,
285        }
286    }
287
288    pub fn is_one(&self) -> bool {
289        match self {
290            ConstantScalarValue::Int(val, _) => *val == 1,
291            ConstantScalarValue::Float(val, _) => *val == 1.0,
292            ConstantScalarValue::UInt(val, _) => *val == 1,
293            ConstantScalarValue::Bool(_) => false,
294        }
295    }
296
297    pub fn cast_to(&self, other: Elem) -> ConstantScalarValue {
298        match (self, other) {
299            (ConstantScalarValue::Int(val, _), Elem::Float(float_kind)) => {
300                ConstantScalarValue::Float(*val as f64, float_kind)
301            }
302            (ConstantScalarValue::Int(val, _), Elem::Int(int_kind)) => {
303                ConstantScalarValue::Int(*val, int_kind)
304            }
305            (ConstantScalarValue::Int(val, _), Elem::UInt(kind)) => {
306                ConstantScalarValue::UInt(*val as u64, kind)
307            }
308            (ConstantScalarValue::Int(val, _), Elem::Bool) => ConstantScalarValue::Bool(*val == 1),
309            (ConstantScalarValue::Float(val, _), Elem::Float(float_kind)) => {
310                ConstantScalarValue::Float(*val, float_kind)
311            }
312            (ConstantScalarValue::Float(val, _), Elem::Int(int_kind)) => {
313                ConstantScalarValue::Int(*val as i64, int_kind)
314            }
315            (ConstantScalarValue::Float(val, _), Elem::UInt(kind)) => {
316                ConstantScalarValue::UInt(*val as u64, kind)
317            }
318            (ConstantScalarValue::Float(val, _), Elem::Bool) => {
319                ConstantScalarValue::Bool(*val == 0.0)
320            }
321            (ConstantScalarValue::UInt(val, _), Elem::Float(float_kind)) => {
322                ConstantScalarValue::Float(*val as f64, float_kind)
323            }
324            (ConstantScalarValue::UInt(val, _), Elem::Int(int_kind)) => {
325                ConstantScalarValue::Int(*val as i64, int_kind)
326            }
327            (ConstantScalarValue::UInt(val, _), Elem::UInt(kind)) => {
328                ConstantScalarValue::UInt(*val, kind)
329            }
330            (ConstantScalarValue::UInt(val, _), Elem::Bool) => ConstantScalarValue::Bool(*val == 1),
331            (ConstantScalarValue::Bool(val), Elem::Float(float_kind)) => {
332                ConstantScalarValue::Float(*val as u32 as f64, float_kind)
333            }
334            (ConstantScalarValue::Bool(val), Elem::Int(int_kind)) => {
335                ConstantScalarValue::Int(*val as i64, int_kind)
336            }
337            (ConstantScalarValue::Bool(val), Elem::UInt(kind)) => {
338                ConstantScalarValue::UInt(*val as u64, kind)
339            }
340            (ConstantScalarValue::Bool(val), Elem::Bool) => ConstantScalarValue::Bool(*val),
341            _ => unreachable!(),
342        }
343    }
344}
345
346impl Display for ConstantScalarValue {
347    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
348        match self {
349            ConstantScalarValue::Int(val, IntKind::I8) => write!(f, "{val}i8"),
350            ConstantScalarValue::Int(val, IntKind::I16) => write!(f, "{val}i16"),
351            ConstantScalarValue::Int(val, IntKind::I32) => write!(f, "{val}i32"),
352            ConstantScalarValue::Int(val, IntKind::I64) => write!(f, "{val}i64"),
353            ConstantScalarValue::Float(val, FloatKind::BF16) => write!(f, "{val}bf16"),
354            ConstantScalarValue::Float(val, FloatKind::F16) => write!(f, "{val}f16"),
355            ConstantScalarValue::Float(val, FloatKind::TF32) => write!(f, "{val}tf32"),
356            ConstantScalarValue::Float(val, FloatKind::Flex32) => write!(f, "{val}flex32"),
357            ConstantScalarValue::Float(val, FloatKind::F32) => write!(f, "{val}f32"),
358            ConstantScalarValue::Float(val, FloatKind::F64) => write!(f, "{val}f64"),
359            ConstantScalarValue::UInt(val, UIntKind::U8) => write!(f, "{val}u8"),
360            ConstantScalarValue::UInt(val, UIntKind::U16) => write!(f, "{val}u16"),
361            ConstantScalarValue::UInt(val, UIntKind::U32) => write!(f, "{val}u32"),
362            ConstantScalarValue::UInt(val, UIntKind::U64) => write!(f, "{val}u64"),
363            ConstantScalarValue::Bool(val) => write!(f, "{val}"),
364        }
365    }
366}
367
368impl Variable {
369    pub fn vectorization_factor(&self) -> u8 {
370        self.item.vectorization.map(NonZero::get).unwrap_or(1u8)
371    }
372
373    pub fn index(&self) -> Option<Id> {
374        match self.kind {
375            VariableKind::GlobalInputArray(id)
376            | VariableKind::GlobalScalar(id)
377            | VariableKind::LocalMut { id, .. }
378            | VariableKind::Versioned { id, .. }
379            | VariableKind::LocalConst { id, .. }
380            | VariableKind::Slice { id, .. }
381            | VariableKind::GlobalOutputArray(id)
382            | VariableKind::ConstantArray { id, .. }
383            | VariableKind::SharedMemory { id, .. }
384            | VariableKind::LocalArray { id, .. }
385            | VariableKind::Matrix { id, .. } => Some(id),
386            _ => None,
387        }
388    }
389
390    pub fn as_const(&self) -> Option<ConstantScalarValue> {
391        match self.kind {
392            VariableKind::ConstantScalar(constant) => Some(constant),
393            _ => None,
394        }
395    }
396}
397
398impl Display for Variable {
399    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
400        match self.kind {
401            VariableKind::GlobalInputArray(id) => write!(f, "input({id})"),
402            VariableKind::GlobalOutputArray(id) => write!(f, "output({id})"),
403            VariableKind::GlobalScalar(id) => write!(f, "scalar({id})"),
404            VariableKind::ConstantScalar(constant) => write!(f, "{constant}"),
405            VariableKind::LocalMut { id } => write!(f, "local({id})"),
406            VariableKind::Versioned { id, version } => {
407                write!(f, "local({id}).v{version}")
408            }
409            VariableKind::LocalConst { id } => write!(f, "binding({id})"),
410            VariableKind::ConstantArray { id, .. } => write!(f, "const_array({id})"),
411            VariableKind::SharedMemory { id, .. } => write!(f, "shared({id})"),
412            VariableKind::LocalArray { id, .. } => write!(f, "array({id})"),
413            VariableKind::Matrix { id, .. } => write!(f, "matrix({id})"),
414            VariableKind::Slice { id } => write!(f, "slice({id})"),
415            VariableKind::Builtin(builtin) => write!(f, "{builtin:?}"),
416        }
417    }
418}
419
420// Useful with the cube_inline macro.
421impl From<&Variable> for Variable {
422    fn from(value: &Variable) -> Self {
423        *value
424    }
425}