cubecl_ir/
variable.rs

1use core::{fmt::Display, hash::Hash};
2
3use crate::{BarrierLevel, StorageType, TypeHash};
4
5use super::{ElemType, FloatKind, IntKind, Matrix, Type, UIntKind};
6use float_ord::FloatOrd;
7
8#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
9#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
10#[allow(missing_docs)]
11pub struct Variable {
12    pub kind: VariableKind,
13    pub ty: Type,
14}
15
16impl Variable {
17    pub fn new(kind: VariableKind, item: Type) -> Self {
18        Self { kind, ty: item }
19    }
20
21    pub fn builtin(builtin: Builtin) -> Self {
22        Self::new(
23            VariableKind::Builtin(builtin),
24            Type::scalar(ElemType::UInt(UIntKind::U32)),
25        )
26    }
27
28    pub fn constant(scalar: ConstantScalarValue) -> Self {
29        let elem = match scalar {
30            ConstantScalarValue::Int(_, int_kind) => ElemType::Int(int_kind),
31            ConstantScalarValue::Float(_, float_kind) => ElemType::Float(float_kind),
32            ConstantScalarValue::UInt(_, kind) => ElemType::UInt(kind),
33            ConstantScalarValue::Bool(_) => ElemType::Bool,
34        };
35        Self::new(VariableKind::ConstantScalar(scalar), Type::scalar(elem))
36    }
37
38    pub fn elem_type(&self) -> ElemType {
39        self.ty.elem_type()
40    }
41
42    pub fn storage_type(&self) -> StorageType {
43        self.ty.storage_type()
44    }
45}
46
47pub type Id = u32;
48
49#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
50#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
51pub enum VariableKind {
52    GlobalInputArray(Id),
53    GlobalOutputArray(Id),
54    GlobalScalar(Id),
55    TensorMapInput(Id),
56    TensorMapOutput(Id),
57    LocalArray {
58        id: Id,
59        length: u32,
60        unroll_factor: u32,
61    },
62    LocalMut {
63        id: Id,
64    },
65    LocalConst {
66        id: Id,
67    },
68    Versioned {
69        id: Id,
70        version: u16,
71    },
72    ConstantScalar(ConstantScalarValue),
73    ConstantArray {
74        id: Id,
75        length: u32,
76        unroll_factor: u32,
77    },
78    SharedMemory {
79        id: Id,
80        length: u32,
81        unroll_factor: u32,
82        alignment: Option<u32>,
83    },
84    Matrix {
85        id: Id,
86        mat: Matrix,
87    },
88    Builtin(Builtin),
89    Pipeline {
90        id: Id,
91        num_stages: u8,
92    },
93    Barrier {
94        id: Id,
95        level: BarrierLevel,
96    },
97}
98
99#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
100#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash, PartialOrd, Ord)]
101#[repr(u8)]
102pub enum Builtin {
103    UnitPos,
104    UnitPosX,
105    UnitPosY,
106    UnitPosZ,
107    CubePosCluster,
108    CubePosClusterX,
109    CubePosClusterY,
110    CubePosClusterZ,
111    CubePos,
112    CubePosX,
113    CubePosY,
114    CubePosZ,
115    CubeDim,
116    CubeDimX,
117    CubeDimY,
118    CubeDimZ,
119    CubeClusterDim,
120    CubeClusterDimX,
121    CubeClusterDimY,
122    CubeClusterDimZ,
123    CubeCount,
124    CubeCountX,
125    CubeCountY,
126    CubeCountZ,
127    PlaneDim,
128    UnitPosPlane,
129    AbsolutePos,
130    AbsolutePosX,
131    AbsolutePosY,
132    AbsolutePosZ,
133}
134
135impl Variable {
136    /// Whether a variable is always immutable. Used for optimizations to determine whether it's
137    /// safe to inline/merge
138    pub fn is_immutable(&self) -> bool {
139        match self.kind {
140            VariableKind::GlobalOutputArray { .. } => false,
141            VariableKind::TensorMapInput(_) => true,
142            VariableKind::TensorMapOutput(_) => false,
143            VariableKind::LocalMut { .. } => false,
144            VariableKind::SharedMemory { .. } => false,
145            VariableKind::Matrix { .. } => false,
146            VariableKind::LocalArray { .. } => false,
147            VariableKind::GlobalInputArray { .. } => false,
148            VariableKind::GlobalScalar { .. } => true,
149            VariableKind::Versioned { .. } => true,
150            VariableKind::LocalConst { .. } => true,
151            VariableKind::ConstantScalar(_) => true,
152            VariableKind::ConstantArray { .. } => true,
153            VariableKind::Builtin(_) => true,
154            VariableKind::Pipeline { .. } => false,
155            VariableKind::Barrier { .. } => false,
156        }
157    }
158
159    /// Is this an array type that yields [`Item`]s when indexed, or a scalar/vector that yields
160    /// [`Elem`]s when indexed?
161    pub fn is_array(&self) -> bool {
162        matches!(
163            self.kind,
164            VariableKind::GlobalInputArray { .. }
165                | VariableKind::GlobalOutputArray { .. }
166                | VariableKind::ConstantArray { .. }
167                | VariableKind::SharedMemory { .. }
168                | VariableKind::LocalArray { .. }
169                | VariableKind::Matrix { .. }
170        )
171    }
172
173    pub fn has_length(&self) -> bool {
174        matches!(
175            self.kind,
176            VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
177        )
178    }
179
180    pub fn has_buffer_length(&self) -> bool {
181        matches!(
182            self.kind,
183            VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
184        )
185    }
186
187    /// Determines if the value is a constant with the specified value (converted if necessary)
188    pub fn is_constant(&self, value: i64) -> bool {
189        match self.kind {
190            VariableKind::ConstantScalar(ConstantScalarValue::Int(val, _)) => val == value,
191            VariableKind::ConstantScalar(ConstantScalarValue::UInt(val, _)) => val as i64 == value,
192            VariableKind::ConstantScalar(ConstantScalarValue::Float(val, _)) => val == value as f64,
193            _ => false,
194        }
195    }
196
197    /// Determines if the value is a boolean constant with the `true` value
198    pub fn is_true(&self) -> bool {
199        match self.kind {
200            VariableKind::ConstantScalar(ConstantScalarValue::Bool(val)) => val,
201            _ => false,
202        }
203    }
204
205    /// Determines if the value is a boolean constant with the `false` value
206    pub fn is_false(&self) -> bool {
207        match self.kind {
208            VariableKind::ConstantScalar(ConstantScalarValue::Bool(val)) => !val,
209            _ => false,
210        }
211    }
212}
213
214/// The scalars are stored with the highest precision possible, but they might get reduced during
215/// compilation.
216#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
217#[derive(Debug, Clone, Copy, TypeHash, PartialEq, PartialOrd)]
218#[allow(missing_docs)]
219pub enum ConstantScalarValue {
220    Int(i64, IntKind),
221    Float(f64, FloatKind),
222    UInt(u64, UIntKind),
223    Bool(bool),
224}
225
226impl Eq for ConstantScalarValue {}
227impl Hash for ConstantScalarValue {
228    fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
229        core::mem::discriminant(self).hash(ra_expand_state);
230        match self {
231            ConstantScalarValue::Int(f0, f1) => {
232                f0.hash(ra_expand_state);
233                f1.hash(ra_expand_state);
234            }
235            ConstantScalarValue::Float(f0, f1) => {
236                FloatOrd(*f0).hash(ra_expand_state);
237                f1.hash(ra_expand_state);
238            }
239            ConstantScalarValue::UInt(f0, f1) => {
240                f0.hash(ra_expand_state);
241                f1.hash(ra_expand_state);
242            }
243            ConstantScalarValue::Bool(f0) => {
244                f0.hash(ra_expand_state);
245            }
246        }
247    }
248}
249
250impl ConstantScalarValue {
251    /// Returns the element type of the scalar.
252    pub fn elem_type(&self) -> ElemType {
253        match self {
254            ConstantScalarValue::Int(_, kind) => ElemType::Int(*kind),
255            ConstantScalarValue::Float(_, kind) => ElemType::Float(*kind),
256            ConstantScalarValue::UInt(_, kind) => ElemType::UInt(*kind),
257            ConstantScalarValue::Bool(_) => ElemType::Bool,
258        }
259    }
260
261    pub fn storage_type(&self) -> StorageType {
262        self.elem_type().into()
263    }
264
265    /// Returns the value of the scalar as a usize.
266    ///
267    /// It will return [None] if the scalar type is a float or a bool.
268    pub fn try_as_usize(&self) -> Option<usize> {
269        match self {
270            ConstantScalarValue::UInt(val, _) => Some(*val as usize),
271            ConstantScalarValue::Int(val, _) => Some(*val as usize),
272            ConstantScalarValue::Float(_, _) => None,
273            ConstantScalarValue::Bool(_) => None,
274        }
275    }
276
277    /// Returns the value of the scalar as a usize.
278    ///
279    /// It will panic if the scalar type is a float or a bool.
280    pub fn as_usize(&self) -> usize {
281        self.try_as_usize()
282            .expect("Only Int and UInt kind can be made into usize.")
283    }
284
285    /// Returns the value of the scalar as a u32.
286    ///
287    /// It will return [None] if the scalar type is a float or a bool.
288    pub fn try_as_u32(&self) -> Option<u32> {
289        match self {
290            ConstantScalarValue::UInt(val, _) => Some(*val as u32),
291            ConstantScalarValue::Int(val, _) => Some(*val as u32),
292            ConstantScalarValue::Float(_, _) => None,
293            ConstantScalarValue::Bool(_) => None,
294        }
295    }
296
297    /// Returns the value of the scalar as a u32.
298    ///
299    /// It will panic if the scalar type is a float or a bool.
300    pub fn as_u32(&self) -> u32 {
301        self.try_as_u32()
302            .expect("Only Int and UInt kind can be made into u32.")
303    }
304
305    /// Returns the value of the scalar as a u64.
306    ///
307    /// It will return [None] if the scalar type is a float or a bool.
308    pub fn try_as_u64(&self) -> Option<u64> {
309        match self {
310            ConstantScalarValue::UInt(val, _) => Some(*val),
311            ConstantScalarValue::Int(val, _) => Some(*val as u64),
312            ConstantScalarValue::Float(_, _) => None,
313            ConstantScalarValue::Bool(_) => None,
314        }
315    }
316
317    /// Returns the value of the scalar as a u64.
318    ///
319    /// It will panic if the scalar type is a float or a bool.
320    pub fn as_u64(&self) -> u64 {
321        self.try_as_u64()
322            .expect("Only Int and UInt kind can be made into u64.")
323    }
324
325    /// Returns the value of the scalar as a i64.
326    ///
327    /// It will return [None] if the scalar type is a float or a bool.
328    pub fn try_as_i64(&self) -> Option<i64> {
329        match self {
330            ConstantScalarValue::UInt(val, _) => Some(*val as i64),
331            ConstantScalarValue::Int(val, _) => Some(*val),
332            ConstantScalarValue::Float(_, _) => None,
333            ConstantScalarValue::Bool(_) => None,
334        }
335    }
336
337    /// Returns the value of the scalar as a u32.
338    ///
339    /// It will panic if the scalar type is a float or a bool.
340    pub fn as_i64(&self) -> i64 {
341        self.try_as_i64()
342            .expect("Only Int and UInt kind can be made into i64.")
343    }
344
345    /// Returns the value of the variable as a bool if it actually is a bool.
346    pub fn try_as_bool(&self) -> Option<bool> {
347        match self {
348            ConstantScalarValue::Bool(val) => Some(*val),
349            _ => None,
350        }
351    }
352
353    /// Returns the value of the variable as a bool.
354    ///
355    /// It will panic if the scalar isn't a bool.
356    pub fn as_bool(&self) -> bool {
357        self.try_as_bool()
358            .expect("Only bool can be made into a bool")
359    }
360
361    pub fn is_zero(&self) -> bool {
362        match self {
363            ConstantScalarValue::Int(val, _) => *val == 0,
364            ConstantScalarValue::Float(val, _) => *val == 0.0,
365            ConstantScalarValue::UInt(val, _) => *val == 0,
366            ConstantScalarValue::Bool(_) => false,
367        }
368    }
369
370    pub fn is_one(&self) -> bool {
371        match self {
372            ConstantScalarValue::Int(val, _) => *val == 1,
373            ConstantScalarValue::Float(val, _) => *val == 1.0,
374            ConstantScalarValue::UInt(val, _) => *val == 1,
375            ConstantScalarValue::Bool(_) => false,
376        }
377    }
378
379    pub fn cast_to(&self, other: StorageType) -> ConstantScalarValue {
380        match (self, other.elem_type()) {
381            (ConstantScalarValue::Int(val, _), ElemType::Float(float_kind)) => {
382                ConstantScalarValue::Float(*val as f64, float_kind)
383            }
384            (ConstantScalarValue::Int(val, _), ElemType::Int(int_kind)) => {
385                ConstantScalarValue::Int(*val, int_kind)
386            }
387            (ConstantScalarValue::Int(val, _), ElemType::UInt(kind)) => {
388                ConstantScalarValue::UInt(*val as u64, kind)
389            }
390            (ConstantScalarValue::Int(val, _), ElemType::Bool) => {
391                ConstantScalarValue::Bool(*val == 1)
392            }
393            (ConstantScalarValue::Float(val, _), ElemType::Float(float_kind)) => {
394                ConstantScalarValue::Float(*val, float_kind)
395            }
396            (ConstantScalarValue::Float(val, _), ElemType::Int(int_kind)) => {
397                ConstantScalarValue::Int(*val as i64, int_kind)
398            }
399            (ConstantScalarValue::Float(val, _), ElemType::UInt(kind)) => {
400                ConstantScalarValue::UInt(*val as u64, kind)
401            }
402            (ConstantScalarValue::Float(val, _), ElemType::Bool) => {
403                ConstantScalarValue::Bool(*val == 0.0)
404            }
405            (ConstantScalarValue::UInt(val, _), ElemType::Float(float_kind)) => {
406                ConstantScalarValue::Float(*val as f64, float_kind)
407            }
408            (ConstantScalarValue::UInt(val, _), ElemType::Int(int_kind)) => {
409                ConstantScalarValue::Int(*val as i64, int_kind)
410            }
411            (ConstantScalarValue::UInt(val, _), ElemType::UInt(kind)) => {
412                ConstantScalarValue::UInt(*val, kind)
413            }
414            (ConstantScalarValue::UInt(val, _), ElemType::Bool) => {
415                ConstantScalarValue::Bool(*val == 1)
416            }
417            (ConstantScalarValue::Bool(val), ElemType::Float(float_kind)) => {
418                ConstantScalarValue::Float(*val as u32 as f64, float_kind)
419            }
420            (ConstantScalarValue::Bool(val), ElemType::Int(int_kind)) => {
421                ConstantScalarValue::Int(*val as i64, int_kind)
422            }
423            (ConstantScalarValue::Bool(val), ElemType::UInt(kind)) => {
424                ConstantScalarValue::UInt(*val as u64, kind)
425            }
426            (ConstantScalarValue::Bool(val), ElemType::Bool) => ConstantScalarValue::Bool(*val),
427        }
428    }
429}
430
431impl Display for ConstantScalarValue {
432    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
433        match self {
434            ConstantScalarValue::Int(val, IntKind::I8) => write!(f, "{val}i8"),
435            ConstantScalarValue::Int(val, IntKind::I16) => write!(f, "{val}i16"),
436            ConstantScalarValue::Int(val, IntKind::I32) => write!(f, "{val}i32"),
437            ConstantScalarValue::Int(val, IntKind::I64) => write!(f, "{val}i64"),
438            ConstantScalarValue::Float(val, FloatKind::E2M1) => write!(f, "{val}e2m1"),
439            ConstantScalarValue::Float(val, FloatKind::E2M3) => write!(f, "{val}e2m3"),
440            ConstantScalarValue::Float(val, FloatKind::E3M2) => write!(f, "{val}e3m2"),
441            ConstantScalarValue::Float(val, FloatKind::E4M3) => write!(f, "{val}e4m3"),
442            ConstantScalarValue::Float(val, FloatKind::E5M2) => write!(f, "{val}e5m2"),
443            ConstantScalarValue::Float(val, FloatKind::UE8M0) => write!(f, "{val}ue8m0"),
444            ConstantScalarValue::Float(val, FloatKind::BF16) => write!(f, "{val}bf16"),
445            ConstantScalarValue::Float(val, FloatKind::F16) => write!(f, "{val}f16"),
446            ConstantScalarValue::Float(val, FloatKind::TF32) => write!(f, "{val}tf32"),
447            ConstantScalarValue::Float(val, FloatKind::Flex32) => write!(f, "{val}flex32"),
448            ConstantScalarValue::Float(val, FloatKind::F32) => write!(f, "{val}f32"),
449            ConstantScalarValue::Float(val, FloatKind::F64) => write!(f, "{val}f64"),
450            ConstantScalarValue::UInt(val, UIntKind::U8) => write!(f, "{val}u8"),
451            ConstantScalarValue::UInt(val, UIntKind::U16) => write!(f, "{val}u16"),
452            ConstantScalarValue::UInt(val, UIntKind::U32) => write!(f, "{val}u32"),
453            ConstantScalarValue::UInt(val, UIntKind::U64) => write!(f, "{val}u64"),
454            ConstantScalarValue::Bool(val) => write!(f, "{val}"),
455        }
456    }
457}
458
459impl Variable {
460    pub fn line_size(&self) -> u32 {
461        self.ty.line_size()
462    }
463
464    pub fn index(&self) -> Option<Id> {
465        match self.kind {
466            VariableKind::GlobalInputArray(id)
467            | VariableKind::GlobalOutputArray(id)
468            | VariableKind::TensorMapInput(id)
469            | VariableKind::TensorMapOutput(id)
470            | VariableKind::GlobalScalar(id)
471            | VariableKind::LocalMut { id, .. }
472            | VariableKind::Versioned { id, .. }
473            | VariableKind::LocalConst { id, .. }
474            | VariableKind::ConstantArray { id, .. }
475            | VariableKind::SharedMemory { id, .. }
476            | VariableKind::LocalArray { id, .. }
477            | VariableKind::Matrix { id, .. } => Some(id),
478            _ => None,
479        }
480    }
481
482    pub fn as_const(&self) -> Option<ConstantScalarValue> {
483        match self.kind {
484            VariableKind::ConstantScalar(constant) => Some(constant),
485            _ => None,
486        }
487    }
488}
489
490impl Display for Variable {
491    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
492        match self.kind {
493            VariableKind::GlobalInputArray(id) => write!(f, "input({id})"),
494            VariableKind::GlobalOutputArray(id) => write!(f, "output({id})"),
495            VariableKind::GlobalScalar(id) => write!(f, "scalar({id})"),
496            VariableKind::TensorMapInput(id) => write!(f, "tensor_map({id})"),
497            VariableKind::TensorMapOutput(id) => write!(f, "tensor_map({id})"),
498            VariableKind::ConstantScalar(constant) => write!(f, "{constant}"),
499            VariableKind::LocalMut { id } => write!(f, "local({id})"),
500            VariableKind::Versioned { id, version } => {
501                write!(f, "local({id}).v{version}")
502            }
503            VariableKind::LocalConst { id } => write!(f, "binding({id})"),
504            VariableKind::ConstantArray { id, .. } => write!(f, "const_array({id})"),
505            VariableKind::SharedMemory { id, .. } => write!(f, "shared({id})"),
506            VariableKind::LocalArray { id, .. } => write!(f, "array({id})"),
507            VariableKind::Matrix { id, .. } => write!(f, "matrix({id})"),
508            VariableKind::Builtin(builtin) => write!(f, "{builtin:?}"),
509            VariableKind::Pipeline { id, .. } => write!(f, "pipeline({id})"),
510            VariableKind::Barrier { id, .. } => write!(f, "barrier({id})"),
511        }
512    }
513}
514
515// Useful with the cube_inline macro.
516impl From<&Variable> for Variable {
517    fn from(value: &Variable) -> Self {
518        *value
519    }
520}