cubecl_ir/
variable.rs

1use core::{fmt::Display, hash::Hash};
2
3use crate::{BarrierLevel, StorageType, TypeHash};
4
5use super::{ElemType, FloatKind, IntKind, Matrix, Type, UIntKind};
6use float_ord::FloatOrd;
7
8#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
9#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
10#[allow(missing_docs)]
11pub struct Variable {
12    pub kind: VariableKind,
13    pub ty: Type,
14}
15
16impl Variable {
17    pub fn new(kind: VariableKind, item: Type) -> Self {
18        Self { kind, ty: item }
19    }
20
21    pub fn builtin(builtin: Builtin) -> Self {
22        Self::new(
23            VariableKind::Builtin(builtin),
24            Type::scalar(ElemType::UInt(UIntKind::U32)),
25        )
26    }
27
28    pub fn constant(scalar: ConstantScalarValue) -> Self {
29        let elem = match scalar {
30            ConstantScalarValue::Int(_, int_kind) => ElemType::Int(int_kind),
31            ConstantScalarValue::Float(_, float_kind) => ElemType::Float(float_kind),
32            ConstantScalarValue::UInt(_, kind) => ElemType::UInt(kind),
33            ConstantScalarValue::Bool(_) => ElemType::Bool,
34        };
35        Self::new(VariableKind::ConstantScalar(scalar), Type::scalar(elem))
36    }
37
38    pub fn elem_type(&self) -> ElemType {
39        self.ty.elem_type()
40    }
41
42    pub fn storage_type(&self) -> StorageType {
43        self.ty.storage_type()
44    }
45}
46
47pub type Id = u32;
48
49#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
50#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
51pub enum VariableKind {
52    GlobalInputArray(Id),
53    GlobalOutputArray(Id),
54    GlobalScalar(Id),
55    TensorMapInput(Id),
56    TensorMapOutput(Id),
57    LocalArray {
58        id: Id,
59        length: u32,
60        unroll_factor: u32,
61    },
62    LocalMut {
63        id: Id,
64    },
65    LocalConst {
66        id: Id,
67    },
68    Versioned {
69        id: Id,
70        version: u16,
71    },
72    ConstantScalar(ConstantScalarValue),
73    ConstantArray {
74        id: Id,
75        length: u32,
76        unroll_factor: u32,
77    },
78    SharedArray {
79        id: Id,
80        length: u32,
81        unroll_factor: u32,
82        alignment: Option<u32>,
83    },
84    Shared {
85        id: Id,
86    },
87    Matrix {
88        id: Id,
89        mat: Matrix,
90    },
91    Builtin(Builtin),
92    Pipeline {
93        id: Id,
94        num_stages: u8,
95    },
96    BarrierToken {
97        id: Id,
98        level: BarrierLevel,
99    },
100}
101
102#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
103#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash, PartialOrd, Ord)]
104#[repr(u8)]
105pub enum Builtin {
106    UnitPos,
107    UnitPosX,
108    UnitPosY,
109    UnitPosZ,
110    CubePosCluster,
111    CubePosClusterX,
112    CubePosClusterY,
113    CubePosClusterZ,
114    CubePos,
115    CubePosX,
116    CubePosY,
117    CubePosZ,
118    CubeDim,
119    CubeDimX,
120    CubeDimY,
121    CubeDimZ,
122    CubeClusterDim,
123    CubeClusterDimX,
124    CubeClusterDimY,
125    CubeClusterDimZ,
126    CubeCount,
127    CubeCountX,
128    CubeCountY,
129    CubeCountZ,
130    PlaneDim,
131    UnitPosPlane,
132    AbsolutePos,
133    AbsolutePosX,
134    AbsolutePosY,
135    AbsolutePosZ,
136}
137
138impl Variable {
139    /// Whether a variable is always immutable. Used for optimizations to determine whether it's
140    /// safe to inline/merge
141    pub fn is_immutable(&self) -> bool {
142        match self.kind {
143            VariableKind::GlobalOutputArray { .. } => false,
144            VariableKind::TensorMapInput(_) => true,
145            VariableKind::TensorMapOutput(_) => false,
146            VariableKind::LocalMut { .. } => false,
147            VariableKind::SharedArray { .. } => false,
148            VariableKind::Shared { .. } => false,
149            VariableKind::Matrix { .. } => false,
150            VariableKind::LocalArray { .. } => false,
151            VariableKind::GlobalInputArray { .. } => false,
152            VariableKind::GlobalScalar { .. } => true,
153            VariableKind::Versioned { .. } => true,
154            VariableKind::LocalConst { .. } => true,
155            VariableKind::ConstantScalar(_) => true,
156            VariableKind::ConstantArray { .. } => true,
157            VariableKind::Builtin(_) => true,
158            VariableKind::Pipeline { .. } => false,
159            VariableKind::BarrierToken { .. } => false,
160        }
161    }
162
163    /// Is this an array type that yields [`Item`]s when indexed, or a scalar/vector that yields
164    /// [`Elem`]s when indexed?
165    pub fn is_array(&self) -> bool {
166        matches!(
167            self.kind,
168            VariableKind::GlobalInputArray { .. }
169                | VariableKind::GlobalOutputArray { .. }
170                | VariableKind::ConstantArray { .. }
171                | VariableKind::SharedArray { .. }
172                | VariableKind::LocalArray { .. }
173                | VariableKind::Matrix { .. }
174        )
175    }
176
177    pub fn has_length(&self) -> bool {
178        matches!(
179            self.kind,
180            VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
181        )
182    }
183
184    pub fn has_buffer_length(&self) -> bool {
185        matches!(
186            self.kind,
187            VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
188        )
189    }
190
191    /// Determines if the value is a constant with the specified value (converted if necessary)
192    pub fn is_constant(&self, value: i64) -> bool {
193        match self.kind {
194            VariableKind::ConstantScalar(ConstantScalarValue::Int(val, _)) => val == value,
195            VariableKind::ConstantScalar(ConstantScalarValue::UInt(val, _)) => val as i64 == value,
196            VariableKind::ConstantScalar(ConstantScalarValue::Float(val, _)) => val == value as f64,
197            _ => false,
198        }
199    }
200
201    /// Determines if the value is a boolean constant with the `true` value
202    pub fn is_true(&self) -> bool {
203        match self.kind {
204            VariableKind::ConstantScalar(ConstantScalarValue::Bool(val)) => val,
205            _ => false,
206        }
207    }
208
209    /// Determines if the value is a boolean constant with the `false` value
210    pub fn is_false(&self) -> bool {
211        match self.kind {
212            VariableKind::ConstantScalar(ConstantScalarValue::Bool(val)) => !val,
213            _ => false,
214        }
215    }
216}
217
218/// The scalars are stored with the highest precision possible, but they might get reduced during
219/// compilation.
220#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
221#[derive(Debug, Clone, Copy, TypeHash, PartialEq, PartialOrd)]
222#[allow(missing_docs)]
223pub enum ConstantScalarValue {
224    Int(i64, IntKind),
225    Float(f64, FloatKind),
226    UInt(u64, UIntKind),
227    Bool(bool),
228}
229
230impl Eq for ConstantScalarValue {}
231impl Hash for ConstantScalarValue {
232    fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
233        core::mem::discriminant(self).hash(ra_expand_state);
234        match self {
235            ConstantScalarValue::Int(f0, f1) => {
236                f0.hash(ra_expand_state);
237                f1.hash(ra_expand_state);
238            }
239            ConstantScalarValue::Float(f0, f1) => {
240                FloatOrd(*f0).hash(ra_expand_state);
241                f1.hash(ra_expand_state);
242            }
243            ConstantScalarValue::UInt(f0, f1) => {
244                f0.hash(ra_expand_state);
245                f1.hash(ra_expand_state);
246            }
247            ConstantScalarValue::Bool(f0) => {
248                f0.hash(ra_expand_state);
249            }
250        }
251    }
252}
253
254impl ConstantScalarValue {
255    /// Returns the element type of the scalar.
256    pub fn elem_type(&self) -> ElemType {
257        match self {
258            ConstantScalarValue::Int(_, kind) => ElemType::Int(*kind),
259            ConstantScalarValue::Float(_, kind) => ElemType::Float(*kind),
260            ConstantScalarValue::UInt(_, kind) => ElemType::UInt(*kind),
261            ConstantScalarValue::Bool(_) => ElemType::Bool,
262        }
263    }
264
265    pub fn storage_type(&self) -> StorageType {
266        self.elem_type().into()
267    }
268
269    /// Returns the value of the scalar as a usize.
270    ///
271    /// It will return [None] if the scalar type is a float or a bool.
272    pub fn try_as_usize(&self) -> Option<usize> {
273        match self {
274            ConstantScalarValue::UInt(val, _) => Some(*val as usize),
275            ConstantScalarValue::Int(val, _) => Some(*val as usize),
276            ConstantScalarValue::Float(_, _) => None,
277            ConstantScalarValue::Bool(_) => None,
278        }
279    }
280
281    /// Returns the value of the scalar as a usize.
282    ///
283    /// It will panic if the scalar type is a float or a bool.
284    pub fn as_usize(&self) -> usize {
285        self.try_as_usize()
286            .expect("Only Int and UInt kind can be made into usize.")
287    }
288
289    /// Returns the value of the scalar as a u32.
290    ///
291    /// It will return [None] if the scalar type is a float or a bool.
292    pub fn try_as_u32(&self) -> Option<u32> {
293        match self {
294            ConstantScalarValue::UInt(val, _) => Some(*val as u32),
295            ConstantScalarValue::Int(val, _) => Some(*val as u32),
296            ConstantScalarValue::Float(_, _) => None,
297            ConstantScalarValue::Bool(_) => None,
298        }
299    }
300
301    /// Returns the value of the scalar as a u32.
302    ///
303    /// It will panic if the scalar type is a float or a bool.
304    pub fn as_u32(&self) -> u32 {
305        self.try_as_u32()
306            .expect("Only Int and UInt kind can be made into u32.")
307    }
308
309    /// Returns the value of the scalar as a u64.
310    ///
311    /// It will return [None] if the scalar type is a float or a bool.
312    pub fn try_as_u64(&self) -> Option<u64> {
313        match self {
314            ConstantScalarValue::UInt(val, _) => Some(*val),
315            ConstantScalarValue::Int(val, _) => Some(*val as u64),
316            ConstantScalarValue::Float(_, _) => None,
317            ConstantScalarValue::Bool(_) => None,
318        }
319    }
320
321    /// Returns the value of the scalar as a u64.
322    ///
323    /// It will panic if the scalar type is a float or a bool.
324    pub fn as_u64(&self) -> u64 {
325        self.try_as_u64()
326            .expect("Only Int and UInt kind can be made into u64.")
327    }
328
329    /// Returns the value of the scalar as a i64.
330    ///
331    /// It will return [None] if the scalar type is a float or a bool.
332    pub fn try_as_i64(&self) -> Option<i64> {
333        match self {
334            ConstantScalarValue::UInt(val, _) => Some(*val as i64),
335            ConstantScalarValue::Int(val, _) => Some(*val),
336            ConstantScalarValue::Float(_, _) => None,
337            ConstantScalarValue::Bool(_) => None,
338        }
339    }
340
341    /// Returns the value of the scalar as a i64.
342    ///
343    /// It will panic if the scalar type is a float or a bool.
344    pub fn as_i64(&self) -> i64 {
345        self.try_as_i64()
346            .expect("Only Int and UInt kind can be made into i64.")
347    }
348
349    /// Returns the value of the scalar as a f64.
350    ///
351    /// It will return [None] if the scalar type is an int or a bool.
352    pub fn try_as_f64(&self) -> Option<f64> {
353        match self {
354            ConstantScalarValue::Float(val, _) => Some(*val),
355            _ => None,
356        }
357    }
358
359    /// Returns the value of the scalar as a f64.
360    ///
361    /// It will panic if the scalar type is an int or a bool.
362    pub fn as_f64(&self) -> f64 {
363        self.try_as_f64()
364            .expect("Only Float kind can be made into f64.")
365    }
366
367    /// Returns the value of the variable as a bool if it actually is a bool.
368    pub fn try_as_bool(&self) -> Option<bool> {
369        match self {
370            ConstantScalarValue::Bool(val) => Some(*val),
371            _ => None,
372        }
373    }
374
375    /// Returns the value of the variable as a bool.
376    ///
377    /// It will panic if the scalar isn't a bool.
378    pub fn as_bool(&self) -> bool {
379        self.try_as_bool()
380            .expect("Only bool can be made into a bool")
381    }
382
383    pub fn is_zero(&self) -> bool {
384        match self {
385            ConstantScalarValue::Int(val, _) => *val == 0,
386            ConstantScalarValue::Float(val, _) => *val == 0.0,
387            ConstantScalarValue::UInt(val, _) => *val == 0,
388            ConstantScalarValue::Bool(_) => false,
389        }
390    }
391
392    pub fn is_one(&self) -> bool {
393        match self {
394            ConstantScalarValue::Int(val, _) => *val == 1,
395            ConstantScalarValue::Float(val, _) => *val == 1.0,
396            ConstantScalarValue::UInt(val, _) => *val == 1,
397            ConstantScalarValue::Bool(_) => false,
398        }
399    }
400
401    pub fn cast_to(&self, other: StorageType) -> ConstantScalarValue {
402        match (self, other.elem_type()) {
403            (ConstantScalarValue::Int(val, _), ElemType::Float(float_kind)) => {
404                ConstantScalarValue::Float(*val as f64, float_kind)
405            }
406            (ConstantScalarValue::Int(val, _), ElemType::Int(int_kind)) => {
407                ConstantScalarValue::Int(*val, int_kind)
408            }
409            (ConstantScalarValue::Int(val, _), ElemType::UInt(kind)) => {
410                ConstantScalarValue::UInt(*val as u64, kind)
411            }
412            (ConstantScalarValue::Int(val, _), ElemType::Bool) => {
413                ConstantScalarValue::Bool(*val == 1)
414            }
415            (ConstantScalarValue::Float(val, _), ElemType::Float(float_kind)) => {
416                ConstantScalarValue::Float(*val, float_kind)
417            }
418            (ConstantScalarValue::Float(val, _), ElemType::Int(int_kind)) => {
419                ConstantScalarValue::Int(*val as i64, int_kind)
420            }
421            (ConstantScalarValue::Float(val, _), ElemType::UInt(kind)) => {
422                ConstantScalarValue::UInt(*val as u64, kind)
423            }
424            (ConstantScalarValue::Float(val, _), ElemType::Bool) => {
425                ConstantScalarValue::Bool(*val == 0.0)
426            }
427            (ConstantScalarValue::UInt(val, _), ElemType::Float(float_kind)) => {
428                ConstantScalarValue::Float(*val as f64, float_kind)
429            }
430            (ConstantScalarValue::UInt(val, _), ElemType::Int(int_kind)) => {
431                ConstantScalarValue::Int(*val as i64, int_kind)
432            }
433            (ConstantScalarValue::UInt(val, _), ElemType::UInt(kind)) => {
434                ConstantScalarValue::UInt(*val, kind)
435            }
436            (ConstantScalarValue::UInt(val, _), ElemType::Bool) => {
437                ConstantScalarValue::Bool(*val == 1)
438            }
439            (ConstantScalarValue::Bool(val), ElemType::Float(float_kind)) => {
440                ConstantScalarValue::Float(*val as u32 as f64, float_kind)
441            }
442            (ConstantScalarValue::Bool(val), ElemType::Int(int_kind)) => {
443                ConstantScalarValue::Int(*val as i64, int_kind)
444            }
445            (ConstantScalarValue::Bool(val), ElemType::UInt(kind)) => {
446                ConstantScalarValue::UInt(*val as u64, kind)
447            }
448            (ConstantScalarValue::Bool(val), ElemType::Bool) => ConstantScalarValue::Bool(*val),
449        }
450    }
451}
452
453impl Display for ConstantScalarValue {
454    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
455        match self {
456            ConstantScalarValue::Int(val, IntKind::I8) => write!(f, "{val}i8"),
457            ConstantScalarValue::Int(val, IntKind::I16) => write!(f, "{val}i16"),
458            ConstantScalarValue::Int(val, IntKind::I32) => write!(f, "{val}i32"),
459            ConstantScalarValue::Int(val, IntKind::I64) => write!(f, "{val}i64"),
460            ConstantScalarValue::Float(val, FloatKind::E2M1) => write!(f, "{val}e2m1"),
461            ConstantScalarValue::Float(val, FloatKind::E2M3) => write!(f, "{val}e2m3"),
462            ConstantScalarValue::Float(val, FloatKind::E3M2) => write!(f, "{val}e3m2"),
463            ConstantScalarValue::Float(val, FloatKind::E4M3) => write!(f, "{val}e4m3"),
464            ConstantScalarValue::Float(val, FloatKind::E5M2) => write!(f, "{val}e5m2"),
465            ConstantScalarValue::Float(val, FloatKind::UE8M0) => write!(f, "{val}ue8m0"),
466            ConstantScalarValue::Float(val, FloatKind::BF16) => write!(f, "{val}bf16"),
467            ConstantScalarValue::Float(val, FloatKind::F16) => write!(f, "{val}f16"),
468            ConstantScalarValue::Float(val, FloatKind::TF32) => write!(f, "{val}tf32"),
469            ConstantScalarValue::Float(val, FloatKind::Flex32) => write!(f, "{val}flex32"),
470            ConstantScalarValue::Float(val, FloatKind::F32) => write!(f, "{val}f32"),
471            ConstantScalarValue::Float(val, FloatKind::F64) => write!(f, "{val}f64"),
472            ConstantScalarValue::UInt(val, UIntKind::U8) => write!(f, "{val}u8"),
473            ConstantScalarValue::UInt(val, UIntKind::U16) => write!(f, "{val}u16"),
474            ConstantScalarValue::UInt(val, UIntKind::U32) => write!(f, "{val}u32"),
475            ConstantScalarValue::UInt(val, UIntKind::U64) => write!(f, "{val}u64"),
476            ConstantScalarValue::Bool(val) => write!(f, "{val}"),
477        }
478    }
479}
480
481impl Variable {
482    pub fn line_size(&self) -> u32 {
483        self.ty.line_size()
484    }
485
486    pub fn index(&self) -> Option<Id> {
487        match self.kind {
488            VariableKind::GlobalInputArray(id)
489            | VariableKind::GlobalOutputArray(id)
490            | VariableKind::TensorMapInput(id)
491            | VariableKind::TensorMapOutput(id)
492            | VariableKind::GlobalScalar(id)
493            | VariableKind::LocalMut { id, .. }
494            | VariableKind::Versioned { id, .. }
495            | VariableKind::LocalConst { id, .. }
496            | VariableKind::ConstantArray { id, .. }
497            | VariableKind::SharedArray { id, .. }
498            | VariableKind::Shared { id, .. }
499            | VariableKind::LocalArray { id, .. }
500            | VariableKind::Matrix { id, .. } => Some(id),
501            _ => None,
502        }
503    }
504
505    pub fn as_const(&self) -> Option<ConstantScalarValue> {
506        match self.kind {
507            VariableKind::ConstantScalar(constant) => Some(constant),
508            _ => None,
509        }
510    }
511}
512
513impl Display for Variable {
514    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
515        match self.kind {
516            VariableKind::GlobalInputArray(id) => write!(f, "input({id})"),
517            VariableKind::GlobalOutputArray(id) => write!(f, "output({id})"),
518            VariableKind::GlobalScalar(id) => write!(f, "scalar({id})"),
519            VariableKind::TensorMapInput(id) => write!(f, "tensor_map({id})"),
520            VariableKind::TensorMapOutput(id) => write!(f, "tensor_map({id})"),
521            VariableKind::ConstantScalar(constant) => write!(f, "{constant}"),
522            VariableKind::LocalMut { id } => write!(f, "local({id})"),
523            VariableKind::Versioned { id, version } => {
524                write!(f, "local({id}).v{version}")
525            }
526            VariableKind::LocalConst { id } => write!(f, "binding({id})"),
527            VariableKind::ConstantArray { id, .. } => write!(f, "const_array({id})"),
528            VariableKind::SharedArray { id, .. } => write!(f, "shared_array({id})"),
529            VariableKind::Shared { id } => write!(f, "shared({id})"),
530            VariableKind::LocalArray { id, .. } => write!(f, "array({id})"),
531            VariableKind::Matrix { id, .. } => write!(f, "matrix({id})"),
532            VariableKind::Builtin(builtin) => write!(f, "{builtin:?}"),
533            VariableKind::Pipeline { id, .. } => write!(f, "pipeline({id})"),
534            VariableKind::BarrierToken { id, .. } => write!(f, "barrier_token({id})"),
535        }
536    }
537}
538
539// Useful with the cube_inline macro.
540impl From<&Variable> for Variable {
541    fn from(value: &Variable) -> Self {
542        *value
543    }
544}