cubecl_ir/
variable.rs

1use core::{fmt::Display, hash::Hash};
2
3use crate::{BarrierLevel, StorageType, TypeHash};
4
5use super::{ElemType, FloatKind, IntKind, Matrix, Type, UIntKind};
6use float_ord::FloatOrd;
7
8#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
9#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
10#[allow(missing_docs)]
11pub struct Variable {
12    pub kind: VariableKind,
13    pub ty: Type,
14}
15
16impl Variable {
17    pub fn new(kind: VariableKind, item: Type) -> Self {
18        Self { kind, ty: item }
19    }
20
21    pub fn builtin(builtin: Builtin) -> Self {
22        Self::new(
23            VariableKind::Builtin(builtin),
24            Type::scalar(ElemType::UInt(UIntKind::U32)),
25        )
26    }
27
28    pub fn constant(scalar: ConstantScalarValue) -> Self {
29        let elem = match scalar {
30            ConstantScalarValue::Int(_, int_kind) => ElemType::Int(int_kind),
31            ConstantScalarValue::Float(_, float_kind) => ElemType::Float(float_kind),
32            ConstantScalarValue::UInt(_, kind) => ElemType::UInt(kind),
33            ConstantScalarValue::Bool(_) => ElemType::Bool,
34        };
35        Self::new(VariableKind::ConstantScalar(scalar), Type::scalar(elem))
36    }
37
38    pub fn elem_type(&self) -> ElemType {
39        self.ty.elem_type()
40    }
41
42    pub fn storage_type(&self) -> StorageType {
43        self.ty.storage_type()
44    }
45}
46
47pub type Id = u32;
48
49#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
50#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
51pub enum VariableKind {
52    GlobalInputArray(Id),
53    GlobalOutputArray(Id),
54    GlobalScalar(Id),
55    TensorMapInput(Id),
56    TensorMapOutput(Id),
57    LocalArray {
58        id: Id,
59        length: u32,
60        unroll_factor: u32,
61    },
62    LocalMut {
63        id: Id,
64    },
65    LocalConst {
66        id: Id,
67    },
68    Versioned {
69        id: Id,
70        version: u16,
71    },
72    ConstantScalar(ConstantScalarValue),
73    ConstantArray {
74        id: Id,
75        length: u32,
76        unroll_factor: u32,
77    },
78    SharedMemory {
79        id: Id,
80        length: u32,
81        unroll_factor: u32,
82        alignment: Option<u32>,
83    },
84    Matrix {
85        id: Id,
86        mat: Matrix,
87    },
88    Builtin(Builtin),
89    Pipeline {
90        id: Id,
91        num_stages: u8,
92    },
93    Barrier {
94        id: Id,
95        level: BarrierLevel,
96    },
97    BarrierToken {
98        id: Id,
99        level: BarrierLevel,
100    },
101}
102
103#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
104#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash, PartialOrd, Ord)]
105#[repr(u8)]
106pub enum Builtin {
107    UnitPos,
108    UnitPosX,
109    UnitPosY,
110    UnitPosZ,
111    CubePosCluster,
112    CubePosClusterX,
113    CubePosClusterY,
114    CubePosClusterZ,
115    CubePos,
116    CubePosX,
117    CubePosY,
118    CubePosZ,
119    CubeDim,
120    CubeDimX,
121    CubeDimY,
122    CubeDimZ,
123    CubeClusterDim,
124    CubeClusterDimX,
125    CubeClusterDimY,
126    CubeClusterDimZ,
127    CubeCount,
128    CubeCountX,
129    CubeCountY,
130    CubeCountZ,
131    PlaneDim,
132    UnitPosPlane,
133    AbsolutePos,
134    AbsolutePosX,
135    AbsolutePosY,
136    AbsolutePosZ,
137}
138
139impl Variable {
140    /// Whether a variable is always immutable. Used for optimizations to determine whether it's
141    /// safe to inline/merge
142    pub fn is_immutable(&self) -> bool {
143        match self.kind {
144            VariableKind::GlobalOutputArray { .. } => false,
145            VariableKind::TensorMapInput(_) => true,
146            VariableKind::TensorMapOutput(_) => false,
147            VariableKind::LocalMut { .. } => false,
148            VariableKind::SharedMemory { .. } => false,
149            VariableKind::Matrix { .. } => false,
150            VariableKind::LocalArray { .. } => false,
151            VariableKind::GlobalInputArray { .. } => false,
152            VariableKind::GlobalScalar { .. } => true,
153            VariableKind::Versioned { .. } => true,
154            VariableKind::LocalConst { .. } => true,
155            VariableKind::ConstantScalar(_) => true,
156            VariableKind::ConstantArray { .. } => true,
157            VariableKind::Builtin(_) => true,
158            VariableKind::Pipeline { .. } => false,
159            VariableKind::Barrier { .. } => false,
160            VariableKind::BarrierToken { .. } => false,
161        }
162    }
163
164    /// Is this an array type that yields [`Item`]s when indexed, or a scalar/vector that yields
165    /// [`Elem`]s when indexed?
166    pub fn is_array(&self) -> bool {
167        matches!(
168            self.kind,
169            VariableKind::GlobalInputArray { .. }
170                | VariableKind::GlobalOutputArray { .. }
171                | VariableKind::ConstantArray { .. }
172                | VariableKind::SharedMemory { .. }
173                | VariableKind::LocalArray { .. }
174                | VariableKind::Matrix { .. }
175        )
176    }
177
178    pub fn has_length(&self) -> bool {
179        matches!(
180            self.kind,
181            VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
182        )
183    }
184
185    pub fn has_buffer_length(&self) -> bool {
186        matches!(
187            self.kind,
188            VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
189        )
190    }
191
192    /// Determines if the value is a constant with the specified value (converted if necessary)
193    pub fn is_constant(&self, value: i64) -> bool {
194        match self.kind {
195            VariableKind::ConstantScalar(ConstantScalarValue::Int(val, _)) => val == value,
196            VariableKind::ConstantScalar(ConstantScalarValue::UInt(val, _)) => val as i64 == value,
197            VariableKind::ConstantScalar(ConstantScalarValue::Float(val, _)) => val == value as f64,
198            _ => false,
199        }
200    }
201
202    /// Determines if the value is a boolean constant with the `true` value
203    pub fn is_true(&self) -> bool {
204        match self.kind {
205            VariableKind::ConstantScalar(ConstantScalarValue::Bool(val)) => val,
206            _ => false,
207        }
208    }
209
210    /// Determines if the value is a boolean constant with the `false` value
211    pub fn is_false(&self) -> bool {
212        match self.kind {
213            VariableKind::ConstantScalar(ConstantScalarValue::Bool(val)) => !val,
214            _ => false,
215        }
216    }
217}
218
219/// The scalars are stored with the highest precision possible, but they might get reduced during
220/// compilation.
221#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
222#[derive(Debug, Clone, Copy, TypeHash, PartialEq, PartialOrd)]
223#[allow(missing_docs)]
224pub enum ConstantScalarValue {
225    Int(i64, IntKind),
226    Float(f64, FloatKind),
227    UInt(u64, UIntKind),
228    Bool(bool),
229}
230
231impl Eq for ConstantScalarValue {}
232impl Hash for ConstantScalarValue {
233    fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
234        core::mem::discriminant(self).hash(ra_expand_state);
235        match self {
236            ConstantScalarValue::Int(f0, f1) => {
237                f0.hash(ra_expand_state);
238                f1.hash(ra_expand_state);
239            }
240            ConstantScalarValue::Float(f0, f1) => {
241                FloatOrd(*f0).hash(ra_expand_state);
242                f1.hash(ra_expand_state);
243            }
244            ConstantScalarValue::UInt(f0, f1) => {
245                f0.hash(ra_expand_state);
246                f1.hash(ra_expand_state);
247            }
248            ConstantScalarValue::Bool(f0) => {
249                f0.hash(ra_expand_state);
250            }
251        }
252    }
253}
254
255impl ConstantScalarValue {
256    /// Returns the element type of the scalar.
257    pub fn elem_type(&self) -> ElemType {
258        match self {
259            ConstantScalarValue::Int(_, kind) => ElemType::Int(*kind),
260            ConstantScalarValue::Float(_, kind) => ElemType::Float(*kind),
261            ConstantScalarValue::UInt(_, kind) => ElemType::UInt(*kind),
262            ConstantScalarValue::Bool(_) => ElemType::Bool,
263        }
264    }
265
266    pub fn storage_type(&self) -> StorageType {
267        self.elem_type().into()
268    }
269
270    /// Returns the value of the scalar as a usize.
271    ///
272    /// It will return [None] if the scalar type is a float or a bool.
273    pub fn try_as_usize(&self) -> Option<usize> {
274        match self {
275            ConstantScalarValue::UInt(val, _) => Some(*val as usize),
276            ConstantScalarValue::Int(val, _) => Some(*val as usize),
277            ConstantScalarValue::Float(_, _) => None,
278            ConstantScalarValue::Bool(_) => None,
279        }
280    }
281
282    /// Returns the value of the scalar as a usize.
283    ///
284    /// It will panic if the scalar type is a float or a bool.
285    pub fn as_usize(&self) -> usize {
286        self.try_as_usize()
287            .expect("Only Int and UInt kind can be made into usize.")
288    }
289
290    /// Returns the value of the scalar as a u32.
291    ///
292    /// It will return [None] if the scalar type is a float or a bool.
293    pub fn try_as_u32(&self) -> Option<u32> {
294        match self {
295            ConstantScalarValue::UInt(val, _) => Some(*val as u32),
296            ConstantScalarValue::Int(val, _) => Some(*val as u32),
297            ConstantScalarValue::Float(_, _) => None,
298            ConstantScalarValue::Bool(_) => None,
299        }
300    }
301
302    /// Returns the value of the scalar as a u32.
303    ///
304    /// It will panic if the scalar type is a float or a bool.
305    pub fn as_u32(&self) -> u32 {
306        self.try_as_u32()
307            .expect("Only Int and UInt kind can be made into u32.")
308    }
309
310    /// Returns the value of the scalar as a u64.
311    ///
312    /// It will return [None] if the scalar type is a float or a bool.
313    pub fn try_as_u64(&self) -> Option<u64> {
314        match self {
315            ConstantScalarValue::UInt(val, _) => Some(*val),
316            ConstantScalarValue::Int(val, _) => Some(*val as u64),
317            ConstantScalarValue::Float(_, _) => None,
318            ConstantScalarValue::Bool(_) => None,
319        }
320    }
321
322    /// Returns the value of the scalar as a u64.
323    ///
324    /// It will panic if the scalar type is a float or a bool.
325    pub fn as_u64(&self) -> u64 {
326        self.try_as_u64()
327            .expect("Only Int and UInt kind can be made into u64.")
328    }
329
330    /// Returns the value of the scalar as a i64.
331    ///
332    /// It will return [None] if the scalar type is a float or a bool.
333    pub fn try_as_i64(&self) -> Option<i64> {
334        match self {
335            ConstantScalarValue::UInt(val, _) => Some(*val as i64),
336            ConstantScalarValue::Int(val, _) => Some(*val),
337            ConstantScalarValue::Float(_, _) => None,
338            ConstantScalarValue::Bool(_) => None,
339        }
340    }
341
342    /// Returns the value of the scalar as a i64.
343    ///
344    /// It will panic if the scalar type is a float or a bool.
345    pub fn as_i64(&self) -> i64 {
346        self.try_as_i64()
347            .expect("Only Int and UInt kind can be made into i64.")
348    }
349
350    /// Returns the value of the scalar as a f64.
351    ///
352    /// It will return [None] if the scalar type is an int or a bool.
353    pub fn try_as_f64(&self) -> Option<f64> {
354        match self {
355            ConstantScalarValue::Float(val, _) => Some(*val),
356            _ => None,
357        }
358    }
359
360    /// Returns the value of the scalar as a f64.
361    ///
362    /// It will panic if the scalar type is an int or a bool.
363    pub fn as_f64(&self) -> f64 {
364        self.try_as_f64()
365            .expect("Only Float kind can be made into f64.")
366    }
367
368    /// Returns the value of the variable as a bool if it actually is a bool.
369    pub fn try_as_bool(&self) -> Option<bool> {
370        match self {
371            ConstantScalarValue::Bool(val) => Some(*val),
372            _ => None,
373        }
374    }
375
376    /// Returns the value of the variable as a bool.
377    ///
378    /// It will panic if the scalar isn't a bool.
379    pub fn as_bool(&self) -> bool {
380        self.try_as_bool()
381            .expect("Only bool can be made into a bool")
382    }
383
384    pub fn is_zero(&self) -> bool {
385        match self {
386            ConstantScalarValue::Int(val, _) => *val == 0,
387            ConstantScalarValue::Float(val, _) => *val == 0.0,
388            ConstantScalarValue::UInt(val, _) => *val == 0,
389            ConstantScalarValue::Bool(_) => false,
390        }
391    }
392
393    pub fn is_one(&self) -> bool {
394        match self {
395            ConstantScalarValue::Int(val, _) => *val == 1,
396            ConstantScalarValue::Float(val, _) => *val == 1.0,
397            ConstantScalarValue::UInt(val, _) => *val == 1,
398            ConstantScalarValue::Bool(_) => false,
399        }
400    }
401
402    pub fn cast_to(&self, other: StorageType) -> ConstantScalarValue {
403        match (self, other.elem_type()) {
404            (ConstantScalarValue::Int(val, _), ElemType::Float(float_kind)) => {
405                ConstantScalarValue::Float(*val as f64, float_kind)
406            }
407            (ConstantScalarValue::Int(val, _), ElemType::Int(int_kind)) => {
408                ConstantScalarValue::Int(*val, int_kind)
409            }
410            (ConstantScalarValue::Int(val, _), ElemType::UInt(kind)) => {
411                ConstantScalarValue::UInt(*val as u64, kind)
412            }
413            (ConstantScalarValue::Int(val, _), ElemType::Bool) => {
414                ConstantScalarValue::Bool(*val == 1)
415            }
416            (ConstantScalarValue::Float(val, _), ElemType::Float(float_kind)) => {
417                ConstantScalarValue::Float(*val, float_kind)
418            }
419            (ConstantScalarValue::Float(val, _), ElemType::Int(int_kind)) => {
420                ConstantScalarValue::Int(*val as i64, int_kind)
421            }
422            (ConstantScalarValue::Float(val, _), ElemType::UInt(kind)) => {
423                ConstantScalarValue::UInt(*val as u64, kind)
424            }
425            (ConstantScalarValue::Float(val, _), ElemType::Bool) => {
426                ConstantScalarValue::Bool(*val == 0.0)
427            }
428            (ConstantScalarValue::UInt(val, _), ElemType::Float(float_kind)) => {
429                ConstantScalarValue::Float(*val as f64, float_kind)
430            }
431            (ConstantScalarValue::UInt(val, _), ElemType::Int(int_kind)) => {
432                ConstantScalarValue::Int(*val as i64, int_kind)
433            }
434            (ConstantScalarValue::UInt(val, _), ElemType::UInt(kind)) => {
435                ConstantScalarValue::UInt(*val, kind)
436            }
437            (ConstantScalarValue::UInt(val, _), ElemType::Bool) => {
438                ConstantScalarValue::Bool(*val == 1)
439            }
440            (ConstantScalarValue::Bool(val), ElemType::Float(float_kind)) => {
441                ConstantScalarValue::Float(*val as u32 as f64, float_kind)
442            }
443            (ConstantScalarValue::Bool(val), ElemType::Int(int_kind)) => {
444                ConstantScalarValue::Int(*val as i64, int_kind)
445            }
446            (ConstantScalarValue::Bool(val), ElemType::UInt(kind)) => {
447                ConstantScalarValue::UInt(*val as u64, kind)
448            }
449            (ConstantScalarValue::Bool(val), ElemType::Bool) => ConstantScalarValue::Bool(*val),
450        }
451    }
452}
453
454impl Display for ConstantScalarValue {
455    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
456        match self {
457            ConstantScalarValue::Int(val, IntKind::I8) => write!(f, "{val}i8"),
458            ConstantScalarValue::Int(val, IntKind::I16) => write!(f, "{val}i16"),
459            ConstantScalarValue::Int(val, IntKind::I32) => write!(f, "{val}i32"),
460            ConstantScalarValue::Int(val, IntKind::I64) => write!(f, "{val}i64"),
461            ConstantScalarValue::Float(val, FloatKind::E2M1) => write!(f, "{val}e2m1"),
462            ConstantScalarValue::Float(val, FloatKind::E2M3) => write!(f, "{val}e2m3"),
463            ConstantScalarValue::Float(val, FloatKind::E3M2) => write!(f, "{val}e3m2"),
464            ConstantScalarValue::Float(val, FloatKind::E4M3) => write!(f, "{val}e4m3"),
465            ConstantScalarValue::Float(val, FloatKind::E5M2) => write!(f, "{val}e5m2"),
466            ConstantScalarValue::Float(val, FloatKind::UE8M0) => write!(f, "{val}ue8m0"),
467            ConstantScalarValue::Float(val, FloatKind::BF16) => write!(f, "{val}bf16"),
468            ConstantScalarValue::Float(val, FloatKind::F16) => write!(f, "{val}f16"),
469            ConstantScalarValue::Float(val, FloatKind::TF32) => write!(f, "{val}tf32"),
470            ConstantScalarValue::Float(val, FloatKind::Flex32) => write!(f, "{val}flex32"),
471            ConstantScalarValue::Float(val, FloatKind::F32) => write!(f, "{val}f32"),
472            ConstantScalarValue::Float(val, FloatKind::F64) => write!(f, "{val}f64"),
473            ConstantScalarValue::UInt(val, UIntKind::U8) => write!(f, "{val}u8"),
474            ConstantScalarValue::UInt(val, UIntKind::U16) => write!(f, "{val}u16"),
475            ConstantScalarValue::UInt(val, UIntKind::U32) => write!(f, "{val}u32"),
476            ConstantScalarValue::UInt(val, UIntKind::U64) => write!(f, "{val}u64"),
477            ConstantScalarValue::Bool(val) => write!(f, "{val}"),
478        }
479    }
480}
481
482impl Variable {
483    pub fn line_size(&self) -> u32 {
484        self.ty.line_size()
485    }
486
487    pub fn index(&self) -> Option<Id> {
488        match self.kind {
489            VariableKind::GlobalInputArray(id)
490            | VariableKind::GlobalOutputArray(id)
491            | VariableKind::TensorMapInput(id)
492            | VariableKind::TensorMapOutput(id)
493            | VariableKind::GlobalScalar(id)
494            | VariableKind::LocalMut { id, .. }
495            | VariableKind::Versioned { id, .. }
496            | VariableKind::LocalConst { id, .. }
497            | VariableKind::ConstantArray { id, .. }
498            | VariableKind::SharedMemory { id, .. }
499            | VariableKind::LocalArray { id, .. }
500            | VariableKind::Matrix { id, .. } => Some(id),
501            _ => None,
502        }
503    }
504
505    pub fn as_const(&self) -> Option<ConstantScalarValue> {
506        match self.kind {
507            VariableKind::ConstantScalar(constant) => Some(constant),
508            _ => None,
509        }
510    }
511}
512
513impl Display for Variable {
514    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
515        match self.kind {
516            VariableKind::GlobalInputArray(id) => write!(f, "input({id})"),
517            VariableKind::GlobalOutputArray(id) => write!(f, "output({id})"),
518            VariableKind::GlobalScalar(id) => write!(f, "scalar({id})"),
519            VariableKind::TensorMapInput(id) => write!(f, "tensor_map({id})"),
520            VariableKind::TensorMapOutput(id) => write!(f, "tensor_map({id})"),
521            VariableKind::ConstantScalar(constant) => write!(f, "{constant}"),
522            VariableKind::LocalMut { id } => write!(f, "local({id})"),
523            VariableKind::Versioned { id, version } => {
524                write!(f, "local({id}).v{version}")
525            }
526            VariableKind::LocalConst { id } => write!(f, "binding({id})"),
527            VariableKind::ConstantArray { id, .. } => write!(f, "const_array({id})"),
528            VariableKind::SharedMemory { id, .. } => write!(f, "shared({id})"),
529            VariableKind::LocalArray { id, .. } => write!(f, "array({id})"),
530            VariableKind::Matrix { id, .. } => write!(f, "matrix({id})"),
531            VariableKind::Builtin(builtin) => write!(f, "{builtin:?}"),
532            VariableKind::Pipeline { id, .. } => write!(f, "pipeline({id})"),
533            VariableKind::Barrier { id, .. } => write!(f, "barrier({id})"),
534            VariableKind::BarrierToken { id, .. } => write!(f, "barrier_token({id})"),
535        }
536    }
537}
538
539// Useful with the cube_inline macro.
540impl From<&Variable> for Variable {
541    fn from(value: &Variable) -> Self {
542        *value
543    }
544}